proxy_test.go 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020
  1. package proxy
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "flag"
  7. "fmt"
  8. "io"
  9. "net"
  10. "net/http"
  11. "net/http/httptest"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "time"
  16. "github.com/gobwas/ws/wsutil"
  17. gorillaWS "github.com/gorilla/websocket"
  18. "github.com/rs/zerolog"
  19. "github.com/stretchr/testify/assert"
  20. "github.com/stretchr/testify/require"
  21. "github.com/urfave/cli/v2"
  22. "go.uber.org/mock/gomock"
  23. "golang.org/x/sync/errgroup"
  24. "github.com/cloudflare/cloudflared/mocks"
  25. cfdflow "github.com/cloudflare/cloudflared/flow"
  26. "github.com/cloudflare/cloudflared/cfio"
  27. "github.com/cloudflare/cloudflared/config"
  28. "github.com/cloudflare/cloudflared/connection"
  29. "github.com/cloudflare/cloudflared/hello"
  30. "github.com/cloudflare/cloudflared/ingress"
  31. "github.com/cloudflare/cloudflared/logger"
  32. "github.com/cloudflare/cloudflared/tracing"
  33. "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  34. )
  35. var (
  36. testTags = []pogs.Tag{{Name: "Name", Value: "value"}}
  37. noWarpRouting = ingress.WarpRoutingConfig{}
  38. testWarpRouting = ingress.WarpRoutingConfig{
  39. ConnectTimeout: config.CustomDuration{Duration: time.Second},
  40. }
  41. )
  42. type mockHTTPRespWriter struct {
  43. *httptest.ResponseRecorder
  44. }
  45. func newMockHTTPRespWriter() *mockHTTPRespWriter {
  46. return &mockHTTPRespWriter{
  47. httptest.NewRecorder(),
  48. }
  49. }
  50. func (w *mockHTTPRespWriter) WriteResponse() error {
  51. return nil
  52. }
  53. func (w *mockHTTPRespWriter) WriteRespHeaders(status int, header http.Header) error {
  54. w.WriteHeader(status)
  55. for header, val := range header {
  56. w.Header()[header] = val
  57. }
  58. return nil
  59. }
  60. func (w *mockHTTPRespWriter) AddTrailer(trailerName, trailerValue string) {
  61. // do nothing
  62. }
  63. func (w *mockHTTPRespWriter) Read(data []byte) (int, error) {
  64. return 0, fmt.Errorf("mockHTTPRespWriter doesn't implement io.Reader")
  65. }
  66. func (m *mockHTTPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  67. panic("Hijack not implemented")
  68. }
  69. type mockWSRespWriter struct {
  70. *mockHTTPRespWriter
  71. writeNotification chan []byte
  72. reader io.Reader
  73. }
  74. func newMockWSRespWriter(reader io.Reader) *mockWSRespWriter {
  75. return &mockWSRespWriter{
  76. newMockHTTPRespWriter(),
  77. make(chan []byte),
  78. reader,
  79. }
  80. }
  81. func (w *mockWSRespWriter) Write(data []byte) (int, error) {
  82. w.writeNotification <- data
  83. return len(data), nil
  84. }
  85. func (w *mockWSRespWriter) respBody() io.ReadWriter {
  86. data := <-w.writeNotification
  87. return bytes.NewBuffer(data)
  88. }
  89. func (w *mockWSRespWriter) Close() error {
  90. close(w.writeNotification)
  91. return nil
  92. }
  93. func (w *mockWSRespWriter) Read(data []byte) (int, error) {
  94. return w.reader.Read(data)
  95. }
  96. func (w *mockWSRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  97. panic("Hijack not implemented")
  98. }
  99. type mockSSERespWriter struct {
  100. *mockHTTPRespWriter
  101. writeNotification chan []byte
  102. }
  103. func newMockSSERespWriter() *mockSSERespWriter {
  104. return &mockSSERespWriter{
  105. newMockHTTPRespWriter(),
  106. make(chan []byte),
  107. }
  108. }
  109. func (w *mockSSERespWriter) Write(data []byte) (int, error) {
  110. newData := make([]byte, len(data))
  111. copy(newData, data)
  112. w.writeNotification <- newData
  113. return len(data), nil
  114. }
  115. func (w *mockSSERespWriter) WriteString(str string) (int, error) {
  116. return w.Write([]byte(str))
  117. }
  118. func (w *mockSSERespWriter) ReadBytes() []byte {
  119. return <-w.writeNotification
  120. }
  121. func TestProxySingleOrigin(t *testing.T) {
  122. log := zerolog.Nop()
  123. ctx, cancel := context.WithCancel(context.Background())
  124. flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
  125. flagSet.Bool("hello-world", true, "")
  126. cliCtx := cli.NewContext(cli.NewApp(), flagSet, nil)
  127. err := cliCtx.Set("hello-world", "true")
  128. require.NoError(t, err)
  129. ingressRule, err := ingress.ParseIngressFromConfigAndCLI(&config.Configuration{}, cliCtx, &log)
  130. require.NoError(t, err)
  131. require.NoError(t, ingressRule.StartOrigins(&log, ctx.Done()))
  132. proxy := NewOriginProxy(ingressRule, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
  133. t.Run("testProxyHTTP", testProxyHTTP(proxy))
  134. t.Run("testProxyWebsocket", testProxyWebsocket(proxy))
  135. t.Run("testProxySSE", testProxySSE(proxy))
  136. cancel()
  137. }
  138. func testProxyHTTP(proxy connection.OriginProxy) func(t *testing.T) {
  139. return func(t *testing.T) {
  140. responseWriter := newMockHTTPRespWriter()
  141. req, err := http.NewRequest(http.MethodGet, "http://localhost:8080", nil)
  142. require.NoError(t, err)
  143. log := zerolog.Nop()
  144. err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
  145. require.NoError(t, err)
  146. for _, tag := range testTags {
  147. assert.Equal(t, tag.Value, req.Header.Get(TagHeaderNamePrefix+tag.Name))
  148. }
  149. assert.Equal(t, http.StatusOK, responseWriter.Code)
  150. }
  151. }
  152. func testProxyWebsocket(proxy connection.OriginProxy) func(t *testing.T) {
  153. return func(t *testing.T) {
  154. // WSRoute is a websocket echo handler
  155. const testTimeout = 5 * time.Second * 1000
  156. ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
  157. defer cancel()
  158. readPipe, writePipe := io.Pipe()
  159. req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s", hello.WSRoute), readPipe)
  160. req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
  161. req.Header.Set("Connection", "Upgrade")
  162. req.Header.Set("Upgrade", "websocket")
  163. responseWriter := newMockWSRespWriter(nil)
  164. finished := make(chan struct{})
  165. errGroup, ctx := errgroup.WithContext(ctx)
  166. errGroup.Go(func() error {
  167. log := zerolog.Nop()
  168. err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), true)
  169. require.NoError(t, err)
  170. require.Equal(t, http.StatusSwitchingProtocols, responseWriter.Code)
  171. return nil
  172. })
  173. errGroup.Go(func() error {
  174. select {
  175. case <-finished:
  176. case <-ctx.Done():
  177. }
  178. if ctx.Err() == context.DeadlineExceeded {
  179. t.Errorf("Test timed out")
  180. readPipe.Close()
  181. writePipe.Close()
  182. responseWriter.Close()
  183. }
  184. return nil
  185. })
  186. msg := []byte("test websocket")
  187. err = wsutil.WriteClientText(writePipe, msg)
  188. require.NoError(t, err)
  189. // ReadServerText reads next data message from rw, considering that caller represents proxy side.
  190. returnedMsg, err := wsutil.ReadServerText(responseWriter.respBody())
  191. require.NoError(t, err)
  192. require.Equal(t, msg, returnedMsg)
  193. err = wsutil.WriteClientBinary(writePipe, msg)
  194. require.NoError(t, err)
  195. returnedMsg, err = wsutil.ReadServerBinary(responseWriter.respBody())
  196. require.NoError(t, err)
  197. require.Equal(t, msg, returnedMsg)
  198. _ = readPipe.Close()
  199. _ = writePipe.Close()
  200. _ = responseWriter.Close()
  201. close(finished)
  202. _ = errGroup.Wait()
  203. }
  204. }
  205. func testProxySSE(proxy connection.OriginProxy) func(t *testing.T) {
  206. return func(t *testing.T) {
  207. var (
  208. pushCount = 50
  209. pushFreq = time.Millisecond * 10
  210. )
  211. responseWriter := newMockSSERespWriter()
  212. ctx, cancel := context.WithCancel(context.Background())
  213. req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://localhost:8080%s?freq=%s", hello.SSERoute, pushFreq), nil)
  214. require.NoError(t, err)
  215. var wg sync.WaitGroup
  216. wg.Add(1)
  217. go func() {
  218. defer wg.Done()
  219. log := zerolog.Nop()
  220. err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
  221. require.Equal(t, "context canceled", err.Error())
  222. require.Equal(t, http.StatusOK, responseWriter.Code)
  223. }()
  224. for i := 0; i < pushCount; i++ {
  225. line := responseWriter.ReadBytes()
  226. expect := fmt.Sprintf("%d\n\n", i)
  227. require.Equal(t, []byte(expect), line, "Expect to read %v, got %v", expect, line)
  228. }
  229. cancel()
  230. wg.Wait()
  231. }
  232. }
  233. // Regression test to guarantee that we always write the contents downstream even if EOF is reached without
  234. // hitting the delimiter
  235. func TestProxySSEAllData(t *testing.T) {
  236. eyeballReader := io.NopCloser(strings.NewReader("data\r\r"))
  237. responseWriter := newMockSSERespWriter()
  238. // responseWriter uses an unbuffered channel, so we call in a different go-routine
  239. go func() {
  240. _, _ = cfio.Copy(responseWriter, eyeballReader)
  241. }()
  242. result := string(<-responseWriter.writeNotification)
  243. require.Equal(t, "data\r\r", result)
  244. }
  245. func TestProxyMultipleOrigins(t *testing.T) {
  246. api := httptest.NewServer(mockAPI{})
  247. defer api.Close()
  248. unvalidatedIngress := []config.UnvalidatedIngressRule{
  249. {
  250. Hostname: "api.example.com",
  251. Service: api.URL,
  252. },
  253. {
  254. Hostname: "hello.example.com",
  255. Service: "hello-world",
  256. },
  257. {
  258. Hostname: "health.example.com",
  259. Path: "/health",
  260. Service: "http_status:200",
  261. },
  262. {
  263. Hostname: "*",
  264. Service: "http_status:404",
  265. },
  266. }
  267. tests := []MultipleIngressTest{
  268. {
  269. url: "http://api.example.com",
  270. expectedStatus: http.StatusCreated,
  271. expectedBody: []byte("Created"),
  272. },
  273. {
  274. url: fmt.Sprintf("http://hello.example.com%s", hello.HealthRoute),
  275. expectedStatus: http.StatusOK,
  276. expectedBody: []byte("ok"),
  277. },
  278. {
  279. url: "http://health.example.com/health",
  280. expectedStatus: http.StatusOK,
  281. },
  282. {
  283. url: "http://health.example.com/",
  284. expectedStatus: http.StatusNotFound,
  285. },
  286. {
  287. url: "http://not-found.example.com",
  288. expectedStatus: http.StatusNotFound,
  289. },
  290. }
  291. runIngressTestScenarios(t, unvalidatedIngress, tests)
  292. }
  293. type MultipleIngressTest struct {
  294. url string
  295. expectedStatus int
  296. expectedBody []byte
  297. }
  298. func runIngressTestScenarios(t *testing.T, unvalidatedIngress []config.UnvalidatedIngressRule, tests []MultipleIngressTest) {
  299. ingress, err := ingress.ParseIngress(&config.Configuration{
  300. TunnelID: t.Name(),
  301. Ingress: unvalidatedIngress,
  302. })
  303. require.NoError(t, err)
  304. log := zerolog.Nop()
  305. ctx, cancel := context.WithCancel(context.Background())
  306. require.NoError(t, ingress.StartOrigins(&log, ctx.Done()))
  307. proxy := NewOriginProxy(ingress, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
  308. for _, test := range tests {
  309. responseWriter := newMockHTTPRespWriter()
  310. req, err := http.NewRequest(http.MethodGet, test.url, nil)
  311. require.NoError(t, err)
  312. err = proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false)
  313. require.NoError(t, err)
  314. assert.Equal(t, test.expectedStatus, responseWriter.Code)
  315. if test.expectedBody != nil {
  316. assert.Equal(t, test.expectedBody, responseWriter.Body.Bytes())
  317. } else {
  318. assert.Equal(t, 0, responseWriter.Body.Len())
  319. }
  320. }
  321. cancel()
  322. }
  323. type mockAPI struct{}
  324. func (ma mockAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) {
  325. w.WriteHeader(http.StatusCreated)
  326. _, _ = w.Write([]byte("Created"))
  327. }
  328. type errorOriginTransport struct{}
  329. func (errorOriginTransport) RoundTrip(*http.Request) (*http.Response, error) {
  330. return nil, fmt.Errorf("Proxy error")
  331. }
  332. func TestProxyError(t *testing.T) {
  333. ing := ingress.Ingress{
  334. Rules: []ingress.Rule{
  335. {
  336. Hostname: "*",
  337. Path: nil,
  338. Service: ingress.MockOriginHTTPService{
  339. Transport: errorOriginTransport{},
  340. },
  341. },
  342. },
  343. }
  344. log := zerolog.Nop()
  345. proxy := NewOriginProxy(ing, noWarpRouting, testTags, cfdflow.NewLimiter(0), time.Duration(0), &log)
  346. responseWriter := newMockHTTPRespWriter()
  347. req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
  348. require.NoError(t, err)
  349. require.Error(t, proxy.ProxyHTTP(responseWriter, tracing.NewTracedHTTPRequest(req, 0, &log), false))
  350. }
  351. type replayer struct {
  352. sync.RWMutex
  353. rw *bytes.Buffer
  354. }
  355. func (r *replayer) Read(p []byte) (int, error) {
  356. r.RLock()
  357. defer r.RUnlock()
  358. return r.rw.Read(p)
  359. }
  360. func (r *replayer) Write(p []byte) (int, error) {
  361. r.Lock()
  362. defer r.Unlock()
  363. n, err := r.rw.Write(p)
  364. return n, err
  365. }
  366. func (r *replayer) String() string {
  367. r.Lock()
  368. defer r.Unlock()
  369. return r.rw.String()
  370. }
  371. func (r *replayer) Bytes() []byte {
  372. r.Lock()
  373. defer r.Unlock()
  374. return r.rw.Bytes()
  375. }
  376. // TestConnections tests every possible permutation of connection protocols
  377. // proxied by cloudflared.
  378. //
  379. // WS - WS : When a websocket based ingress is configured on the origin and
  380. // the eyeball is also a websocket client streaming data.
  381. // TCP - TCP : When teamnet is enabled and an http or tcp service is running
  382. // on the origin.
  383. // TCP - WS: When teamnet is enabled and a websocket based service is running
  384. // on the origin.
  385. // WS - TCP: When a tcp based ingress is configured on the origin and the
  386. // eyeball sends tcp packets wrapped in websockets. (E.g: cloudflared access).
  387. func TestConnections(t *testing.T) {
  388. logger := logger.Create(nil)
  389. replayer := &replayer{rw: bytes.NewBuffer([]byte{})}
  390. type args struct {
  391. ingressServiceScheme string
  392. originService func(*testing.T, net.Listener)
  393. eyeballResponseWriter connection.ResponseWriter
  394. eyeballRequestBody io.ReadCloser
  395. // Can be set to nil to show warp routing is not enabled.
  396. warpRoutingService *ingress.WarpRoutingService
  397. // eyeball connection type.
  398. connectionType connection.Type
  399. // requestheaders to be sent in the call to proxy.Proxy
  400. requestHeaders http.Header
  401. // flowLimiterResponse is the response of the cfdflow.Limiter#Acquire method call
  402. flowLimiterResponse error
  403. }
  404. type want struct {
  405. message []byte
  406. headers http.Header
  407. err bool
  408. }
  409. var tests = []struct {
  410. name string
  411. args args
  412. want want
  413. }{
  414. {
  415. name: "ws-ws proxy",
  416. args: args{
  417. ingressServiceScheme: "ws://",
  418. originService: runEchoWSService,
  419. eyeballResponseWriter: newWSRespWriter(replayer),
  420. eyeballRequestBody: newWSRequestBody([]byte("test1")),
  421. connectionType: connection.TypeWebsocket,
  422. requestHeaders: map[string][]string{
  423. // Example key from https://tools.ietf.org/html/rfc6455#section-1.2
  424. "Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
  425. "Test-Cloudflared-Echo": {"Echo"},
  426. },
  427. },
  428. want: want{
  429. message: []byte("echo-test1"),
  430. headers: map[string][]string{
  431. "Connection": {"Upgrade"},
  432. "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
  433. "Upgrade": {"websocket"},
  434. "Test-Cloudflared-Echo": {"Echo"},
  435. },
  436. },
  437. },
  438. {
  439. name: "tcp-tcp proxy",
  440. args: args{
  441. ingressServiceScheme: "tcp://",
  442. originService: runEchoTCPService,
  443. eyeballResponseWriter: newTCPRespWriter(replayer),
  444. eyeballRequestBody: newTCPRequestBody([]byte("test2")),
  445. warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
  446. connectionType: connection.TypeTCP,
  447. requestHeaders: map[string][]string{
  448. "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
  449. },
  450. },
  451. want: want{
  452. message: []byte("echo-test2"),
  453. headers: http.Header{},
  454. },
  455. },
  456. {
  457. name: "tcp-ws proxy",
  458. args: args{
  459. ingressServiceScheme: "ws://",
  460. originService: runEchoWSService,
  461. // eyeballResponseWriter gets set after roundtrip dial.
  462. eyeballRequestBody: newPipedWSRequestBody([]byte("test3")),
  463. warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
  464. requestHeaders: map[string][]string{
  465. "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
  466. },
  467. connectionType: connection.TypeTCP,
  468. },
  469. want: want{
  470. message: []byte("echo-test3"),
  471. // We expect no headers here because they are sent back via
  472. // the stream.
  473. headers: http.Header{},
  474. },
  475. },
  476. {
  477. name: "ws-tcp proxy",
  478. args: args{
  479. ingressServiceScheme: "tcp://",
  480. originService: runEchoTCPService,
  481. eyeballResponseWriter: newWSRespWriter(replayer),
  482. eyeballRequestBody: newWSRequestBody([]byte("test4")),
  483. connectionType: connection.TypeWebsocket,
  484. requestHeaders: map[string][]string{
  485. // Example key from https://tools.ietf.org/html/rfc6455#section-1.2
  486. "Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
  487. },
  488. },
  489. want: want{
  490. message: []byte("echo-test4"),
  491. headers: map[string][]string{
  492. "Connection": {"Upgrade"},
  493. "Sec-Websocket-Accept": {"s3pPLMBiTxaQ9kYGzzhZRbK+xOo="},
  494. "Upgrade": {"websocket"},
  495. },
  496. },
  497. },
  498. {
  499. // Send (unexpected) HTTP when origin expects WS (to unwrap for raw TCP)
  500. name: "http-(ws)tcp proxy",
  501. args: args{
  502. ingressServiceScheme: "tcp://",
  503. originService: runEchoTCPService,
  504. eyeballResponseWriter: newMockHTTPRespWriter(),
  505. eyeballRequestBody: http.NoBody,
  506. connectionType: connection.TypeHTTP,
  507. requestHeaders: map[string][]string{
  508. "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
  509. },
  510. },
  511. want: want{
  512. message: []byte{},
  513. headers: map[string][]string{},
  514. },
  515. },
  516. {
  517. name: "tcp-tcp proxy without warpRoutingService enabled",
  518. args: args{
  519. ingressServiceScheme: "tcp://",
  520. originService: runEchoTCPService,
  521. eyeballResponseWriter: newTCPRespWriter(replayer),
  522. eyeballRequestBody: newTCPRequestBody([]byte("test2")),
  523. connectionType: connection.TypeTCP,
  524. requestHeaders: map[string][]string{
  525. "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
  526. },
  527. },
  528. want: want{
  529. message: []byte{},
  530. err: true,
  531. },
  532. },
  533. {
  534. name: "ws-ws proxy when origin is different",
  535. args: args{
  536. ingressServiceScheme: "ws://",
  537. originService: runEchoWSService,
  538. eyeballResponseWriter: newWSRespWriter(replayer),
  539. eyeballRequestBody: newWSRequestBody([]byte("test1")),
  540. connectionType: connection.TypeWebsocket,
  541. requestHeaders: map[string][]string{
  542. // Example key from https://tools.ietf.org/html/rfc6455#section-1.2
  543. "Sec-Websocket-Key": {"dGhlIHNhbXBsZSBub25jZQ=="},
  544. "Origin": {"Different origin"},
  545. },
  546. },
  547. want: want{
  548. message: []byte("Forbidden\n"),
  549. err: false,
  550. headers: map[string][]string{
  551. "Content-Length": {"10"},
  552. "Content-Type": {"text/plain; charset=utf-8"},
  553. "Sec-Websocket-Version": {"13"},
  554. "X-Content-Type-Options": {"nosniff"},
  555. },
  556. },
  557. },
  558. {
  559. name: "tcp-* proxy when origin service has already closed the connection/ is no longer running",
  560. args: args{
  561. ingressServiceScheme: "tcp://",
  562. originService: func(t *testing.T, ln net.Listener) {
  563. // closing the listener created by the test.
  564. ln.Close()
  565. },
  566. eyeballResponseWriter: newTCPRespWriter(replayer),
  567. eyeballRequestBody: newTCPRequestBody([]byte("test2")),
  568. connectionType: connection.TypeTCP,
  569. requestHeaders: map[string][]string{
  570. "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
  571. },
  572. },
  573. want: want{
  574. message: []byte{},
  575. err: true,
  576. },
  577. },
  578. {
  579. name: "tcp-* proxy rate limited flow",
  580. args: args{
  581. ingressServiceScheme: "tcp://",
  582. originService: runEchoTCPService,
  583. eyeballResponseWriter: newTCPRespWriter(replayer),
  584. eyeballRequestBody: newTCPRequestBody([]byte("rate-limited")),
  585. warpRoutingService: ingress.NewWarpRoutingService(testWarpRouting, time.Duration(0)),
  586. connectionType: connection.TypeTCP,
  587. requestHeaders: map[string][]string{
  588. "Cf-Cloudflared-Proxy-Src": {"non-blank-value"},
  589. },
  590. flowLimiterResponse: cfdflow.ErrTooManyActiveFlows,
  591. },
  592. want: want{
  593. message: []byte{},
  594. err: true,
  595. },
  596. },
  597. }
  598. for _, test := range tests {
  599. t.Run(test.name, func(t *testing.T) {
  600. ctx, cancel := context.WithCancel(context.Background())
  601. ln, err := net.Listen("tcp", "127.0.0.1:0")
  602. require.NoError(t, err)
  603. // Starts origin service
  604. test.args.originService(t, ln)
  605. ingressRule := createSingleIngressConfig(t, test.args.ingressServiceScheme+ln.Addr().String())
  606. _ = ingressRule.StartOrigins(logger, ctx.Done())
  607. // Mock flow limiter
  608. ctrl := gomock.NewController(t)
  609. defer ctrl.Finish()
  610. flowLimiter := mocks.NewMockLimiter(ctrl)
  611. flowLimiter.EXPECT().Acquire("tcp").AnyTimes().Return(test.args.flowLimiterResponse)
  612. flowLimiter.EXPECT().Release().AnyTimes()
  613. proxy := NewOriginProxy(ingressRule, testWarpRouting, testTags, flowLimiter, time.Duration(0), logger)
  614. proxy.warpRouting = test.args.warpRoutingService
  615. dest := ln.Addr().String()
  616. req, err := http.NewRequest(
  617. http.MethodGet,
  618. test.args.ingressServiceScheme+ln.Addr().String(),
  619. test.args.eyeballRequestBody,
  620. )
  621. require.NoError(t, err)
  622. req.Header = test.args.requestHeaders
  623. respWriter := test.args.eyeballResponseWriter
  624. if pipedReqBody, ok := test.args.eyeballRequestBody.(*pipedRequestBody); ok {
  625. respWriter = newTCPRespWriter(pipedReqBody.pipedConn)
  626. go func() {
  627. resp := pipedReqBody.roundtrip(test.args.ingressServiceScheme + ln.Addr().String())
  628. _, _ = replayer.Write(resp)
  629. }()
  630. }
  631. if test.args.connectionType == connection.TypeTCP {
  632. rwa := connection.NewHTTPResponseReadWriterAcker(respWriter, respWriter.(http.Flusher), req)
  633. err = proxy.ProxyTCP(ctx, rwa, &connection.TCPRequest{Dest: dest})
  634. } else {
  635. log := zerolog.Nop()
  636. err = proxy.ProxyHTTP(respWriter, tracing.NewTracedHTTPRequest(req, 0, &log), test.args.connectionType == connection.TypeWebsocket)
  637. }
  638. cancel()
  639. require.Equal(t, test.want.err, err != nil)
  640. require.Equal(t, test.want.message, replayer.Bytes())
  641. require.Equal(t, test.want.headers, respWriter.Header())
  642. replayer.rw.Reset()
  643. })
  644. }
  645. }
  646. type requestBody struct {
  647. pw *io.PipeWriter
  648. pr *io.PipeReader
  649. }
  650. func newWSRequestBody(data []byte) *requestBody {
  651. pr, pw := io.Pipe()
  652. go func() {
  653. _ = wsutil.WriteClientBinary(pw, data)
  654. }()
  655. return &requestBody{
  656. pr: pr,
  657. pw: pw,
  658. }
  659. }
  660. func newTCPRequestBody(data []byte) *requestBody {
  661. pr, pw := io.Pipe()
  662. go func() {
  663. _, _ = pw.Write(data)
  664. }()
  665. return &requestBody{
  666. pr: pr,
  667. pw: pw,
  668. }
  669. }
  670. func (r *requestBody) Read(p []byte) (n int, err error) {
  671. return r.pr.Read(p)
  672. }
  673. func (r *requestBody) Close() error {
  674. _ = r.pw.Close()
  675. _ = r.pr.Close()
  676. return nil
  677. }
  678. type pipedRequestBody struct {
  679. dialer gorillaWS.Dialer
  680. pipedConn net.Conn
  681. wsConn net.Conn
  682. messageToWrite []byte
  683. }
  684. func newPipedWSRequestBody(data []byte) *pipedRequestBody {
  685. conn1, conn2 := net.Pipe()
  686. dialer := gorillaWS.Dialer{
  687. NetDial: func(network, addr string) (net.Conn, error) {
  688. return conn2, nil
  689. },
  690. }
  691. return &pipedRequestBody{
  692. dialer: dialer,
  693. pipedConn: conn1,
  694. wsConn: conn2,
  695. messageToWrite: data,
  696. }
  697. }
  698. func (p *pipedRequestBody) roundtrip(addr string) []byte {
  699. header := http.Header{}
  700. conn, resp, err := p.dialer.Dial(addr, header)
  701. if err != nil {
  702. panic(err)
  703. }
  704. defer conn.Close()
  705. defer resp.Body.Close()
  706. if resp.StatusCode != http.StatusSwitchingProtocols {
  707. panic(fmt.Errorf("resp returned status code: %d", resp.StatusCode))
  708. }
  709. err = conn.WriteMessage(gorillaWS.TextMessage, p.messageToWrite)
  710. if err != nil {
  711. panic(err)
  712. }
  713. _, data, err := conn.ReadMessage()
  714. if err != nil {
  715. panic(err)
  716. }
  717. return data
  718. }
  719. func (p *pipedRequestBody) Read(data []byte) (n int, err error) {
  720. return p.pipedConn.Read(data)
  721. }
  722. func (p *pipedRequestBody) Close() error {
  723. return nil
  724. }
  725. type wsRespWriter struct {
  726. w io.Writer
  727. responseHeaders http.Header
  728. code int
  729. }
  730. // newWSRespWriter uses wsutil.WriteClientText to generate websocket frames.
  731. // and wsutil.ReadClientText to translate frames from server to byte data.
  732. // In essence, this acts as a wsClient.
  733. func newWSRespWriter(w io.Writer) *wsRespWriter {
  734. return &wsRespWriter{
  735. w: w,
  736. }
  737. }
  738. // Write is written to by ingress.Stream and serves as the output to the client.
  739. func (w *wsRespWriter) Write(p []byte) (int, error) {
  740. returnedMsg, err := wsutil.ReadServerBinary(bytes.NewBuffer(p))
  741. if err != nil {
  742. // The data was not returned by a websocket connection.
  743. if err != io.ErrUnexpectedEOF {
  744. return w.w.Write(p)
  745. }
  746. }
  747. return w.w.Write(returnedMsg)
  748. }
  749. func (w *wsRespWriter) WriteRespHeaders(status int, header http.Header) error {
  750. w.responseHeaders = header
  751. w.code = status
  752. return nil
  753. }
  754. func (w *wsRespWriter) Flush() {}
  755. func (w *wsRespWriter) AddTrailer(trailerName, trailerValue string) {
  756. // do nothing
  757. }
  758. // respHeaders is a test function to read respHeaders
  759. func (w *wsRespWriter) Header() http.Header {
  760. // Removing indeterminstic header because it cannot be asserted.
  761. w.responseHeaders.Del("Date")
  762. return w.responseHeaders
  763. }
  764. func (w *wsRespWriter) WriteHeader(status int) {
  765. // unused
  766. }
  767. func (m *wsRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  768. panic("Hijack not implemented")
  769. }
  770. type mockTCPRespWriter struct {
  771. w io.Writer
  772. responseHeaders http.Header
  773. code int
  774. }
  775. func newTCPRespWriter(w io.Writer) *mockTCPRespWriter {
  776. return &mockTCPRespWriter{
  777. w: w,
  778. }
  779. }
  780. func (m *mockTCPRespWriter) Read(p []byte) (n int, err error) {
  781. return len(p), nil
  782. }
  783. func (m *mockTCPRespWriter) Write(p []byte) (n int, err error) {
  784. return m.w.Write(p)
  785. }
  786. func (m *mockTCPRespWriter) Flush() {}
  787. func (m *mockTCPRespWriter) AddTrailer(trailerName, trailerValue string) {
  788. // do nothing
  789. }
  790. func (m *mockTCPRespWriter) WriteRespHeaders(status int, header http.Header) error {
  791. m.responseHeaders = header
  792. m.code = status
  793. return nil
  794. }
  795. // respHeaders is a test function to read respHeaders
  796. func (m *mockTCPRespWriter) Header() http.Header {
  797. return m.responseHeaders
  798. }
  799. func (m *mockTCPRespWriter) WriteHeader(status int) {
  800. // do nothing
  801. }
  802. func (m *mockTCPRespWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  803. panic("Hijack not implemented")
  804. }
  805. func createSingleIngressConfig(t *testing.T, service string) ingress.Ingress {
  806. ingressConfig := &config.Configuration{
  807. Ingress: []config.UnvalidatedIngressRule{
  808. {
  809. Hostname: "*",
  810. Service: service,
  811. },
  812. },
  813. }
  814. ingressRule, err := ingress.ParseIngress(ingressConfig)
  815. require.NoError(t, err)
  816. return ingressRule
  817. }
  818. func runEchoTCPService(t *testing.T, l net.Listener) {
  819. go func() {
  820. for {
  821. conn, err := l.Accept()
  822. if err != nil {
  823. panic(err)
  824. }
  825. defer conn.Close()
  826. for {
  827. buf := make([]byte, 1024)
  828. size, err := conn.Read(buf)
  829. if err == io.EOF {
  830. return
  831. }
  832. data := []byte("echo-")
  833. data = append(data, buf[:size]...)
  834. _, err = conn.Write(data)
  835. if err != nil {
  836. t.Log(err)
  837. }
  838. return
  839. }
  840. }
  841. }()
  842. }
  843. func runEchoWSService(t *testing.T, l net.Listener) {
  844. var upgrader = gorillaWS.Upgrader{
  845. ReadBufferSize: 10,
  846. WriteBufferSize: 10,
  847. }
  848. var ws = func(w http.ResponseWriter, r *http.Request) {
  849. header := make(http.Header)
  850. for k, vs := range r.Header {
  851. if k == "Test-Cloudflared-Echo" {
  852. header[k] = vs
  853. }
  854. }
  855. conn, err := upgrader.Upgrade(w, r, header)
  856. if err != nil {
  857. t.Log(err)
  858. return
  859. }
  860. defer conn.Close()
  861. for {
  862. messageType, p, err := conn.ReadMessage()
  863. if err != nil {
  864. return
  865. }
  866. data := []byte("echo-")
  867. data = append(data, p...)
  868. if err := conn.WriteMessage(messageType, data); err != nil {
  869. return
  870. }
  871. }
  872. }
  873. // nolint: gosec
  874. server := http.Server{
  875. Handler: http.HandlerFunc(ws),
  876. }
  877. go func() {
  878. err := server.Serve(l)
  879. if err != nil {
  880. panic(err)
  881. }
  882. }()
  883. }