123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- package ingress
- import (
- "bytes"
- "context"
- "crypto/tls"
- "fmt"
- "io/ioutil"
- "net"
- "net/http"
- "net/http/httptest"
- "net/url"
- "testing"
- "time"
- "github.com/cloudflare/cloudflared/logger"
- "github.com/cloudflare/cloudflared/socks"
- "github.com/gobwas/ws/wsutil"
- gorillaWS "github.com/gorilla/websocket"
- "github.com/stretchr/testify/assert"
- "github.com/stretchr/testify/require"
- "golang.org/x/net/proxy"
- "golang.org/x/sync/errgroup"
- )
- const (
- testStreamTimeout = time.Second * 3
- echoHeaderName = "Test-Cloudflared-Echo"
- )
- var (
- testLogger = logger.Create(nil)
- testMessage = []byte("TestStreamOriginConnection")
- testResponse = []byte(fmt.Sprintf("echo-%s", testMessage))
- )
- func TestStreamTCPConnection(t *testing.T) {
- cfdConn, originConn := net.Pipe()
- tcpConn := tcpConnection{
- conn: cfdConn,
- }
- eyeballConn, edgeConn := net.Pipe()
- ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
- defer cancel()
- errGroup, ctx := errgroup.WithContext(ctx)
- errGroup.Go(func() error {
- _, err := eyeballConn.Write(testMessage)
- readBuffer := make([]byte, len(testResponse))
- _, err = eyeballConn.Read(readBuffer)
- require.NoError(t, err)
- require.Equal(t, testResponse, readBuffer)
- return nil
- })
- errGroup.Go(func() error {
- echoTCPOrigin(t, originConn)
- originConn.Close()
- return nil
- })
- tcpConn.Stream(ctx, edgeConn, testLogger)
- require.NoError(t, errGroup.Wait())
- }
- func TestDefaultStreamWSOverTCPConnection(t *testing.T) {
- cfdConn, originConn := net.Pipe()
- tcpOverWSConn := tcpOverWSConnection{
- conn: cfdConn,
- streamHandler: DefaultStreamHandler,
- }
- eyeballConn, edgeConn := net.Pipe()
- ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
- defer cancel()
- errGroup, ctx := errgroup.WithContext(ctx)
- errGroup.Go(func() error {
- echoWSEyeball(t, eyeballConn)
- return nil
- })
- errGroup.Go(func() error {
- echoTCPOrigin(t, originConn)
- originConn.Close()
- return nil
- })
- tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
- require.NoError(t, errGroup.Wait())
- }
- // TestSocksStreamWSOverTCPConnection simulates proxying in socks mode.
- // Eyeball side runs cloudflared accesss tcp with --url flag to start a websocket forwarder which
- // wraps SOCKS5 traffic in websocket
- // Origin side runs a tcpOverWSConnection with socks.StreamHandler
- func TestSocksStreamWSOverTCPConnection(t *testing.T) {
- var (
- sendMessage = t.Name()
- echoHeaderIncomingValue = fmt.Sprintf("header-%s", sendMessage)
- echoMessage = fmt.Sprintf("echo-%s", sendMessage)
- echoHeaderReturnValue = fmt.Sprintf("echo-%s", echoHeaderIncomingValue)
- )
- statusCodes := []int{
- http.StatusOK,
- http.StatusTemporaryRedirect,
- http.StatusBadRequest,
- http.StatusInternalServerError,
- }
- for _, status := range statusCodes {
- handler := func(w http.ResponseWriter, r *http.Request) {
- body, err := ioutil.ReadAll(r.Body)
- require.NoError(t, err)
- require.Equal(t, []byte(sendMessage), body)
- require.Equal(t, echoHeaderIncomingValue, r.Header.Get(echoHeaderName))
- w.Header().Set(echoHeaderName, echoHeaderReturnValue)
- w.WriteHeader(status)
- w.Write([]byte(echoMessage))
- }
- origin := httptest.NewServer(http.HandlerFunc(handler))
- defer origin.Close()
- originURL, err := url.Parse(origin.URL)
- require.NoError(t, err)
- originConn, err := net.Dial("tcp", originURL.Host)
- require.NoError(t, err)
- tcpOverWSConn := tcpOverWSConnection{
- conn: originConn,
- streamHandler: socks.StreamHandler,
- }
- wsForwarderOutConn, edgeConn := net.Pipe()
- ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
- defer cancel()
- errGroup, ctx := errgroup.WithContext(ctx)
- errGroup.Go(func() error {
- tcpOverWSConn.Stream(ctx, edgeConn, testLogger)
- return nil
- })
- wsForwarderListener, err := net.Listen("tcp", "127.0.0.1:0")
- require.NoError(t, err)
- errGroup.Go(func() error {
- wsForwarderInConn, err := wsForwarderListener.Accept()
- require.NoError(t, err)
- defer wsForwarderInConn.Close()
- Stream(wsForwarderInConn, &wsEyeball{wsForwarderOutConn}, testLogger)
- return nil
- })
- eyeballDialer, err := proxy.SOCKS5("tcp", wsForwarderListener.Addr().String(), nil, proxy.Direct)
- require.NoError(t, err)
- transport := &http.Transport{
- Dial: eyeballDialer.Dial,
- }
- // Request URL doesn't matter because the transport is using eyeballDialer to connectq
- req, err := http.NewRequestWithContext(ctx, "GET", "http://test-socks-stream.com", bytes.NewBuffer([]byte(sendMessage)))
- assert.NoError(t, err)
- req.Header.Set(echoHeaderName, echoHeaderIncomingValue)
- resp, err := transport.RoundTrip(req)
- assert.NoError(t, err)
- assert.Equal(t, status, resp.StatusCode)
- require.Equal(t, echoHeaderReturnValue, resp.Header.Get(echoHeaderName))
- body, err := ioutil.ReadAll(resp.Body)
- require.NoError(t, err)
- require.Equal(t, []byte(echoMessage), body)
- wsForwarderOutConn.Close()
- edgeConn.Close()
- tcpOverWSConn.Close()
- require.NoError(t, errGroup.Wait())
- }
- }
- func TestStreamWSConnection(t *testing.T) {
- eyeballConn, edgeConn := net.Pipe()
- origin := echoWSOrigin(t)
- defer origin.Close()
- req, err := http.NewRequest(http.MethodGet, origin.URL, nil)
- require.NoError(t, err)
- req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
- clientTLSConfig := &tls.Config{
- InsecureSkipVerify: true,
- }
- wsConn, resp, err := newWSConnection(clientTLSConfig, req)
- require.NoError(t, err)
- require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
- require.Equal(t, "Upgrade", resp.Header.Get("Connection"))
- require.Equal(t, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", resp.Header.Get("Sec-Websocket-Accept"))
- require.Equal(t, "websocket", resp.Header.Get("Upgrade"))
- ctx, cancel := context.WithTimeout(context.Background(), testStreamTimeout)
- defer cancel()
- errGroup, ctx := errgroup.WithContext(ctx)
- errGroup.Go(func() error {
- echoWSEyeball(t, eyeballConn)
- return nil
- })
- wsConn.Stream(ctx, edgeConn, testLogger)
- require.NoError(t, errGroup.Wait())
- }
- type wsEyeball struct {
- conn net.Conn
- }
- func (wse *wsEyeball) Read(p []byte) (int, error) {
- data, err := wsutil.ReadServerBinary(wse.conn)
- if err != nil {
- return 0, err
- }
- return copy(p, data), nil
- }
- func (wse *wsEyeball) Write(p []byte) (int, error) {
- err := wsutil.WriteClientBinary(wse.conn, p)
- return len(p), err
- }
- func echoWSEyeball(t *testing.T, conn net.Conn) {
- require.NoError(t, wsutil.WriteClientBinary(conn, testMessage))
- readMsg, err := wsutil.ReadServerBinary(conn)
- require.NoError(t, err)
- require.Equal(t, testResponse, readMsg)
- require.NoError(t, conn.Close())
- }
- func echoWSOrigin(t *testing.T) *httptest.Server {
- var upgrader = gorillaWS.Upgrader{
- ReadBufferSize: 10,
- WriteBufferSize: 10,
- }
- ws := func(w http.ResponseWriter, r *http.Request) {
- header := make(http.Header)
- for k, vs := range r.Header {
- if k == "Test-Cloudflared-Echo" {
- header[k] = vs
- }
- }
- conn, err := upgrader.Upgrade(w, r, header)
- require.NoError(t, err)
- defer conn.Close()
- for {
- messageType, p, err := conn.ReadMessage()
- if err != nil {
- return
- }
- require.Equal(t, testMessage, p)
- if err := conn.WriteMessage(messageType, testResponse); err != nil {
- return
- }
- }
- }
- // NewTLSServer starts the server in another thread
- return httptest.NewTLSServer(http.HandlerFunc(ws))
- }
- func echoTCPOrigin(t *testing.T, conn net.Conn) {
- readBuffer := make([]byte, len(testMessage))
- _, err := conn.Read(readBuffer)
- assert.NoError(t, err)
- assert.Equal(t, testMessage, readBuffer)
- _, err = conn.Write(testResponse)
- assert.NoError(t, err)
- }
|