123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539 |
- package http
- import (
- "context"
- "crypto/tls"
- "errors"
- "fmt"
- "net/http"
- "strings"
- "testing"
- "github.com/stretchr/testify/require"
- )
- func TestMiddlewareAuth(t *testing.T) {
- servers := []struct {
- name string
- http Config
- auth AuthConfig
- user string
- pass string
- }{
- {
- name: "Basic",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- },
- auth: AuthConfig{
- Realm: "test",
- BasicUser: "test",
- BasicPass: "test",
- },
- user: "test",
- pass: "test",
- },
- {
- name: "Htpasswd/MD5",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- },
- auth: AuthConfig{
- Realm: "test",
- HtPasswd: "./testdata/.htpasswd",
- },
- user: "md5",
- pass: "md5",
- },
- {
- name: "Htpasswd/SHA",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- },
- auth: AuthConfig{
- Realm: "test",
- HtPasswd: "./testdata/.htpasswd",
- },
- user: "sha",
- pass: "sha",
- },
- {
- name: "Htpasswd/Bcrypt",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- },
- auth: AuthConfig{
- Realm: "test",
- HtPasswd: "./testdata/.htpasswd",
- },
- user: "bcrypt",
- pass: "bcrypt",
- },
- {
- name: "Custom",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- },
- auth: AuthConfig{
- Realm: "test",
- CustomAuthFn: func(user, pass string) (value interface{}, err error) {
- if user == "custom" && pass == "custom" {
- return true, nil
- }
- return nil, errors.New("invalid credentials")
- },
- },
- user: "custom",
- pass: "custom",
- },
- }
- for _, ss := range servers {
- t.Run(ss.name, func(t *testing.T) {
- s, err := NewServer(context.Background(), WithConfig(ss.http), WithAuth(ss.auth))
- require.NoError(t, err)
- defer func() {
- require.NoError(t, s.Shutdown())
- }()
- expected := []byte("secret-page")
- s.Router().Mount("/", testEchoHandler(expected))
- s.Serve()
- url := testGetServerURL(t, s)
- t.Run("NoCreds", func(t *testing.T) {
- client := &http.Client{}
- req, err := http.NewRequest("GET", url, nil)
- require.NoError(t, err)
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer func() {
- _ = resp.Body.Close()
- }()
- require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "using no creds should return unauthorized")
- wwwAuthHeader := resp.Header.Get("WWW-Authenticate")
- require.NotEmpty(t, wwwAuthHeader, "resp should contain WWW-Authtentication header")
- require.Contains(t, wwwAuthHeader, fmt.Sprintf("realm=%q", ss.auth.Realm), "WWW-Authtentication header should contain relam")
- })
- t.Run("BadCreds", func(t *testing.T) {
- client := &http.Client{}
- req, err := http.NewRequest("GET", url, nil)
- require.NoError(t, err)
- req.SetBasicAuth(ss.user+"BAD", ss.pass+"BAD")
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer func() {
- _ = resp.Body.Close()
- }()
- require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "using bad creds should return unauthorized")
- wwwAuthHeader := resp.Header.Get("WWW-Authenticate")
- require.NotEmpty(t, wwwAuthHeader, "resp should contain WWW-Authtentication header")
- require.Contains(t, wwwAuthHeader, fmt.Sprintf("realm=%q", ss.auth.Realm), "WWW-Authtentication header should contain relam")
- })
- t.Run("GoodCreds", func(t *testing.T) {
- client := &http.Client{}
- req, err := http.NewRequest("GET", url, nil)
- require.NoError(t, err)
- req.SetBasicAuth(ss.user, ss.pass)
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer func() {
- _ = resp.Body.Close()
- }()
- require.Equal(t, http.StatusOK, resp.StatusCode, "using good creds should return ok")
- testExpectRespBody(t, resp, expected)
- })
- })
- }
- }
- func TestMiddlewareAuthCertificateUser(t *testing.T) {
- serverCertBytes := testReadTestdataFile(t, "local.crt")
- serverKeyBytes := testReadTestdataFile(t, "local.key")
- clientCertBytes := testReadTestdataFile(t, "client.crt")
- clientKeyBytes := testReadTestdataFile(t, "client.key")
- clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes)
- require.NoError(t, err)
- emptyCertBytes := testReadTestdataFile(t, "emptyclient.crt")
- emptyKeyBytes := testReadTestdataFile(t, "emptyclient.key")
- emptyCert, err := tls.X509KeyPair(emptyCertBytes, emptyKeyBytes)
- require.NoError(t, err)
- invalidCert, err := tls.X509KeyPair(serverCertBytes, serverKeyBytes)
- require.NoError(t, err)
- servers := []struct {
- name string
- wantErr bool
- status int
- result string
- http Config
- auth AuthConfig
- clientCerts []tls.Certificate
- }{
- {
- name: "Missing",
- wantErr: true,
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- TLSCertBody: serverCertBytes,
- TLSKeyBody: serverKeyBytes,
- MinTLSVersion: "tls1.0",
- ClientCA: "./testdata/client-ca.crt",
- },
- },
- {
- name: "Invalid",
- wantErr: true,
- clientCerts: []tls.Certificate{invalidCert},
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- TLSCertBody: serverCertBytes,
- TLSKeyBody: serverKeyBytes,
- MinTLSVersion: "tls1.0",
- ClientCA: "./testdata/client-ca.crt",
- },
- },
- {
- name: "EmptyCommonName",
- status: http.StatusUnauthorized,
- result: fmt.Sprintf("%s\n", http.StatusText(http.StatusUnauthorized)),
- clientCerts: []tls.Certificate{emptyCert},
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- TLSCertBody: serverCertBytes,
- TLSKeyBody: serverKeyBytes,
- MinTLSVersion: "tls1.0",
- ClientCA: "./testdata/client-ca.crt",
- },
- },
- {
- name: "Valid",
- status: http.StatusOK,
- result: "rclone-dev-client",
- clientCerts: []tls.Certificate{clientCert},
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- TLSCertBody: serverCertBytes,
- TLSKeyBody: serverKeyBytes,
- MinTLSVersion: "tls1.0",
- ClientCA: "./testdata/client-ca.crt",
- },
- },
- {
- name: "CustomAuth/Invalid",
- status: http.StatusUnauthorized,
- result: fmt.Sprintf("%d %s\n", http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)),
- clientCerts: []tls.Certificate{clientCert},
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- TLSCertBody: serverCertBytes,
- TLSKeyBody: serverKeyBytes,
- MinTLSVersion: "tls1.0",
- ClientCA: "./testdata/client-ca.crt",
- },
- auth: AuthConfig{
- Realm: "test",
- CustomAuthFn: func(user, pass string) (value interface{}, err error) {
- if user == "custom" && pass == "custom" {
- return true, nil
- }
- return nil, errors.New("invalid credentials")
- },
- },
- },
- {
- name: "CustomAuth/Valid",
- status: http.StatusOK,
- result: "rclone-dev-client",
- clientCerts: []tls.Certificate{clientCert},
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- TLSCertBody: serverCertBytes,
- TLSKeyBody: serverKeyBytes,
- MinTLSVersion: "tls1.0",
- ClientCA: "./testdata/client-ca.crt",
- },
- auth: AuthConfig{
- Realm: "test",
- CustomAuthFn: func(user, pass string) (value interface{}, err error) {
- fmt.Println("CUSTOMAUTH", user, pass)
- if user == "rclone-dev-client" && pass == "" {
- return true, nil
- }
- return nil, errors.New("invalid credentials")
- },
- },
- },
- }
- for _, ss := range servers {
- t.Run(ss.name, func(t *testing.T) {
- s, err := NewServer(context.Background(), WithConfig(ss.http), WithAuth(ss.auth))
- require.NoError(t, err)
- defer func() {
- require.NoError(t, s.Shutdown())
- }()
- s.Router().Mount("/", testAuthUserHandler())
- s.Serve()
- url := testGetServerURL(t, s)
- client := &http.Client{
- Transport: &http.Transport{
- TLSClientConfig: &tls.Config{
- Certificates: ss.clientCerts,
- InsecureSkipVerify: true,
- },
- },
- }
- req, err := http.NewRequest("GET", url, nil)
- require.NoError(t, err)
- resp, err := client.Do(req)
- if ss.wantErr {
- require.Error(t, err)
- return
- }
- require.NoError(t, err)
- defer func() {
- _ = resp.Body.Close()
- }()
- require.Equal(t, ss.status, resp.StatusCode, fmt.Sprintf("should return status %d", ss.status))
- testExpectRespBody(t, resp, []byte(ss.result))
- })
- }
- }
- var _testCORSHeaderKeys = []string{
- "Access-Control-Allow-Origin",
- "Access-Control-Allow-Headers",
- "Access-Control-Allow-Methods",
- }
- func TestMiddlewareCORS(t *testing.T) {
- servers := []struct {
- name string
- http Config
- tryRoot bool
- method string
- status int
- }{
- {
- name: "CustomOrigin",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- AllowOrigin: "http://test.rclone.org",
- },
- method: "GET",
- status: http.StatusOK,
- },
- {
- name: "WithBaseURL",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- AllowOrigin: "http://test.rclone.org",
- BaseURL: "/baseurl/",
- },
- method: "GET",
- status: http.StatusOK,
- },
- {
- name: "WithBaseURLTryRootGET",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- AllowOrigin: "http://test.rclone.org",
- BaseURL: "/baseurl/",
- },
- method: "GET",
- status: http.StatusNotFound,
- tryRoot: true,
- },
- {
- name: "WithBaseURLTryRootOPTIONS",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- AllowOrigin: "http://test.rclone.org",
- BaseURL: "/baseurl/",
- },
- method: "OPTIONS",
- status: http.StatusOK,
- tryRoot: true,
- },
- }
- for _, ss := range servers {
- t.Run(ss.name, func(t *testing.T) {
- s, err := NewServer(context.Background(), WithConfig(ss.http))
- require.NoError(t, err)
- defer func() {
- require.NoError(t, s.Shutdown())
- }()
- expected := []byte("data")
- s.Router().Mount("/", testEchoHandler(expected))
- s.Serve()
- url := testGetServerURL(t, s)
- // Try the query on the root, ignoring the baseURL
- if ss.tryRoot {
- slash := strings.LastIndex(url[:len(url)-1], "/")
- url = url[:slash+1]
- }
- client := &http.Client{}
- req, err := http.NewRequest(ss.method, url, nil)
- require.NoError(t, err)
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer func() {
- _ = resp.Body.Close()
- }()
- require.Equal(t, ss.status, resp.StatusCode, "should return expected error code")
- if ss.status == http.StatusNotFound {
- return
- }
- testExpectRespBody(t, resp, expected)
- for _, key := range _testCORSHeaderKeys {
- require.Contains(t, resp.Header, key, "CORS headers should be sent")
- }
- expectedOrigin := url
- if ss.http.AllowOrigin != "" {
- expectedOrigin = ss.http.AllowOrigin
- }
- require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
- })
- }
- }
- func TestMiddlewareCORSEmptyOrigin(t *testing.T) {
- servers := []struct {
- name string
- http Config
- }{
- {
- name: "EmptyOrigin",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- AllowOrigin: "",
- },
- },
- }
- for _, ss := range servers {
- t.Run(ss.name, func(t *testing.T) {
- s, err := NewServer(context.Background(), WithConfig(ss.http))
- require.NoError(t, err)
- defer func() {
- require.NoError(t, s.Shutdown())
- }()
- expected := []byte("data")
- s.Router().Mount("/", testEchoHandler(expected))
- s.Serve()
- url := testGetServerURL(t, s)
- client := &http.Client{}
- req, err := http.NewRequest("GET", url, nil)
- require.NoError(t, err)
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer func() {
- _ = resp.Body.Close()
- }()
- require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok")
- testExpectRespBody(t, resp, expected)
- for _, key := range _testCORSHeaderKeys {
- require.NotContains(t, resp.Header, key, "CORS headers should not be sent")
- }
- })
- }
- }
- func TestMiddlewareCORSWithAuth(t *testing.T) {
- authServers := []struct {
- name string
- http Config
- auth AuthConfig
- }{
- {
- name: "ServerWithAuth",
- http: Config{
- ListenAddr: []string{"127.0.0.1:0"},
- AllowOrigin: "http://test.rclone.org",
- },
- auth: AuthConfig{
- Realm: "test",
- BasicUser: "test_user",
- BasicPass: "test_pass",
- },
- },
- }
- for _, ss := range authServers {
- t.Run(ss.name, func(t *testing.T) {
- s, err := NewServer(context.Background(), WithConfig(ss.http))
- require.NoError(t, err)
- defer func() {
- require.NoError(t, s.Shutdown())
- }()
- s.Router().Mount("/", testEmptyHandler())
- s.Serve()
- url := testGetServerURL(t, s)
- client := &http.Client{}
- req, err := http.NewRequest("OPTIONS", url, nil)
- require.NoError(t, err)
- resp, err := client.Do(req)
- require.NoError(t, err)
- defer func() {
- _ = resp.Body.Close()
- }()
- require.Equal(t, http.StatusOK, resp.StatusCode, "OPTIONS should return ok even if not authenticated")
- testExpectRespBody(t, resp, []byte{})
- for _, key := range _testCORSHeaderKeys {
- require.Contains(t, resp.Header, key, "CORS headers should be sent even if not authenticated")
- }
- expectedOrigin := url
- if ss.http.AllowOrigin != "" {
- expectedOrigin = ss.http.AllowOrigin
- }
- require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
- })
- }
- }
|