http2_test.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625
  1. package connection
  2. import (
  3. "bytes"
  4. "context"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "net/http"
  10. "net/http/httptest"
  11. "sync"
  12. "testing"
  13. "time"
  14. "github.com/gobwas/ws/wsutil"
  15. "github.com/google/uuid"
  16. "github.com/rs/zerolog"
  17. "github.com/stretchr/testify/assert"
  18. "github.com/stretchr/testify/require"
  19. "golang.org/x/net/http2"
  20. "github.com/cloudflare/cloudflared/client"
  21. "github.com/cloudflare/cloudflared/tracing"
  22. "github.com/cloudflare/cloudflared/tunnelrpc"
  23. "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  24. )
  25. var (
  26. testTransport = http2.Transport{}
  27. )
  28. func newTestHTTP2Connection() (*HTTP2Connection, net.Conn) {
  29. edgeConn, cfdConn := net.Pipe()
  30. var connIndex = uint8(0)
  31. log := zerolog.Nop()
  32. obs := NewObserver(&log, &log)
  33. controlStream := NewControlStream(
  34. obs,
  35. mockConnectedFuse{},
  36. &TunnelProperties{},
  37. connIndex,
  38. nil,
  39. nil,
  40. 1*time.Second,
  41. nil,
  42. 1*time.Second,
  43. HTTP2,
  44. )
  45. return NewHTTP2Connection(
  46. cfdConn,
  47. // OriginProxy is set in testConfigManager
  48. testOrchestrator,
  49. &client.ConnectionOptionsSnapshot{},
  50. obs,
  51. connIndex,
  52. controlStream,
  53. &log,
  54. ), edgeConn
  55. }
  56. func TestHTTP2ConfigurationSet(t *testing.T) {
  57. http2Conn, edgeConn := newTestHTTP2Connection()
  58. ctx, cancel := context.WithCancel(context.Background())
  59. var wg sync.WaitGroup
  60. wg.Add(1)
  61. go func() {
  62. defer wg.Done()
  63. _ = http2Conn.Serve(ctx)
  64. }()
  65. edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
  66. require.NoError(t, err)
  67. reqBody := []byte(`{
  68. "version": 2,
  69. "config": {"warp-routing": {"enabled": true}, "originRequest" : {"connectTimeout": 10}, "ingress" : [ {"hostname": "test", "service": "https://localhost:8000" } , {"service": "http_status:404"} ]}}
  70. `)
  71. reader := bytes.NewReader(reqBody)
  72. req, err := http.NewRequestWithContext(ctx, http.MethodPut, "http://localhost:8080/ok", reader)
  73. require.NoError(t, err)
  74. req.Header.Set(InternalUpgradeHeader, ConfigurationUpdate)
  75. resp, err := edgeHTTP2Conn.RoundTrip(req)
  76. require.NoError(t, err)
  77. require.Equal(t, http.StatusOK, resp.StatusCode)
  78. bdy, err := io.ReadAll(resp.Body)
  79. defer resp.Body.Close()
  80. require.NoError(t, err)
  81. assert.Equal(t, `{"lastAppliedVersion":2,"err":null}`, string(bdy))
  82. cancel()
  83. wg.Wait()
  84. }
  85. func TestServeHTTP(t *testing.T) {
  86. tests := []testRequest{
  87. {
  88. name: "ok",
  89. endpoint: "ok",
  90. expectedStatus: http.StatusOK,
  91. expectedBody: []byte(http.StatusText(http.StatusOK)),
  92. },
  93. {
  94. name: "large_file",
  95. endpoint: "large_file",
  96. expectedStatus: http.StatusOK,
  97. expectedBody: testLargeResp,
  98. },
  99. {
  100. name: "Bad request",
  101. endpoint: "400",
  102. expectedStatus: http.StatusBadRequest,
  103. expectedBody: []byte(http.StatusText(http.StatusBadRequest)),
  104. },
  105. {
  106. name: "Internal server error",
  107. endpoint: "500",
  108. expectedStatus: http.StatusInternalServerError,
  109. expectedBody: []byte(http.StatusText(http.StatusInternalServerError)),
  110. },
  111. {
  112. name: "Proxy error",
  113. endpoint: "error",
  114. expectedStatus: http.StatusBadGateway,
  115. expectedBody: nil,
  116. isProxyError: true,
  117. },
  118. }
  119. http2Conn, edgeConn := newTestHTTP2Connection()
  120. ctx, cancel := context.WithCancel(context.Background())
  121. var wg sync.WaitGroup
  122. wg.Add(1)
  123. go func() {
  124. defer wg.Done()
  125. _ = http2Conn.Serve(ctx)
  126. }()
  127. edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
  128. require.NoError(t, err)
  129. for _, test := range tests {
  130. endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
  131. req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
  132. require.NoError(t, err)
  133. resp, err := edgeHTTP2Conn.RoundTrip(req)
  134. require.NoError(t, err)
  135. require.Equal(t, test.expectedStatus, resp.StatusCode)
  136. if test.expectedBody != nil {
  137. respBody, err := io.ReadAll(resp.Body)
  138. require.NoError(t, err)
  139. require.Equal(t, test.expectedBody, respBody)
  140. }
  141. _ = resp.Body.Close()
  142. if test.isProxyError {
  143. require.Equal(t, responseMetaHeaderCfd, resp.Header.Get(ResponseMetaHeader))
  144. } else {
  145. require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
  146. }
  147. }
  148. cancel()
  149. wg.Wait()
  150. }
  151. type mockNamedTunnelRPCClient struct {
  152. shouldFail error
  153. registered chan struct{}
  154. unregistered chan struct{}
  155. }
  156. func (mc mockNamedTunnelRPCClient) SendLocalConfiguration(c context.Context, config []byte) error {
  157. return nil
  158. }
  159. func (mc mockNamedTunnelRPCClient) RegisterConnection(
  160. ctx context.Context,
  161. auth pogs.TunnelAuth,
  162. tunnelID uuid.UUID,
  163. options *pogs.ConnectionOptions,
  164. connIndex uint8,
  165. edgeAddress net.IP,
  166. ) (*pogs.ConnectionDetails, error) {
  167. if mc.shouldFail != nil {
  168. return nil, mc.shouldFail
  169. }
  170. close(mc.registered)
  171. return &pogs.ConnectionDetails{
  172. Location: "LIS",
  173. UUID: uuid.New(),
  174. TunnelIsRemotelyManaged: false,
  175. }, nil
  176. }
  177. func (mc mockNamedTunnelRPCClient) GracefulShutdown(ctx context.Context, gracePeriod time.Duration) error {
  178. close(mc.unregistered)
  179. return nil
  180. }
  181. func (mockNamedTunnelRPCClient) Close() {}
  182. type mockRPCClientFactory struct {
  183. shouldFail error
  184. registered chan struct{}
  185. unregistered chan struct{}
  186. }
  187. func (mf *mockRPCClientFactory) newMockRPCClient(context.Context, io.ReadWriteCloser, time.Duration) tunnelrpc.RegistrationClient {
  188. return &mockNamedTunnelRPCClient{
  189. shouldFail: mf.shouldFail,
  190. registered: mf.registered,
  191. unregistered: mf.unregistered,
  192. }
  193. }
  194. type wsRespWriter struct {
  195. *httptest.ResponseRecorder
  196. readPipe *io.PipeReader
  197. writePipe *io.PipeWriter
  198. closed bool
  199. panicked bool
  200. }
  201. func newWSRespWriter() *wsRespWriter {
  202. readPipe, writePipe := io.Pipe()
  203. return &wsRespWriter{
  204. httptest.NewRecorder(),
  205. readPipe,
  206. writePipe,
  207. false,
  208. false,
  209. }
  210. }
  211. type nowriter struct {
  212. io.Reader
  213. }
  214. func (nowriter) Write(_ []byte) (int, error) {
  215. return 0, fmt.Errorf("writer not implemented")
  216. }
  217. func (w *wsRespWriter) RespBody() io.ReadWriter {
  218. return nowriter{w.readPipe}
  219. }
  220. func (w *wsRespWriter) Write(data []byte) (n int, err error) {
  221. if w.closed {
  222. w.panicked = true
  223. return 0, errors.New("wsRespWriter panicked")
  224. }
  225. return w.writePipe.Write(data)
  226. }
  227. func (w *wsRespWriter) close() {
  228. w.closed = true
  229. }
  230. func TestServeWS(t *testing.T) {
  231. http2Conn, _ := newTestHTTP2Connection()
  232. ctx, cancel := context.WithCancel(context.Background())
  233. respWriter := newWSRespWriter()
  234. readPipe, writePipe := io.Pipe()
  235. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/ws/echo", readPipe)
  236. require.NoError(t, err)
  237. req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
  238. serveDone := make(chan struct{})
  239. go func() {
  240. defer close(serveDone)
  241. http2Conn.ServeHTTP(respWriter, req)
  242. respWriter.close()
  243. }()
  244. data := []byte("test websocket")
  245. err = wsutil.WriteClientBinary(writePipe, data)
  246. require.NoError(t, err)
  247. respBody, err := wsutil.ReadServerBinary(respWriter.RespBody())
  248. require.NoError(t, err)
  249. require.Equal(t, data, respBody, "expect %s, got %s", string(data), string(respBody))
  250. cancel()
  251. resp := respWriter.Result()
  252. defer resp.Body.Close()
  253. // http2RespWriter should rewrite status 101 to 200
  254. require.Equal(t, http.StatusOK, resp.StatusCode)
  255. require.Equal(t, responseMetaHeaderOrigin, resp.Header.Get(ResponseMetaHeader))
  256. <-serveDone
  257. require.False(t, respWriter.panicked)
  258. }
  259. // TestNoWriteAfterServeHTTPReturns is a regression test of https://jira.cfops.it/browse/TUN-5184
  260. // to make sure we don't write to the ResponseWriter after the ServeHTTP method returns
  261. func TestNoWriteAfterServeHTTPReturns(t *testing.T) {
  262. cfdHTTP2Conn, edgeTCPConn := newTestHTTP2Connection()
  263. ctx, cancel := context.WithCancel(context.Background())
  264. var wg sync.WaitGroup
  265. serverDone := make(chan struct{})
  266. go func() {
  267. defer close(serverDone)
  268. _ = cfdHTTP2Conn.Serve(ctx)
  269. }()
  270. edgeTransport := http2.Transport{}
  271. edgeHTTP2Conn, err := edgeTransport.NewClientConn(edgeTCPConn)
  272. require.NoError(t, err)
  273. message := []byte(t.Name())
  274. for i := 0; i < 100; i++ {
  275. wg.Add(1)
  276. go func() {
  277. defer wg.Done()
  278. readPipe, writePipe := io.Pipe()
  279. reqCtx, reqCancel := context.WithCancel(ctx)
  280. req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, "http://localhost:8080/ws/flaky", readPipe)
  281. assert.NoError(t, err)
  282. req.Header.Set(InternalUpgradeHeader, WebsocketUpgrade)
  283. resp, err := edgeHTTP2Conn.RoundTrip(req)
  284. assert.NoError(t, err)
  285. _ = resp.Body.Close()
  286. // http2RespWriter should rewrite status 101 to 200
  287. assert.Equal(t, http.StatusOK, resp.StatusCode)
  288. wg.Add(1)
  289. go func() {
  290. defer wg.Done()
  291. for {
  292. select {
  293. case <-reqCtx.Done():
  294. return
  295. default:
  296. }
  297. _ = wsutil.WriteClientBinary(writePipe, message)
  298. }
  299. }()
  300. time.Sleep(time.Millisecond * 100)
  301. reqCancel()
  302. }()
  303. }
  304. wg.Wait()
  305. cancel()
  306. <-serverDone
  307. }
  308. func TestServeControlStream(t *testing.T) {
  309. http2Conn, edgeConn := newTestHTTP2Connection()
  310. rpcClientFactory := mockRPCClientFactory{
  311. registered: make(chan struct{}),
  312. unregistered: make(chan struct{}),
  313. }
  314. obs := NewObserver(&log, &log)
  315. controlStream := NewControlStream(
  316. obs,
  317. mockConnectedFuse{},
  318. &TunnelProperties{},
  319. 1,
  320. nil,
  321. rpcClientFactory.newMockRPCClient,
  322. 1*time.Second,
  323. nil,
  324. 1*time.Second,
  325. HTTP2,
  326. )
  327. http2Conn.controlStreamHandler = controlStream
  328. ctx, cancel := context.WithCancel(context.Background())
  329. var wg sync.WaitGroup
  330. wg.Add(1)
  331. go func() {
  332. defer wg.Done()
  333. _ = http2Conn.Serve(ctx)
  334. }()
  335. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
  336. require.NoError(t, err)
  337. req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
  338. edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
  339. require.NoError(t, err)
  340. wg.Add(1)
  341. go func() {
  342. defer wg.Done()
  343. // nolint: bodyclose
  344. _, _ = edgeHTTP2Conn.RoundTrip(req)
  345. }()
  346. <-rpcClientFactory.registered
  347. cancel()
  348. <-rpcClientFactory.unregistered
  349. assert.False(t, http2Conn.stoppedGracefully)
  350. wg.Wait()
  351. }
  352. func TestFailRegistration(t *testing.T) {
  353. http2Conn, edgeConn := newTestHTTP2Connection()
  354. rpcClientFactory := mockRPCClientFactory{
  355. shouldFail: errDuplicationConnection,
  356. registered: make(chan struct{}),
  357. unregistered: make(chan struct{}),
  358. }
  359. obs := NewObserver(&log, &log)
  360. controlStream := NewControlStream(
  361. obs,
  362. mockConnectedFuse{},
  363. &TunnelProperties{},
  364. http2Conn.connIndex,
  365. nil,
  366. rpcClientFactory.newMockRPCClient,
  367. 1*time.Second,
  368. nil,
  369. 1*time.Second,
  370. HTTP2,
  371. )
  372. http2Conn.controlStreamHandler = controlStream
  373. ctx, cancel := context.WithCancel(context.Background())
  374. var wg sync.WaitGroup
  375. wg.Add(1)
  376. go func() {
  377. defer wg.Done()
  378. _ = http2Conn.Serve(ctx)
  379. }()
  380. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
  381. require.NoError(t, err)
  382. req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
  383. edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
  384. require.NoError(t, err)
  385. resp, err := edgeHTTP2Conn.RoundTrip(req)
  386. require.NoError(t, err)
  387. defer resp.Body.Close()
  388. require.Equal(t, http.StatusBadGateway, resp.StatusCode)
  389. require.Error(t, http2Conn.controlStreamErr)
  390. cancel()
  391. wg.Wait()
  392. }
  393. func TestGracefulShutdownHTTP2(t *testing.T) {
  394. http2Conn, edgeConn := newTestHTTP2Connection()
  395. rpcClientFactory := mockRPCClientFactory{
  396. registered: make(chan struct{}),
  397. unregistered: make(chan struct{}),
  398. }
  399. events := &eventCollectorSink{}
  400. shutdownC := make(chan struct{})
  401. obs := NewObserver(&log, &log)
  402. obs.RegisterSink(events)
  403. controlStream := NewControlStream(
  404. obs,
  405. mockConnectedFuse{},
  406. &TunnelProperties{},
  407. http2Conn.connIndex,
  408. nil,
  409. rpcClientFactory.newMockRPCClient,
  410. 1*time.Second,
  411. shutdownC,
  412. 1*time.Second,
  413. HTTP2,
  414. )
  415. http2Conn.controlStreamHandler = controlStream
  416. ctx, cancel := context.WithCancel(context.Background())
  417. var wg sync.WaitGroup
  418. wg.Add(1)
  419. go func() {
  420. defer wg.Done()
  421. _ = http2Conn.Serve(ctx)
  422. }()
  423. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080/", nil)
  424. require.NoError(t, err)
  425. req.Header.Set(InternalUpgradeHeader, ControlStreamUpgrade)
  426. edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
  427. require.NoError(t, err)
  428. wg.Add(1)
  429. go func() {
  430. defer wg.Done()
  431. // nolint: bodyclose
  432. _, _ = edgeHTTP2Conn.RoundTrip(req)
  433. }()
  434. select {
  435. case <-rpcClientFactory.registered:
  436. break // ok
  437. case <-time.Tick(time.Second):
  438. t.Fatal("timeout out waiting for registration")
  439. }
  440. // signal graceful shutdown
  441. close(shutdownC)
  442. select {
  443. case <-rpcClientFactory.unregistered:
  444. break // ok
  445. case <-time.Tick(time.Second):
  446. t.Fatal("timeout out waiting for unregistered signal")
  447. }
  448. assert.True(t, controlStream.IsStopped())
  449. cancel()
  450. wg.Wait()
  451. events.assertSawEvent(t, Event{
  452. Index: http2Conn.connIndex,
  453. EventType: Unregistering,
  454. })
  455. }
  456. func TestServeTCP_RateLimited(t *testing.T) {
  457. ctx, cancel := context.WithCancel(context.Background())
  458. http2Conn, edgeConn := newTestHTTP2Connection()
  459. var wg sync.WaitGroup
  460. wg.Add(1)
  461. go func() {
  462. defer wg.Done()
  463. _ = http2Conn.Serve(ctx)
  464. }()
  465. edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
  466. require.NoError(t, err)
  467. req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost:8080", nil)
  468. require.NoError(t, err)
  469. req.Header.Set(InternalTCPProxySrcHeader, "tcp")
  470. req.Header.Set(tracing.TracerContextName, "flow-rate-limited")
  471. resp, err := edgeHTTP2Conn.RoundTrip(req)
  472. require.NoError(t, err)
  473. defer resp.Body.Close()
  474. require.Equal(t, http.StatusBadGateway, resp.StatusCode)
  475. require.Equal(t, responseMetaHeaderCfdFlowRateLimited, resp.Header.Get(ResponseMetaHeader))
  476. cancel()
  477. wg.Wait()
  478. }
  479. func benchmarkServeHTTP(b *testing.B, test testRequest) {
  480. http2Conn, edgeConn := newTestHTTP2Connection()
  481. ctx, cancel := context.WithCancel(context.Background())
  482. var wg sync.WaitGroup
  483. wg.Add(1)
  484. go func() {
  485. defer wg.Done()
  486. _ = http2Conn.Serve(ctx)
  487. }()
  488. endpoint := fmt.Sprintf("http://localhost:8080/%s", test.endpoint)
  489. req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
  490. require.NoError(b, err)
  491. edgeHTTP2Conn, err := testTransport.NewClientConn(edgeConn)
  492. require.NoError(b, err)
  493. b.ResetTimer()
  494. for i := 0; i < b.N; i++ {
  495. b.StartTimer()
  496. resp, err := edgeHTTP2Conn.RoundTrip(req)
  497. b.StopTimer()
  498. require.NoError(b, err)
  499. require.Equal(b, test.expectedStatus, resp.StatusCode)
  500. if test.expectedBody != nil {
  501. respBody, err := io.ReadAll(resp.Body)
  502. require.NoError(b, err)
  503. require.Equal(b, test.expectedBody, respBody)
  504. }
  505. resp.Body.Close()
  506. }
  507. cancel()
  508. wg.Wait()
  509. }
  510. func BenchmarkServeHTTPSimple(b *testing.B) {
  511. test := testRequest{
  512. name: "ok",
  513. endpoint: "ok",
  514. expectedStatus: http.StatusOK,
  515. expectedBody: []byte(http.StatusText(http.StatusOK)),
  516. }
  517. benchmarkServeHTTP(b, test)
  518. }
  519. func BenchmarkServeHTTPLargeFile(b *testing.B) {
  520. test := testRequest{
  521. name: "large_file",
  522. endpoint: "large_file",
  523. expectedStatus: http.StatusOK,
  524. expectedBody: testLargeResp,
  525. }
  526. benchmarkServeHTTP(b, test)
  527. }