shard_tracker_test.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package amp_test
  2. import (
  3. "crypto/rand"
  4. "testing"
  5. "github.com/lightningnetwork/lnd/amp"
  6. "github.com/lightningnetwork/lnd/lnwire"
  7. "github.com/lightningnetwork/lnd/routing/shards"
  8. "github.com/stretchr/testify/require"
  9. )
  10. // TestAMPShardTracker tests that we can derive and cancel shards at will using
  11. // the AMP shard tracker.
  12. func TestAMPShardTracker(t *testing.T) {
  13. var root, setID, payAddr [32]byte
  14. _, err := rand.Read(root[:])
  15. require.NoError(t, err)
  16. _, err = rand.Read(setID[:])
  17. require.NoError(t, err)
  18. _, err = rand.Read(payAddr[:])
  19. require.NoError(t, err)
  20. var totalAmt lnwire.MilliSatoshi = 1000
  21. // Create an AMP shard tracker using the random data we just generated.
  22. tracker := amp.NewShardTracker(root, setID, payAddr, totalAmt)
  23. // Trying to retrieve a hash for id 0 should result in an error.
  24. _, err = tracker.GetHash(0)
  25. require.Error(t, err)
  26. // We start by creating 20 shards.
  27. const numShards = 20
  28. var shards []shards.PaymentShard
  29. for i := uint64(0); i < numShards; i++ {
  30. s, err := tracker.NewShard(i, i == numShards-1)
  31. require.NoError(t, err)
  32. // Check that the shards have their payloads set as expected.
  33. require.Equal(t, setID, s.AMP().SetID())
  34. require.Equal(t, totalAmt, s.MPP().TotalMsat())
  35. require.Equal(t, payAddr, s.MPP().PaymentAddr())
  36. shards = append(shards, s)
  37. }
  38. // Make sure we can retrieve the hash for all of them.
  39. for i := uint64(0); i < numShards; i++ {
  40. hash, err := tracker.GetHash(i)
  41. require.NoError(t, err)
  42. require.Equal(t, shards[i].Hash(), hash)
  43. }
  44. // Now cancel half of the shards.
  45. j := 0
  46. for i := uint64(0); i < numShards; i++ {
  47. if i%2 == 0 {
  48. err := tracker.CancelShard(i)
  49. require.NoError(t, err)
  50. continue
  51. }
  52. // Keep shard.
  53. shards[j] = shards[i]
  54. j++
  55. }
  56. shards = shards[:j]
  57. // Get a new last shard.
  58. s, err := tracker.NewShard(numShards, true)
  59. require.NoError(t, err)
  60. shards = append(shards, s)
  61. // Finally make sure these shards together can be used to reconstruct
  62. // the children.
  63. childDescs := make([]amp.ChildDesc, len(shards))
  64. for i, s := range shards {
  65. childDescs[i] = amp.ChildDesc{
  66. Share: s.AMP().RootShare(),
  67. Index: s.AMP().ChildIndex(),
  68. }
  69. }
  70. // Using the child descriptors, reconstruct the children.
  71. children := amp.ReconstructChildren(childDescs...)
  72. // Validate that the derived child preimages match the hash of each shard.
  73. for i, child := range children {
  74. require.Equal(t, shards[i].Hash(), child.Hash)
  75. }
  76. }