123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- package lnwire
- import (
- "bytes"
- "fmt"
- "testing"
- "github.com/lightningnetwork/lnd/tlv"
- "github.com/stretchr/testify/require"
- )
- // testCase is a test case for the UpdateAddHTLC message.
- type testCase struct {
- // Msg is the message to be encoded and decoded.
- Msg UpdateAddHTLC
- // ExpectEncodeError is a flag that indicates whether we expect the
- // encoding of the message to fail.
- ExpectEncodeError bool
- }
- // generateTestCases generates a set of UpdateAddHTLC message test cases.
- func generateTestCases(t *testing.T) []testCase {
- // Firstly, we'll set basic values for the message fields.
- //
- // Generate random channel ID.
- chanIDBytes, err := generateRandomBytes(32)
- require.NoError(t, err)
- var chanID ChannelID
- copy(chanID[:], chanIDBytes)
- // Generate random payment hash.
- paymentHashBytes, err := generateRandomBytes(32)
- require.NoError(t, err)
- var paymentHash [32]byte
- copy(paymentHash[:], paymentHashBytes)
- // Generate random onion blob.
- onionBlobBytes, err := generateRandomBytes(OnionPacketSize)
- require.NoError(t, err)
- var onionBlob [OnionPacketSize]byte
- copy(onionBlob[:], onionBlobBytes)
- // Define the blinding point.
- blinding, err := pubkeyFromHex(
- "0228f2af0abe322403480fb3ee172f7f1601e67d1da6cad40b54c4468d4" +
- "8236c39",
- )
- require.NoError(t, err)
- blindingPoint := tlv.SomeRecordT(
- tlv.NewPrimitiveRecord[BlindingPointTlvType](blinding),
- )
- // Define custom records.
- recordKey1 := uint64(MinCustomRecordsTlvType + 1)
- recordValue1, err := generateRandomBytes(10)
- require.NoError(t, err)
- recordKey2 := uint64(MinCustomRecordsTlvType + 2)
- recordValue2, err := generateRandomBytes(10)
- require.NoError(t, err)
- customRecords := CustomRecords{
- recordKey1: recordValue1,
- recordKey2: recordValue2,
- }
- // Construct an instance of extra data that contains records with TLV
- // types below the minimum custom records threshold and that lack
- // corresponding fields in the message struct. Content should persist in
- // the extra data field after encoding and decoding.
- var (
- recordBytes45 = []byte("recordBytes45")
- tlvRecord45 = tlv.NewPrimitiveRecord[tlv.TlvType45](
- recordBytes45,
- )
- recordBytes55 = []byte("recordBytes55")
- tlvRecord55 = tlv.NewPrimitiveRecord[tlv.TlvType55](
- recordBytes55,
- )
- )
- var extraData ExtraOpaqueData
- err = extraData.PackRecords(
- []tlv.RecordProducer{&tlvRecord45, &tlvRecord55}...,
- )
- require.NoError(t, err)
- invalidCustomRecords := CustomRecords{
- MinCustomRecordsTlvType - 1: recordValue1,
- }
- return []testCase{
- {
- Msg: UpdateAddHTLC{
- ChanID: chanID,
- ID: 42,
- Amount: MilliSatoshi(1000),
- PaymentHash: paymentHash,
- Expiry: 43,
- OnionBlob: onionBlob,
- BlindingPoint: blindingPoint,
- CustomRecords: customRecords,
- ExtraData: extraData,
- },
- },
- // Add a test case where the blinding point field is not
- // populated.
- {
- Msg: UpdateAddHTLC{
- ChanID: chanID,
- ID: 42,
- Amount: MilliSatoshi(1000),
- PaymentHash: paymentHash,
- Expiry: 43,
- OnionBlob: onionBlob,
- CustomRecords: customRecords,
- },
- },
- // Add a test case where the custom records field is not
- // populated.
- {
- Msg: UpdateAddHTLC{
- ChanID: chanID,
- ID: 42,
- Amount: MilliSatoshi(1000),
- PaymentHash: paymentHash,
- Expiry: 43,
- OnionBlob: onionBlob,
- BlindingPoint: blindingPoint,
- },
- },
- // Add a case where the custom records are invalid.
- {
- Msg: UpdateAddHTLC{
- ChanID: chanID,
- ID: 42,
- Amount: MilliSatoshi(1000),
- PaymentHash: paymentHash,
- Expiry: 43,
- OnionBlob: onionBlob,
- BlindingPoint: blindingPoint,
- CustomRecords: invalidCustomRecords,
- },
- ExpectEncodeError: true,
- },
- }
- }
- // TestUpdateAddHtlcEncodeDecode tests UpdateAddHTLC message encoding and
- // decoding for all supported field values.
- func TestUpdateAddHtlcEncodeDecode(t *testing.T) {
- t.Parallel()
- // Generate test cases.
- testCases := generateTestCases(t)
- // Execute test cases.
- for tcIdx, tc := range testCases {
- t.Run(fmt.Sprintf("testcase-%d", tcIdx), func(t *testing.T) {
- // Encode test case message.
- var buf bytes.Buffer
- err := tc.Msg.Encode(&buf, 0)
- // Check if we expect an encoding error.
- if tc.ExpectEncodeError {
- require.Error(t, err)
- return
- }
- require.NoError(t, err)
- // Decode the encoded message bytes message.
- var actualMsg UpdateAddHTLC
- decodeReader := bytes.NewReader(buf.Bytes())
- err = actualMsg.Decode(decodeReader, 0)
- require.NoError(t, err)
- // Compare the two messages to ensure equality.
- require.Equal(t, tc.Msg, actualMsg)
- })
- }
- }
|