cors.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. /*
  2. Package cors is net/http handler to handle CORS related requests
  3. as defined by http://www.w3.org/TR/cors/
  4. You can configure it by passing an option struct to cors.New:
  5. c := cors.New(cors.Options{
  6. AllowedOrigins: []string{"foo.com"},
  7. AllowedMethods: []string{"GET", "POST", "DELETE"},
  8. AllowCredentials: true,
  9. })
  10. Then insert the handler in the chain:
  11. handler = c.Handler(handler)
  12. See Options documentation for more options.
  13. The resulting handler is a standard net/http handler.
  14. */
  15. package cors
  16. import (
  17. "log"
  18. "net/http"
  19. "os"
  20. "strconv"
  21. "strings"
  22. "github.com/rs/xhandler"
  23. "golang.org/x/net/context"
  24. )
  25. // Options is a configuration container to setup the CORS middleware.
  26. type Options struct {
  27. // AllowedOrigins is a list of origins a cross-domain request can be executed from.
  28. // If the special "*" value is present in the list, all origins will be allowed.
  29. // An origin may contain a wildcard (*) to replace 0 or more characters
  30. // (i.e.: http://*.domain.com). Usage of wildcards implies a small performance penality.
  31. // Only one wildcard can be used per origin.
  32. // Default value is ["*"]
  33. AllowedOrigins []string
  34. // AllowOriginFunc is a custom function to validate the origin. It take the origin
  35. // as argument and returns true if allowed or false otherwise. If this option is
  36. // set, the content of AllowedOrigins is ignored.
  37. AllowOriginFunc func(origin string) bool
  38. // AllowedMethods is a list of methods the client is allowed to use with
  39. // cross-domain requests. Default value is simple methods (GET and POST)
  40. AllowedMethods []string
  41. // AllowedHeaders is list of non simple headers the client is allowed to use with
  42. // cross-domain requests.
  43. // If the special "*" value is present in the list, all headers will be allowed.
  44. // Default value is [] but "Origin" is always appended to the list.
  45. AllowedHeaders []string
  46. // ExposedHeaders indicates which headers are safe to expose to the API of a CORS
  47. // API specification
  48. ExposedHeaders []string
  49. // AllowCredentials indicates whether the request can include user credentials like
  50. // cookies, HTTP authentication or client side SSL certificates.
  51. AllowCredentials bool
  52. // MaxAge indicates how long (in seconds) the results of a preflight request
  53. // can be cached
  54. MaxAge int
  55. // OptionsPassthrough instructs preflight to let other potential next handlers to
  56. // process the OPTIONS method. Turn this on if your application handles OPTIONS.
  57. OptionsPassthrough bool
  58. // Debugging flag adds additional output to debug server side CORS issues
  59. Debug bool
  60. }
  61. // Cors http handler
  62. type Cors struct {
  63. // Debug logger
  64. Log *log.Logger
  65. // Set to true when allowed origins contains a "*"
  66. allowedOriginsAll bool
  67. // Normalized list of plain allowed origins
  68. allowedOrigins []string
  69. // List of allowed origins containing wildcards
  70. allowedWOrigins []wildcard
  71. // Optional origin validator function
  72. allowOriginFunc func(origin string) bool
  73. // Set to true when allowed headers contains a "*"
  74. allowedHeadersAll bool
  75. // Normalized list of allowed headers
  76. allowedHeaders []string
  77. // Normalized list of allowed methods
  78. allowedMethods []string
  79. // Normalized list of exposed headers
  80. exposedHeaders []string
  81. allowCredentials bool
  82. maxAge int
  83. optionPassthrough bool
  84. }
  85. // New creates a new Cors handler with the provided options.
  86. func New(options Options) *Cors {
  87. c := &Cors{
  88. exposedHeaders: convert(options.ExposedHeaders, http.CanonicalHeaderKey),
  89. allowOriginFunc: options.AllowOriginFunc,
  90. allowCredentials: options.AllowCredentials,
  91. maxAge: options.MaxAge,
  92. optionPassthrough: options.OptionsPassthrough,
  93. }
  94. if options.Debug {
  95. c.Log = log.New(os.Stdout, "[cors] ", log.LstdFlags)
  96. }
  97. // Normalize options
  98. // Note: for origins and methods matching, the spec requires a case-sensitive matching.
  99. // As it may error prone, we chose to ignore the spec here.
  100. // Allowed Origins
  101. if len(options.AllowedOrigins) == 0 {
  102. // Default is all origins
  103. c.allowedOriginsAll = true
  104. } else {
  105. c.allowedOrigins = []string{}
  106. c.allowedWOrigins = []wildcard{}
  107. for _, origin := range options.AllowedOrigins {
  108. // Normalize
  109. origin = strings.ToLower(origin)
  110. if origin == "*" {
  111. // If "*" is present in the list, turn the whole list into a match all
  112. c.allowedOriginsAll = true
  113. c.allowedOrigins = nil
  114. c.allowedWOrigins = nil
  115. break
  116. } else if i := strings.IndexByte(origin, '*'); i >= 0 {
  117. // Split the origin in two: start and end string without the *
  118. w := wildcard{origin[0:i], origin[i+1 : len(origin)]}
  119. c.allowedWOrigins = append(c.allowedWOrigins, w)
  120. } else {
  121. c.allowedOrigins = append(c.allowedOrigins, origin)
  122. }
  123. }
  124. }
  125. // Allowed Headers
  126. if len(options.AllowedHeaders) == 0 {
  127. // Use sensible defaults
  128. c.allowedHeaders = []string{"Origin", "Accept", "Content-Type"}
  129. } else {
  130. // Origin is always appended as some browsers will always request for this header at preflight
  131. c.allowedHeaders = convert(append(options.AllowedHeaders, "Origin"), http.CanonicalHeaderKey)
  132. for _, h := range options.AllowedHeaders {
  133. if h == "*" {
  134. c.allowedHeadersAll = true
  135. c.allowedHeaders = nil
  136. break
  137. }
  138. }
  139. }
  140. // Allowed Methods
  141. if len(options.AllowedMethods) == 0 {
  142. // Default is spec's "simple" methods
  143. c.allowedMethods = []string{"GET", "POST"}
  144. } else {
  145. c.allowedMethods = convert(options.AllowedMethods, strings.ToUpper)
  146. }
  147. return c
  148. }
  149. // Default creates a new Cors handler with default options
  150. func Default() *Cors {
  151. return New(Options{})
  152. }
  153. // Handler apply the CORS specification on the request, and add relevant CORS headers
  154. // as necessary.
  155. func (c *Cors) Handler(h http.Handler) http.Handler {
  156. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  157. if r.Method == "OPTIONS" {
  158. c.logf("Handler: Preflight request")
  159. c.handlePreflight(w, r)
  160. // Preflight requests are standalone and should stop the chain as some other
  161. // middleware may not handle OPTIONS requests correctly. One typical example
  162. // is authentication middleware ; OPTIONS requests won't carry authentication
  163. // headers (see #1)
  164. if c.optionPassthrough {
  165. h.ServeHTTP(w, r)
  166. } else {
  167. w.WriteHeader(http.StatusOK)
  168. }
  169. } else {
  170. c.logf("Handler: Actual request")
  171. c.handleActualRequest(w, r)
  172. h.ServeHTTP(w, r)
  173. }
  174. })
  175. }
  176. // HandlerC is net/context aware handler
  177. func (c *Cors) HandlerC(h xhandler.HandlerC) xhandler.HandlerC {
  178. return xhandler.HandlerFuncC(func(ctx context.Context, w http.ResponseWriter, r *http.Request) {
  179. if r.Method == "OPTIONS" {
  180. c.logf("Handler: Preflight request")
  181. c.handlePreflight(w, r)
  182. // Preflight requests are standalone and should stop the chain as some other
  183. // middleware may not handle OPTIONS requests correctly. One typical example
  184. // is authentication middleware ; OPTIONS requests won't carry authentication
  185. // headers (see #1)
  186. if c.optionPassthrough {
  187. h.ServeHTTPC(ctx, w, r)
  188. } else {
  189. w.WriteHeader(http.StatusOK)
  190. }
  191. } else {
  192. c.logf("Handler: Actual request")
  193. c.handleActualRequest(w, r)
  194. h.ServeHTTPC(ctx, w, r)
  195. }
  196. })
  197. }
  198. // HandlerFunc provides Martini compatible handler
  199. func (c *Cors) HandlerFunc(w http.ResponseWriter, r *http.Request) {
  200. if r.Method == "OPTIONS" {
  201. c.logf("HandlerFunc: Preflight request")
  202. c.handlePreflight(w, r)
  203. } else {
  204. c.logf("HandlerFunc: Actual request")
  205. c.handleActualRequest(w, r)
  206. }
  207. }
  208. // Negroni compatible interface
  209. func (c *Cors) ServeHTTP(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
  210. if r.Method == "OPTIONS" {
  211. c.logf("ServeHTTP: Preflight request")
  212. c.handlePreflight(w, r)
  213. // Preflight requests are standalone and should stop the chain as some other
  214. // middleware may not handle OPTIONS requests correctly. One typical example
  215. // is authentication middleware ; OPTIONS requests won't carry authentication
  216. // headers (see #1)
  217. if c.optionPassthrough {
  218. next(w, r)
  219. } else {
  220. w.WriteHeader(http.StatusOK)
  221. }
  222. } else {
  223. c.logf("ServeHTTP: Actual request")
  224. c.handleActualRequest(w, r)
  225. next(w, r)
  226. }
  227. }
  228. // handlePreflight handles pre-flight CORS requests
  229. func (c *Cors) handlePreflight(w http.ResponseWriter, r *http.Request) {
  230. headers := w.Header()
  231. origin := r.Header.Get("Origin")
  232. if r.Method != "OPTIONS" {
  233. c.logf(" Preflight aborted: %s!=OPTIONS", r.Method)
  234. return
  235. }
  236. // Always set Vary headers
  237. // see https://github.com/rs/cors/issues/10,
  238. // https://github.com/rs/cors/commit/dbdca4d95feaa7511a46e6f1efb3b3aa505bc43f#commitcomment-12352001
  239. headers.Add("Vary", "Origin")
  240. headers.Add("Vary", "Access-Control-Request-Method")
  241. headers.Add("Vary", "Access-Control-Request-Headers")
  242. if origin == "" {
  243. c.logf(" Preflight aborted: empty origin")
  244. return
  245. }
  246. if !c.isOriginAllowed(origin) {
  247. c.logf(" Preflight aborted: origin '%s' not allowed", origin)
  248. return
  249. }
  250. reqMethod := r.Header.Get("Access-Control-Request-Method")
  251. if !c.isMethodAllowed(reqMethod) {
  252. c.logf(" Preflight aborted: method '%s' not allowed", reqMethod)
  253. return
  254. }
  255. reqHeaders := parseHeaderList(r.Header.Get("Access-Control-Request-Headers"))
  256. if !c.areHeadersAllowed(reqHeaders) {
  257. c.logf(" Preflight aborted: headers '%v' not allowed", reqHeaders)
  258. return
  259. }
  260. headers.Set("Access-Control-Allow-Origin", origin)
  261. // Spec says: Since the list of methods can be unbounded, simply returning the method indicated
  262. // by Access-Control-Request-Method (if supported) can be enough
  263. headers.Set("Access-Control-Allow-Methods", strings.ToUpper(reqMethod))
  264. if len(reqHeaders) > 0 {
  265. // Spec says: Since the list of headers can be unbounded, simply returning supported headers
  266. // from Access-Control-Request-Headers can be enough
  267. headers.Set("Access-Control-Allow-Headers", strings.Join(reqHeaders, ", "))
  268. }
  269. if c.allowCredentials {
  270. headers.Set("Access-Control-Allow-Credentials", "true")
  271. }
  272. if c.maxAge > 0 {
  273. headers.Set("Access-Control-Max-Age", strconv.Itoa(c.maxAge))
  274. }
  275. c.logf(" Preflight response headers: %v", headers)
  276. }
  277. // handleActualRequest handles simple cross-origin requests, actual request or redirects
  278. func (c *Cors) handleActualRequest(w http.ResponseWriter, r *http.Request) {
  279. headers := w.Header()
  280. origin := r.Header.Get("Origin")
  281. if r.Method == "OPTIONS" {
  282. c.logf(" Actual request no headers added: method == %s", r.Method)
  283. return
  284. }
  285. // Always set Vary, see https://github.com/rs/cors/issues/10
  286. headers.Add("Vary", "Origin")
  287. if origin == "" {
  288. c.logf(" Actual request no headers added: missing origin")
  289. return
  290. }
  291. if !c.isOriginAllowed(origin) {
  292. c.logf(" Actual request no headers added: origin '%s' not allowed", origin)
  293. return
  294. }
  295. // Note that spec does define a way to specifically disallow a simple method like GET or
  296. // POST. Access-Control-Allow-Methods is only used for pre-flight requests and the
  297. // spec doesn't instruct to check the allowed methods for simple cross-origin requests.
  298. // We think it's a nice feature to be able to have control on those methods though.
  299. if !c.isMethodAllowed(r.Method) {
  300. c.logf(" Actual request no headers added: method '%s' not allowed", r.Method)
  301. return
  302. }
  303. headers.Set("Access-Control-Allow-Origin", origin)
  304. if len(c.exposedHeaders) > 0 {
  305. headers.Set("Access-Control-Expose-Headers", strings.Join(c.exposedHeaders, ", "))
  306. }
  307. if c.allowCredentials {
  308. headers.Set("Access-Control-Allow-Credentials", "true")
  309. }
  310. c.logf(" Actual response added headers: %v", headers)
  311. }
  312. // convenience method. checks if debugging is turned on before printing
  313. func (c *Cors) logf(format string, a ...interface{}) {
  314. if c.Log != nil {
  315. c.Log.Printf(format, a...)
  316. }
  317. }
  318. // isOriginAllowed checks if a given origin is allowed to perform cross-domain requests
  319. // on the endpoint
  320. func (c *Cors) isOriginAllowed(origin string) bool {
  321. if c.allowOriginFunc != nil {
  322. return c.allowOriginFunc(origin)
  323. }
  324. if c.allowedOriginsAll {
  325. return true
  326. }
  327. origin = strings.ToLower(origin)
  328. for _, o := range c.allowedOrigins {
  329. if o == origin {
  330. return true
  331. }
  332. }
  333. for _, w := range c.allowedWOrigins {
  334. if w.match(origin) {
  335. return true
  336. }
  337. }
  338. return false
  339. }
  340. // isMethodAllowed checks if a given method can be used as part of a cross-domain request
  341. // on the endpoing
  342. func (c *Cors) isMethodAllowed(method string) bool {
  343. if len(c.allowedMethods) == 0 {
  344. // If no method allowed, always return false, even for preflight request
  345. return false
  346. }
  347. method = strings.ToUpper(method)
  348. if method == "OPTIONS" {
  349. // Always allow preflight requests
  350. return true
  351. }
  352. for _, m := range c.allowedMethods {
  353. if m == method {
  354. return true
  355. }
  356. }
  357. return false
  358. }
  359. // areHeadersAllowed checks if a given list of headers are allowed to used within
  360. // a cross-domain request.
  361. func (c *Cors) areHeadersAllowed(requestedHeaders []string) bool {
  362. if c.allowedHeadersAll || len(requestedHeaders) == 0 {
  363. return true
  364. }
  365. for _, header := range requestedHeaders {
  366. header = http.CanonicalHeaderKey(header)
  367. found := false
  368. for _, h := range c.allowedHeaders {
  369. if h == header {
  370. found = true
  371. }
  372. }
  373. if !found {
  374. return false
  375. }
  376. }
  377. return true
  378. }