origin_connection_test.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. package ingress
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/tls"
  6. "fmt"
  7. "io/ioutil"
  8. "net"
  9. "net/http"
  10. "net/http/httptest"
  11. "net/url"
  12. "testing"
  13. "time"
  14. "github.com/cloudflare/cloudflared/logger"
  15. "github.com/cloudflare/cloudflared/socks"
  16. "github.com/gobwas/ws/wsutil"
  17. gorillaWS "github.com/gorilla/websocket"
  18. "github.com/stretchr/testify/assert"
  19. "github.com/stretchr/testify/require"
  20. "golang.org/x/net/proxy"
  21. "golang.org/x/sync/errgroup"
  22. )
  23. const (
  24. testStreamTimeout = time.Second * 3
  25. echoHeaderName = "Test-Cloudflared-Echo"
  26. )
  27. var (
  28. testLogger = logger.Create(nil)
  29. testMessage = []byte("TestStreamOriginConnection")
  30. testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
  31. )
  32. func TestStreamTCPConnection(t *testing.T) {
  33. cfdConn, originConn := net.Pipe()
  34. tcpConn := tcpConnection{
  35. conn: cfdConn,
  36. }
  37. eyeballConn, edgeConn := net.Pipe()
  38. ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
  39. defer cancel()
  40. errGroup, ctx := errgroup.WithContext(ctx)
  41. errGroup.Go(func() error {
  42. _, err := eyeballConn.Write(testMessage)
  43. readBuffer := make([]byte, len(testResponse))
  44. _, err = eyeballConn.Read(readBuffer)
  45. require.NoError(t, err)
  46. require.Equal(t, testResponse, readBuffer)
  47. return nil
  48. })
  49. errGroup.Go(func() error {
  50. echoTCPOrigin(t, originConn)
  51. originConn.Close()
  52. return nil
  53. })
  54. tcpConn.Stream(ctx, edgeConn, testLogger)
  55. require.NoError(t, errGroup.Wait())
  56. }
  57. func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
  58. cfdConn, originConn := net.Pipe()
  59. tcpOverWSConn := tcpOverWSConnection{
  60. conn: cfdConn,
  61. streamHandler: DefaultStreamHandler,
  62. }
  63. eyeballConn, edgeConn := net.Pipe()
  64. ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
  65. defer cancel()
  66. errGroup, ctx := errgroup.WithContext(ctx)
  67. errGroup.Go(func() error {
  68. echoWSEyeball(t, eyeballConn)
  69. return nil
  70. })
  71. errGroup.Go(func() error {
  72. echoTCPOrigin(t, originConn)
  73. originConn.Close()
  74. return nil
  75. })
  76. tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
  77. require.NoError(t, errGroup.Wait())
  78. }
  79. // TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
  80. // Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
  81. // wraps SOCKS5 traffic in websocket
  82. // Origin side runs a tcpOverWSConnection with socks.StreamHandler
  83. func TestSocksStreamWSOverTCPConnection(t *testing.T) {
  84. var (
  85. sendMessage = t.Name()
  86. echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage)
  87. echoMessage = fmt.Sprintf("echo-%s", sendMessage)
  88. echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue)
  89. )
  90. statusCodes := []int{
  91. http.StatusOK,
  92. http.StatusTemporaryRedirect,
  93. http.StatusBadRequest,
  94. http.StatusInternalServerError,
  95. }
  96. for _, status := range statusCodes {
  97. handler := func(w http.ResponseWriter, r *http.Request) {
  98. body, err := ioutil.ReadAll(r.Body)
  99. require.NoError(t, err)
  100. require.Equal(t, []byte(sendMessage), body)
  101. require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
  102. w.Header().Set(echoHeaderName, echoHeaderReturnValue)
  103. w.WriteHeader(status)
  104. w.Write([]byte(echoMessage))
  105. }
  106. origin := httptest.NewServer(http.HandlerFunc(handler))
  107. defer origin.Close()
  108. originURL, err := url.Parse(origin.URL)
  109. require.NoError(t, err)
  110. originConn, err := net.Dial("tcp", originURL.Host)
  111. require.NoError(t, err)
  112. tcpOverWSConn := tcpOverWSConnection{
  113. conn: originConn,
  114. streamHandler: socks.StreamHandler,
  115. }
  116. wsForwarderOutConn, edgeConn := net.Pipe()
  117. ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
  118. defer cancel()
  119. errGroup, ctx := errgroup.WithContext(ctx)
  120. errGroup.Go(func() error {
  121. tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
  122. return nil
  123. })
  124. wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0")
  125. require.NoError(t, err)
  126. errGroup.Go(func() error {
  127. wsForwarderInConn, err := wsForwarderListener.Accept()
  128. require.NoError(t, err)
  129. defer wsForwarderInConn.Close()
  130. Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
  131. return nil
  132. })
  133. eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct)
  134. require.NoError(t, err)
  135. transport := &http.Transport{
  136. Dial: eyeballDialer.Dial,
  137. }
  138. // Request URL doesn't matter because the transport is using eyeballDialer to connectq
  139. req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
  140. assert.NoError(t, err)
  141. req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
  142. resp, err := transport.RoundTrip(req)
  143. assert.NoError(t, err)
  144. assert.Equal(t, status, resp.StatusCode)
  145. require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
  146. body, err := ioutil.ReadAll(resp.Body)
  147. require.NoError(t, err)
  148. require.Equal(t, []byte(echoMessage), body)
  149. wsForwarderOutConn.Close()
  150. edgeConn.Close()
  151. tcpOverWSConn.Close()
  152. require.NoError(t, errGroup.Wait())
  153. }
  154. }
  155. func TestStreamWSConnection(t *testing.T) {
  156. eyeballConn, edgeConn := net.Pipe()
  157. origin := echoWSOrigin(t)
  158. defer origin.Close()
  159. req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
  160. require.NoError(t, err)
  161. req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
  162. clientTLSConfig := &tls.Config{
  163. InsecureSkipVerify: true,
  164. }
  165. wsConn, resp, err := newWSConnection(clientTLSConfig, req)
  166. require.NoError(t, err)
  167. require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
  168. require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
  169. require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
  170. require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
  171. ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
  172. defer cancel()
  173. errGroup, ctx := errgroup.WithContext(ctx)
  174. errGroup.Go(func() error {
  175. echoWSEyeball(t, eyeballConn)
  176. return nil
  177. })
  178. wsConn.Stream(ctx, edgeConn, testLogger)
  179. require.NoError(t, errGroup.Wait())
  180. }
  181. type wsEyeball struct {
  182. conn net.Conn
  183. }
  184. func (wse *wsEyeball) Read(p []byte) (int, error) {
  185. data, err := wsutil.ReadServerBinary(wse.conn)
  186. if err != nil {
  187. return 0, err
  188. }
  189. return copy(p, data), nil
  190. }
  191. func (wse *wsEyeball) Write(p []byte) (int, error) {
  192. err := wsutil.WriteClientBinary(wse.conn, p)
  193. return len(p), err
  194. }
  195. func echoWSEyeball(t *testing.T, conn net.Conn) {
  196. require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
  197. readMsg, err := wsutil.ReadServerBinary(conn)
  198. require.NoError(t, err)
  199. require.Equal(t, testResponse, readMsg)
  200. require.NoError(t, conn.Close())
  201. }
  202. func echoWSOrigin(t *testing.T) *httptest.Server {
  203. var upgrader = gorillaWS.Upgrader{
  204. ReadBufferSize: 10,
  205. WriteBufferSize: 10,
  206. }
  207. ws := func(w http.ResponseWriter, r *http.Request) {
  208. header := make(http.Header)
  209. for k, vs := range r.Header {
  210. if k == "Test-Cloudflared-Echo" {
  211. header[k] = vs
  212. }
  213. }
  214. conn, err := upgrader.Upgrade(w, r, header)
  215. require.NoError(t, err)
  216. defer conn.Close()
  217. for {
  218. messageType, p, err := conn.ReadMessage()
  219. if err != nil {
  220. return
  221. }
  222. require.Equal(t, testMessage, p)
  223. if err := conn.WriteMessage(messageType, testResponse); err != nil {
  224. return
  225. }
  226. }
  227. }
  228. // NewTLSServer starts the server in another thread
  229. return httptest.NewTLSServer(http.HandlerFunc(ws))
  230. }
  231. func echoTCPOrigin(t *testing.T, conn net.Conn) {
  232. readBuffer := make([]byte, len(testMessage))
  233. _, err := conn.Read(readBuffer)
  234. assert.NoError(t, err)
  235. assert.Equal(t, testMessage, readBuffer)
  236. _, err = conn.Write(testResponse)
  237. assert.NoError(t, err)
  238. }