payment_result.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. package htlcswitch
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "io"
  7. "sync"
  8. "github.com/lightningnetwork/lnd/channeldb"
  9. "github.com/lightningnetwork/lnd/kvdb"
  10. "github.com/lightningnetwork/lnd/lnwire"
  11. "github.com/lightningnetwork/lnd/multimutex"
  12. )
  13. var (
  14. // networkResultStoreBucketKey is used for the root level bucket that
  15. // stores the network result for each payment ID.
  16. networkResultStoreBucketKey = []byte("network-result-store-bucket")
  17. // ErrPaymentIDNotFound is an error returned if the given paymentID is
  18. // not found.
  19. ErrPaymentIDNotFound = errors.New("paymentID not found")
  20. // ErrPaymentIDAlreadyExists is returned if we try to write a pending
  21. // payment whose paymentID already exists.
  22. ErrPaymentIDAlreadyExists = errors.New("paymentID already exists")
  23. )
  24. // PaymentResult wraps a decoded result received from the network after a
  25. // payment attempt was made. This is what is eventually handed to the router
  26. // for processing.
  27. type PaymentResult struct {
  28. // Preimage is set by the switch in case a sent HTLC was settled.
  29. Preimage [32]byte
  30. // Error is non-nil in case a HTLC send failed, and the HTLC is now
  31. // irrevocably canceled. If the payment failed during forwarding, this
  32. // error will be a *ForwardingError.
  33. Error error
  34. }
  35. // networkResult is the raw result received from the network after a payment
  36. // attempt has been made. Since the switch doesn't always have the necessary
  37. // data to decode the raw message, we store it together with some meta data,
  38. // and decode it when the router query for the final result.
  39. type networkResult struct {
  40. // msg is the received result. This should be of type UpdateFulfillHTLC
  41. // or UpdateFailHTLC.
  42. msg lnwire.Message
  43. // unencrypted indicates whether the failure encoded in the message is
  44. // unencrypted, and hence doesn't need to be decrypted.
  45. unencrypted bool
  46. // isResolution indicates whether this is a resolution message, in
  47. // which the failure reason might not be included.
  48. isResolution bool
  49. }
  50. // serializeNetworkResult serializes the networkResult.
  51. func serializeNetworkResult(w io.Writer, n *networkResult) error {
  52. return channeldb.WriteElements(w, n.msg, n.unencrypted, n.isResolution)
  53. }
  54. // deserializeNetworkResult deserializes the networkResult.
  55. func deserializeNetworkResult(r io.Reader) (*networkResult, error) {
  56. n := &networkResult{}
  57. if err := channeldb.ReadElements(r,
  58. &n.msg, &n.unencrypted, &n.isResolution,
  59. ); err != nil {
  60. return nil, err
  61. }
  62. return n, nil
  63. }
  64. // networkResultStore is a persistent store that stores any results of HTLCs in
  65. // flight on the network. Since payment results are inherently asynchronous, it
  66. // is used as a common access point for senders of HTLCs, to know when a result
  67. // is back. The Switch will checkpoint any received result to the store, and
  68. // the store will keep results and notify the callers about them.
  69. type networkResultStore struct {
  70. backend kvdb.Backend
  71. // results is a map from paymentIDs to channels where subscribers to
  72. // payment results will be notified.
  73. results map[uint64][]chan *networkResult
  74. resultsMtx sync.Mutex
  75. // paymentIDMtx is a multimutex used to make sure the database and
  76. // result subscribers map is consistent for each payment ID in case of
  77. // concurrent callers.
  78. paymentIDMtx *multimutex.Mutex[uint64]
  79. }
  80. func newNetworkResultStore(db kvdb.Backend) *networkResultStore {
  81. return &networkResultStore{
  82. backend: db,
  83. results: make(map[uint64][]chan *networkResult),
  84. paymentIDMtx: multimutex.NewMutex[uint64](),
  85. }
  86. }
  87. // storeResult stores the networkResult for the given paymentID, and
  88. // notifies any subscribers.
  89. func (store *networkResultStore) storeResult(paymentID uint64,
  90. result *networkResult) error {
  91. // We get a mutex for this payment ID. This is needed to ensure
  92. // consistency between the database state and the subscribers in case
  93. // of concurrent calls.
  94. store.paymentIDMtx.Lock(paymentID)
  95. defer store.paymentIDMtx.Unlock(paymentID)
  96. log.Debugf("Storing result for paymentID=%v", paymentID)
  97. // Serialize the payment result.
  98. var b bytes.Buffer
  99. if err := serializeNetworkResult(&b, result); err != nil {
  100. return err
  101. }
  102. var paymentIDBytes [8]byte
  103. binary.BigEndian.PutUint64(paymentIDBytes[:], paymentID)
  104. err := kvdb.Batch(store.backend, func(tx kvdb.RwTx) error {
  105. networkResults, err := tx.CreateTopLevelBucket(
  106. networkResultStoreBucketKey,
  107. )
  108. if err != nil {
  109. return err
  110. }
  111. return networkResults.Put(paymentIDBytes[:], b.Bytes())
  112. })
  113. if err != nil {
  114. return err
  115. }
  116. // Now that the result is stored in the database, we can notify any
  117. // active subscribers.
  118. store.resultsMtx.Lock()
  119. for _, res := range store.results[paymentID] {
  120. res <- result
  121. }
  122. delete(store.results, paymentID)
  123. store.resultsMtx.Unlock()
  124. return nil
  125. }
  126. // subscribeResult is used to get the payment result for the given
  127. // payment ID. It returns a channel on which the result will be delivered when
  128. // ready.
  129. func (store *networkResultStore) subscribeResult(paymentID uint64) (
  130. <-chan *networkResult, error) {
  131. // We get a mutex for this payment ID. This is needed to ensure
  132. // consistency between the database state and the subscribers in case
  133. // of concurrent calls.
  134. store.paymentIDMtx.Lock(paymentID)
  135. defer store.paymentIDMtx.Unlock(paymentID)
  136. log.Debugf("Subscribing to result for paymentID=%v", paymentID)
  137. var (
  138. result *networkResult
  139. resultChan = make(chan *networkResult, 1)
  140. )
  141. err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
  142. var err error
  143. result, err = fetchResult(tx, paymentID)
  144. switch {
  145. // Result not yet available, we will notify once a result is
  146. // available.
  147. case err == ErrPaymentIDNotFound:
  148. return nil
  149. case err != nil:
  150. return err
  151. // The result was found, and will be returned immediately.
  152. default:
  153. return nil
  154. }
  155. }, func() {
  156. result = nil
  157. })
  158. if err != nil {
  159. return nil, err
  160. }
  161. // If the result was found, we can send it on the result channel
  162. // imemdiately.
  163. if result != nil {
  164. resultChan <- result
  165. return resultChan, nil
  166. }
  167. // Otherwise we store the result channel for when the result is
  168. // available.
  169. store.resultsMtx.Lock()
  170. store.results[paymentID] = append(
  171. store.results[paymentID], resultChan,
  172. )
  173. store.resultsMtx.Unlock()
  174. return resultChan, nil
  175. }
  176. // getResult attempts to immediately fetch the result for the given pid from
  177. // the store. If no result is available, ErrPaymentIDNotFound is returned.
  178. func (store *networkResultStore) getResult(pid uint64) (
  179. *networkResult, error) {
  180. var result *networkResult
  181. err := kvdb.View(store.backend, func(tx kvdb.RTx) error {
  182. var err error
  183. result, err = fetchResult(tx, pid)
  184. return err
  185. }, func() {
  186. result = nil
  187. })
  188. if err != nil {
  189. return nil, err
  190. }
  191. return result, nil
  192. }
  193. func fetchResult(tx kvdb.RTx, pid uint64) (*networkResult, error) {
  194. var paymentIDBytes [8]byte
  195. binary.BigEndian.PutUint64(paymentIDBytes[:], pid)
  196. networkResults := tx.ReadBucket(networkResultStoreBucketKey)
  197. if networkResults == nil {
  198. return nil, ErrPaymentIDNotFound
  199. }
  200. // Check whether a result is already available.
  201. resultBytes := networkResults.Get(paymentIDBytes[:])
  202. if resultBytes == nil {
  203. return nil, ErrPaymentIDNotFound
  204. }
  205. // Decode the result we found.
  206. r := bytes.NewReader(resultBytes)
  207. return deserializeNetworkResult(r)
  208. }
  209. // cleanStore removes all entries from the store, except the payment IDs given.
  210. // NOTE: Since every result not listed in the keep map will be deleted, care
  211. // should be taken to ensure no new payment attempts are being made
  212. // concurrently while this process is ongoing, as its result might end up being
  213. // deleted.
  214. func (store *networkResultStore) cleanStore(keep map[uint64]struct{}) error {
  215. return kvdb.Update(store.backend, func(tx kvdb.RwTx) error {
  216. networkResults, err := tx.CreateTopLevelBucket(
  217. networkResultStoreBucketKey,
  218. )
  219. if err != nil {
  220. return err
  221. }
  222. // Iterate through the bucket, deleting all items not in the
  223. // keep map.
  224. var toClean [][]byte
  225. if err := networkResults.ForEach(func(k, _ []byte) error {
  226. pid := binary.BigEndian.Uint64(k)
  227. if _, ok := keep[pid]; ok {
  228. return nil
  229. }
  230. toClean = append(toClean, k)
  231. return nil
  232. }); err != nil {
  233. return err
  234. }
  235. for _, k := range toClean {
  236. err := networkResults.Delete(k)
  237. if err != nil {
  238. return err
  239. }
  240. }
  241. if len(toClean) > 0 {
  242. log.Infof("Removed %d stale entries from network "+
  243. "result store", len(toClean))
  244. }
  245. return nil
  246. }, func() {})
  247. }