channel_test.go 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703
  1. package channeldb
  2. import (
  3. "bytes"
  4. "math/rand"
  5. "net"
  6. "reflect"
  7. "runtime"
  8. "sync/atomic"
  9. "testing"
  10. "github.com/btcsuite/btcd/btcec/v2"
  11. "github.com/btcsuite/btcd/btcutil"
  12. "github.com/btcsuite/btcd/chaincfg/chainhash"
  13. "github.com/btcsuite/btcd/wire"
  14. _ "github.com/btcsuite/btcwallet/walletdb/bdb"
  15. "github.com/davecgh/go-spew/spew"
  16. "github.com/lightningnetwork/lnd/channeldb/models"
  17. "github.com/lightningnetwork/lnd/clock"
  18. "github.com/lightningnetwork/lnd/keychain"
  19. "github.com/lightningnetwork/lnd/kvdb"
  20. "github.com/lightningnetwork/lnd/lnmock"
  21. "github.com/lightningnetwork/lnd/lntest/channels"
  22. "github.com/lightningnetwork/lnd/lnwire"
  23. "github.com/lightningnetwork/lnd/shachain"
  24. "github.com/lightningnetwork/lnd/tlv"
  25. "github.com/stretchr/testify/require"
  26. )
  27. var (
  28. key = [chainhash.HashSize]byte{
  29. 0x81, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
  30. 0x68, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
  31. 0xd, 0xe7, 0x93, 0xe4, 0xb7, 0x25, 0xb8, 0x4d,
  32. 0x1e, 0xb, 0x4c, 0xf9, 0x9e, 0xc5, 0x8c, 0xe9,
  33. }
  34. rev = [chainhash.HashSize]byte{
  35. 0x51, 0xb6, 0x37, 0xd8, 0xfc, 0xd2, 0xc6, 0xda,
  36. 0x48, 0x59, 0xe6, 0x96, 0x31, 0x13, 0xa1, 0x17,
  37. 0x2d, 0xe7, 0x93, 0xe4,
  38. }
  39. privKey, pubKey = btcec.PrivKeyFromBytes(key[:])
  40. wireSig, _ = lnwire.NewSigFromSignature(testSig)
  41. testClock = clock.NewTestClock(testNow)
  42. // defaultPendingHeight is the default height at which we set
  43. // channels to pending.
  44. defaultPendingHeight = 100
  45. // defaultAddr is the default address that we mark test channels pending
  46. // with.
  47. defaultAddr = &net.TCPAddr{
  48. IP: net.ParseIP("127.0.0.1"),
  49. Port: 18555,
  50. }
  51. // keyLocIndex is the KeyLocator Index we use for
  52. // TestKeyLocatorEncoding.
  53. keyLocIndex = uint32(2049)
  54. // dummyLocalOutputIndex specifics a default value for our output index
  55. // in this test.
  56. dummyLocalOutputIndex = uint32(0)
  57. // dummyRemoteOutIndex specifics a default value for their output index
  58. // in this test.
  59. dummyRemoteOutIndex = uint32(1)
  60. // uniqueOutputIndex is used to create a unique funding outpoint.
  61. //
  62. // NOTE: must be incremented when used.
  63. uniqueOutputIndex = atomic.Uint32{}
  64. )
  65. // testChannelParams is a struct which details the specifics of how a channel
  66. // should be created.
  67. type testChannelParams struct {
  68. // channel is the channel that will be written to disk.
  69. channel *OpenChannel
  70. // addr is the address that the channel will be synced pending with.
  71. addr *net.TCPAddr
  72. // pendingHeight is the height that the channel should be recorded as
  73. // pending.
  74. pendingHeight uint32
  75. // openChannel is set to true if the channel should be fully marked as
  76. // open if this is false, the channel will be left in pending state.
  77. openChannel bool
  78. }
  79. // testChannelOption is a functional option which can be used to alter the
  80. // default channel that is creates for testing.
  81. type testChannelOption func(params *testChannelParams)
  82. // pendingHeightOption is an option which can be used to set the height the
  83. // channel is marked as pending at.
  84. func pendingHeightOption(height uint32) testChannelOption {
  85. return func(params *testChannelParams) {
  86. params.pendingHeight = height
  87. }
  88. }
  89. // openChannelOption is an option which can be used to create a test channel
  90. // that is open.
  91. func openChannelOption() testChannelOption {
  92. return func(params *testChannelParams) {
  93. params.openChannel = true
  94. }
  95. }
  96. // localHtlcsOption is an option which allows setting of htlcs on the local
  97. // commitment.
  98. func localHtlcsOption(htlcs []HTLC) testChannelOption {
  99. return func(params *testChannelParams) {
  100. params.channel.LocalCommitment.Htlcs = htlcs
  101. }
  102. }
  103. // remoteHtlcsOption is an option which allows setting of htlcs on the remote
  104. // commitment.
  105. func remoteHtlcsOption(htlcs []HTLC) testChannelOption {
  106. return func(params *testChannelParams) {
  107. params.channel.RemoteCommitment.Htlcs = htlcs
  108. }
  109. }
  110. // loadFwdPkgs is a helper method that reads all forwarding packages for a
  111. // particular packager.
  112. func loadFwdPkgs(t *testing.T, db kvdb.Backend,
  113. packager FwdPackager) []*FwdPkg {
  114. var (
  115. fwdPkgs []*FwdPkg
  116. err error
  117. )
  118. err = kvdb.View(db, func(tx kvdb.RTx) error {
  119. fwdPkgs, err = packager.LoadFwdPkgs(tx)
  120. return err
  121. }, func() {})
  122. require.NoError(t, err, "unable to load fwd pkgs")
  123. return fwdPkgs
  124. }
  125. // localShutdownOption is an option which sets the local upfront shutdown
  126. // script for the channel.
  127. func localShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
  128. return func(params *testChannelParams) {
  129. params.channel.LocalShutdownScript = addr
  130. }
  131. }
  132. // remoteShutdownOption is an option which sets the remote upfront shutdown
  133. // script for the channel.
  134. func remoteShutdownOption(addr lnwire.DeliveryAddress) testChannelOption {
  135. return func(params *testChannelParams) {
  136. params.channel.RemoteShutdownScript = addr
  137. }
  138. }
  139. // fundingPointOption is an option which sets the funding outpoint of the
  140. // channel.
  141. func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
  142. return func(params *testChannelParams) {
  143. params.channel.FundingOutpoint = chanPoint
  144. }
  145. }
  146. // channelIDOption is an option which sets the short channel ID of the channel.
  147. var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption {
  148. return func(params *testChannelParams) {
  149. params.channel.ShortChannelID = chanID
  150. }
  151. }
  152. // createTestChannel writes a test channel to the database. It takes a set of
  153. // functional options which can be used to overwrite the default of creating
  154. // a pending channel that was broadcast at height 100.
  155. func createTestChannel(t *testing.T, cdb *ChannelStateDB,
  156. opts ...testChannelOption) *OpenChannel {
  157. // Create a default set of parameters.
  158. params := &testChannelParams{
  159. channel: createTestChannelState(t, cdb),
  160. addr: defaultAddr,
  161. openChannel: false,
  162. pendingHeight: uint32(defaultPendingHeight),
  163. }
  164. // Apply all functional options to the test channel params.
  165. for _, o := range opts {
  166. o(params)
  167. }
  168. // Mark the channel as pending.
  169. err := params.channel.SyncPending(params.addr, params.pendingHeight)
  170. if err != nil {
  171. t.Fatalf("unable to save and serialize channel "+
  172. "state: %v", err)
  173. }
  174. // If the parameters do not specify that we should open the channel
  175. // fully, we return the pending channel.
  176. if !params.openChannel {
  177. return params.channel
  178. }
  179. // Mark the channel as open with the short channel id provided.
  180. err = params.channel.MarkAsOpen(params.channel.ShortChannelID)
  181. require.NoError(t, err, "unable to mark channel open")
  182. return params.channel
  183. }
  184. func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
  185. // Simulate 1000 channel updates.
  186. producer, err := shachain.NewRevocationProducerFromBytes(key[:])
  187. require.NoError(t, err, "could not get producer")
  188. store := shachain.NewRevocationStore()
  189. for i := 0; i < 1; i++ {
  190. preImage, err := producer.AtIndex(uint64(i))
  191. if err != nil {
  192. t.Fatalf("could not get "+
  193. "preimage: %v", err)
  194. }
  195. if err := store.AddNextEntry(preImage); err != nil {
  196. t.Fatalf("could not add entry: %v", err)
  197. }
  198. }
  199. localCfg := ChannelConfig{
  200. ChannelConstraints: ChannelConstraints{
  201. DustLimit: btcutil.Amount(rand.Int63()),
  202. MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()),
  203. ChanReserve: btcutil.Amount(rand.Int63()),
  204. MinHTLC: lnwire.MilliSatoshi(rand.Int63()),
  205. MaxAcceptedHtlcs: uint16(rand.Int31()),
  206. CsvDelay: uint16(rand.Int31()),
  207. },
  208. MultiSigKey: keychain.KeyDescriptor{
  209. PubKey: privKey.PubKey(),
  210. },
  211. RevocationBasePoint: keychain.KeyDescriptor{
  212. PubKey: privKey.PubKey(),
  213. },
  214. PaymentBasePoint: keychain.KeyDescriptor{
  215. PubKey: privKey.PubKey(),
  216. },
  217. DelayBasePoint: keychain.KeyDescriptor{
  218. PubKey: privKey.PubKey(),
  219. },
  220. HtlcBasePoint: keychain.KeyDescriptor{
  221. PubKey: privKey.PubKey(),
  222. },
  223. }
  224. remoteCfg := ChannelConfig{
  225. ChannelConstraints: ChannelConstraints{
  226. DustLimit: btcutil.Amount(rand.Int63()),
  227. MaxPendingAmount: lnwire.MilliSatoshi(rand.Int63()),
  228. ChanReserve: btcutil.Amount(rand.Int63()),
  229. MinHTLC: lnwire.MilliSatoshi(rand.Int63()),
  230. MaxAcceptedHtlcs: uint16(rand.Int31()),
  231. CsvDelay: uint16(rand.Int31()),
  232. },
  233. MultiSigKey: keychain.KeyDescriptor{
  234. PubKey: privKey.PubKey(),
  235. KeyLocator: keychain.KeyLocator{
  236. Family: keychain.KeyFamilyMultiSig,
  237. Index: 9,
  238. },
  239. },
  240. RevocationBasePoint: keychain.KeyDescriptor{
  241. PubKey: privKey.PubKey(),
  242. KeyLocator: keychain.KeyLocator{
  243. Family: keychain.KeyFamilyRevocationBase,
  244. Index: 8,
  245. },
  246. },
  247. PaymentBasePoint: keychain.KeyDescriptor{
  248. PubKey: privKey.PubKey(),
  249. KeyLocator: keychain.KeyLocator{
  250. Family: keychain.KeyFamilyPaymentBase,
  251. Index: 7,
  252. },
  253. },
  254. DelayBasePoint: keychain.KeyDescriptor{
  255. PubKey: privKey.PubKey(),
  256. KeyLocator: keychain.KeyLocator{
  257. Family: keychain.KeyFamilyDelayBase,
  258. Index: 6,
  259. },
  260. },
  261. HtlcBasePoint: keychain.KeyDescriptor{
  262. PubKey: privKey.PubKey(),
  263. KeyLocator: keychain.KeyLocator{
  264. Family: keychain.KeyFamilyHtlcBase,
  265. Index: 5,
  266. },
  267. },
  268. }
  269. chanID := lnwire.NewShortChanIDFromInt(uint64(rand.Int63()))
  270. // Increment the uniqueOutputIndex so we always get a unique value for
  271. // the funding outpoint.
  272. uniqueOutputIndex.Add(1)
  273. op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()}
  274. return &OpenChannel{
  275. ChanType: SingleFunderBit | FrozenBit,
  276. ChainHash: key,
  277. FundingOutpoint: op,
  278. ShortChannelID: chanID,
  279. IsInitiator: true,
  280. IsPending: true,
  281. IdentityPub: pubKey,
  282. Capacity: btcutil.Amount(10000),
  283. LocalChanCfg: localCfg,
  284. RemoteChanCfg: remoteCfg,
  285. TotalMSatSent: 8,
  286. TotalMSatReceived: 2,
  287. LocalCommitment: ChannelCommitment{
  288. CommitHeight: 0,
  289. LocalBalance: lnwire.MilliSatoshi(9000),
  290. RemoteBalance: lnwire.MilliSatoshi(3000),
  291. CommitFee: btcutil.Amount(rand.Int63()),
  292. FeePerKw: btcutil.Amount(5000),
  293. CommitTx: channels.TestFundingTx,
  294. CommitSig: bytes.Repeat([]byte{1}, 71),
  295. },
  296. RemoteCommitment: ChannelCommitment{
  297. CommitHeight: 0,
  298. LocalBalance: lnwire.MilliSatoshi(3000),
  299. RemoteBalance: lnwire.MilliSatoshi(9000),
  300. CommitFee: btcutil.Amount(rand.Int63()),
  301. FeePerKw: btcutil.Amount(5000),
  302. CommitTx: channels.TestFundingTx,
  303. CommitSig: bytes.Repeat([]byte{1}, 71),
  304. },
  305. NumConfsRequired: 4,
  306. RemoteCurrentRevocation: privKey.PubKey(),
  307. RemoteNextRevocation: privKey.PubKey(),
  308. RevocationProducer: producer,
  309. RevocationStore: store,
  310. Db: cdb,
  311. Packager: NewChannelPackager(chanID),
  312. FundingTxn: channels.TestFundingTx,
  313. ThawHeight: uint32(defaultPendingHeight),
  314. InitialLocalBalance: lnwire.MilliSatoshi(9000),
  315. InitialRemoteBalance: lnwire.MilliSatoshi(3000),
  316. }
  317. }
  318. func TestOpenChannelPutGetDelete(t *testing.T) {
  319. t.Parallel()
  320. fullDB, err := MakeTestDB(t)
  321. require.NoError(t, err, "unable to make test database")
  322. cdb := fullDB.ChannelStateDB()
  323. // Create the test channel state, with additional htlcs on the local
  324. // and remote commitment.
  325. localHtlcs := []HTLC{
  326. {
  327. Signature: testSig.Serialize(),
  328. Incoming: true,
  329. Amt: 10,
  330. RHash: key,
  331. RefundTimeout: 1,
  332. OnionBlob: lnmock.MockOnion(),
  333. },
  334. }
  335. remoteHtlcs := []HTLC{
  336. {
  337. Signature: testSig.Serialize(),
  338. Incoming: false,
  339. Amt: 10,
  340. RHash: key,
  341. RefundTimeout: 1,
  342. OnionBlob: lnmock.MockOnion(),
  343. },
  344. }
  345. state := createTestChannel(
  346. t, cdb,
  347. remoteHtlcsOption(remoteHtlcs),
  348. localHtlcsOption(localHtlcs),
  349. )
  350. openChannels, err := cdb.FetchOpenChannels(state.IdentityPub)
  351. require.NoError(t, err, "unable to fetch open channel")
  352. newState := openChannels[0]
  353. // The decoded channel state should be identical to what we stored
  354. // above.
  355. if !reflect.DeepEqual(state, newState) {
  356. t.Fatalf("channel state doesn't match:: %v vs %v",
  357. spew.Sdump(state), spew.Sdump(newState))
  358. }
  359. // We'll also test that the channel is properly able to hot swap the
  360. // next revocation for the state machine. This tests the initial
  361. // post-funding revocation exchange.
  362. nextRevKey, err := btcec.NewPrivateKey()
  363. require.NoError(t, err, "unable to create new private key")
  364. if err := state.InsertNextRevocation(nextRevKey.PubKey()); err != nil {
  365. t.Fatalf("unable to update revocation: %v", err)
  366. }
  367. openChannels, err = cdb.FetchOpenChannels(state.IdentityPub)
  368. require.NoError(t, err, "unable to fetch open channel")
  369. updatedChan := openChannels[0]
  370. // Ensure that the revocation was set properly.
  371. if !nextRevKey.PubKey().IsEqual(updatedChan.RemoteNextRevocation) {
  372. t.Fatalf("next revocation wasn't updated")
  373. }
  374. // Finally to wrap up the test, delete the state of the channel within
  375. // the database. This involves "closing" the channel which removes all
  376. // written state, and creates a small "summary" elsewhere within the
  377. // database.
  378. closeSummary := &ChannelCloseSummary{
  379. ChanPoint: state.FundingOutpoint,
  380. RemotePub: state.IdentityPub,
  381. SettledBalance: btcutil.Amount(500),
  382. TimeLockedBalance: btcutil.Amount(10000),
  383. IsPending: false,
  384. CloseType: CooperativeClose,
  385. }
  386. if err := state.CloseChannel(closeSummary); err != nil {
  387. t.Fatalf("unable to close channel: %v", err)
  388. }
  389. // As the channel is now closed, attempting to fetch all open channels
  390. // for our fake node ID should return an empty slice.
  391. openChans, err := cdb.FetchOpenChannels(state.IdentityPub)
  392. require.NoError(t, err, "unable to fetch open channels")
  393. if len(openChans) != 0 {
  394. t.Fatalf("all channels not deleted, found %v", len(openChans))
  395. }
  396. // Additionally, attempting to fetch all the open channels globally
  397. // should yield no results.
  398. openChans, err = cdb.FetchAllChannels()
  399. if err != nil {
  400. t.Fatal("unable to fetch all open chans")
  401. }
  402. if len(openChans) != 0 {
  403. t.Fatalf("all channels not deleted, found %v", len(openChans))
  404. }
  405. }
  406. // TestOptionalShutdown tests the reading and writing of channels with and
  407. // without optional shutdown script fields.
  408. func TestOptionalShutdown(t *testing.T) {
  409. local := lnwire.DeliveryAddress([]byte("local shutdown script"))
  410. remote := lnwire.DeliveryAddress([]byte("remote shutdown script"))
  411. if _, err := rand.Read(remote); err != nil {
  412. t.Fatalf("Could not create random script: %v", err)
  413. }
  414. tests := []struct {
  415. name string
  416. localShutdown lnwire.DeliveryAddress
  417. remoteShutdown lnwire.DeliveryAddress
  418. }{
  419. {
  420. name: "no shutdown scripts",
  421. localShutdown: nil,
  422. remoteShutdown: nil,
  423. },
  424. {
  425. name: "local shutdown script",
  426. localShutdown: local,
  427. remoteShutdown: nil,
  428. },
  429. {
  430. name: "remote shutdown script",
  431. localShutdown: nil,
  432. remoteShutdown: remote,
  433. },
  434. {
  435. name: "both scripts set",
  436. localShutdown: local,
  437. remoteShutdown: remote,
  438. },
  439. }
  440. for _, test := range tests {
  441. test := test
  442. t.Run(test.name, func(t *testing.T) {
  443. fullDB, err := MakeTestDB(t)
  444. if err != nil {
  445. t.Fatalf("unable to make test database: %v", err)
  446. }
  447. cdb := fullDB.ChannelStateDB()
  448. // Create a channel with upfront scripts set as
  449. // specified in the test.
  450. state := createTestChannel(
  451. t, cdb,
  452. localShutdownOption(test.localShutdown),
  453. remoteShutdownOption(test.remoteShutdown),
  454. )
  455. openChannels, err := cdb.FetchOpenChannels(
  456. state.IdentityPub,
  457. )
  458. if err != nil {
  459. t.Fatalf("unable to fetch open"+
  460. " channel: %v", err)
  461. }
  462. if len(openChannels) != 1 {
  463. t.Fatalf("Expected one channel open,"+
  464. " got: %v", len(openChannels))
  465. }
  466. if !bytes.Equal(openChannels[0].LocalShutdownScript,
  467. test.localShutdown) {
  468. t.Fatalf("Expected local: %x, got: %x",
  469. test.localShutdown,
  470. openChannels[0].LocalShutdownScript)
  471. }
  472. if !bytes.Equal(openChannels[0].RemoteShutdownScript,
  473. test.remoteShutdown) {
  474. t.Fatalf("Expected remote: %x, got: %x",
  475. test.remoteShutdown,
  476. openChannels[0].RemoteShutdownScript)
  477. }
  478. })
  479. }
  480. }
  481. func assertCommitmentEqual(t *testing.T, a, b *ChannelCommitment) {
  482. if !reflect.DeepEqual(a, b) {
  483. _, _, line, _ := runtime.Caller(1)
  484. t.Fatalf("line %v: commitments don't match: %v vs %v",
  485. line, spew.Sdump(a), spew.Sdump(b))
  486. }
  487. }
  488. // assertRevocationLogEntryEqual asserts that, for all the fields of a given
  489. // revocation log entry, their values match those on a given ChannelCommitment.
  490. func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment,
  491. r *RevocationLog) {
  492. // Check the common fields.
  493. require.EqualValues(
  494. t, r.CommitTxHash, c.CommitTx.TxHash(), "CommitTx mismatch",
  495. )
  496. // Now check the common fields from the HTLCs.
  497. require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch")
  498. for i, rHtlc := range r.HTLCEntries {
  499. cHtlc := c.Htlcs[i]
  500. require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch")
  501. require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(),
  502. "Amt mismatch")
  503. require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout,
  504. "RefundTimeout mismatch")
  505. require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex,
  506. "OutputIndex mismatch")
  507. require.Equal(t, rHtlc.Incoming, cHtlc.Incoming,
  508. "Incoming mismatch")
  509. }
  510. }
  511. func TestChannelStateTransition(t *testing.T) {
  512. t.Parallel()
  513. fullDB, err := MakeTestDB(t)
  514. require.NoError(t, err, "unable to make test database")
  515. cdb := fullDB.ChannelStateDB()
  516. // First create a minimal channel, then perform a full sync in order to
  517. // persist the data.
  518. channel := createTestChannel(t, cdb)
  519. // Add some HTLCs which were added during this new state transition.
  520. // Half of the HTLCs are incoming, while the other half are outgoing.
  521. var (
  522. htlcs []HTLC
  523. htlcAmt lnwire.MilliSatoshi
  524. )
  525. for i := uint32(0); i < 10; i++ {
  526. var incoming bool
  527. if i > 5 {
  528. incoming = true
  529. }
  530. htlc := HTLC{
  531. Signature: testSig.Serialize(),
  532. Incoming: incoming,
  533. Amt: 10,
  534. RHash: key,
  535. RefundTimeout: i,
  536. OutputIndex: int32(i * 3),
  537. LogIndex: uint64(i * 2),
  538. HtlcIndex: uint64(i),
  539. }
  540. copy(
  541. htlc.OnionBlob[:],
  542. bytes.Repeat([]byte{2}, lnwire.OnionPacketSize),
  543. )
  544. htlcs = append(htlcs, htlc)
  545. htlcAmt += htlc.Amt
  546. }
  547. // Create a new channel delta which includes the above HTLCs, some
  548. // balance updates, and an increment of the current commitment height.
  549. // Additionally, modify the signature and commitment transaction.
  550. newSequence := uint32(129498)
  551. newSig := bytes.Repeat([]byte{3}, 71)
  552. newTx := channel.LocalCommitment.CommitTx.Copy()
  553. newTx.TxIn[0].Sequence = newSequence
  554. commitment := ChannelCommitment{
  555. CommitHeight: 1,
  556. LocalLogIndex: 2,
  557. LocalHtlcIndex: 1,
  558. RemoteLogIndex: 2,
  559. RemoteHtlcIndex: 1,
  560. LocalBalance: lnwire.MilliSatoshi(1e8),
  561. RemoteBalance: lnwire.MilliSatoshi(1e8),
  562. CommitFee: 55,
  563. FeePerKw: 99,
  564. CommitTx: newTx,
  565. CommitSig: newSig,
  566. Htlcs: htlcs,
  567. }
  568. // First update the local node's broadcastable state and also add a
  569. // CommitDiff remote node's as well in order to simulate a proper state
  570. // transition.
  571. unsignedAckedUpdates := []LogUpdate{
  572. {
  573. LogIndex: 2,
  574. UpdateMsg: &lnwire.UpdateAddHTLC{
  575. ChanID: lnwire.ChannelID{1, 2, 3},
  576. },
  577. },
  578. }
  579. _, err = channel.UpdateCommitment(&commitment, unsignedAckedUpdates)
  580. require.NoError(t, err, "unable to update commitment")
  581. // Assert that update is correctly written to the database.
  582. dbUnsignedAckedUpdates, err := channel.UnsignedAckedUpdates()
  583. require.NoError(t, err, "unable to fetch dangling remote updates")
  584. if len(dbUnsignedAckedUpdates) != 1 {
  585. t.Fatalf("unexpected number of dangling remote updates")
  586. }
  587. if !reflect.DeepEqual(
  588. dbUnsignedAckedUpdates[0], unsignedAckedUpdates[0],
  589. ) {
  590. t.Fatalf("unexpected update: expected %v, got %v",
  591. spew.Sdump(unsignedAckedUpdates[0]),
  592. spew.Sdump(dbUnsignedAckedUpdates))
  593. }
  594. // The balances, new update, the HTLCs and the changes to the fake
  595. // commitment transaction along with the modified signature should all
  596. // have been updated.
  597. updatedChannel, err := cdb.FetchOpenChannels(channel.IdentityPub)
  598. require.NoError(t, err, "unable to fetch updated channel")
  599. assertCommitmentEqual(t, &commitment, &updatedChannel[0].LocalCommitment)
  600. numDiskUpdates, err := updatedChannel[0].CommitmentHeight()
  601. require.NoError(t, err, "unable to read commitment height from disk")
  602. if numDiskUpdates != uint64(commitment.CommitHeight) {
  603. t.Fatalf("num disk updates doesn't match: %v vs %v",
  604. numDiskUpdates, commitment.CommitHeight)
  605. }
  606. // Attempting to query for a commitment diff should return
  607. // ErrNoPendingCommit as we haven't yet created a new state for them.
  608. _, err = channel.RemoteCommitChainTip()
  609. if err != ErrNoPendingCommit {
  610. t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
  611. }
  612. // To simulate us extending a new state to the remote party, we'll also
  613. // create a new commit diff for them.
  614. remoteCommit := commitment
  615. remoteCommit.LocalBalance = lnwire.MilliSatoshi(2e8)
  616. remoteCommit.RemoteBalance = lnwire.MilliSatoshi(3e8)
  617. remoteCommit.CommitHeight = 1
  618. commitDiff := &CommitDiff{
  619. Commitment: remoteCommit,
  620. CommitSig: &lnwire.CommitSig{
  621. ChanID: lnwire.ChannelID(key),
  622. CommitSig: wireSig,
  623. HtlcSigs: []lnwire.Sig{
  624. wireSig,
  625. wireSig,
  626. },
  627. },
  628. LogUpdates: []LogUpdate{
  629. {
  630. LogIndex: 1,
  631. UpdateMsg: &lnwire.UpdateAddHTLC{
  632. ID: 1,
  633. Amount: lnwire.NewMSatFromSatoshis(100),
  634. Expiry: 25,
  635. },
  636. },
  637. {
  638. LogIndex: 2,
  639. UpdateMsg: &lnwire.UpdateAddHTLC{
  640. ID: 2,
  641. Amount: lnwire.NewMSatFromSatoshis(200),
  642. Expiry: 50,
  643. },
  644. },
  645. },
  646. OpenedCircuitKeys: []models.CircuitKey{},
  647. ClosedCircuitKeys: []models.CircuitKey{},
  648. }
  649. copy(commitDiff.LogUpdates[0].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
  650. bytes.Repeat([]byte{1}, 32))
  651. copy(commitDiff.LogUpdates[1].UpdateMsg.(*lnwire.UpdateAddHTLC).PaymentHash[:],
  652. bytes.Repeat([]byte{2}, 32))
  653. if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
  654. t.Fatalf("unable to add to commit chain: %v", err)
  655. }
  656. // The commitment tip should now match the commitment that we just
  657. // inserted.
  658. diskCommitDiff, err := channel.RemoteCommitChainTip()
  659. require.NoError(t, err, "unable to fetch commit diff")
  660. if !reflect.DeepEqual(commitDiff, diskCommitDiff) {
  661. t.Fatalf("commit diffs don't match: %v vs %v", spew.Sdump(remoteCommit),
  662. spew.Sdump(diskCommitDiff))
  663. }
  664. // We'll save the old remote commitment as this will be added to the
  665. // revocation log shortly.
  666. oldRemoteCommit := channel.RemoteCommitment
  667. // Next, write to the log which tracks the necessary revocation state
  668. // needed to rectify any fishy behavior by the remote party. Modify the
  669. // current uncollapsed revocation state to simulate a state transition
  670. // by the remote party.
  671. channel.RemoteCurrentRevocation = channel.RemoteNextRevocation
  672. newPriv, err := btcec.NewPrivateKey()
  673. require.NoError(t, err, "unable to generate key")
  674. channel.RemoteNextRevocation = newPriv.PubKey()
  675. fwdPkg := NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight,
  676. diskCommitDiff.LogUpdates, nil)
  677. err = channel.AdvanceCommitChainTail(
  678. fwdPkg, nil, dummyLocalOutputIndex, dummyRemoteOutIndex,
  679. )
  680. require.NoError(t, err, "unable to append to revocation log")
  681. // At this point, the remote commit chain should be nil, and the posted
  682. // remote commitment should match the one we added as a diff above.
  683. if _, err := channel.RemoteCommitChainTip(); err != ErrNoPendingCommit {
  684. t.Fatalf("expected ErrNoPendingCommit, instead got %v", err)
  685. }
  686. // We should be able to fetch the channel delta created above by its
  687. // update number with all the state properly reconstructed.
  688. diskPrevCommit, _, err := channel.FindPreviousState(
  689. oldRemoteCommit.CommitHeight,
  690. )
  691. require.NoError(t, err, "unable to fetch past delta")
  692. // Check the output indexes are saved as expected.
  693. require.EqualValues(
  694. t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex,
  695. )
  696. require.EqualValues(
  697. t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex,
  698. )
  699. // The two deltas (the original vs the on-disk version) should
  700. // identical, and all HTLC data should properly be retained.
  701. assertRevocationLogEntryEqual(t, &oldRemoteCommit, diskPrevCommit)
  702. // The state number recovered from the tail of the revocation log
  703. // should be identical to this current state.
  704. logTailHeight, err := channel.revocationLogTailCommitHeight()
  705. require.NoError(t, err, "unable to retrieve log")
  706. if logTailHeight != oldRemoteCommit.CommitHeight {
  707. t.Fatal("update number doesn't match")
  708. }
  709. oldRemoteCommit = channel.RemoteCommitment
  710. // Next modify the posted diff commitment slightly, then create a new
  711. // commitment diff and advance the tail.
  712. commitDiff.Commitment.CommitHeight = 2
  713. commitDiff.Commitment.LocalBalance -= htlcAmt
  714. commitDiff.Commitment.RemoteBalance += htlcAmt
  715. commitDiff.LogUpdates = []LogUpdate{}
  716. if err := channel.AppendRemoteCommitChain(commitDiff); err != nil {
  717. t.Fatalf("unable to add to commit chain: %v", err)
  718. }
  719. fwdPkg = NewFwdPkg(channel.ShortChanID(), oldRemoteCommit.CommitHeight, nil, nil)
  720. err = channel.AdvanceCommitChainTail(
  721. fwdPkg, nil, dummyLocalOutputIndex, dummyRemoteOutIndex,
  722. )
  723. require.NoError(t, err, "unable to append to revocation log")
  724. // Once again, fetch the state and ensure it has been properly updated.
  725. prevCommit, _, err := channel.FindPreviousState(
  726. oldRemoteCommit.CommitHeight,
  727. )
  728. require.NoError(t, err, "unable to fetch past delta")
  729. // Check the output indexes are saved as expected.
  730. require.EqualValues(
  731. t, dummyLocalOutputIndex, diskPrevCommit.OurOutputIndex,
  732. )
  733. require.EqualValues(
  734. t, dummyRemoteOutIndex, diskPrevCommit.TheirOutputIndex,
  735. )
  736. assertRevocationLogEntryEqual(t, &oldRemoteCommit, prevCommit)
  737. // Once again, state number recovered from the tail of the revocation
  738. // log should be identical to this current state.
  739. logTailHeight, err = channel.revocationLogTailCommitHeight()
  740. require.NoError(t, err, "unable to retrieve log")
  741. if logTailHeight != oldRemoteCommit.CommitHeight {
  742. t.Fatal("update number doesn't match")
  743. }
  744. // The revocation state stored on-disk should now also be identical.
  745. updatedChannel, err = cdb.FetchOpenChannels(channel.IdentityPub)
  746. require.NoError(t, err, "unable to fetch updated channel")
  747. if !channel.RemoteCurrentRevocation.IsEqual(updatedChannel[0].RemoteCurrentRevocation) {
  748. t.Fatalf("revocation state was not synced")
  749. }
  750. if !channel.RemoteNextRevocation.IsEqual(updatedChannel[0].RemoteNextRevocation) {
  751. t.Fatalf("revocation state was not synced")
  752. }
  753. // At this point, we should have 2 forwarding packages added.
  754. fwdPkgs := loadFwdPkgs(t, cdb.backend, channel.Packager)
  755. require.Len(t, fwdPkgs, 2, "wrong number of forwarding packages")
  756. // Now attempt to delete the channel from the database.
  757. closeSummary := &ChannelCloseSummary{
  758. ChanPoint: channel.FundingOutpoint,
  759. RemotePub: channel.IdentityPub,
  760. SettledBalance: btcutil.Amount(500),
  761. TimeLockedBalance: btcutil.Amount(10000),
  762. IsPending: false,
  763. CloseType: RemoteForceClose,
  764. }
  765. if err := updatedChannel[0].CloseChannel(closeSummary); err != nil {
  766. t.Fatalf("unable to delete updated channel: %v", err)
  767. }
  768. // If we attempt to fetch the target channel again, it shouldn't be
  769. // found.
  770. channels, err := cdb.FetchOpenChannels(channel.IdentityPub)
  771. require.NoError(t, err, "unable to fetch updated channels")
  772. if len(channels) != 0 {
  773. t.Fatalf("%v channels, found, but none should be",
  774. len(channels))
  775. }
  776. // Attempting to find previous states on the channel should fail as the
  777. // revocation log has been deleted.
  778. _, _, err = updatedChannel[0].FindPreviousState(
  779. oldRemoteCommit.CommitHeight,
  780. )
  781. if err == nil {
  782. t.Fatal("revocation log search should have failed")
  783. }
  784. // All forwarding packages of this channel has been deleted too.
  785. fwdPkgs = loadFwdPkgs(t, cdb.backend, channel.Packager)
  786. require.Empty(t, fwdPkgs, "no forwarding packages should exist")
  787. }
  788. func TestFetchPendingChannels(t *testing.T) {
  789. t.Parallel()
  790. fullDB, err := MakeTestDB(t)
  791. require.NoError(t, err, "unable to make test database")
  792. cdb := fullDB.ChannelStateDB()
  793. // Create a pending channel that was broadcast at height 99.
  794. const broadcastHeight = 99
  795. createTestChannel(t, cdb, pendingHeightOption(broadcastHeight))
  796. pendingChannels, err := cdb.FetchPendingChannels()
  797. require.NoError(t, err, "unable to list pending channels")
  798. if len(pendingChannels) != 1 {
  799. t.Fatalf("incorrect number of pending channels: expecting %v,"+
  800. "got %v", 1, len(pendingChannels))
  801. }
  802. // The broadcast height of the pending channel should have been set
  803. // properly.
  804. if pendingChannels[0].FundingBroadcastHeight != broadcastHeight {
  805. t.Fatalf("broadcast height mismatch: expected %v, got %v",
  806. pendingChannels[0].FundingBroadcastHeight,
  807. broadcastHeight)
  808. }
  809. chanOpenLoc := lnwire.ShortChannelID{
  810. BlockHeight: 5,
  811. TxIndex: 10,
  812. TxPosition: 15,
  813. }
  814. err = pendingChannels[0].MarkAsOpen(chanOpenLoc)
  815. require.NoError(t, err, "unable to mark channel as open")
  816. if pendingChannels[0].IsPending {
  817. t.Fatalf("channel marked open should no longer be pending")
  818. }
  819. if pendingChannels[0].ShortChanID() != chanOpenLoc {
  820. t.Fatalf("channel opening height not updated: expected %v, "+
  821. "got %v", spew.Sdump(pendingChannels[0].ShortChanID()),
  822. chanOpenLoc)
  823. }
  824. // Next, we'll re-fetch the channel to ensure that the open height was
  825. // properly set.
  826. openChans, err := cdb.FetchAllChannels()
  827. require.NoError(t, err, "unable to fetch channels")
  828. if openChans[0].ShortChanID() != chanOpenLoc {
  829. t.Fatalf("channel opening heights don't match: expected %v, "+
  830. "got %v", spew.Sdump(openChans[0].ShortChanID()),
  831. chanOpenLoc)
  832. }
  833. if openChans[0].FundingBroadcastHeight != broadcastHeight {
  834. t.Fatalf("broadcast height mismatch: expected %v, got %v",
  835. openChans[0].FundingBroadcastHeight,
  836. broadcastHeight)
  837. }
  838. pendingChannels, err = cdb.FetchPendingChannels()
  839. require.NoError(t, err, "unable to list pending channels")
  840. if len(pendingChannels) != 0 {
  841. t.Fatalf("incorrect number of pending channels: expecting %v,"+
  842. "got %v", 0, len(pendingChannels))
  843. }
  844. }
  845. func TestFetchClosedChannels(t *testing.T) {
  846. t.Parallel()
  847. fullDB, err := MakeTestDB(t)
  848. require.NoError(t, err, "unable to make test database")
  849. cdb := fullDB.ChannelStateDB()
  850. // Create an open channel in the database.
  851. state := createTestChannel(t, cdb, openChannelOption())
  852. // Next, close the channel by including a close channel summary in the
  853. // database.
  854. summary := &ChannelCloseSummary{
  855. ChanPoint: state.FundingOutpoint,
  856. ClosingTXID: rev,
  857. RemotePub: state.IdentityPub,
  858. Capacity: state.Capacity,
  859. SettledBalance: state.LocalCommitment.LocalBalance.ToSatoshis(),
  860. TimeLockedBalance: state.RemoteCommitment.LocalBalance.ToSatoshis() + 10000,
  861. CloseType: RemoteForceClose,
  862. IsPending: true,
  863. LocalChanConfig: state.LocalChanCfg,
  864. }
  865. if err := state.CloseChannel(summary); err != nil {
  866. t.Fatalf("unable to close channel: %v", err)
  867. }
  868. // Query the database to ensure that the channel has now been properly
  869. // closed. We should get the same result whether querying for pending
  870. // channels only, or not.
  871. pendingClosed, err := cdb.FetchClosedChannels(true)
  872. require.NoError(t, err, "failed fetching closed channels")
  873. if len(pendingClosed) != 1 {
  874. t.Fatalf("incorrect number of pending closed channels: expecting %v,"+
  875. "got %v", 1, len(pendingClosed))
  876. }
  877. if !reflect.DeepEqual(summary, pendingClosed[0]) {
  878. t.Fatalf("database summaries don't match: expected %v got %v",
  879. spew.Sdump(summary), spew.Sdump(pendingClosed[0]))
  880. }
  881. closed, err := cdb.FetchClosedChannels(false)
  882. require.NoError(t, err, "failed fetching all closed channels")
  883. if len(closed) != 1 {
  884. t.Fatalf("incorrect number of closed channels: expecting %v, "+
  885. "got %v", 1, len(closed))
  886. }
  887. if !reflect.DeepEqual(summary, closed[0]) {
  888. t.Fatalf("database summaries don't match: expected %v got %v",
  889. spew.Sdump(summary), spew.Sdump(closed[0]))
  890. }
  891. // Mark the channel as fully closed.
  892. err = cdb.MarkChanFullyClosed(&state.FundingOutpoint)
  893. require.NoError(t, err, "failed fully closing channel")
  894. // The channel should no longer be considered pending, but should still
  895. // be retrieved when fetching all the closed channels.
  896. closed, err = cdb.FetchClosedChannels(false)
  897. require.NoError(t, err, "failed fetching closed channels")
  898. if len(closed) != 1 {
  899. t.Fatalf("incorrect number of closed channels: expecting %v, "+
  900. "got %v", 1, len(closed))
  901. }
  902. pendingClose, err := cdb.FetchClosedChannels(true)
  903. require.NoError(t, err, "failed fetching channels pending close")
  904. if len(pendingClose) != 0 {
  905. t.Fatalf("incorrect number of closed channels: expecting %v, "+
  906. "got %v", 0, len(closed))
  907. }
  908. }
  909. // TestFetchWaitingCloseChannels ensures that the correct channels that are
  910. // waiting to be closed are returned.
  911. func TestFetchWaitingCloseChannels(t *testing.T) {
  912. t.Parallel()
  913. const numChannels = 2
  914. const broadcastHeight = 99
  915. // We'll start by creating two channels within our test database. One of
  916. // them will have their funding transaction confirmed on-chain, while
  917. // the other one will remain unconfirmed.
  918. fullDB, err := MakeTestDB(t)
  919. require.NoError(t, err, "unable to make test database")
  920. cdb := fullDB.ChannelStateDB()
  921. channels := make([]*OpenChannel, numChannels)
  922. for i := 0; i < numChannels; i++ {
  923. // Create a pending channel in the database at the broadcast
  924. // height.
  925. channels[i] = createTestChannel(
  926. t, cdb, pendingHeightOption(broadcastHeight),
  927. )
  928. }
  929. // We'll only confirm the first one.
  930. channelConf := lnwire.ShortChannelID{
  931. BlockHeight: broadcastHeight + 1,
  932. TxIndex: 10,
  933. TxPosition: 15,
  934. }
  935. if err := channels[0].MarkAsOpen(channelConf); err != nil {
  936. t.Fatalf("unable to mark channel as open: %v", err)
  937. }
  938. // Then, we'll mark the channels as if their commitments were broadcast.
  939. // This would happen in the event of a force close and should make the
  940. // channels enter a state of waiting close.
  941. for _, channel := range channels {
  942. closeTx := wire.NewMsgTx(2)
  943. closeTx.AddTxIn(
  944. &wire.TxIn{
  945. PreviousOutPoint: channel.FundingOutpoint,
  946. },
  947. )
  948. if err := channel.MarkCommitmentBroadcasted(closeTx, true); err != nil {
  949. t.Fatalf("unable to mark commitment broadcast: %v", err)
  950. }
  951. // Now try to marking a coop close with a nil tx. This should
  952. // succeed, but it shouldn't exit when queried.
  953. if err = channel.MarkCoopBroadcasted(nil, true); err != nil {
  954. t.Fatalf("unable to mark nil coop broadcast: %v", err)
  955. }
  956. _, err := channel.BroadcastedCooperative()
  957. if err != ErrNoCloseTx {
  958. t.Fatalf("expected no closing tx error, got: %v", err)
  959. }
  960. // Finally, modify the close tx deterministically and also mark
  961. // it as coop closed. Later we will test that distinct
  962. // transactions are returned for both coop and force closes.
  963. closeTx.TxIn[0].PreviousOutPoint.Index ^= 1
  964. if err := channel.MarkCoopBroadcasted(closeTx, true); err != nil {
  965. t.Fatalf("unable to mark coop broadcast: %v", err)
  966. }
  967. }
  968. // Now, we'll fetch all the channels waiting to be closed from the
  969. // database. We should expect to see both channels above, even if any of
  970. // them haven't had their funding transaction confirm on-chain.
  971. waitingCloseChannels, err := cdb.FetchWaitingCloseChannels()
  972. require.NoError(t, err, "unable to fetch all waiting close channels")
  973. if len(waitingCloseChannels) != numChannels {
  974. t.Fatalf("expected %d channels waiting to be closed, got %d", 2,
  975. len(waitingCloseChannels))
  976. }
  977. expectedChannels := make(map[wire.OutPoint]struct{})
  978. for _, channel := range channels {
  979. expectedChannels[channel.FundingOutpoint] = struct{}{}
  980. }
  981. for _, channel := range waitingCloseChannels {
  982. if _, ok := expectedChannels[channel.FundingOutpoint]; !ok {
  983. t.Fatalf("expected channel %v to be waiting close",
  984. channel.FundingOutpoint)
  985. }
  986. chanPoint := channel.FundingOutpoint
  987. // Assert that the force close transaction is retrievable.
  988. forceCloseTx, err := channel.BroadcastedCommitment()
  989. if err != nil {
  990. t.Fatalf("Unable to retrieve commitment: %v", err)
  991. }
  992. if forceCloseTx.TxIn[0].PreviousOutPoint != chanPoint {
  993. t.Fatalf("expected outpoint %v, got %v",
  994. chanPoint,
  995. forceCloseTx.TxIn[0].PreviousOutPoint)
  996. }
  997. // Assert that the coop close transaction is retrievable.
  998. coopCloseTx, err := channel.BroadcastedCooperative()
  999. if err != nil {
  1000. t.Fatalf("unable to retrieve coop close: %v", err)
  1001. }
  1002. chanPoint.Index ^= 1
  1003. if coopCloseTx.TxIn[0].PreviousOutPoint != chanPoint {
  1004. t.Fatalf("expected outpoint %v, got %v",
  1005. chanPoint,
  1006. coopCloseTx.TxIn[0].PreviousOutPoint)
  1007. }
  1008. }
  1009. }
  1010. // TestShutdownInfo tests that a channel's shutdown info can correctly be
  1011. // persisted and retrieved.
  1012. func TestShutdownInfo(t *testing.T) {
  1013. t.Parallel()
  1014. tests := []struct {
  1015. name string
  1016. localInit bool
  1017. }{
  1018. {
  1019. name: "local node initiated",
  1020. localInit: true,
  1021. },
  1022. {
  1023. name: "remote node initiated",
  1024. localInit: false,
  1025. },
  1026. }
  1027. for _, test := range tests {
  1028. test := test
  1029. t.Run(test.name, func(t *testing.T) {
  1030. t.Parallel()
  1031. testShutdownInfo(t, test.localInit)
  1032. })
  1033. }
  1034. }
  1035. func testShutdownInfo(t *testing.T, locallyInitiated bool) {
  1036. fullDB, err := MakeTestDB(t)
  1037. require.NoError(t, err, "unable to make test database")
  1038. cdb := fullDB.ChannelStateDB()
  1039. // First a test channel.
  1040. channel := createTestChannel(t, cdb)
  1041. // We haven't persisted any shutdown info for this channel yet.
  1042. _, err = channel.ShutdownInfo()
  1043. require.Error(t, err, ErrNoShutdownInfo)
  1044. // Construct a new delivery script and create a new ShutdownInfo object.
  1045. script := []byte{1, 3, 4, 5}
  1046. // Create a ShutdownInfo struct.
  1047. shutdownInfo := NewShutdownInfo(script, locallyInitiated)
  1048. // Persist the shutdown info.
  1049. require.NoError(t, channel.MarkShutdownSent(shutdownInfo))
  1050. // We should now be able to retrieve the shutdown info.
  1051. info, err := channel.ShutdownInfo()
  1052. require.NoError(t, err)
  1053. require.True(t, info.IsSome())
  1054. // Assert that the decoded values of the shutdown info are correct.
  1055. info.WhenSome(func(info ShutdownInfo) {
  1056. require.EqualValues(t, script, info.DeliveryScript.Val)
  1057. require.Equal(t, locallyInitiated, info.LocalInitiator.Val)
  1058. })
  1059. }
  1060. // TestRefresh asserts that Refresh updates the in-memory state of another
  1061. // OpenChannel to reflect a preceding call to MarkOpen on a different
  1062. // OpenChannel.
  1063. func TestRefresh(t *testing.T) {
  1064. t.Parallel()
  1065. fullDB, err := MakeTestDB(t)
  1066. require.NoError(t, err, "unable to make test database")
  1067. cdb := fullDB.ChannelStateDB()
  1068. // First create a test channel.
  1069. state := createTestChannel(t, cdb)
  1070. // Next, locate the pending channel with the database.
  1071. pendingChannels, err := cdb.FetchPendingChannels()
  1072. if err != nil {
  1073. t.Fatalf("unable to load pending channels; %v", err)
  1074. }
  1075. var pendingChannel *OpenChannel
  1076. for _, channel := range pendingChannels {
  1077. if channel.FundingOutpoint == state.FundingOutpoint {
  1078. pendingChannel = channel
  1079. break
  1080. }
  1081. }
  1082. if pendingChannel == nil {
  1083. t.Fatalf("unable to find pending channel with funding "+
  1084. "outpoint=%v: %v", state.FundingOutpoint, err)
  1085. }
  1086. // Next, simulate the confirmation of the channel by marking it as
  1087. // pending within the database.
  1088. chanOpenLoc := lnwire.ShortChannelID{
  1089. BlockHeight: 105,
  1090. TxIndex: 10,
  1091. TxPosition: 15,
  1092. }
  1093. err = state.MarkAsOpen(chanOpenLoc)
  1094. require.NoError(t, err, "unable to mark channel open")
  1095. // The short_chan_id of the receiver to MarkAsOpen should reflect the
  1096. // open location, but the other pending channel should remain unchanged.
  1097. if state.ShortChanID() == pendingChannel.ShortChanID() {
  1098. t.Fatalf("pending channel short_chan_ID should not have been " +
  1099. "updated before refreshing short_chan_id")
  1100. }
  1101. // Now that the receiver's short channel id has been updated, check to
  1102. // ensure that the channel packager's source has been updated as well.
  1103. // This ensures that the packager will read and write to buckets
  1104. // corresponding to the new short chan id, instead of the prior.
  1105. if state.Packager.(*ChannelPackager).source != chanOpenLoc {
  1106. t.Fatalf("channel packager source was not updated: want %v, "+
  1107. "got %v", chanOpenLoc,
  1108. state.Packager.(*ChannelPackager).source)
  1109. }
  1110. // Now, refresh the state of the pending channel.
  1111. err = pendingChannel.Refresh()
  1112. require.NoError(t, err, "unable to refresh short_chan_id")
  1113. // This should result in both OpenChannel's now having the same
  1114. // ShortChanID.
  1115. if state.ShortChanID() != pendingChannel.ShortChanID() {
  1116. t.Fatalf("expected pending channel short_chan_id to be "+
  1117. "refreshed: want %v, got %v", state.ShortChanID(),
  1118. pendingChannel.ShortChanID())
  1119. }
  1120. // Check to ensure that the _other_ OpenChannel channel packager's
  1121. // source has also been updated after the refresh. This ensures that the
  1122. // other packagers will read and write to buckets corresponding to the
  1123. // updated short chan id.
  1124. if pendingChannel.Packager.(*ChannelPackager).source != chanOpenLoc {
  1125. t.Fatalf("channel packager source was not updated: want %v, "+
  1126. "got %v", chanOpenLoc,
  1127. pendingChannel.Packager.(*ChannelPackager).source)
  1128. }
  1129. // Check to ensure that this channel is no longer pending and this field
  1130. // is up to date.
  1131. if pendingChannel.IsPending {
  1132. t.Fatalf("channel pending state wasn't updated: want false got true")
  1133. }
  1134. }
  1135. // TestCloseInitiator tests the setting of close initiator statuses for
  1136. // cooperative closes and local force closes.
  1137. func TestCloseInitiator(t *testing.T) {
  1138. tests := []struct {
  1139. name string
  1140. // updateChannel is called to update the channel as broadcast,
  1141. // cooperatively or not, based on the test's requirements.
  1142. updateChannel func(c *OpenChannel) error
  1143. expectedStatuses []ChannelStatus
  1144. }{
  1145. {
  1146. name: "local coop close",
  1147. // Mark the channel as cooperatively closed, initiated
  1148. // by the local party.
  1149. updateChannel: func(c *OpenChannel) error {
  1150. return c.MarkCoopBroadcasted(
  1151. &wire.MsgTx{}, true,
  1152. )
  1153. },
  1154. expectedStatuses: []ChannelStatus{
  1155. ChanStatusLocalCloseInitiator,
  1156. ChanStatusCoopBroadcasted,
  1157. },
  1158. },
  1159. {
  1160. name: "remote coop close",
  1161. // Mark the channel as cooperatively closed, initiated
  1162. // by the remote party.
  1163. updateChannel: func(c *OpenChannel) error {
  1164. return c.MarkCoopBroadcasted(
  1165. &wire.MsgTx{}, false,
  1166. )
  1167. },
  1168. expectedStatuses: []ChannelStatus{
  1169. ChanStatusRemoteCloseInitiator,
  1170. ChanStatusCoopBroadcasted,
  1171. },
  1172. },
  1173. {
  1174. name: "local force close",
  1175. // Mark the channel's commitment as broadcast with
  1176. // local initiator.
  1177. updateChannel: func(c *OpenChannel) error {
  1178. return c.MarkCommitmentBroadcasted(
  1179. &wire.MsgTx{}, true,
  1180. )
  1181. },
  1182. expectedStatuses: []ChannelStatus{
  1183. ChanStatusLocalCloseInitiator,
  1184. ChanStatusCommitBroadcasted,
  1185. },
  1186. },
  1187. }
  1188. for _, test := range tests {
  1189. test := test
  1190. t.Run(test.name, func(t *testing.T) {
  1191. t.Parallel()
  1192. fullDB, err := MakeTestDB(t)
  1193. if err != nil {
  1194. t.Fatalf("unable to make test database: %v",
  1195. err)
  1196. }
  1197. cdb := fullDB.ChannelStateDB()
  1198. // Create an open channel.
  1199. channel := createTestChannel(
  1200. t, cdb, openChannelOption(),
  1201. )
  1202. err = test.updateChannel(channel)
  1203. if err != nil {
  1204. t.Fatalf("unexpected error: %v", err)
  1205. }
  1206. // Lookup open channels in the database.
  1207. dbChans, err := fetchChannels(
  1208. cdb, pendingChannelFilter(false),
  1209. )
  1210. if err != nil {
  1211. t.Fatalf("unexpected error: %v", err)
  1212. }
  1213. if len(dbChans) != 1 {
  1214. t.Fatalf("expected 1 channel, got: %v",
  1215. len(dbChans))
  1216. }
  1217. // Check that the statuses that we expect were written
  1218. // to disk.
  1219. for _, status := range test.expectedStatuses {
  1220. if !dbChans[0].HasChanStatus(status) {
  1221. t.Fatalf("expected channel to have "+
  1222. "status: %v, has status: %v",
  1223. status, dbChans[0].chanStatus)
  1224. }
  1225. }
  1226. })
  1227. }
  1228. }
  1229. // TestCloseChannelStatus tests setting of a channel status on the historical
  1230. // channel on channel close.
  1231. func TestCloseChannelStatus(t *testing.T) {
  1232. fullDB, err := MakeTestDB(t)
  1233. if err != nil {
  1234. t.Fatalf("unable to make test database: %v",
  1235. err)
  1236. }
  1237. cdb := fullDB.ChannelStateDB()
  1238. // Create an open channel.
  1239. channel := createTestChannel(
  1240. t, cdb, openChannelOption(),
  1241. )
  1242. if err := channel.CloseChannel(
  1243. &ChannelCloseSummary{
  1244. ChanPoint: channel.FundingOutpoint,
  1245. RemotePub: channel.IdentityPub,
  1246. }, ChanStatusRemoteCloseInitiator,
  1247. ); err != nil {
  1248. t.Fatalf("unexpected error: %v", err)
  1249. }
  1250. histChan, err := channel.Db.FetchHistoricalChannel(
  1251. &channel.FundingOutpoint,
  1252. )
  1253. require.NoError(t, err, "unexpected error")
  1254. if !histChan.HasChanStatus(ChanStatusRemoteCloseInitiator) {
  1255. t.Fatalf("channel should have status")
  1256. }
  1257. }
  1258. // TestHasChanStatus asserts the behavior of HasChanStatus by checking the
  1259. // behavior of various status flags in addition to the special case of
  1260. // ChanStatusDefault which is treated like a flag in the code base even though
  1261. // it isn't.
  1262. func TestHasChanStatus(t *testing.T) {
  1263. tests := []struct {
  1264. name string
  1265. status ChannelStatus
  1266. expHas map[ChannelStatus]bool
  1267. }{
  1268. {
  1269. name: "default",
  1270. status: ChanStatusDefault,
  1271. expHas: map[ChannelStatus]bool{
  1272. ChanStatusDefault: true,
  1273. ChanStatusBorked: false,
  1274. },
  1275. },
  1276. {
  1277. name: "single flag",
  1278. status: ChanStatusBorked,
  1279. expHas: map[ChannelStatus]bool{
  1280. ChanStatusDefault: false,
  1281. ChanStatusBorked: true,
  1282. },
  1283. },
  1284. {
  1285. name: "multiple flags",
  1286. status: ChanStatusBorked | ChanStatusLocalDataLoss,
  1287. expHas: map[ChannelStatus]bool{
  1288. ChanStatusDefault: false,
  1289. ChanStatusBorked: true,
  1290. ChanStatusLocalDataLoss: true,
  1291. },
  1292. },
  1293. }
  1294. for _, test := range tests {
  1295. test := test
  1296. t.Run(test.name, func(t *testing.T) {
  1297. c := &OpenChannel{
  1298. chanStatus: test.status,
  1299. }
  1300. for status, expHas := range test.expHas {
  1301. has := c.HasChanStatus(status)
  1302. if has == expHas {
  1303. continue
  1304. }
  1305. t.Fatalf("expected chan status to "+
  1306. "have %s? %t, got: %t",
  1307. status, expHas, has)
  1308. }
  1309. })
  1310. }
  1311. }
  1312. // TestKeyLocatorEncoding tests that we are able to serialize a given
  1313. // keychain.KeyLocator. After successfully encoding, we check that the decode
  1314. // output arrives at the same initial KeyLocator.
  1315. func TestKeyLocatorEncoding(t *testing.T) {
  1316. keyLoc := keychain.KeyLocator{
  1317. Family: keychain.KeyFamilyRevocationRoot,
  1318. Index: keyLocIndex,
  1319. }
  1320. // First, we'll encode the KeyLocator into a buffer.
  1321. var (
  1322. b bytes.Buffer
  1323. buf [8]byte
  1324. )
  1325. err := EKeyLocator(&b, &keyLoc, &buf)
  1326. require.NoError(t, err, "unable to encode key locator")
  1327. // Next, we'll attempt to decode the bytes into a new KeyLocator.
  1328. r := bytes.NewReader(b.Bytes())
  1329. var decodedKeyLoc keychain.KeyLocator
  1330. err = DKeyLocator(r, &decodedKeyLoc, &buf, 8)
  1331. require.NoError(t, err, "unable to decode key locator")
  1332. // Finally, we'll compare that the original KeyLocator and the decoded
  1333. // version are equal.
  1334. require.Equal(t, keyLoc, decodedKeyLoc)
  1335. }
  1336. // TestFinalHtlcs tests final htlc storage and retrieval.
  1337. func TestFinalHtlcs(t *testing.T) {
  1338. t.Parallel()
  1339. fullDB, err := MakeTestDB(t, OptionStoreFinalHtlcResolutions(true))
  1340. require.NoError(t, err, "unable to make test database")
  1341. cdb := fullDB.ChannelStateDB()
  1342. chanID := lnwire.ShortChannelID{
  1343. BlockHeight: 1,
  1344. TxIndex: 2,
  1345. TxPosition: 3,
  1346. }
  1347. // Test unknown htlc lookup.
  1348. const unknownHtlcID = 999
  1349. _, err = cdb.LookupFinalHtlc(chanID, unknownHtlcID)
  1350. require.ErrorIs(t, err, ErrHtlcUnknown)
  1351. // Test offchain final htlcs.
  1352. const offchainHtlcID = 1
  1353. err = kvdb.Update(cdb.backend, func(tx kvdb.RwTx) error {
  1354. bucket, err := fetchFinalHtlcsBucketRw(
  1355. tx, chanID,
  1356. )
  1357. require.NoError(t, err)
  1358. return putFinalHtlc(bucket, offchainHtlcID, FinalHtlcInfo{
  1359. Settled: true,
  1360. Offchain: true,
  1361. })
  1362. }, func() {})
  1363. require.NoError(t, err)
  1364. info, err := cdb.LookupFinalHtlc(chanID, offchainHtlcID)
  1365. require.NoError(t, err)
  1366. require.True(t, info.Settled)
  1367. require.True(t, info.Offchain)
  1368. // Test onchain final htlcs.
  1369. const onchainHtlcID = 2
  1370. err = cdb.PutOnchainFinalHtlcOutcome(chanID, onchainHtlcID, true)
  1371. require.NoError(t, err)
  1372. info, err = cdb.LookupFinalHtlc(chanID, onchainHtlcID)
  1373. require.NoError(t, err)
  1374. require.True(t, info.Settled)
  1375. require.False(t, info.Offchain)
  1376. // Test unknown htlc lookup for existing channel.
  1377. _, err = cdb.LookupFinalHtlc(chanID, unknownHtlcID)
  1378. require.ErrorIs(t, err, ErrHtlcUnknown)
  1379. }
  1380. // TestHTLCsExtraData tests serialization and deserialization of HTLCs
  1381. // combined with extra data.
  1382. func TestHTLCsExtraData(t *testing.T) {
  1383. t.Parallel()
  1384. mockHtlc := HTLC{
  1385. Signature: testSig.Serialize(),
  1386. Incoming: false,
  1387. Amt: 10,
  1388. RHash: key,
  1389. RefundTimeout: 1,
  1390. OnionBlob: lnmock.MockOnion(),
  1391. }
  1392. // Add a blinding point to a htlc.
  1393. blindingPointHTLC := HTLC{
  1394. Signature: testSig.Serialize(),
  1395. Incoming: false,
  1396. Amt: 10,
  1397. RHash: key,
  1398. RefundTimeout: 1,
  1399. OnionBlob: lnmock.MockOnion(),
  1400. BlindingPoint: tlv.SomeRecordT(
  1401. tlv.NewPrimitiveRecord[lnwire.BlindingPointTlvType](
  1402. pubKey,
  1403. ),
  1404. ),
  1405. }
  1406. testCases := []struct {
  1407. name string
  1408. htlcs []HTLC
  1409. blindingIdx int
  1410. }{
  1411. {
  1412. // Serialize multiple HLTCs with no extra data to
  1413. // assert that there is no regression for HTLCs with
  1414. // no extra data.
  1415. name: "no extra data",
  1416. htlcs: []HTLC{
  1417. mockHtlc, mockHtlc,
  1418. },
  1419. },
  1420. {
  1421. // Some HTLCs with extra data, some without.
  1422. name: "mixed extra data",
  1423. htlcs: []HTLC{
  1424. mockHtlc,
  1425. blindingPointHTLC,
  1426. mockHtlc,
  1427. },
  1428. },
  1429. }
  1430. for _, testCase := range testCases {
  1431. testCase := testCase
  1432. t.Run(testCase.name, func(t *testing.T) {
  1433. t.Parallel()
  1434. var b bytes.Buffer
  1435. err := SerializeHtlcs(&b, testCase.htlcs...)
  1436. require.NoError(t, err)
  1437. r := bytes.NewReader(b.Bytes())
  1438. htlcs, err := DeserializeHtlcs(r)
  1439. require.NoError(t, err)
  1440. require.EqualValues(t, len(testCase.htlcs), len(htlcs))
  1441. for i, htlc := range htlcs {
  1442. // We use the extra data field when we
  1443. // serialize, so we set to nil to be able to
  1444. // assert on equal for the test.
  1445. htlc.ExtraData = nil
  1446. require.Equal(t, testCase.htlcs[i], htlc)
  1447. }
  1448. })
  1449. }
  1450. }
  1451. // TestOnionBlobIncorrectLength tests HTLC deserialization in the case where
  1452. // the OnionBlob saved on disk is of an unexpected length. This error case is
  1453. // only expected in the case of database corruption (or some severe protocol
  1454. // breakdown/bug). A HTLC is manually serialized because we cannot force a
  1455. // case where we write an onion blob of incorrect length.
  1456. func TestOnionBlobIncorrectLength(t *testing.T) {
  1457. t.Parallel()
  1458. var b bytes.Buffer
  1459. var numHtlcs uint16 = 1
  1460. require.NoError(t, WriteElement(&b, numHtlcs))
  1461. require.NoError(t, WriteElements(
  1462. &b,
  1463. // Number of HTLCs.
  1464. numHtlcs,
  1465. // Signature, incoming, amount, Rhash, Timeout.
  1466. testSig.Serialize(), false, lnwire.MilliSatoshi(10), key,
  1467. uint32(1),
  1468. // Write an onion blob that is half of our expected size.
  1469. bytes.Repeat([]byte{1}, lnwire.OnionPacketSize/2),
  1470. ))
  1471. _, err := DeserializeHtlcs(&b)
  1472. require.ErrorIs(t, err, ErrOnionBlobLength)
  1473. }