proxy_test.go 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. package dbconnect
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "io/ioutil"
  7. "net"
  8. "net/http"
  9. "net/http/httptest"
  10. "strings"
  11. "testing"
  12. "github.com/gorilla/mux"
  13. "github.com/stretchr/testify/assert"
  14. )
  15. func TestNewInsecureProxy(t *testing.T) {
  16. origins := []string{
  17. "",
  18. ":/",
  19. "http://localhost",
  20. "tcp://localhost:9000?debug=true",
  21. "mongodb://127.0.0.1",
  22. }
  23. for _, origin := range origins {
  24. proxy, err := NewInsecureProxy(context.Background(), origin)
  25. assert.Error(t, err)
  26. assert.Empty(t, proxy)
  27. }
  28. }
  29. func TestProxyIsAllowed(t *testing.T) {
  30. proxy := helperNewProxy(t)
  31. req := httptest.NewRequest("GET", "https://1.1.1.1/ping", nil)
  32. assert.True(t, proxy.IsAllowed(req))
  33. proxy = helperNewProxy(t, true)
  34. req.Header.Set("Cf-access-jwt-assertion", "xxx")
  35. assert.False(t, proxy.IsAllowed(req))
  36. }
  37. func TestProxyStart(t *testing.T) {
  38. proxy := helperNewProxy(t)
  39. ctx := context.Background()
  40. listenerC := make(chan net.Listener)
  41. err := proxy.Start(ctx, "1.1.1.1:", listenerC)
  42. assert.Error(t, err)
  43. err = proxy.Start(ctx, "127.0.0.1:-1", listenerC)
  44. assert.Error(t, err)
  45. ctx, cancel := context.WithTimeout(ctx, 0)
  46. defer cancel()
  47. err = proxy.Start(ctx, "127.0.0.1:", listenerC)
  48. assert.IsType(t, http.ErrServerClosed, err)
  49. }
  50. func TestProxyHTTPRouter(t *testing.T) {
  51. proxy := helperNewProxy(t)
  52. router := proxy.httpRouter()
  53. tests := []struct {
  54. path string
  55. method string
  56. valid bool
  57. }{
  58. {"", "GET", false},
  59. {"/", "GET", false},
  60. {"/ping", "GET", true},
  61. {"/ping", "HEAD", true},
  62. {"/ping", "POST", false},
  63. {"/submit", "POST", true},
  64. {"/submit", "GET", false},
  65. {"/submit/extra", "POST", false},
  66. }
  67. for _, test := range tests {
  68. match := &mux.RouteMatch{}
  69. ok := router.Match(httptest.NewRequest(test.method, "https://1.1.1.1"+test.path, nil), match)
  70. assert.True(t, ok == test.valid, test.path)
  71. }
  72. }
  73. func TestProxyHTTPPing(t *testing.T) {
  74. proxy := helperNewProxy(t)
  75. server := httptest.NewServer(proxy.httpPing())
  76. defer server.Close()
  77. client := server.Client()
  78. res, err := client.Get(server.URL)
  79. assert.NoError(t, err)
  80. assert.Equal(t, http.StatusOK, res.StatusCode)
  81. assert.Equal(t, int64(2), res.ContentLength)
  82. res, err = client.Head(server.URL)
  83. assert.NoError(t, err)
  84. assert.Equal(t, http.StatusOK, res.StatusCode)
  85. assert.Equal(t, int64(-1), res.ContentLength)
  86. }
  87. func TestProxyHTTPSubmit(t *testing.T) {
  88. proxy := helperNewProxy(t)
  89. server := httptest.NewServer(proxy.httpSubmit())
  90. defer server.Close()
  91. client := server.Client()
  92. tests := []struct {
  93. input string
  94. status int
  95. output string
  96. }{
  97. {"", http.StatusBadRequest, "request body cannot be empty"},
  98. {"{}", http.StatusBadRequest, "cannot provide an empty statement"},
  99. {"{\"statement\":\"Ok\"}", http.StatusUnprocessableEntity, "cannot provide invalid sql mode: ''"},
  100. {"{\"statement\":\"Ok\",\"mode\":\"query\"}", http.StatusUnprocessableEntity, "near \"Ok\": syntax error"},
  101. {"{\"statement\":\"CREATE TABLE t (a INT);\",\"mode\":\"exec\"}", http.StatusOK, "{\"last_insert_id\":0,\"rows_affected\":0}\n"},
  102. }
  103. for _, test := range tests {
  104. res, err := client.Post(server.URL, "application/json", strings.NewReader(test.input))
  105. assert.NoError(t, err)
  106. assert.Equal(t, test.status, res.StatusCode)
  107. if res.StatusCode > http.StatusOK {
  108. assert.Equal(t, "text/plain; charset=utf-8", res.Header.Get("Content-type"))
  109. } else {
  110. assert.Equal(t, "application/json", res.Header.Get("Content-type"))
  111. }
  112. data, err := ioutil.ReadAll(res.Body)
  113. defer res.Body.Close()
  114. str := string(data)
  115. assert.NoError(t, err)
  116. assert.Equal(t, test.output, str)
  117. }
  118. }
  119. func TestProxyHTTPSubmitForbidden(t *testing.T) {
  120. proxy := helperNewProxy(t, true)
  121. server := httptest.NewServer(proxy.httpSubmit())
  122. defer server.Close()
  123. client := server.Client()
  124. res, err := client.Get(server.URL)
  125. assert.NoError(t, err)
  126. assert.Equal(t, http.StatusForbidden, res.StatusCode)
  127. assert.Zero(t, res.ContentLength)
  128. }
  129. func TestProxyHTTPRespond(t *testing.T) {
  130. proxy := helperNewProxy(t)
  131. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  132. proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
  133. }))
  134. defer server.Close()
  135. client := server.Client()
  136. res, err := client.Get(server.URL)
  137. assert.NoError(t, err)
  138. assert.Equal(t, http.StatusAccepted, res.StatusCode)
  139. assert.Equal(t, int64(5), res.ContentLength)
  140. data, err := ioutil.ReadAll(res.Body)
  141. defer res.Body.Close()
  142. assert.Equal(t, []byte("Hello"), data)
  143. }
  144. func TestProxyHTTPRespondForbidden(t *testing.T) {
  145. proxy := helperNewProxy(t, true)
  146. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  147. proxy.httpRespond(w, r, http.StatusAccepted, "Hello")
  148. }))
  149. defer server.Close()
  150. client := server.Client()
  151. res, err := client.Get(server.URL)
  152. assert.NoError(t, err)
  153. assert.Equal(t, http.StatusAccepted, res.StatusCode)
  154. assert.Equal(t, int64(0), res.ContentLength)
  155. }
  156. func TestHTTPError(t *testing.T) {
  157. _, errTimeout := net.DialTimeout("tcp", "127.0.0.1", 0)
  158. assert.Error(t, errTimeout)
  159. tests := []struct {
  160. input error
  161. status int
  162. output error
  163. }{
  164. {nil, http.StatusNotImplemented, fmt.Errorf("error expected but found none")},
  165. {io.EOF, http.StatusBadRequest, fmt.Errorf("request body cannot be empty")},
  166. {context.DeadlineExceeded, http.StatusRequestTimeout, nil},
  167. {context.Canceled, 444, nil},
  168. {errTimeout, http.StatusRequestTimeout, nil},
  169. {fmt.Errorf(""), http.StatusInternalServerError, nil},
  170. }
  171. for _, test := range tests {
  172. status, err := httpError(http.StatusInternalServerError, test.input)
  173. assert.Error(t, err)
  174. assert.Equal(t, test.status, status)
  175. if test.output == nil {
  176. test.output = test.input
  177. }
  178. assert.Equal(t, test.output, err)
  179. }
  180. }
  181. func helperNewProxy(t *testing.T, secure ...bool) *Proxy {
  182. t.Helper()
  183. proxy, err := NewSecureProxy(context.Background(), "file::memory:?cache=shared", "test.cloudflareaccess.com", "")
  184. assert.NoError(t, err)
  185. assert.NotNil(t, proxy)
  186. if len(secure) == 0 {
  187. proxy.accessValidator = nil // Mark as insecure
  188. }
  189. return proxy
  190. }