123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- package dbconnect
- import (
- "context"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "net/http"
- "net/http/httptest"
- "strings"
- "testing"
- "github.com/gorilla/mux"
- "github.com/stretchr/testify/assert"
- )
- func TestNewInsecureProxy(t *testing.T) {
- origins := []string{
- "",
- ":/",
- "http://localhost",
- "tcp://localhost:9000?debug=true",
- "mongodb://127.0.0.1",
- }
- for _, origin := range origins {
- proxy, err := NewInsecureProxy(context.Background(), origin)
- assert.Error(t, err)
- assert.Empty(t, proxy)
- }
- }
- func TestProxyIsAllowed(t *testing.T) {
- proxy := helperNewProxy(t)
- req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
- assert.True(t, proxy.IsAllowed(req))
- proxy = helperNewProxy(t, true)
- req.Header.Set("Cf-access-jwt-assertion", "xxx")
- assert.False(t, proxy.IsAllowed(req))
- }
- func TestProxyStart(t *testing.T) {
- proxy := helperNewProxy(t)
- ctx := context.Background()
- listenerC := make(chan net.Listener)
- err := proxy.Start(ctx, "1.1.1.1:", listenerC)
- assert.Error(t, err)
- err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
- assert.Error(t, err)
- ctx, cancel := context.WithTimeout(ctx, 0)
- defer cancel()
- err = proxy.Start(ctx, "127.0.0.1:", listenerC)
- assert.IsType(t, http.ErrServerClosed, err)
- }
- func TestProxyHTTPRouter(t *testing.T) {
- proxy := helperNewProxy(t)
- router := proxy.httpRouter()
- tests := []struct {
- path string
- method string
- valid bool
- }{
- {"", "GET", false},
- {"/", "GET", false},
- {"/ping", "GET", true},
- {"/ping", "HEAD", true},
- {"/ping", "POST", false},
- {"/submit", "POST", true},
- {"/submit", "GET", false},
- {"/submit/extra", "POST", false},
- }
- for _, test := range tests {
- match := &mux.RouteMatch{}
- ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
- assert.True(t, ok == test.valid, test.path)
- }
- }
- func TestProxyHTTPPing(t *testing.T) {
- proxy := helperNewProxy(t)
- server := httptest.NewServer(proxy.httpPing())
- defer server.Close()
- client := server.Client()
- res, err := client.Get(server.URL)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, res.StatusCode)
- assert.Equal(t, int64(2), res.ContentLength)
- res, err = client.Head(server.URL)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusOK, res.StatusCode)
- assert.Equal(t, int64(-1), res.ContentLength)
- }
- func TestProxyHTTPSubmit(t *testing.T) {
- proxy := helperNewProxy(t)
- server := httptest.NewServer(proxy.httpSubmit())
- defer server.Close()
- client := server.Client()
- tests := []struct {
- input string
- status int
- output string
- }{
- {"", http.StatusBadRequest, "request body cannot be empty"},
- {"{}", http.StatusBadRequest, "cannot provide an empty statement"},
- {"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
- {"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
- {"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
- }
- for _, test := range tests {
- res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
- assert.NoError(t, err)
- assert.Equal(t, test.status, res.StatusCode)
- if res.StatusCode > http.StatusOK {
- assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
- } else {
- assert.Equal(t, "application/json", res.Header.Get("Content-type"))
- }
- data, err := ioutil.ReadAll(res.Body)
- defer res.Body.Close()
- str := string(data)
- assert.NoError(t, err)
- assert.Equal(t, test.output, str)
- }
- }
- func TestProxyHTTPSubmitForbidden(t *testing.T) {
- proxy := helperNewProxy(t, true)
- server := httptest.NewServer(proxy.httpSubmit())
- defer server.Close()
- client := server.Client()
- res, err := client.Get(server.URL)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusForbidden, res.StatusCode)
- assert.Zero(t, res.ContentLength)
- }
- func TestProxyHTTPRespond(t *testing.T) {
- proxy := helperNewProxy(t)
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
- }))
- defer server.Close()
- client := server.Client()
- res, err := client.Get(server.URL)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusAccepted, res.StatusCode)
- assert.Equal(t, int64(5), res.ContentLength)
- data, err := ioutil.ReadAll(res.Body)
- defer res.Body.Close()
- assert.Equal(t, []byte("Hello"), data)
- }
- func TestProxyHTTPRespondForbidden(t *testing.T) {
- proxy := helperNewProxy(t, true)
- server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
- }))
- defer server.Close()
- client := server.Client()
- res, err := client.Get(server.URL)
- assert.NoError(t, err)
- assert.Equal(t, http.StatusAccepted, res.StatusCode)
- assert.Equal(t, int64(0), res.ContentLength)
- }
- func TestHTTPError(t *testing.T) {
- _, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
- assert.Error(t, errTimeout)
- tests := []struct {
- input error
- status int
- output error
- }{
- {nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
- {io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
- {context.DeadlineExceeded, http.StatusRequestTimeout, nil},
- {context.Canceled, 444, nil},
- {errTimeout, http.StatusRequestTimeout, nil},
- {fmt.Errorf(""), http.StatusInternalServerError, nil},
- }
- for _, test := range tests {
- status, err := httpError(http.StatusInternalServerError, test.input)
- assert.Error(t, err)
- assert.Equal(t, test.status, status)
- if test.output == nil {
- test.output = test.input
- }
- assert.Equal(t, test.output, err)
- }
- }
- func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
- t.Helper()
- proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
- assert.NoError(t, err)
- assert.NotNil(t, proxy)
- if len(secure) == 0 {
- proxy.accessValidator = nil // Mark as insecure
- }
- return proxy
- }
|