connection_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. package connection
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "fmt"
  6. "io"
  7. "math/big"
  8. "net/http"
  9. "testing"
  10. "time"
  11. pkgerrors "github.com/pkg/errors"
  12. "github.com/rs/zerolog"
  13. "github.com/stretchr/testify/require"
  14. cfdflow "github.com/cloudflare/cloudflared/flow"
  15. "github.com/cloudflare/cloudflared/stream"
  16. "github.com/cloudflare/cloudflared/tracing"
  17. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  18. "github.com/cloudflare/cloudflared/websocket"
  19. )
  20. const (
  21. largeFileSize = 2 * 1024 * 1024
  22. testGracePeriod = time.Millisecond * 100
  23. )
  24. var (
  25. testOrchestrator = &mockOrchestrator{
  26. originProxy: &mockOriginProxy{},
  27. }
  28. log = zerolog.Nop()
  29. testLargeResp = make([]byte, largeFileSize)
  30. )
  31. var _ ReadWriteAcker = (*HTTPResponseReadWriteAcker)(nil)
  32. type testRequest struct {
  33. name string
  34. endpoint string
  35. expectedStatus int
  36. expectedBody []byte
  37. isProxyError bool
  38. }
  39. type mockOrchestrator struct {
  40. originProxy OriginProxy
  41. }
  42. func (mcr *mockOrchestrator) GetConfigJSON() ([]byte, error) {
  43. return nil, fmt.Errorf("not implemented")
  44. }
  45. func (*mockOrchestrator) UpdateConfig(version int32, config []byte) *tunnelpogs.UpdateConfigurationResponse {
  46. return &tunnelpogs.UpdateConfigurationResponse{
  47. LastAppliedVersion: version,
  48. }
  49. }
  50. func (mcr *mockOrchestrator) GetOriginProxy() (OriginProxy, error) {
  51. return mcr.originProxy, nil
  52. }
  53. func (mcr *mockOrchestrator) WarpRoutingEnabled() (enabled bool) {
  54. return true
  55. }
  56. type mockOriginProxy struct{}
  57. func (moc *mockOriginProxy) ProxyHTTP(
  58. w ResponseWriter,
  59. tr *tracing.TracedHTTPRequest,
  60. isWebsocket bool,
  61. ) error {
  62. req := tr.Request
  63. if isWebsocket {
  64. switch req.URL.Path {
  65. case "/ws/echo":
  66. return wsEchoEndpoint(w, req)
  67. case "/ws/flaky":
  68. return wsFlakyEndpoint(w, req)
  69. default:
  70. originRespEndpoint(w, http.StatusNotFound, []byte("ws endpoint not found"))
  71. return fmt.Errorf("unknown websocket endpoint %s", req.URL.Path)
  72. }
  73. }
  74. switch req.URL.Path {
  75. case "/ok":
  76. originRespEndpoint(w, http.StatusOK, []byte(http.StatusText(http.StatusOK)))
  77. case "/large_file":
  78. originRespEndpoint(w, http.StatusOK, testLargeResp)
  79. case "/400":
  80. originRespEndpoint(w, http.StatusBadRequest, []byte(http.StatusText(http.StatusBadRequest)))
  81. case "/500":
  82. originRespEndpoint(w, http.StatusInternalServerError, []byte(http.StatusText(http.StatusInternalServerError)))
  83. case "/error":
  84. return fmt.Errorf("Failed to proxy to origin")
  85. default:
  86. originRespEndpoint(w, http.StatusNotFound, []byte("page not found"))
  87. }
  88. return nil
  89. }
  90. func (moc *mockOriginProxy) ProxyTCP(
  91. ctx context.Context,
  92. rwa ReadWriteAcker,
  93. r *TCPRequest,
  94. ) error {
  95. if r.CfTraceID == "flow-rate-limited" {
  96. return pkgerrors.Wrap(cfdflow.ErrTooManyActiveFlows, "tcp flow rate limited")
  97. }
  98. return nil
  99. }
  100. type echoPipe struct {
  101. reader *io.PipeReader
  102. writer *io.PipeWriter
  103. }
  104. func (ep *echoPipe) Read(p []byte) (int, error) {
  105. return ep.reader.Read(p)
  106. }
  107. func (ep *echoPipe) Write(p []byte) (int, error) {
  108. return ep.writer.Write(p)
  109. }
  110. // A mock origin that echos data by streaming like a tcpOverWSConnection
  111. // https://github.com/cloudflare/cloudflared/blob/master/ingress/origin_connection.go
  112. func wsEchoEndpoint(w ResponseWriter, r *http.Request) error {
  113. resp := &http.Response{
  114. StatusCode: http.StatusSwitchingProtocols,
  115. }
  116. if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
  117. return err
  118. }
  119. wsCtx, cancel := context.WithCancel(r.Context())
  120. readPipe, writePipe := io.Pipe()
  121. wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
  122. go func() {
  123. select {
  124. case <-wsCtx.Done():
  125. case <-r.Context().Done():
  126. }
  127. readPipe.Close()
  128. writePipe.Close()
  129. }()
  130. originConn := &echoPipe{reader: readPipe, writer: writePipe}
  131. stream.Pipe(wsConn, originConn, &log)
  132. cancel()
  133. wsConn.Close()
  134. return nil
  135. }
  136. type flakyConn struct {
  137. closeAt time.Time
  138. }
  139. func (fc *flakyConn) Read(p []byte) (int, error) {
  140. if time.Now().After(fc.closeAt) {
  141. return 0, io.EOF
  142. }
  143. n := copy(p, "Read from flaky connection")
  144. return n, nil
  145. }
  146. func (fc *flakyConn) Write(p []byte) (int, error) {
  147. if time.Now().After(fc.closeAt) {
  148. return 0, fmt.Errorf("flaky connection closed")
  149. }
  150. return len(p), nil
  151. }
  152. func wsFlakyEndpoint(w ResponseWriter, r *http.Request) error {
  153. resp := &http.Response{
  154. StatusCode: http.StatusSwitchingProtocols,
  155. }
  156. if err := w.WriteRespHeaders(resp.StatusCode, resp.Header); err != nil {
  157. return err
  158. }
  159. wsCtx, cancel := context.WithCancel(r.Context())
  160. wsConn := websocket.NewConn(wsCtx, NewHTTPResponseReadWriterAcker(w, w.(http.Flusher), r), &log)
  161. rInt, _ := rand.Int(rand.Reader, big.NewInt(50))
  162. closedAfter := time.Millisecond * time.Duration(rInt.Int64())
  163. originConn := &flakyConn{closeAt: time.Now().Add(closedAfter)}
  164. stream.Pipe(wsConn, originConn, &log)
  165. cancel()
  166. wsConn.Close()
  167. return nil
  168. }
  169. func originRespEndpoint(w ResponseWriter, status int, data []byte) {
  170. resp := &http.Response{
  171. StatusCode: status,
  172. }
  173. _ = w.WriteRespHeaders(resp.StatusCode, resp.Header)
  174. _, _ = w.Write(data)
  175. }
  176. type mockConnectedFuse struct{}
  177. func (mcf mockConnectedFuse) Connected() {}
  178. func (mcf mockConnectedFuse) IsConnected() bool {
  179. return true
  180. }
  181. func TestShouldFlushHeaders(t *testing.T) {
  182. tests := []struct {
  183. headers map[string]string
  184. shouldFlush bool
  185. }{
  186. {
  187. headers: map[string]string{contentTypeHeader: "application/json", contentLengthHeader: "1"},
  188. shouldFlush: false,
  189. },
  190. {
  191. headers: map[string]string{contentTypeHeader: "text/html", contentLengthHeader: "1"},
  192. shouldFlush: false,
  193. },
  194. {
  195. headers: map[string]string{contentTypeHeader: "text/event-stream", contentLengthHeader: "1"},
  196. shouldFlush: true,
  197. },
  198. {
  199. headers: map[string]string{contentTypeHeader: "application/grpc", contentLengthHeader: "1"},
  200. shouldFlush: true,
  201. },
  202. {
  203. headers: map[string]string{contentTypeHeader: "application/x-ndjson", contentLengthHeader: "1"},
  204. shouldFlush: true,
  205. },
  206. {
  207. headers: map[string]string{contentTypeHeader: "application/json"},
  208. shouldFlush: true,
  209. },
  210. {
  211. headers: map[string]string{contentTypeHeader: "application/json", contentLengthHeader: "-1", transferEncodingHeader: "chunked"},
  212. shouldFlush: true,
  213. },
  214. }
  215. for _, test := range tests {
  216. headers := http.Header{}
  217. for k, v := range test.headers {
  218. headers.Add(k, v)
  219. }
  220. require.Equal(t, test.shouldFlush, shouldFlush(headers))
  221. }
  222. }