bmt_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. // Copyright 2017 The go-ethereum Authors
  2. // This file is part of the go-ethereum library.
  3. //
  4. // The go-ethereum library is free software: you can redistribute it and/or modify
  5. // it under the terms of the GNU Lesser General Public License as published by
  6. // the Free Software Foundation, either version 3 of the License, or
  7. // (at your option) any later version.
  8. //
  9. // The go-ethereum library is distributed in the hope that it will be useful,
  10. // but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. // GNU Lesser General Public License for more details.
  13. //
  14. // You should have received a copy of the GNU Lesser General Public License
  15. // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
  16. package bmt
  17. import (
  18. "bytes"
  19. crand "crypto/rand"
  20. "fmt"
  21. "hash"
  22. "io"
  23. "math/rand"
  24. "sync"
  25. "sync/atomic"
  26. "testing"
  27. "time"
  28. "github.com/ethereum/go-ethereum/crypto/sha3"
  29. )
  30. const (
  31. maxproccnt = 8
  32. )
  33. // TestRefHasher tests that the RefHasher computes the expected BMT hash for
  34. // all data lengths between 0 and 256 bytes
  35. func TestRefHasher(t *testing.T) {
  36. hashFunc := sha3.NewKeccak256
  37. sha3 := func(data ...[]byte) []byte {
  38. h := hashFunc()
  39. for _, v := range data {
  40. h.Write(v)
  41. }
  42. return h.Sum(nil)
  43. }
  44. // the test struct is used to specify the expected BMT hash for data
  45. // lengths between "from" and "to"
  46. type test struct {
  47. from int64
  48. to int64
  49. expected func([]byte) []byte
  50. }
  51. var tests []*test
  52. // all lengths in [0,64] should be:
  53. //
  54. // sha3(data)
  55. //
  56. tests = append(tests, &test{
  57. from: 0,
  58. to: 64,
  59. expected: func(data []byte) []byte {
  60. return sha3(data)
  61. },
  62. })
  63. // all lengths in [65,96] should be:
  64. //
  65. // sha3(
  66. // sha3(data[:64])
  67. // data[64:]
  68. // )
  69. //
  70. tests = append(tests, &test{
  71. from: 65,
  72. to: 96,
  73. expected: func(data []byte) []byte {
  74. return sha3(sha3(data[:64]), data[64:])
  75. },
  76. })
  77. // all lengths in [97,128] should be:
  78. //
  79. // sha3(
  80. // sha3(data[:64])
  81. // sha3(data[64:])
  82. // )
  83. //
  84. tests = append(tests, &test{
  85. from: 97,
  86. to: 128,
  87. expected: func(data []byte) []byte {
  88. return sha3(sha3(data[:64]), sha3(data[64:]))
  89. },
  90. })
  91. // all lengths in [129,160] should be:
  92. //
  93. // sha3(
  94. // sha3(
  95. // sha3(data[:64])
  96. // sha3(data[64:128])
  97. // )
  98. // data[128:]
  99. // )
  100. //
  101. tests = append(tests, &test{
  102. from: 129,
  103. to: 160,
  104. expected: func(data []byte) []byte {
  105. return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), data[128:])
  106. },
  107. })
  108. // all lengths in [161,192] should be:
  109. //
  110. // sha3(
  111. // sha3(
  112. // sha3(data[:64])
  113. // sha3(data[64:128])
  114. // )
  115. // sha3(data[128:])
  116. // )
  117. //
  118. tests = append(tests, &test{
  119. from: 161,
  120. to: 192,
  121. expected: func(data []byte) []byte {
  122. return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(data[128:]))
  123. },
  124. })
  125. // all lengths in [193,224] should be:
  126. //
  127. // sha3(
  128. // sha3(
  129. // sha3(data[:64])
  130. // sha3(data[64:128])
  131. // )
  132. // sha3(
  133. // sha3(data[128:192])
  134. // data[192:]
  135. // )
  136. // )
  137. //
  138. tests = append(tests, &test{
  139. from: 193,
  140. to: 224,
  141. expected: func(data []byte) []byte {
  142. return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), data[192:]))
  143. },
  144. })
  145. // all lengths in [225,256] should be:
  146. //
  147. // sha3(
  148. // sha3(
  149. // sha3(data[:64])
  150. // sha3(data[64:128])
  151. // )
  152. // sha3(
  153. // sha3(data[128:192])
  154. // sha3(data[192:])
  155. // )
  156. // )
  157. //
  158. tests = append(tests, &test{
  159. from: 225,
  160. to: 256,
  161. expected: func(data []byte) []byte {
  162. return sha3(sha3(sha3(data[:64]), sha3(data[64:128])), sha3(sha3(data[128:192]), sha3(data[192:])))
  163. },
  164. })
  165. // run the tests
  166. for _, x := range tests {
  167. for length := x.from; length <= x.to; length++ {
  168. t.Run(fmt.Sprintf("%d_bytes", length), func(t *testing.T) {
  169. data := make([]byte, length)
  170. if _, err := io.ReadFull(crand.Reader, data); err != nil && err != io.EOF {
  171. t.Fatal(err)
  172. }
  173. expected := x.expected(data)
  174. actual := NewRefHasher(hashFunc, 128).Hash(data)
  175. if !bytes.Equal(actual, expected) {
  176. t.Fatalf("expected %x, got %x", expected, actual)
  177. }
  178. })
  179. }
  180. }
  181. }
  182. func testDataReader(l int) (r io.Reader) {
  183. return io.LimitReader(crand.Reader, int64(l))
  184. }
  185. func TestHasherCorrectness(t *testing.T) {
  186. err := testHasher(testBaseHasher)
  187. if err != nil {
  188. t.Fatal(err)
  189. }
  190. }
  191. func testHasher(f func(BaseHasher, []byte, int, int) error) error {
  192. tdata := testDataReader(4128)
  193. data := make([]byte, 4128)
  194. tdata.Read(data)
  195. hasher := sha3.NewKeccak256
  196. size := hasher().Size()
  197. counts := []int{1, 2, 3, 4, 5, 8, 16, 32, 64, 128}
  198. var err error
  199. for _, count := range counts {
  200. max := count * size
  201. incr := 1
  202. for n := 0; n <= max+incr; n += incr {
  203. err = f(hasher, data, n, count)
  204. if err != nil {
  205. return err
  206. }
  207. }
  208. }
  209. return nil
  210. }
  211. func TestHasherReuseWithoutRelease(t *testing.T) {
  212. testHasherReuse(1, t)
  213. }
  214. func TestHasherReuseWithRelease(t *testing.T) {
  215. testHasherReuse(maxproccnt, t)
  216. }
  217. func testHasherReuse(i int, t *testing.T) {
  218. hasher := sha3.NewKeccak256
  219. pool := NewTreePool(hasher, 128, i)
  220. defer pool.Drain(0)
  221. bmt := New(pool)
  222. for i := 0; i < 500; i++ {
  223. n := rand.Intn(4096)
  224. tdata := testDataReader(n)
  225. data := make([]byte, n)
  226. tdata.Read(data)
  227. err := testHasherCorrectness(bmt, hasher, data, n, 128)
  228. if err != nil {
  229. t.Fatal(err)
  230. }
  231. }
  232. }
  233. func TestHasherConcurrency(t *testing.T) {
  234. hasher := sha3.NewKeccak256
  235. pool := NewTreePool(hasher, 128, maxproccnt)
  236. defer pool.Drain(0)
  237. wg := sync.WaitGroup{}
  238. cycles := 100
  239. wg.Add(maxproccnt * cycles)
  240. errc := make(chan error)
  241. for p := 0; p < maxproccnt; p++ {
  242. for i := 0; i < cycles; i++ {
  243. go func() {
  244. bmt := New(pool)
  245. n := rand.Intn(4096)
  246. tdata := testDataReader(n)
  247. data := make([]byte, n)
  248. tdata.Read(data)
  249. err := testHasherCorrectness(bmt, hasher, data, n, 128)
  250. wg.Done()
  251. if err != nil {
  252. errc <- err
  253. }
  254. }()
  255. }
  256. }
  257. go func() {
  258. wg.Wait()
  259. close(errc)
  260. }()
  261. var err error
  262. select {
  263. case <-time.NewTimer(5 * time.Second).C:
  264. err = fmt.Errorf("timed out")
  265. case err = <-errc:
  266. }
  267. if err != nil {
  268. t.Fatal(err)
  269. }
  270. }
  271. func testBaseHasher(hasher BaseHasher, d []byte, n, count int) error {
  272. pool := NewTreePool(hasher, count, 1)
  273. defer pool.Drain(0)
  274. bmt := New(pool)
  275. return testHasherCorrectness(bmt, hasher, d, n, count)
  276. }
  277. func testHasherCorrectness(bmt hash.Hash, hasher BaseHasher, d []byte, n, count int) (err error) {
  278. data := d[:n]
  279. rbmt := NewRefHasher(hasher, count)
  280. exp := rbmt.Hash(data)
  281. timeout := time.NewTimer(time.Second)
  282. c := make(chan error)
  283. go func() {
  284. bmt.Reset()
  285. bmt.Write(data)
  286. got := bmt.Sum(nil)
  287. if !bytes.Equal(got, exp) {
  288. c <- fmt.Errorf("wrong hash: expected %x, got %x", exp, got)
  289. }
  290. close(c)
  291. }()
  292. select {
  293. case <-timeout.C:
  294. err = fmt.Errorf("BMT hash calculation timed out")
  295. case err = <-c:
  296. }
  297. return err
  298. }
  299. func BenchmarkSHA3_4k(t *testing.B) { benchmarkSHA3(4096, t) }
  300. func BenchmarkSHA3_2k(t *testing.B) { benchmarkSHA3(4096/2, t) }
  301. func BenchmarkSHA3_1k(t *testing.B) { benchmarkSHA3(4096/4, t) }
  302. func BenchmarkSHA3_512b(t *testing.B) { benchmarkSHA3(4096/8, t) }
  303. func BenchmarkSHA3_256b(t *testing.B) { benchmarkSHA3(4096/16, t) }
  304. func BenchmarkSHA3_128b(t *testing.B) { benchmarkSHA3(4096/32, t) }
  305. func BenchmarkBMTBaseline_4k(t *testing.B) { benchmarkBMTBaseline(4096, t) }
  306. func BenchmarkBMTBaseline_2k(t *testing.B) { benchmarkBMTBaseline(4096/2, t) }
  307. func BenchmarkBMTBaseline_1k(t *testing.B) { benchmarkBMTBaseline(4096/4, t) }
  308. func BenchmarkBMTBaseline_512b(t *testing.B) { benchmarkBMTBaseline(4096/8, t) }
  309. func BenchmarkBMTBaseline_256b(t *testing.B) { benchmarkBMTBaseline(4096/16, t) }
  310. func BenchmarkBMTBaseline_128b(t *testing.B) { benchmarkBMTBaseline(4096/32, t) }
  311. func BenchmarkRefHasher_4k(t *testing.B) { benchmarkRefHasher(4096, t) }
  312. func BenchmarkRefHasher_2k(t *testing.B) { benchmarkRefHasher(4096/2, t) }
  313. func BenchmarkRefHasher_1k(t *testing.B) { benchmarkRefHasher(4096/4, t) }
  314. func BenchmarkRefHasher_512b(t *testing.B) { benchmarkRefHasher(4096/8, t) }
  315. func BenchmarkRefHasher_256b(t *testing.B) { benchmarkRefHasher(4096/16, t) }
  316. func BenchmarkRefHasher_128b(t *testing.B) { benchmarkRefHasher(4096/32, t) }
  317. func BenchmarkHasher_4k(t *testing.B) { benchmarkHasher(4096, t) }
  318. func BenchmarkHasher_2k(t *testing.B) { benchmarkHasher(4096/2, t) }
  319. func BenchmarkHasher_1k(t *testing.B) { benchmarkHasher(4096/4, t) }
  320. func BenchmarkHasher_512b(t *testing.B) { benchmarkHasher(4096/8, t) }
  321. func BenchmarkHasher_256b(t *testing.B) { benchmarkHasher(4096/16, t) }
  322. func BenchmarkHasher_128b(t *testing.B) { benchmarkHasher(4096/32, t) }
  323. func BenchmarkHasherNoReuse_4k(t *testing.B) { benchmarkHasherReuse(1, 4096, t) }
  324. func BenchmarkHasherNoReuse_2k(t *testing.B) { benchmarkHasherReuse(1, 4096/2, t) }
  325. func BenchmarkHasherNoReuse_1k(t *testing.B) { benchmarkHasherReuse(1, 4096/4, t) }
  326. func BenchmarkHasherNoReuse_512b(t *testing.B) { benchmarkHasherReuse(1, 4096/8, t) }
  327. func BenchmarkHasherNoReuse_256b(t *testing.B) { benchmarkHasherReuse(1, 4096/16, t) }
  328. func BenchmarkHasherNoReuse_128b(t *testing.B) { benchmarkHasherReuse(1, 4096/32, t) }
  329. func BenchmarkHasherReuse_4k(t *testing.B) { benchmarkHasherReuse(16, 4096, t) }
  330. func BenchmarkHasherReuse_2k(t *testing.B) { benchmarkHasherReuse(16, 4096/2, t) }
  331. func BenchmarkHasherReuse_1k(t *testing.B) { benchmarkHasherReuse(16, 4096/4, t) }
  332. func BenchmarkHasherReuse_512b(t *testing.B) { benchmarkHasherReuse(16, 4096/8, t) }
  333. func BenchmarkHasherReuse_256b(t *testing.B) { benchmarkHasherReuse(16, 4096/16, t) }
  334. func BenchmarkHasherReuse_128b(t *testing.B) { benchmarkHasherReuse(16, 4096/32, t) }
  335. // benchmarks the minimum hashing time for a balanced (for simplicity) BMT
  336. // by doing count/segmentsize parallel hashings of 2*segmentsize bytes
  337. // doing it on n maxproccnt each reusing the base hasher
  338. // the premise is that this is the minimum computation needed for a BMT
  339. // therefore this serves as a theoretical optimum for concurrent implementations
  340. func benchmarkBMTBaseline(n int, t *testing.B) {
  341. tdata := testDataReader(64)
  342. data := make([]byte, 64)
  343. tdata.Read(data)
  344. hasher := sha3.NewKeccak256
  345. t.ReportAllocs()
  346. t.ResetTimer()
  347. for i := 0; i < t.N; i++ {
  348. count := int32((n-1)/hasher().Size() + 1)
  349. wg := sync.WaitGroup{}
  350. wg.Add(maxproccnt)
  351. var i int32
  352. for j := 0; j < maxproccnt; j++ {
  353. go func() {
  354. defer wg.Done()
  355. h := hasher()
  356. for atomic.AddInt32(&i, 1) < count {
  357. h.Reset()
  358. h.Write(data)
  359. h.Sum(nil)
  360. }
  361. }()
  362. }
  363. wg.Wait()
  364. }
  365. }
  366. func benchmarkHasher(n int, t *testing.B) {
  367. tdata := testDataReader(n)
  368. data := make([]byte, n)
  369. tdata.Read(data)
  370. size := 1
  371. hasher := sha3.NewKeccak256
  372. segmentCount := 128
  373. pool := NewTreePool(hasher, segmentCount, size)
  374. bmt := New(pool)
  375. t.ReportAllocs()
  376. t.ResetTimer()
  377. for i := 0; i < t.N; i++ {
  378. bmt.Reset()
  379. bmt.Write(data)
  380. bmt.Sum(nil)
  381. }
  382. }
  383. func benchmarkHasherReuse(poolsize, n int, t *testing.B) {
  384. tdata := testDataReader(n)
  385. data := make([]byte, n)
  386. tdata.Read(data)
  387. hasher := sha3.NewKeccak256
  388. segmentCount := 128
  389. pool := NewTreePool(hasher, segmentCount, poolsize)
  390. cycles := 200
  391. t.ReportAllocs()
  392. t.ResetTimer()
  393. for i := 0; i < t.N; i++ {
  394. wg := sync.WaitGroup{}
  395. wg.Add(cycles)
  396. for j := 0; j < cycles; j++ {
  397. bmt := New(pool)
  398. go func() {
  399. defer wg.Done()
  400. bmt.Reset()
  401. bmt.Write(data)
  402. bmt.Sum(nil)
  403. }()
  404. }
  405. wg.Wait()
  406. }
  407. }
  408. func benchmarkSHA3(n int, t *testing.B) {
  409. data := make([]byte, n)
  410. tdata := testDataReader(n)
  411. tdata.Read(data)
  412. hasher := sha3.NewKeccak256
  413. h := hasher()
  414. t.ReportAllocs()
  415. t.ResetTimer()
  416. for i := 0; i < t.N; i++ {
  417. h.Reset()
  418. h.Write(data)
  419. h.Sum(nil)
  420. }
  421. }
  422. func benchmarkRefHasher(n int, t *testing.B) {
  423. data := make([]byte, n)
  424. tdata := testDataReader(n)
  425. tdata.Read(data)
  426. hasher := sha3.NewKeccak256
  427. rbmt := NewRefHasher(hasher, 128)
  428. t.ReportAllocs()
  429. t.ResetTimer()
  430. for i := 0; i < t.N; i++ {
  431. rbmt.Hash(data)
  432. }
  433. }