middleware_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. package http
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net/http"
  8. "strings"
  9. "testing"
  10. "github.com/stretchr/testify/require"
  11. )
  12. func TestMiddlewareAuth(t *testing.T) {
  13. servers := []struct {
  14. name string
  15. http Config
  16. auth AuthConfig
  17. user string
  18. pass string
  19. }{
  20. {
  21. name: "Basic",
  22. http: Config{
  23. ListenAddr: []string{"127.0.0.1:0"},
  24. },
  25. auth: AuthConfig{
  26. Realm: "test",
  27. BasicUser: "test",
  28. BasicPass: "test",
  29. },
  30. user: "test",
  31. pass: "test",
  32. },
  33. {
  34. name: "Htpasswd/MD5",
  35. http: Config{
  36. ListenAddr: []string{"127.0.0.1:0"},
  37. },
  38. auth: AuthConfig{
  39. Realm: "test",
  40. HtPasswd: "./testdata/.htpasswd",
  41. },
  42. user: "md5",
  43. pass: "md5",
  44. },
  45. {
  46. name: "Htpasswd/SHA",
  47. http: Config{
  48. ListenAddr: []string{"127.0.0.1:0"},
  49. },
  50. auth: AuthConfig{
  51. Realm: "test",
  52. HtPasswd: "./testdata/.htpasswd",
  53. },
  54. user: "sha",
  55. pass: "sha",
  56. },
  57. {
  58. name: "Htpasswd/Bcrypt",
  59. http: Config{
  60. ListenAddr: []string{"127.0.0.1:0"},
  61. },
  62. auth: AuthConfig{
  63. Realm: "test",
  64. HtPasswd: "./testdata/.htpasswd",
  65. },
  66. user: "bcrypt",
  67. pass: "bcrypt",
  68. },
  69. {
  70. name: "Custom",
  71. http: Config{
  72. ListenAddr: []string{"127.0.0.1:0"},
  73. },
  74. auth: AuthConfig{
  75. Realm: "test",
  76. CustomAuthFn: func(user, pass string) (value interface{}, err error) {
  77. if user == "custom" && pass == "custom" {
  78. return true, nil
  79. }
  80. return nil, errors.New("invalid credentials")
  81. },
  82. },
  83. user: "custom",
  84. pass: "custom",
  85. },
  86. }
  87. for _, ss := range servers {
  88. t.Run(ss.name, func(t *testing.T) {
  89. s, err := NewServer(context.Background(), WithConfig(ss.http), WithAuth(ss.auth))
  90. require.NoError(t, err)
  91. defer func() {
  92. require.NoError(t, s.Shutdown())
  93. }()
  94. expected := []byte("secret-page")
  95. s.Router().Mount("/", testEchoHandler(expected))
  96. s.Serve()
  97. url := testGetServerURL(t, s)
  98. t.Run("NoCreds", func(t *testing.T) {
  99. client := &http.Client{}
  100. req, err := http.NewRequest("GET", url, nil)
  101. require.NoError(t, err)
  102. resp, err := client.Do(req)
  103. require.NoError(t, err)
  104. defer func() {
  105. _ = resp.Body.Close()
  106. }()
  107. require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "using no creds should return unauthorized")
  108. wwwAuthHeader := resp.Header.Get("WWW-Authenticate")
  109. require.NotEmpty(t, wwwAuthHeader, "resp should contain WWW-Authtentication header")
  110. require.Contains(t, wwwAuthHeader, fmt.Sprintf("realm=%q", ss.auth.Realm), "WWW-Authtentication header should contain relam")
  111. })
  112. t.Run("BadCreds", func(t *testing.T) {
  113. client := &http.Client{}
  114. req, err := http.NewRequest("GET", url, nil)
  115. require.NoError(t, err)
  116. req.SetBasicAuth(ss.user+"BAD", ss.pass+"BAD")
  117. resp, err := client.Do(req)
  118. require.NoError(t, err)
  119. defer func() {
  120. _ = resp.Body.Close()
  121. }()
  122. require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "using bad creds should return unauthorized")
  123. wwwAuthHeader := resp.Header.Get("WWW-Authenticate")
  124. require.NotEmpty(t, wwwAuthHeader, "resp should contain WWW-Authtentication header")
  125. require.Contains(t, wwwAuthHeader, fmt.Sprintf("realm=%q", ss.auth.Realm), "WWW-Authtentication header should contain relam")
  126. })
  127. t.Run("GoodCreds", func(t *testing.T) {
  128. client := &http.Client{}
  129. req, err := http.NewRequest("GET", url, nil)
  130. require.NoError(t, err)
  131. req.SetBasicAuth(ss.user, ss.pass)
  132. resp, err := client.Do(req)
  133. require.NoError(t, err)
  134. defer func() {
  135. _ = resp.Body.Close()
  136. }()
  137. require.Equal(t, http.StatusOK, resp.StatusCode, "using good creds should return ok")
  138. testExpectRespBody(t, resp, expected)
  139. })
  140. })
  141. }
  142. }
  143. func TestMiddlewareAuthCertificateUser(t *testing.T) {
  144. serverCertBytes := testReadTestdataFile(t, "local.crt")
  145. serverKeyBytes := testReadTestdataFile(t, "local.key")
  146. clientCertBytes := testReadTestdataFile(t, "client.crt")
  147. clientKeyBytes := testReadTestdataFile(t, "client.key")
  148. clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes)
  149. require.NoError(t, err)
  150. emptyCertBytes := testReadTestdataFile(t, "emptyclient.crt")
  151. emptyKeyBytes := testReadTestdataFile(t, "emptyclient.key")
  152. emptyCert, err := tls.X509KeyPair(emptyCertBytes, emptyKeyBytes)
  153. require.NoError(t, err)
  154. invalidCert, err := tls.X509KeyPair(serverCertBytes, serverKeyBytes)
  155. require.NoError(t, err)
  156. servers := []struct {
  157. name string
  158. wantErr bool
  159. status int
  160. result string
  161. http Config
  162. auth AuthConfig
  163. clientCerts []tls.Certificate
  164. }{
  165. {
  166. name: "Missing",
  167. wantErr: true,
  168. http: Config{
  169. ListenAddr: []string{"127.0.0.1:0"},
  170. TLSCertBody: serverCertBytes,
  171. TLSKeyBody: serverKeyBytes,
  172. MinTLSVersion: "tls1.0",
  173. ClientCA: "./testdata/client-ca.crt",
  174. },
  175. },
  176. {
  177. name: "Invalid",
  178. wantErr: true,
  179. clientCerts: []tls.Certificate{invalidCert},
  180. http: Config{
  181. ListenAddr: []string{"127.0.0.1:0"},
  182. TLSCertBody: serverCertBytes,
  183. TLSKeyBody: serverKeyBytes,
  184. MinTLSVersion: "tls1.0",
  185. ClientCA: "./testdata/client-ca.crt",
  186. },
  187. },
  188. {
  189. name: "EmptyCommonName",
  190. status: http.StatusUnauthorized,
  191. result: fmt.Sprintf("%s\n", http.StatusText(http.StatusUnauthorized)),
  192. clientCerts: []tls.Certificate{emptyCert},
  193. http: Config{
  194. ListenAddr: []string{"127.0.0.1:0"},
  195. TLSCertBody: serverCertBytes,
  196. TLSKeyBody: serverKeyBytes,
  197. MinTLSVersion: "tls1.0",
  198. ClientCA: "./testdata/client-ca.crt",
  199. },
  200. },
  201. {
  202. name: "Valid",
  203. status: http.StatusOK,
  204. result: "rclone-dev-client",
  205. clientCerts: []tls.Certificate{clientCert},
  206. http: Config{
  207. ListenAddr: []string{"127.0.0.1:0"},
  208. TLSCertBody: serverCertBytes,
  209. TLSKeyBody: serverKeyBytes,
  210. MinTLSVersion: "tls1.0",
  211. ClientCA: "./testdata/client-ca.crt",
  212. },
  213. },
  214. {
  215. name: "CustomAuth/Invalid",
  216. status: http.StatusUnauthorized,
  217. result: fmt.Sprintf("%d %s\n", http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized)),
  218. clientCerts: []tls.Certificate{clientCert},
  219. http: Config{
  220. ListenAddr: []string{"127.0.0.1:0"},
  221. TLSCertBody: serverCertBytes,
  222. TLSKeyBody: serverKeyBytes,
  223. MinTLSVersion: "tls1.0",
  224. ClientCA: "./testdata/client-ca.crt",
  225. },
  226. auth: AuthConfig{
  227. Realm: "test",
  228. CustomAuthFn: func(user, pass string) (value interface{}, err error) {
  229. if user == "custom" && pass == "custom" {
  230. return true, nil
  231. }
  232. return nil, errors.New("invalid credentials")
  233. },
  234. },
  235. },
  236. {
  237. name: "CustomAuth/Valid",
  238. status: http.StatusOK,
  239. result: "rclone-dev-client",
  240. clientCerts: []tls.Certificate{clientCert},
  241. http: Config{
  242. ListenAddr: []string{"127.0.0.1:0"},
  243. TLSCertBody: serverCertBytes,
  244. TLSKeyBody: serverKeyBytes,
  245. MinTLSVersion: "tls1.0",
  246. ClientCA: "./testdata/client-ca.crt",
  247. },
  248. auth: AuthConfig{
  249. Realm: "test",
  250. CustomAuthFn: func(user, pass string) (value interface{}, err error) {
  251. fmt.Println("CUSTOMAUTH", user, pass)
  252. if user == "rclone-dev-client" && pass == "" {
  253. return true, nil
  254. }
  255. return nil, errors.New("invalid credentials")
  256. },
  257. },
  258. },
  259. }
  260. for _, ss := range servers {
  261. t.Run(ss.name, func(t *testing.T) {
  262. s, err := NewServer(context.Background(), WithConfig(ss.http), WithAuth(ss.auth))
  263. require.NoError(t, err)
  264. defer func() {
  265. require.NoError(t, s.Shutdown())
  266. }()
  267. s.Router().Mount("/", testAuthUserHandler())
  268. s.Serve()
  269. url := testGetServerURL(t, s)
  270. client := &http.Client{
  271. Transport: &http.Transport{
  272. TLSClientConfig: &tls.Config{
  273. Certificates: ss.clientCerts,
  274. InsecureSkipVerify: true,
  275. },
  276. },
  277. }
  278. req, err := http.NewRequest("GET", url, nil)
  279. require.NoError(t, err)
  280. resp, err := client.Do(req)
  281. if ss.wantErr {
  282. require.Error(t, err)
  283. return
  284. }
  285. require.NoError(t, err)
  286. defer func() {
  287. _ = resp.Body.Close()
  288. }()
  289. require.Equal(t, ss.status, resp.StatusCode, fmt.Sprintf("should return status %d", ss.status))
  290. testExpectRespBody(t, resp, []byte(ss.result))
  291. })
  292. }
  293. }
  294. var _testCORSHeaderKeys = []string{
  295. "Access-Control-Allow-Origin",
  296. "Access-Control-Allow-Headers",
  297. "Access-Control-Allow-Methods",
  298. }
  299. func TestMiddlewareCORS(t *testing.T) {
  300. servers := []struct {
  301. name string
  302. http Config
  303. tryRoot bool
  304. method string
  305. status int
  306. }{
  307. {
  308. name: "CustomOrigin",
  309. http: Config{
  310. ListenAddr: []string{"127.0.0.1:0"},
  311. AllowOrigin: "http://test.rclone.org",
  312. },
  313. method: "GET",
  314. status: http.StatusOK,
  315. },
  316. {
  317. name: "WithBaseURL",
  318. http: Config{
  319. ListenAddr: []string{"127.0.0.1:0"},
  320. AllowOrigin: "http://test.rclone.org",
  321. BaseURL: "/baseurl/",
  322. },
  323. method: "GET",
  324. status: http.StatusOK,
  325. },
  326. {
  327. name: "WithBaseURLTryRootGET",
  328. http: Config{
  329. ListenAddr: []string{"127.0.0.1:0"},
  330. AllowOrigin: "http://test.rclone.org",
  331. BaseURL: "/baseurl/",
  332. },
  333. method: "GET",
  334. status: http.StatusNotFound,
  335. tryRoot: true,
  336. },
  337. {
  338. name: "WithBaseURLTryRootOPTIONS",
  339. http: Config{
  340. ListenAddr: []string{"127.0.0.1:0"},
  341. AllowOrigin: "http://test.rclone.org",
  342. BaseURL: "/baseurl/",
  343. },
  344. method: "OPTIONS",
  345. status: http.StatusOK,
  346. tryRoot: true,
  347. },
  348. }
  349. for _, ss := range servers {
  350. t.Run(ss.name, func(t *testing.T) {
  351. s, err := NewServer(context.Background(), WithConfig(ss.http))
  352. require.NoError(t, err)
  353. defer func() {
  354. require.NoError(t, s.Shutdown())
  355. }()
  356. expected := []byte("data")
  357. s.Router().Mount("/", testEchoHandler(expected))
  358. s.Serve()
  359. url := testGetServerURL(t, s)
  360. // Try the query on the root, ignoring the baseURL
  361. if ss.tryRoot {
  362. slash := strings.LastIndex(url[:len(url)-1], "/")
  363. url = url[:slash+1]
  364. }
  365. client := &http.Client{}
  366. req, err := http.NewRequest(ss.method, url, nil)
  367. require.NoError(t, err)
  368. resp, err := client.Do(req)
  369. require.NoError(t, err)
  370. defer func() {
  371. _ = resp.Body.Close()
  372. }()
  373. require.Equal(t, ss.status, resp.StatusCode, "should return expected error code")
  374. if ss.status == http.StatusNotFound {
  375. return
  376. }
  377. testExpectRespBody(t, resp, expected)
  378. for _, key := range _testCORSHeaderKeys {
  379. require.Contains(t, resp.Header, key, "CORS headers should be sent")
  380. }
  381. expectedOrigin := url
  382. if ss.http.AllowOrigin != "" {
  383. expectedOrigin = ss.http.AllowOrigin
  384. }
  385. require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
  386. })
  387. }
  388. }
  389. func TestMiddlewareCORSEmptyOrigin(t *testing.T) {
  390. servers := []struct {
  391. name string
  392. http Config
  393. }{
  394. {
  395. name: "EmptyOrigin",
  396. http: Config{
  397. ListenAddr: []string{"127.0.0.1:0"},
  398. AllowOrigin: "",
  399. },
  400. },
  401. }
  402. for _, ss := range servers {
  403. t.Run(ss.name, func(t *testing.T) {
  404. s, err := NewServer(context.Background(), WithConfig(ss.http))
  405. require.NoError(t, err)
  406. defer func() {
  407. require.NoError(t, s.Shutdown())
  408. }()
  409. expected := []byte("data")
  410. s.Router().Mount("/", testEchoHandler(expected))
  411. s.Serve()
  412. url := testGetServerURL(t, s)
  413. client := &http.Client{}
  414. req, err := http.NewRequest("GET", url, nil)
  415. require.NoError(t, err)
  416. resp, err := client.Do(req)
  417. require.NoError(t, err)
  418. defer func() {
  419. _ = resp.Body.Close()
  420. }()
  421. require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok")
  422. testExpectRespBody(t, resp, expected)
  423. for _, key := range _testCORSHeaderKeys {
  424. require.NotContains(t, resp.Header, key, "CORS headers should not be sent")
  425. }
  426. })
  427. }
  428. }
  429. func TestMiddlewareCORSWithAuth(t *testing.T) {
  430. authServers := []struct {
  431. name string
  432. http Config
  433. auth AuthConfig
  434. }{
  435. {
  436. name: "ServerWithAuth",
  437. http: Config{
  438. ListenAddr: []string{"127.0.0.1:0"},
  439. AllowOrigin: "http://test.rclone.org",
  440. },
  441. auth: AuthConfig{
  442. Realm: "test",
  443. BasicUser: "test_user",
  444. BasicPass: "test_pass",
  445. },
  446. },
  447. }
  448. for _, ss := range authServers {
  449. t.Run(ss.name, func(t *testing.T) {
  450. s, err := NewServer(context.Background(), WithConfig(ss.http))
  451. require.NoError(t, err)
  452. defer func() {
  453. require.NoError(t, s.Shutdown())
  454. }()
  455. s.Router().Mount("/", testEmptyHandler())
  456. s.Serve()
  457. url := testGetServerURL(t, s)
  458. client := &http.Client{}
  459. req, err := http.NewRequest("OPTIONS", url, nil)
  460. require.NoError(t, err)
  461. resp, err := client.Do(req)
  462. require.NoError(t, err)
  463. defer func() {
  464. _ = resp.Body.Close()
  465. }()
  466. require.Equal(t, http.StatusOK, resp.StatusCode, "OPTIONS should return ok even if not authenticated")
  467. testExpectRespBody(t, resp, []byte{})
  468. for _, key := range _testCORSHeaderKeys {
  469. require.Contains(t, resp.Header, key, "CORS headers should be sent even if not authenticated")
  470. }
  471. expectedOrigin := url
  472. if ss.http.AllowOrigin != "" {
  473. expectedOrigin = ss.http.AllowOrigin
  474. }
  475. require.Equal(t, expectedOrigin, resp.Header.Get("Access-Control-Allow-Origin"), "allow origin should match")
  476. })
  477. }
  478. }