update_add_htlc_test.go 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. package lnwire
  2. import (
  3. "bytes"
  4. "fmt"
  5. "testing"
  6. "github.com/lightningnetwork/lnd/tlv"
  7. "github.com/stretchr/testify/require"
  8. )
  9. // testCase is a test case for the UpdateAddHTLC message.
  10. type testCase struct {
  11. // Msg is the message to be encoded and decoded.
  12. Msg UpdateAddHTLC
  13. // ExpectEncodeError is a flag that indicates whether we expect the
  14. // encoding of the message to fail.
  15. ExpectEncodeError bool
  16. }
  17. // generateTestCases generates a set of UpdateAddHTLC message test cases.
  18. func generateTestCases(t *testing.T) []testCase {
  19. // Firstly, we'll set basic values for the message fields.
  20. //
  21. // Generate random channel ID.
  22. chanIDBytes, err := generateRandomBytes(32)
  23. require.NoError(t, err)
  24. var chanID ChannelID
  25. copy(chanID[:], chanIDBytes)
  26. // Generate random payment hash.
  27. paymentHashBytes, err := generateRandomBytes(32)
  28. require.NoError(t, err)
  29. var paymentHash [32]byte
  30. copy(paymentHash[:], paymentHashBytes)
  31. // Generate random onion blob.
  32. onionBlobBytes, err := generateRandomBytes(OnionPacketSize)
  33. require.NoError(t, err)
  34. var onionBlob [OnionPacketSize]byte
  35. copy(onionBlob[:], onionBlobBytes)
  36. // Define the blinding point.
  37. blinding, err := pubkeyFromHex(
  38. "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
  39. "8236c39",
  40. )
  41. require.NoError(t, err)
  42. blindingPoint := tlv.SomeRecordT(
  43. tlv.NewPrimitiveRecord[BlindingPointTlvType](blinding),
  44. )
  45. // Define custom records.
  46. recordKey1 := uint64(MinCustomRecordsTlvType + 1)
  47. recordValue1, err := generateRandomBytes(10)
  48. require.NoError(t, err)
  49. recordKey2 := uint64(MinCustomRecordsTlvType + 2)
  50. recordValue2, err := generateRandomBytes(10)
  51. require.NoError(t, err)
  52. customRecords := CustomRecords{
  53. recordKey1: recordValue1,
  54. recordKey2: recordValue2,
  55. }
  56. // Construct an instance of extra data that contains records with TLV
  57. // types below the minimum custom records threshold and that lack
  58. // corresponding fields in the message struct. Content should persist in
  59. // the extra data field after encoding and decoding.
  60. var (
  61. recordBytes45 = []byte("recordBytes45")
  62. tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
  63. recordBytes45,
  64. )
  65. recordBytes55 = []byte("recordBytes55")
  66. tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
  67. recordBytes55,
  68. )
  69. )
  70. var extraData ExtraOpaqueData
  71. err = extraData.PackRecords(
  72. []tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
  73. )
  74. require.NoError(t, err)
  75. invalidCustomRecords := CustomRecords{
  76. MinCustomRecordsTlvType - 1: recordValue1,
  77. }
  78. return []testCase{
  79. {
  80. Msg: UpdateAddHTLC{
  81. ChanID: chanID,
  82. ID: 42,
  83. Amount: MilliSatoshi(1000),
  84. PaymentHash: paymentHash,
  85. Expiry: 43,
  86. OnionBlob: onionBlob,
  87. BlindingPoint: blindingPoint,
  88. CustomRecords: customRecords,
  89. ExtraData: extraData,
  90. },
  91. },
  92. // Add a test case where the blinding point field is not
  93. // populated.
  94. {
  95. Msg: UpdateAddHTLC{
  96. ChanID: chanID,
  97. ID: 42,
  98. Amount: MilliSatoshi(1000),
  99. PaymentHash: paymentHash,
  100. Expiry: 43,
  101. OnionBlob: onionBlob,
  102. CustomRecords: customRecords,
  103. },
  104. },
  105. // Add a test case where the custom records field is not
  106. // populated.
  107. {
  108. Msg: UpdateAddHTLC{
  109. ChanID: chanID,
  110. ID: 42,
  111. Amount: MilliSatoshi(1000),
  112. PaymentHash: paymentHash,
  113. Expiry: 43,
  114. OnionBlob: onionBlob,
  115. BlindingPoint: blindingPoint,
  116. },
  117. },
  118. // Add a case where the custom records are invalid.
  119. {
  120. Msg: UpdateAddHTLC{
  121. ChanID: chanID,
  122. ID: 42,
  123. Amount: MilliSatoshi(1000),
  124. PaymentHash: paymentHash,
  125. Expiry: 43,
  126. OnionBlob: onionBlob,
  127. BlindingPoint: blindingPoint,
  128. CustomRecords: invalidCustomRecords,
  129. },
  130. ExpectEncodeError: true,
  131. },
  132. }
  133. }
  134. // TestUpdateAddHtlcEncodeDecode tests UpdateAddHTLC message encoding and
  135. // decoding for all supported field values.
  136. func TestUpdateAddHtlcEncodeDecode(t *testing.T) {
  137. t.Parallel()
  138. // Generate test cases.
  139. testCases := generateTestCases(t)
  140. // Execute test cases.
  141. for tcIdx, tc := range testCases {
  142. t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
  143. // Encode test case message.
  144. var buf bytes.Buffer
  145. err := tc.Msg.Encode(&buf, 0)
  146. // Check if we expect an encoding error.
  147. if tc.ExpectEncodeError {
  148. require.Error(t, err)
  149. return
  150. }
  151. require.NoError(t, err)
  152. // Decode the encoded message bytes message.
  153. var actualMsg UpdateAddHTLC
  154. decodeReader := bytes.NewReader(buf.Bytes())
  155. err = actualMsg.Decode(decodeReader, 0)
  156. require.NoError(t, err)
  157. // Compare the two messages to ensure equality.
  158. require.Equal(t, tc.Msg, actualMsg)
  159. })
  160. }
  161. }