noise_test.go 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. package brontide
  2. import (
  3. "bytes"
  4. "encoding/hex"
  5. "fmt"
  6. "io"
  7. "math"
  8. "net"
  9. "testing"
  10. "testing/iotest"
  11. "github.com/btcsuite/btcd/btcec/v2"
  12. "github.com/lightningnetwork/lnd/keychain"
  13. "github.com/lightningnetwork/lnd/lnwire"
  14. "github.com/lightningnetwork/lnd/tor"
  15. "github.com/stretchr/testify/require"
  16. )
  17. type maybeNetConn struct {
  18. conn net.Conn
  19. err error
  20. }
  21. func makeListener() (*Listener, *lnwire.NetAddress, error) {
  22. // First, generate the long-term private keys for the brontide listener.
  23. localPriv, err := btcec.NewPrivateKey()
  24. if err != nil {
  25. return nil, nil, err
  26. }
  27. localKeyECDH := &keychain.PrivKeyECDH{PrivKey: localPriv}
  28. // Having a port of ":0" means a random port, and interface will be
  29. // chosen for our listener.
  30. addr := "localhost:0"
  31. // Our listener will be local, and the connection remote.
  32. listener, err := NewListener(localKeyECDH, addr)
  33. if err != nil {
  34. return nil, nil, err
  35. }
  36. netAddr := &lnwire.NetAddress{
  37. IdentityKey: localPriv.PubKey(),
  38. Address: listener.Addr().(*net.TCPAddr),
  39. }
  40. return listener, netAddr, nil
  41. }
  42. func establishTestConnection(t testing.TB) (net.Conn, net.Conn, error) {
  43. listener, netAddr, err := makeListener()
  44. if err != nil {
  45. return nil, nil, err
  46. }
  47. t.Cleanup(func() {
  48. listener.Close()
  49. })
  50. // Nos, generate the long-term private keys remote end of the connection
  51. // within our test.
  52. remotePriv, err := btcec.NewPrivateKey()
  53. if err != nil {
  54. return nil, nil, err
  55. }
  56. remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
  57. // Initiate a connection with a separate goroutine, and listen with our
  58. // main one. If both errors are nil, then encryption+auth was
  59. // successful.
  60. remoteConnChan := make(chan maybeNetConn, 1)
  61. go func() {
  62. remoteConn, err := Dial(
  63. remoteKeyECDH, netAddr,
  64. tor.DefaultConnTimeout, net.DialTimeout,
  65. )
  66. remoteConnChan <- maybeNetConn{remoteConn, err}
  67. }()
  68. localConnChan := make(chan maybeNetConn, 1)
  69. go func() {
  70. localConn, err := listener.Accept()
  71. localConnChan <- maybeNetConn{localConn, err}
  72. }()
  73. remote := <-remoteConnChan
  74. if remote.err != nil {
  75. return nil, nil, err
  76. }
  77. local := <-localConnChan
  78. if local.err != nil {
  79. return nil, nil, err
  80. }
  81. t.Cleanup(func() {
  82. local.conn.Close()
  83. remote.conn.Close()
  84. })
  85. return local.conn, remote.conn, nil
  86. }
  87. func TestConnectionCorrectness(t *testing.T) {
  88. // Create a test connection, grabbing either side of the connection
  89. // into local variables. If the initial crypto handshake fails, then
  90. // we'll get a non-nil error here.
  91. localConn, remoteConn, err := establishTestConnection(t)
  92. require.NoError(t, err, "unable to establish test connection")
  93. // Test out some message full-message reads.
  94. for i := 0; i < 10; i++ {
  95. msg := []byte(fmt.Sprintf("hello%d", i))
  96. if _, err := localConn.Write(msg); err != nil {
  97. t.Fatalf("remote conn failed to write: %v", err)
  98. }
  99. readBuf := make([]byte, len(msg))
  100. if _, err := remoteConn.Read(readBuf); err != nil {
  101. t.Fatalf("local conn failed to read: %v", err)
  102. }
  103. if !bytes.Equal(readBuf, msg) {
  104. t.Fatalf("messages don't match, %v vs %v",
  105. string(readBuf), string(msg))
  106. }
  107. }
  108. // Now try incremental message reads. This simulates first writing a
  109. // message header, then a message body.
  110. outMsg := []byte("hello world")
  111. if _, err := localConn.Write(outMsg); err != nil {
  112. t.Fatalf("remote conn failed to write: %v", err)
  113. }
  114. readBuf := make([]byte, len(outMsg))
  115. if _, err := remoteConn.Read(readBuf[:len(outMsg)/2]); err != nil {
  116. t.Fatalf("local conn failed to read: %v", err)
  117. }
  118. if _, err := remoteConn.Read(readBuf[len(outMsg)/2:]); err != nil {
  119. t.Fatalf("local conn failed to read: %v", err)
  120. }
  121. if !bytes.Equal(outMsg, readBuf) {
  122. t.Fatalf("messages don't match, %v vs %v",
  123. string(readBuf), string(outMsg))
  124. }
  125. }
  126. // TestConecurrentHandshakes verifies the listener's ability to not be blocked
  127. // by other pending handshakes. This is tested by opening multiple tcp
  128. // connections with the listener, without completing any of the brontide acts.
  129. // The test passes if real brontide dialer connects while the others are
  130. // stalled.
  131. func TestConcurrentHandshakes(t *testing.T) {
  132. listener, netAddr, err := makeListener()
  133. require.NoError(t, err, "unable to create listener connection")
  134. defer listener.Close()
  135. const nblocking = 5
  136. // Open a handful of tcp connections, that do not complete any steps of
  137. // the brontide handshake.
  138. connChan := make(chan maybeNetConn)
  139. for i := 0; i < nblocking; i++ {
  140. go func() {
  141. conn, err := net.Dial("tcp", listener.Addr().String())
  142. connChan <- maybeNetConn{conn, err}
  143. }()
  144. }
  145. // Receive all connections/errors from our blocking tcp dials. We make a
  146. // pass to gather all connections and errors to make sure we defer the
  147. // calls to Close() on all successful connections.
  148. tcpErrs := make([]error, 0, nblocking)
  149. for i := 0; i < nblocking; i++ {
  150. result := <-connChan
  151. if result.conn != nil {
  152. defer result.conn.Close()
  153. }
  154. if result.err != nil {
  155. tcpErrs = append(tcpErrs, result.err)
  156. }
  157. }
  158. for _, tcpErr := range tcpErrs {
  159. if tcpErr != nil {
  160. t.Fatalf("unable to tcp dial listener: %v", tcpErr)
  161. }
  162. }
  163. // Now, construct a new private key and use the brontide dialer to
  164. // connect to the listener.
  165. remotePriv, err := btcec.NewPrivateKey()
  166. require.NoError(t, err, "unable to generate private key")
  167. remoteKeyECDH := &keychain.PrivKeyECDH{PrivKey: remotePriv}
  168. go func() {
  169. remoteConn, err := Dial(
  170. remoteKeyECDH, netAddr,
  171. tor.DefaultConnTimeout, net.DialTimeout,
  172. )
  173. connChan <- maybeNetConn{remoteConn, err}
  174. }()
  175. // This connection should be accepted without error, as the brontide
  176. // connection should bypass stalled tcp connections.
  177. conn, err := listener.Accept()
  178. require.NoError(t, err, "unable to accept dial")
  179. defer conn.Close()
  180. result := <-connChan
  181. if result.err != nil {
  182. t.Fatalf("unable to dial %v: %v", netAddr, result.err)
  183. }
  184. result.conn.Close()
  185. }
  186. func TestMaxPayloadLength(t *testing.T) {
  187. t.Parallel()
  188. b := Machine{}
  189. b.split()
  190. // Create a payload that's only *slightly* above the maximum allotted
  191. // payload length.
  192. payloadToReject := make([]byte, math.MaxUint16+1)
  193. // A write of the payload generated above to the state machine should
  194. // be rejected as it's over the max payload length.
  195. err := b.WriteMessage(payloadToReject)
  196. if err != ErrMaxMessageLengthExceeded {
  197. t.Fatalf("payload is over the max allowed length, the write " +
  198. "should have been rejected")
  199. }
  200. // Generate another payload which should be accepted as a valid
  201. // payload.
  202. payloadToAccept := make([]byte, math.MaxUint16-1)
  203. if err := b.WriteMessage(payloadToAccept); err != nil {
  204. t.Fatalf("write for payload was rejected, should have been " +
  205. "accepted")
  206. }
  207. // Generate a final payload which is only *slightly* above the max payload length
  208. // when the MAC is accounted for.
  209. payloadToReject = make([]byte, math.MaxUint16+1)
  210. // This payload should be rejected.
  211. err = b.WriteMessage(payloadToReject)
  212. if err != ErrMaxMessageLengthExceeded {
  213. t.Fatalf("payload is over the max allowed length, the write " +
  214. "should have been rejected")
  215. }
  216. }
  217. func TestWriteMessageChunking(t *testing.T) {
  218. // Create a test connection, grabbing either side of the connection
  219. // into local variables. If the initial crypto handshake fails, then
  220. // we'll get a non-nil error here.
  221. localConn, remoteConn, err := establishTestConnection(t)
  222. require.NoError(t, err, "unable to establish test connection")
  223. // Attempt to write a message which is over 3x the max allowed payload
  224. // size.
  225. largeMessage := bytes.Repeat([]byte("kek"), math.MaxUint16*3)
  226. // Launch a new goroutine to write the large message generated above in
  227. // chunks. We spawn a new goroutine because otherwise, we may block as
  228. // the kernel waits for the buffer to flush.
  229. errCh := make(chan error)
  230. go func() {
  231. defer close(errCh)
  232. bytesWritten, err := localConn.Write(largeMessage)
  233. if err != nil {
  234. errCh <- fmt.Errorf("unable to write message: %w", err)
  235. return
  236. }
  237. // The entire message should have been written out to the remote
  238. // connection.
  239. if bytesWritten != len(largeMessage) {
  240. errCh <- fmt.Errorf("bytes not fully written")
  241. return
  242. }
  243. }()
  244. // Attempt to read the entirety of the message generated above.
  245. buf := make([]byte, len(largeMessage))
  246. if _, err := io.ReadFull(remoteConn, buf); err != nil {
  247. t.Fatalf("unable to read message: %v", err)
  248. }
  249. err = <-errCh
  250. if err != nil {
  251. t.Fatal(err)
  252. }
  253. // Finally, the message the remote end of the connection received
  254. // should be identical to what we sent from the local connection.
  255. if !bytes.Equal(buf, largeMessage) {
  256. t.Fatalf("bytes don't match")
  257. }
  258. }
  259. // TestBolt0008TestVectors ensures that our implementation of brontide exactly
  260. // matches the test vectors within the specification.
  261. func TestBolt0008TestVectors(t *testing.T) {
  262. t.Parallel()
  263. // First, we'll generate the state of the initiator from the test
  264. // vectors at the appendix of BOLT-0008
  265. initiatorKeyBytes, err := hex.DecodeString("1111111111111111111111" +
  266. "111111111111111111111111111111111111111111")
  267. require.NoError(t, err, "unable to decode hex")
  268. initiatorPriv, _ := btcec.PrivKeyFromBytes(
  269. initiatorKeyBytes,
  270. )
  271. initiatorKeyECDH := &keychain.PrivKeyECDH{PrivKey: initiatorPriv}
  272. // We'll then do the same for the responder.
  273. responderKeyBytes, err := hex.DecodeString("212121212121212121212121" +
  274. "2121212121212121212121212121212121212121")
  275. require.NoError(t, err, "unable to decode hex")
  276. responderPriv, responderPub := btcec.PrivKeyFromBytes(
  277. responderKeyBytes,
  278. )
  279. responderKeyECDH := &keychain.PrivKeyECDH{PrivKey: responderPriv}
  280. // With the initiator's key data parsed, we'll now define a custom
  281. // EphemeralGenerator function for the state machine to ensure that the
  282. // initiator and responder both generate the ephemeral public key
  283. // defined within the test vectors.
  284. initiatorEphemeral := EphemeralGenerator(func() (*btcec.PrivateKey, error) {
  285. e := "121212121212121212121212121212121212121212121212121212" +
  286. "1212121212"
  287. eBytes, err := hex.DecodeString(e)
  288. if err != nil {
  289. return nil, err
  290. }
  291. priv, _ := btcec.PrivKeyFromBytes(eBytes)
  292. return priv, nil
  293. })
  294. responderEphemeral := EphemeralGenerator(func() (*btcec.PrivateKey, error) {
  295. e := "222222222222222222222222222222222222222222222222222" +
  296. "2222222222222"
  297. eBytes, err := hex.DecodeString(e)
  298. if err != nil {
  299. return nil, err
  300. }
  301. priv, _ := btcec.PrivKeyFromBytes(eBytes)
  302. return priv, nil
  303. })
  304. // Finally, we'll create both brontide state machines, so we can begin
  305. // our test.
  306. initiator := NewBrontideMachine(
  307. true, initiatorKeyECDH, responderPub, initiatorEphemeral,
  308. )
  309. responder := NewBrontideMachine(
  310. false, responderKeyECDH, nil, responderEphemeral,
  311. )
  312. // We'll start with the initiator generating the initial payload for
  313. // act one. This should consist of exactly 50 bytes. We'll assert that
  314. // the payload return is _exactly_ the same as what's specified within
  315. // the test vectors.
  316. actOne, err := initiator.GenActOne()
  317. require.NoError(t, err, "unable to generate act one")
  318. expectedActOne, err := hex.DecodeString("00036360e856310ce5d294e" +
  319. "8be33fc807077dc56ac80d95d9cd4ddbd21325eff73f70df608655115" +
  320. "1f58b8afe6c195782c6a")
  321. require.NoError(t, err, "unable to parse expected act one")
  322. if !bytes.Equal(expectedActOne, actOne[:]) {
  323. t.Fatalf("act one mismatch: expected %x, got %x",
  324. expectedActOne, actOne)
  325. }
  326. // With the assertion above passed, we'll now process the act one
  327. // payload with the responder of the crypto handshake.
  328. if err := responder.RecvActOne(actOne); err != nil {
  329. t.Fatalf("responder unable to process act one: %v", err)
  330. }
  331. // Next, we'll start the second act by having the responder generate
  332. // its contribution to the crypto handshake. We'll also verify that we
  333. // produce the _exact_ same byte stream as advertised within the spec's
  334. // test vectors.
  335. actTwo, err := responder.GenActTwo()
  336. require.NoError(t, err, "unable to generate act two")
  337. expectedActTwo, err := hex.DecodeString("0002466d7fcae563e5cb09a0" +
  338. "d1870bb580344804617879a14949cf22285f1bae3f276e2470b93aac58" +
  339. "3c9ef6eafca3f730ae")
  340. require.NoError(t, err, "unable to parse expected act two")
  341. if !bytes.Equal(expectedActTwo, actTwo[:]) {
  342. t.Fatalf("act two mismatch: expected %x, got %x",
  343. expectedActTwo, actTwo)
  344. }
  345. // Moving the handshake along, we'll also ensure that the initiator
  346. // accepts the act two payload.
  347. if err := initiator.RecvActTwo(actTwo); err != nil {
  348. t.Fatalf("initiator unable to process act two: %v", err)
  349. }
  350. // At the final step, we'll generate the last act from the initiator
  351. // and once again verify that it properly matches the test vectors.
  352. actThree, err := initiator.GenActThree()
  353. require.NoError(t, err, "unable to generate act three")
  354. expectedActThree, err := hex.DecodeString("00b9e3a702e93e3a9948c2e" +
  355. "d6e5fd7590a6e1c3a0344cfc9d5b57357049aa22355361aa02e55a8f" +
  356. "c28fef5bd6d71ad0c38228dc68b1c466263b47fdf31e560e139ba")
  357. require.NoError(t, err, "unable to parse expected act three")
  358. if !bytes.Equal(expectedActThree, actThree[:]) {
  359. t.Fatalf("act three mismatch: expected %x, got %x",
  360. expectedActThree, actThree)
  361. }
  362. // Finally, we'll ensure that the responder itself also properly parses
  363. // the last payload in the crypto handshake.
  364. if err := responder.RecvActThree(actThree); err != nil {
  365. t.Fatalf("responder unable to process act three: %v", err)
  366. }
  367. // As a final assertion, we'll ensure that both sides have derived the
  368. // proper symmetric encryption keys.
  369. sendingKey, err := hex.DecodeString("969ab31b4d288cedf6218839b27a3e2" +
  370. "140827047f2c0f01bf5c04435d43511a9")
  371. require.NoError(t, err, "unable to parse sending key")
  372. recvKey, err := hex.DecodeString("bb9020b8965f4df047e07f955f3c4b884" +
  373. "18984aadc5cdb35096b9ea8fa5c3442")
  374. require.NoError(t, err, "unable to parse receiving key")
  375. chainKey, err := hex.DecodeString("919219dbb2920afa8db80f9a51787a840" +
  376. "bcf111ed8d588caf9ab4be716e42b01")
  377. require.NoError(t, err, "unable to parse chaining key")
  378. if !bytes.Equal(initiator.sendCipher.secretKey[:], sendingKey) {
  379. t.Fatalf("sending key mismatch: expected %x, got %x",
  380. initiator.sendCipher.secretKey[:], sendingKey)
  381. }
  382. if !bytes.Equal(initiator.recvCipher.secretKey[:], recvKey) {
  383. t.Fatalf("receiving key mismatch: expected %x, got %x",
  384. initiator.recvCipher.secretKey[:], recvKey)
  385. }
  386. if !bytes.Equal(initiator.chainingKey[:], chainKey) {
  387. t.Fatalf("chaining key mismatch: expected %x, got %x",
  388. initiator.chainingKey[:], chainKey)
  389. }
  390. if !bytes.Equal(responder.sendCipher.secretKey[:], recvKey) {
  391. t.Fatalf("sending key mismatch: expected %x, got %x",
  392. responder.sendCipher.secretKey[:], recvKey)
  393. }
  394. if !bytes.Equal(responder.recvCipher.secretKey[:], sendingKey) {
  395. t.Fatalf("receiving key mismatch: expected %x, got %x",
  396. responder.recvCipher.secretKey[:], sendingKey)
  397. }
  398. if !bytes.Equal(responder.chainingKey[:], chainKey) {
  399. t.Fatalf("chaining key mismatch: expected %x, got %x",
  400. responder.chainingKey[:], chainKey)
  401. }
  402. // Now test as per section "transport-message test" in Test Vectors
  403. // (the transportMessageVectors ciphertexts are from this section of BOLT 8);
  404. // we do slightly greater than 1000 encryption/decryption operations
  405. // to ensure that the key rotation algorithm is operating as expected.
  406. // The starting point for enc/decr is already guaranteed correct from the
  407. // above tests of sendingKey, receivingKey, chainingKey.
  408. transportMessageVectors := map[int]string{
  409. 0: "cf2b30ddf0cf3f80e7c35a6e6730b59fe802473180f396d88a8fb0db8cb" +
  410. "cf25d2f214cf9ea1d95",
  411. 1: "72887022101f0b6753e0c7de21657d35a4cb2a1f5cde2650528bbc8f837" +
  412. "d0f0d7ad833b1a256a1",
  413. 500: "178cb9d7387190fa34db9c2d50027d21793c9bc2d40b1e14dcf30ebeeeb2" +
  414. "20f48364f7a4c68bf8",
  415. 501: "1b186c57d44eb6de4c057c49940d79bb838a145cb528d6e8fd26dbe50a6" +
  416. "0ca2c104b56b60e45bd",
  417. 1000: "4a2f3cc3b5e78ddb83dcb426d9863d9d9a723b0337c89dd0b005d89f8d3" +
  418. "c05c52b76b29b740f09",
  419. 1001: "2ecd8c8a5629d0d02ab457a0fdd0f7b90a192cd46be5ecb6ca570bfc5e2" +
  420. "68338b1a16cf4ef2d36",
  421. }
  422. // Payload for every message is the string "hello".
  423. payload := []byte("hello")
  424. var buf bytes.Buffer
  425. for i := 0; i < 1002; i++ {
  426. err = initiator.WriteMessage(payload)
  427. if err != nil {
  428. t.Fatalf("could not write message %s", payload)
  429. }
  430. _, err = initiator.Flush(&buf)
  431. if err != nil {
  432. t.Fatalf("could not flush message: %v", err)
  433. }
  434. if val, ok := transportMessageVectors[i]; ok {
  435. binaryVal, err := hex.DecodeString(val)
  436. if err != nil {
  437. t.Fatalf("Failed to decode hex string %s", val)
  438. }
  439. if !bytes.Equal(buf.Bytes(), binaryVal) {
  440. t.Fatalf("Ciphertext %x was not equal to expected %s",
  441. buf.String()[:], val)
  442. }
  443. }
  444. // Responder decrypts the bytes, in every iteration, and
  445. // should always be able to decrypt the same payload message.
  446. plaintext, err := responder.ReadMessage(&buf)
  447. if err != nil {
  448. t.Fatalf("failed to read message in responder: %v", err)
  449. }
  450. // Ensure decryption succeeded
  451. if !bytes.Equal(plaintext, payload) {
  452. t.Fatalf("Decryption failed to receive plaintext: %s, got %s",
  453. payload, plaintext)
  454. }
  455. // Clear out the buffer for the next iteration
  456. buf.Reset()
  457. }
  458. }
  459. // timeoutWriter wraps an io.Writer and throws an iotest.ErrTimeout after
  460. // writing n bytes.
  461. type timeoutWriter struct {
  462. w io.Writer
  463. n int64
  464. }
  465. func NewTimeoutWriter(w io.Writer, n int64) io.Writer {
  466. return &timeoutWriter{w, n}
  467. }
  468. func (t *timeoutWriter) Write(p []byte) (int, error) {
  469. n := len(p)
  470. if int64(n) > t.n {
  471. n = int(t.n)
  472. }
  473. n, err := t.w.Write(p[:n])
  474. t.n -= int64(n)
  475. if err == nil && t.n == 0 {
  476. return n, iotest.ErrTimeout
  477. }
  478. return n, err
  479. }
  480. const payloadSize = 10
  481. type flushChunk struct {
  482. errAfter int64
  483. expN int
  484. expErr error
  485. }
  486. type flushTest struct {
  487. name string
  488. chunks []flushChunk
  489. }
  490. var flushTests = []flushTest{
  491. {
  492. name: "partial header write",
  493. chunks: []flushChunk{
  494. // Write 18-byte header in two parts, 16 then 2.
  495. {
  496. errAfter: encHeaderSize - 2,
  497. expN: 0,
  498. expErr: iotest.ErrTimeout,
  499. },
  500. {
  501. errAfter: 2,
  502. expN: 0,
  503. expErr: iotest.ErrTimeout,
  504. },
  505. // Write payload and MAC in one go.
  506. {
  507. errAfter: -1,
  508. expN: payloadSize,
  509. },
  510. },
  511. },
  512. {
  513. name: "full payload then full mac",
  514. chunks: []flushChunk{
  515. // Write entire header and entire payload w/o MAC.
  516. {
  517. errAfter: encHeaderSize + payloadSize,
  518. expN: payloadSize,
  519. expErr: iotest.ErrTimeout,
  520. },
  521. // Write the entire MAC.
  522. {
  523. errAfter: -1,
  524. expN: 0,
  525. },
  526. },
  527. },
  528. {
  529. name: "payload-only, straddle, mac-only",
  530. chunks: []flushChunk{
  531. // Write header and all but last byte of payload.
  532. {
  533. errAfter: encHeaderSize + payloadSize - 1,
  534. expN: payloadSize - 1,
  535. expErr: iotest.ErrTimeout,
  536. },
  537. // Write last byte of payload and first byte of MAC.
  538. {
  539. errAfter: 2,
  540. expN: 1,
  541. expErr: iotest.ErrTimeout,
  542. },
  543. // Write 10 bytes of the MAC.
  544. {
  545. errAfter: 10,
  546. expN: 0,
  547. expErr: iotest.ErrTimeout,
  548. },
  549. // Write the remaining 5 MAC bytes.
  550. {
  551. errAfter: -1,
  552. expN: 0,
  553. },
  554. },
  555. },
  556. }
  557. // TestFlush asserts a Machine's ability to handle timeouts during Flush that
  558. // cause partial writes, and that the machine can properly resume writes on
  559. // subsequent calls to Flush.
  560. func TestFlush(t *testing.T) {
  561. // Run each test individually, to assert that they pass in isolation.
  562. for _, test := range flushTests {
  563. t.Run(test.name, func(t *testing.T) {
  564. var (
  565. w bytes.Buffer
  566. b Machine
  567. )
  568. b.split()
  569. testFlush(t, test, &b, &w)
  570. })
  571. }
  572. // Finally, run the tests serially as if all on one connection.
  573. t.Run("flush serial", func(t *testing.T) {
  574. var (
  575. w bytes.Buffer
  576. b Machine
  577. )
  578. b.split()
  579. for _, test := range flushTests {
  580. testFlush(t, test, &b, &w)
  581. }
  582. })
  583. }
  584. // testFlush buffers a message on the Machine, then flushes it to the io.Writer
  585. // in chunks. Once complete, a final call to flush is made to assert that Write
  586. // is not called again.
  587. func testFlush(t *testing.T, test flushTest, b *Machine, w io.Writer) {
  588. payload := make([]byte, payloadSize)
  589. if err := b.WriteMessage(payload); err != nil {
  590. t.Fatalf("unable to write message: %v", err)
  591. }
  592. for _, chunk := range test.chunks {
  593. assertFlush(t, b, w, chunk.errAfter, chunk.expN, chunk.expErr)
  594. }
  595. // We should always be able to call Flush after a message has been
  596. // successfully written, and it should result in a NOP.
  597. assertFlush(t, b, w, 0, 0, nil)
  598. }
  599. // assertFlush flushes a chunk to the passed io.Writer. If n >= 0, a
  600. // timeoutWriter will be used the flush should stop with iotest.ErrTimeout after
  601. // n bytes. The method asserts that the returned error matches expErr and that
  602. // the number of bytes written by Flush matches expN.
  603. func assertFlush(t *testing.T, b *Machine, w io.Writer, n int64, expN int,
  604. expErr error) {
  605. t.Helper()
  606. if n >= 0 {
  607. w = NewTimeoutWriter(w, n)
  608. }
  609. nn, err := b.Flush(w)
  610. if err != expErr {
  611. t.Fatalf("expected flush err: %v, got: %v", expErr, err)
  612. }
  613. if nn != expN {
  614. t.Fatalf("expected n: %d, got: %d", expN, nn)
  615. }
  616. }