server_test.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. package http
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "io"
  6. "net"
  7. "net/http"
  8. "os"
  9. "path/filepath"
  10. "strings"
  11. "testing"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func testEmptyHandler() http.Handler {
  15. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
  16. }
  17. func testEchoHandler(data []byte) http.Handler {
  18. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  19. _, _ = w.Write(data)
  20. })
  21. }
  22. func testAuthUserHandler() http.Handler {
  23. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  24. userID, ok := CtxGetUser(r.Context())
  25. if !ok {
  26. http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
  27. }
  28. _, _ = w.Write([]byte(userID))
  29. })
  30. }
  31. func testExpectRespBody(t *testing.T, resp *http.Response, expected []byte) {
  32. body, err := io.ReadAll(resp.Body)
  33. require.NoError(t, err)
  34. require.Equal(t, expected, body)
  35. }
  36. func testGetServerURL(t *testing.T, s *Server) string {
  37. urls := s.URLs()
  38. require.GreaterOrEqual(t, len(urls), 1, "server should return at least one url")
  39. return urls[0]
  40. }
  41. func testNewHTTPClientUnix(path string) *http.Client {
  42. return &http.Client{
  43. Transport: &http.Transport{
  44. DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
  45. return net.Dial("unix", path)
  46. },
  47. },
  48. }
  49. }
  50. func testReadTestdataFile(t *testing.T, path string) []byte {
  51. data, err := os.ReadFile(filepath.Join("./testdata", path))
  52. require.NoError(t, err, "")
  53. return data
  54. }
  55. func TestNewServerUnix(t *testing.T) {
  56. ctx := context.Background()
  57. tempDir := t.TempDir()
  58. path := filepath.Join(tempDir, "rclone.sock")
  59. cfg := DefaultCfg()
  60. cfg.ListenAddr = []string{path}
  61. auth := AuthConfig{
  62. BasicUser: "test",
  63. BasicPass: "test",
  64. }
  65. s, err := NewServer(ctx, WithConfig(cfg), WithAuth(auth))
  66. require.NoError(t, err)
  67. defer func() {
  68. require.NoError(t, s.Shutdown())
  69. _, err := os.Stat(path)
  70. require.ErrorIs(t, err, os.ErrNotExist, "shutdown should remove socket")
  71. }()
  72. require.Empty(t, s.URLs(), "unix socket should not appear in URLs")
  73. expected := []byte("hello world")
  74. s.Router().Mount("/", testEchoHandler(expected))
  75. s.Serve()
  76. client := testNewHTTPClientUnix(path)
  77. req, err := http.NewRequest("GET", "http://unix", nil)
  78. require.NoError(t, err)
  79. resp, err := client.Do(req)
  80. require.NoError(t, err)
  81. testExpectRespBody(t, resp, expected)
  82. require.Equal(t, http.StatusOK, resp.StatusCode, "unix sockets should ignore auth")
  83. for _, key := range _testCORSHeaderKeys {
  84. require.NotContains(t, resp.Header, key, "unix sockets should not be sent CORS headers")
  85. }
  86. }
  87. func TestNewServerHTTP(t *testing.T) {
  88. ctx := context.Background()
  89. cfg := DefaultCfg()
  90. cfg.ListenAddr = []string{"127.0.0.1:0"}
  91. auth := AuthConfig{
  92. BasicUser: "test",
  93. BasicPass: "test",
  94. }
  95. s, err := NewServer(ctx, WithConfig(cfg), WithAuth(auth))
  96. require.NoError(t, err)
  97. defer func() {
  98. require.NoError(t, s.Shutdown())
  99. }()
  100. url := testGetServerURL(t, s)
  101. require.True(t, strings.HasPrefix(url, "http://"), "url should have http scheme")
  102. expected := []byte("hello world")
  103. s.Router().Mount("/", testEchoHandler(expected))
  104. s.Serve()
  105. t.Run("StatusUnauthorized", func(t *testing.T) {
  106. client := &http.Client{}
  107. req, err := http.NewRequest("GET", url, nil)
  108. require.NoError(t, err)
  109. resp, err := client.Do(req)
  110. require.NoError(t, err)
  111. defer func() {
  112. _ = resp.Body.Close()
  113. }()
  114. require.Equal(t, http.StatusUnauthorized, resp.StatusCode, "no basic auth creds should return unauthorized")
  115. })
  116. t.Run("StatusOK", func(t *testing.T) {
  117. client := &http.Client{}
  118. req, err := http.NewRequest("GET", url, nil)
  119. require.NoError(t, err)
  120. req.SetBasicAuth(auth.BasicUser, auth.BasicPass)
  121. resp, err := client.Do(req)
  122. require.NoError(t, err)
  123. defer func() {
  124. _ = resp.Body.Close()
  125. }()
  126. require.Equal(t, http.StatusOK, resp.StatusCode, "using basic auth creds should return ok")
  127. testExpectRespBody(t, resp, expected)
  128. })
  129. }
  130. func TestNewServerBaseURL(t *testing.T) {
  131. servers := []struct {
  132. name string
  133. cfg Config
  134. suffix string
  135. }{
  136. {
  137. name: "Empty",
  138. cfg: Config{
  139. ListenAddr: []string{"127.0.0.1:0"},
  140. BaseURL: "",
  141. },
  142. suffix: "/",
  143. },
  144. {
  145. name: "Single/NoTrailingSlash",
  146. cfg: Config{
  147. ListenAddr: []string{"127.0.0.1:0"},
  148. BaseURL: "/rclone",
  149. },
  150. suffix: "/rclone/",
  151. },
  152. {
  153. name: "Single/TrailingSlash",
  154. cfg: Config{
  155. ListenAddr: []string{"127.0.0.1:0"},
  156. BaseURL: "/rclone/",
  157. },
  158. suffix: "/rclone/",
  159. },
  160. {
  161. name: "Multi/NoTrailingSlash",
  162. cfg: Config{
  163. ListenAddr: []string{"127.0.0.1:0"},
  164. BaseURL: "/rclone/test/base/url",
  165. },
  166. suffix: "/rclone/test/base/url/",
  167. },
  168. {
  169. name: "Multi/TrailingSlash",
  170. cfg: Config{
  171. ListenAddr: []string{"127.0.0.1:0"},
  172. BaseURL: "/rclone/test/base/url/",
  173. },
  174. suffix: "/rclone/test/base/url/",
  175. },
  176. }
  177. for _, ss := range servers {
  178. t.Run(ss.name, func(t *testing.T) {
  179. s, err := NewServer(context.Background(), WithConfig(ss.cfg))
  180. require.NoError(t, err)
  181. defer func() {
  182. require.NoError(t, s.Shutdown())
  183. }()
  184. expected := []byte("data")
  185. s.Router().Get("/", testEchoHandler(expected).ServeHTTP)
  186. s.Serve()
  187. url := testGetServerURL(t, s)
  188. require.True(t, strings.HasPrefix(url, "http://"), "url should have http scheme")
  189. require.True(t, strings.HasSuffix(url, ss.suffix), "url should have the expected suffix")
  190. client := &http.Client{}
  191. req, err := http.NewRequest("GET", url, nil)
  192. require.NoError(t, err)
  193. resp, err := client.Do(req)
  194. require.NoError(t, err)
  195. defer func() {
  196. _ = resp.Body.Close()
  197. }()
  198. t.Log(url, resp.Request.URL)
  199. require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok")
  200. testExpectRespBody(t, resp, expected)
  201. })
  202. }
  203. }
  204. func TestNewServerTLS(t *testing.T) {
  205. serverCertBytes := testReadTestdataFile(t, "local.crt")
  206. serverKeyBytes := testReadTestdataFile(t, "local.key")
  207. clientCertBytes := testReadTestdataFile(t, "client.crt")
  208. clientKeyBytes := testReadTestdataFile(t, "client.key")
  209. clientCert, err := tls.X509KeyPair(clientCertBytes, clientKeyBytes)
  210. require.NoError(t, err)
  211. // TODO: generate a proper cert with SAN
  212. servers := []struct {
  213. name string
  214. clientCerts []tls.Certificate
  215. wantErr bool
  216. wantClientErr bool
  217. err error
  218. http Config
  219. }{
  220. {
  221. name: "FromFile/Valid",
  222. http: Config{
  223. ListenAddr: []string{"127.0.0.1:0"},
  224. TLSCert: "./testdata/local.crt",
  225. TLSKey: "./testdata/local.key",
  226. MinTLSVersion: "tls1.0",
  227. },
  228. },
  229. {
  230. name: "FromFile/NoCert",
  231. wantErr: true,
  232. err: ErrTLSFileMismatch,
  233. http: Config{
  234. ListenAddr: []string{"127.0.0.1:0"},
  235. TLSCert: "",
  236. TLSKey: "./testdata/local.key",
  237. MinTLSVersion: "tls1.0",
  238. },
  239. },
  240. {
  241. name: "FromFile/InvalidCert",
  242. wantErr: true,
  243. http: Config{
  244. ListenAddr: []string{"127.0.0.1:0"},
  245. TLSCert: "./testdata/local.crt.invalid",
  246. TLSKey: "./testdata/local.key",
  247. MinTLSVersion: "tls1.0",
  248. },
  249. },
  250. {
  251. name: "FromFile/NoKey",
  252. wantErr: true,
  253. err: ErrTLSFileMismatch,
  254. http: Config{
  255. ListenAddr: []string{"127.0.0.1:0"},
  256. TLSCert: "./testdata/local.crt",
  257. TLSKey: "",
  258. MinTLSVersion: "tls1.0",
  259. },
  260. },
  261. {
  262. name: "FromFile/InvalidKey",
  263. wantErr: true,
  264. http: Config{
  265. ListenAddr: []string{"127.0.0.1:0"},
  266. TLSCert: "./testdata/local.crt",
  267. TLSKey: "./testdata/local.key.invalid",
  268. MinTLSVersion: "tls1.0",
  269. },
  270. },
  271. {
  272. name: "FromBody/Valid",
  273. http: Config{
  274. ListenAddr: []string{"127.0.0.1:0"},
  275. TLSCertBody: serverCertBytes,
  276. TLSKeyBody: serverKeyBytes,
  277. MinTLSVersion: "tls1.0",
  278. },
  279. },
  280. {
  281. name: "FromBody/NoCert",
  282. wantErr: true,
  283. err: ErrTLSBodyMismatch,
  284. http: Config{
  285. ListenAddr: []string{"127.0.0.1:0"},
  286. TLSCertBody: nil,
  287. TLSKeyBody: serverKeyBytes,
  288. MinTLSVersion: "tls1.0",
  289. },
  290. },
  291. {
  292. name: "FromBody/InvalidCert",
  293. wantErr: true,
  294. http: Config{
  295. ListenAddr: []string{"127.0.0.1:0"},
  296. TLSCertBody: []byte("JUNK DATA"),
  297. TLSKeyBody: serverKeyBytes,
  298. MinTLSVersion: "tls1.0",
  299. },
  300. },
  301. {
  302. name: "FromBody/NoKey",
  303. wantErr: true,
  304. err: ErrTLSBodyMismatch,
  305. http: Config{
  306. ListenAddr: []string{"127.0.0.1:0"},
  307. TLSCertBody: serverCertBytes,
  308. TLSKeyBody: nil,
  309. MinTLSVersion: "tls1.0",
  310. },
  311. },
  312. {
  313. name: "FromBody/InvalidKey",
  314. wantErr: true,
  315. http: Config{
  316. ListenAddr: []string{"127.0.0.1:0"},
  317. TLSCertBody: serverCertBytes,
  318. TLSKeyBody: []byte("JUNK DATA"),
  319. MinTLSVersion: "tls1.0",
  320. },
  321. },
  322. {
  323. name: "MinTLSVersion/Valid/1.1",
  324. http: Config{
  325. ListenAddr: []string{"127.0.0.1:0"},
  326. TLSCertBody: serverCertBytes,
  327. TLSKeyBody: serverKeyBytes,
  328. MinTLSVersion: "tls1.1",
  329. },
  330. },
  331. {
  332. name: "MinTLSVersion/Valid/1.2",
  333. http: Config{
  334. ListenAddr: []string{"127.0.0.1:0"},
  335. TLSCertBody: serverCertBytes,
  336. TLSKeyBody: serverKeyBytes,
  337. MinTLSVersion: "tls1.2",
  338. },
  339. },
  340. {
  341. name: "MinTLSVersion/Valid/1.3",
  342. http: Config{
  343. ListenAddr: []string{"127.0.0.1:0"},
  344. TLSCertBody: serverCertBytes,
  345. TLSKeyBody: serverKeyBytes,
  346. MinTLSVersion: "tls1.3",
  347. },
  348. },
  349. {
  350. name: "MinTLSVersion/Invalid",
  351. wantErr: true,
  352. err: ErrInvalidMinTLSVersion,
  353. http: Config{
  354. ListenAddr: []string{"127.0.0.1:0"},
  355. TLSCertBody: serverCertBytes,
  356. TLSKeyBody: serverKeyBytes,
  357. MinTLSVersion: "tls0.9",
  358. },
  359. },
  360. {
  361. name: "MutualTLS/InvalidCA",
  362. clientCerts: []tls.Certificate{clientCert},
  363. wantErr: true,
  364. http: Config{
  365. ListenAddr: []string{"127.0.0.1:0"},
  366. TLSCertBody: serverCertBytes,
  367. TLSKeyBody: serverKeyBytes,
  368. MinTLSVersion: "tls1.0",
  369. ClientCA: "./testdata/client-ca.crt.invalid",
  370. },
  371. },
  372. {
  373. name: "MutualTLS/InvalidClient",
  374. clientCerts: []tls.Certificate{},
  375. wantClientErr: true,
  376. http: Config{
  377. ListenAddr: []string{"127.0.0.1:0"},
  378. TLSCertBody: serverCertBytes,
  379. TLSKeyBody: serverKeyBytes,
  380. MinTLSVersion: "tls1.0",
  381. ClientCA: "./testdata/client-ca.crt",
  382. },
  383. },
  384. {
  385. name: "MutualTLS/Valid",
  386. clientCerts: []tls.Certificate{clientCert},
  387. http: Config{
  388. ListenAddr: []string{"127.0.0.1:0"},
  389. TLSCertBody: serverCertBytes,
  390. TLSKeyBody: serverKeyBytes,
  391. MinTLSVersion: "tls1.0",
  392. ClientCA: "./testdata/client-ca.crt",
  393. },
  394. },
  395. }
  396. for _, ss := range servers {
  397. t.Run(ss.name, func(t *testing.T) {
  398. s, err := NewServer(context.Background(), WithConfig(ss.http))
  399. if ss.wantErr == true {
  400. if ss.err != nil {
  401. require.ErrorIs(t, err, ss.err, "new server should return the expected error")
  402. } else {
  403. require.Error(t, err, "new server should return error for invalid TLS config")
  404. }
  405. return
  406. }
  407. require.NoError(t, err)
  408. defer func() {
  409. require.NoError(t, s.Shutdown())
  410. }()
  411. expected := []byte("secret-page")
  412. s.Router().Mount("/", testEchoHandler(expected))
  413. s.Serve()
  414. url := testGetServerURL(t, s)
  415. require.True(t, strings.HasPrefix(url, "https://"), "url should have https scheme")
  416. client := &http.Client{
  417. Transport: &http.Transport{
  418. DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
  419. dest := strings.TrimPrefix(url, "https://")
  420. dest = strings.TrimSuffix(dest, "/")
  421. return net.Dial("tcp", dest)
  422. },
  423. TLSClientConfig: &tls.Config{
  424. Certificates: ss.clientCerts,
  425. InsecureSkipVerify: true,
  426. },
  427. },
  428. }
  429. req, err := http.NewRequest("GET", "https://dev.rclone.org", nil)
  430. require.NoError(t, err)
  431. resp, err := client.Do(req)
  432. if ss.wantClientErr {
  433. require.Error(t, err, "new server client should return error")
  434. return
  435. }
  436. require.NoError(t, err)
  437. defer func() {
  438. _ = resp.Body.Close()
  439. }()
  440. require.Equal(t, http.StatusOK, resp.StatusCode, "should return ok")
  441. testExpectRespBody(t, resp, expected)
  442. })
  443. }
  444. }
  445. func TestHelpPrefixServer(t *testing.T) {
  446. // This test assumes template variables are placed correctly.
  447. const testPrefix = "server-help-test"
  448. helpMessage := Help(testPrefix)
  449. if !strings.Contains(helpMessage, testPrefix) {
  450. t.Fatal("flag prefix not found")
  451. }
  452. }