matmult.lua 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. --Stephen Stengel
  2. --A program that shows the effect of caching.
  3. --Doing operations with large matricies can cause many cache misses. By
  4. --selecting a small square of a matrix, you can load all its data into
  5. --cache. This reduces the number of cache misses.
  6. --Todo: make a flat matrix version of the functions. This should increase speed by another magnitude.
  7. function main()
  8. local start = os.clock()
  9. math.randomseed( os.time() )
  10. local squareSize = tonumber(arg[1]) or 10 --if no input, default 10
  11. io.write( string.format("squareSize: %d\n", squareSize) )
  12. --~ io.write("Press enter to continue.")
  13. --~ io.read()
  14. --~ hee = createFlatOutMatrix(squareSize)
  15. --~ fID = createFlatIDMatrix(squareSize)
  16. --~ fRM = createFlatRandMatrix(squareSize)
  17. --~ printFlatMatrix(fRM)
  18. --~ out = multFlatBlockMats(fRM, fID, 5)
  19. --~ if -1 == out then
  20. --~ print("quitting because bad block")
  21. --~ return -1
  22. --~ end
  23. --~ print("equal?:")
  24. --~ checkIfFlatMatEqual(out, fRM)
  25. --~ os.exit()
  26. --~ printMatrix(myArray)
  27. --~ printMatrix(idMat)
  28. --~ io.write("Press enter to continue.")
  29. --~ io.read()
  30. --Start tests
  31. myArray = createRandMat(squareSize)
  32. idMat = createIDMat(squareSize)
  33. print("Now running the block style tests for myArray*idMat...")
  34. local blockSize = 8
  35. fwrite("Now timing the block style with block size: %d\n", blockSize)
  36. timeBlockMultTwo(myArray, idMat, blockSize)
  37. local blockSize = 10
  38. fwrite("Now timing the block style with block size: %d\n", blockSize)
  39. timeBlockMultTwo(myArray, idMat, blockSize)
  40. local blockSize = 20
  41. fwrite("Now timing the block style with block size: %d\n", blockSize)
  42. timeBlockMultTwo(myArray, idMat, blockSize)
  43. local blockSize = 25
  44. fwrite("Now timing the block style with block size: %d\n", blockSize)
  45. timeBlockMultTwo(myArray, idMat, blockSize)
  46. local blockSize = 40
  47. fwrite("Now timing the block style with block size: %d\n", blockSize)
  48. timeBlockMultTwo(myArray, idMat, blockSize)
  49. local blockSize = 100
  50. fwrite("Now timing the block style with block size: %d\n", blockSize)
  51. timeBlockMultTwo(myArray, idMat, blockSize)
  52. print("\nNow running the block style tests for myArray*myArray...")
  53. local blockSize = 8
  54. fwrite("Now timing the block style with block size: %d\n", blockSize)
  55. timeBlockMultTwo(myArray, myArray, blockSize)
  56. local blockSize = 10
  57. fwrite("Now timing the block style with block size: %d\n", blockSize)
  58. timeBlockMultTwo(myArray, myArray, blockSize)
  59. local blockSize = 20
  60. fwrite("Now timing the block style with block size: %d\n", blockSize)
  61. timeBlockMultTwo(myArray, myArray, blockSize)
  62. local blockSize = 25
  63. fwrite("Now timing the block style with block size: %d\n", blockSize)
  64. timeBlockMultTwo(myArray, myArray, blockSize)
  65. local blockSize = 40
  66. fwrite("Now timing the block style with block size: %d\n", blockSize)
  67. timeBlockMultTwo(myArray, myArray, blockSize)
  68. local blockSize = 100
  69. fwrite("Now timing the block style with block size: %d\n", blockSize)
  70. timeBlockMultTwo(myArray, myArray, blockSize)
  71. print("\nNow running unoptimized style...")
  72. fwrite("Multiplying myArray and idMat...\n")
  73. timeMultTwo(myArray, idMat)
  74. fwrite("Now timing myArray*myArray...\n")
  75. timeMultTwo(myArray, myArray)
  76. fID = createFlatIDMatrix(squareSize)
  77. fRM = createFlatRandMatrix(squareSize)
  78. print("\nNow running the Flat block style tests for fRM*fID...")
  79. --~ local blockSize = 8
  80. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  81. --~ timeFlatBlockMultTwo(fRM, fID, blockSize)
  82. --~ local blockSize = 10
  83. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  84. --~ timeFlatBlockMultTwo(fRM, fID, blockSize)
  85. --~ local blockSize = 20
  86. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  87. --~ timeFlatBlockMultTwo(fRM, fID, blockSize)
  88. --~ local blockSize = 25
  89. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  90. --~ timeFlatBlockMultTwo(fRM, fID, blockSize)
  91. --~ local blockSize = 40
  92. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  93. --~ timeFlatBlockMultTwo(fRM, fID, blockSize)
  94. local blockSize = 50
  95. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  96. timeFlatBlockMultTwo(fRM, fID, blockSize)
  97. local blockSize = 100
  98. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  99. timeFlatBlockMultTwo(fRM, fID, blockSize)
  100. local blockSize = 200
  101. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  102. timeFlatBlockMultTwo(fRM, fID, blockSize)
  103. local blockSize = 250
  104. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  105. timeFlatBlockMultTwo(fRM, fID, blockSize)
  106. local blockSize = 500
  107. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  108. timeFlatBlockMultTwo(fRM, fID, blockSize)
  109. print("\nNow running the Flat block style tests for fRM*fRM...")
  110. --~ local blockSize = 8
  111. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  112. --~ timeFlatBlockMultTwo(fRM, fRM, blockSize)
  113. --~ local blockSize = 10
  114. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  115. --~ timeFlatBlockMultTwo(fRM, fRM, blockSize)
  116. --~ local blockSize = 20
  117. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  118. --~ timeFlatBlockMultTwo(fRM, fRM, blockSize)
  119. --~ local blockSize = 25
  120. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  121. --~ timeFlatBlockMultTwo(fRM, fRM, blockSize)
  122. --~ local blockSize = 40
  123. --~ fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  124. --~ timeFlatBlockMultTwo(fRM, fRM, blockSize)
  125. local blockSize = 50
  126. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  127. timeFlatBlockMultTwo(fRM, fRM, blockSize)
  128. local blockSize = 100
  129. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  130. timeFlatBlockMultTwo(fRM, fRM, blockSize)
  131. local blockSize = 200
  132. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  133. timeFlatBlockMultTwo(fRM, fRM, blockSize)
  134. local blockSize = 250
  135. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  136. timeFlatBlockMultTwo(fRM, fRM, blockSize)
  137. local blockSize = 500
  138. fwrite("Now timing the Flat block style with block size: %d\n", blockSize)
  139. timeFlatBlockMultTwo(fRM, fRM, blockSize)
  140. print("\nNow running FLAT unoptimized style...")
  141. fwrite("Multiplying myArray and idMat...\n")
  142. timeFlatMultTwo(fRM, fID)
  143. fwrite("Now timing myArray*myArray...\n")
  144. timeFlatMultTwo(fRM, fRM)
  145. fwrite("Overall time: %f\n", os.clock() - start)
  146. end
  147. function checkIfMatEqual(A, B)
  148. for i = 1, #A do
  149. for j = 1, #A do
  150. if A[i][j] ~= B[i][j] then
  151. print("NOT THE SAME!\n")
  152. return false
  153. end
  154. end
  155. end
  156. print("CORRECT!\n")
  157. return true
  158. end
  159. function checkIfFlatMatEqual(A, B)
  160. for i = 1, #A do
  161. if A[i] ~= B[i] then
  162. print("NOT THE SAME!\n")
  163. return false
  164. end
  165. end
  166. print("CORRECT!\n")
  167. return true
  168. end
  169. function timeBlockMultTwo(A, B, blockSize)
  170. local start = os.clock()
  171. out = multSqBlockStyle(A, B, blockSize)
  172. local total = os.clock() - start
  173. fwrite("done!\tTook %.2f seconds!\n", total)
  174. return total
  175. end
  176. function timeFlatBlockMultTwo(A, B, blockSize)
  177. local start = os.clock()
  178. out = multFlatBlockMats(A, B, blockSize)
  179. local total = os.clock() - start
  180. fwrite("done!\tTook %.2f seconds!\n", total)
  181. return total
  182. end
  183. function timeFlatMultTwo(A, B)
  184. local start = os.clock()
  185. out = multFlatMats(A, B)
  186. local total = os.clock() - start
  187. fwrite("done!\tTook %.2f seconds!\n", total)
  188. return total
  189. end
  190. function createFlatRandMatrix(sideSize)
  191. local flatLen = sideSize * sideSize
  192. --~ math.random(100)
  193. local out = {}
  194. for i = 1, sideSize do
  195. for j = 1, sideSize do
  196. out[(i - 1) * sideSize + j] = math.random(100)
  197. end
  198. end
  199. return out
  200. end
  201. function createFlatIDMatrix(size)
  202. local squareSize = size * size
  203. local arr = {}
  204. for i = 1, size do
  205. for j = 1, size do
  206. --(i-1)*size is like [i]
  207. -- the + j is like [j] so arr[i][j]
  208. arr[(i - 1) * size + j] = 0
  209. end
  210. arr[ (i - 1) * size + i ] = 1
  211. end
  212. return arr
  213. end
  214. function createFlatOutMatrix(size)
  215. local squareSize = size * size
  216. local myArray = {}
  217. for i = 1, squareSize do
  218. myArray[i] = 0
  219. end
  220. return myArray
  221. end
  222. function timeMultTwo(A, B)
  223. local start = os.clock()
  224. local out = multSqMats(A, B)
  225. local total = os.clock() - start
  226. fwrite("done!\tTook %.2f seconds!\n", total)
  227. return total
  228. end
  229. function multSqBlockStyle(A, B, blockSize)
  230. local len = #A
  231. if len % blockSize ~= 0 then
  232. print("BAD BLOCK SIZE!\n")
  233. return -1
  234. end
  235. local C = createOutMat(len)
  236. --choose outer squares
  237. for i = 1, len, blockSize do
  238. for j = 1, len, blockSize do
  239. for k = 1, len, blockSize do
  240. --Do multiplication of minisquare
  241. local iblock = i + blockSize - 1
  242. for i1 = i, iblock do
  243. local jblock = j + blockSize - 1
  244. for j1 = j, jblock do
  245. local kblock = k + blockSize - 1
  246. for k1 = k, kblock do
  247. --This line is from: https://www.netlib.org/utk/papers/autoblock/node2.html
  248. -- c[i,j] = c[i,j] + a[i,k] * b[k,j]
  249. --This line is from the video using unwrapped vectors.
  250. --~ C[i1 * len + j1] = C[i1 * len + j1] + ( A[i1 * len + k1] * B[k1 * len + j1] )
  251. C[i1][j1] = C[i1][j1] + ( A[i1][k1] * B[k1][j1] )
  252. end
  253. end
  254. end
  255. end
  256. end
  257. end
  258. return C
  259. end
  260. function multFlatBlockMats(A, B, blockSize)
  261. local totalLen = #A
  262. local len = math.sqrt(totalLen)
  263. if len % blockSize ~= 0 then
  264. print("BAD BLOCK SIZE!\n")
  265. return -1
  266. end
  267. local C = createFlatOutMatrix(len)
  268. --choose outer squares
  269. for i = 1, len, blockSize do
  270. for j = 1, len, blockSize do
  271. for k = 1, len, blockSize do
  272. --Do multiplication of minisquare
  273. local iblock = i + blockSize - 1
  274. for i1 = i, iblock do
  275. local jblock = j + blockSize - 1
  276. for j1 = j, jblock do
  277. local kblock = k + blockSize - 1
  278. for k1 = k, kblock do
  279. --This line is from: https://www.netlib.org/utk/papers/autoblock/node2.html
  280. -- c[i,j] = c[i,j] + a[i,k] * b[k,j]
  281. --This line is from the video using unwrapped vectors.
  282. --~ C[i1 * len + j1] = C[i1 * len + j1] + ( A[i1 * len + k1] * B[k1 * len + j1] )
  283. --~ C[i1][j1] = C[i1][j1] + ( A[i1][k1] * B[k1][j1] )
  284. C[(i1 - 1)*len + j1] = C[(i1 - 1)*len + j1] + ( A[(i1-1)*len + k1] * B[(k1-1)*len + j1] )
  285. end
  286. end
  287. end
  288. end
  289. end
  290. end
  291. return C
  292. end
  293. function createOutMat(len)
  294. --create output matrix
  295. local C = {}
  296. for i = 1, len do
  297. C[i] = {}
  298. for j = 1, len do
  299. C[i][j] = 0
  300. end
  301. end
  302. return C
  303. end
  304. function multSqMats(A, B)
  305. local len = #A
  306. local C = createOutMat(len)
  307. --multiply
  308. for i = 1, len do
  309. for j = 1, len do
  310. for k = 1, len do
  311. C[i][j] = C[i][j] + (A[i][k] * B[k][j])
  312. --output[i][j] += original[i][k] * intermediate[k][j];
  313. end
  314. end
  315. end
  316. return C
  317. end
  318. function multFlatMats(A, B)
  319. local totalLen = #A
  320. local len = math.sqrt(totalLen)
  321. local C = createFlatOutMatrix(len)
  322. --multiply
  323. for i = 1, len do
  324. for j = 1, len do
  325. for k = 1, len do
  326. C[(i - 1)* len + j] = C[(i - 1)* len + j] + ( A[(i - 1)* len + k] * B[(k - 1)*len + j] )
  327. --~ C[i][j] = C[i][j] + (A[i][k] * B[k][j])
  328. end
  329. end
  330. end
  331. return C
  332. end
  333. function printFlatMatrix(mat)
  334. local len = math.sqrt(#mat)
  335. for i = 1, len do
  336. for j = 1, len do
  337. fwrite("%d ", mat[(i - 1) * len + j] )
  338. end
  339. fwrite("\n")
  340. end
  341. end
  342. function printMatrix(myMat)
  343. for i = 1, #myMat do
  344. for j = 1, #myMat[i] do
  345. fwrite("%d ", myMat[i][j])
  346. end
  347. fwrite("\n")
  348. end
  349. end
  350. function createIDMat(squareSize)
  351. local idMat = {}
  352. for i = 1, squareSize do
  353. idMat[i] = {}
  354. for j = 1, squareSize do
  355. idMat[i][j] = 0
  356. end
  357. idMat[i][i] = 1
  358. end
  359. return idMat
  360. end
  361. function createRandMat(squareSize)
  362. local myArray = {}
  363. for i = 1, squareSize do
  364. myArray[i] = {}
  365. for j = 1, squareSize do
  366. myArray[i][j] = math.random(100)
  367. end
  368. end
  369. return myArray
  370. end
  371. --function similar to C printf
  372. function fwrite(fmt, ...)
  373. return io.write( string.format(fmt, ...) )
  374. end
  375. --------
  376. main()--
  377. --------