validation.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. package validation
  2. import (
  3. "context"
  4. "fmt"
  5. "net"
  6. "net/http"
  7. "net/url"
  8. "strings"
  9. "time"
  10. "github.com/pkg/errors"
  11. "golang.org/x/net/idna"
  12. "gopkg.in/coreos/go-oidc.v2"
  13. )
  14. const (
  15. defaultScheme = "http"
  16. accessDomain = "cloudflareaccess.com"
  17. accessCertPath = "/cdn-cgi/access/certs"
  18. accessJwtHeader = "Cf-access-jwt-assertion"
  19. )
  20. var (
  21. supportedProtocols = []string{"http", "https", "rdp", "ssh", "smb", "tcp"}
  22. validationTimeout = time.Duration(30 * time.Second)
  23. )
  24. func ValidateHostname(hostname string) (string, error) {
  25. if hostname == "" {
  26. return "", nil
  27. }
  28. // users gives url(contains schema) not just hostname
  29. if strings.Contains(hostname, ":") || strings.Contains(hostname, "%3A") {
  30. unescapeHostname, err := url.PathUnescape(hostname)
  31. if err != nil {
  32. return "", fmt.Errorf("Hostname(actually a URL) %s has invalid escape characters %s", hostname, unescapeHostname)
  33. }
  34. hostnameToURL, err := url.Parse(unescapeHostname)
  35. if err != nil {
  36. return "", fmt.Errorf("Hostname(actually a URL) %s has invalid format %s", hostname, hostnameToURL)
  37. }
  38. asciiHostname, err := idna.ToASCII(hostnameToURL.Hostname())
  39. if err != nil {
  40. return "", fmt.Errorf("Hostname(actually a URL) %s has invalid ASCII encdoing %s", hostname, asciiHostname)
  41. }
  42. return asciiHostname, nil
  43. }
  44. asciiHostname, err := idna.ToASCII(hostname)
  45. if err != nil {
  46. return "", fmt.Errorf("Hostname %s has invalid ASCII encdoing %s", hostname, asciiHostname)
  47. }
  48. hostnameToURL, err := url.Parse(asciiHostname)
  49. if err != nil {
  50. return "", fmt.Errorf("Hostname %s is not valid", hostnameToURL)
  51. }
  52. return hostnameToURL.RequestURI(), nil
  53. }
  54. // ValidateUrl returns a validated version of `originUrl` with a scheme prepended (by default http://).
  55. // Note: when originUrl contains a scheme, the path is removed:
  56. // ValidateUrl("https://localhost:8080/api/") => "https://localhost:8080"
  57. // but when it does not, the path is preserved:
  58. // ValidateUrl("localhost:8080/api/") => "http://localhost:8080/api/"
  59. // This is arguably a bug, but changing it might break some cloudflared users.
  60. func ValidateUrl(originUrl string) (*url.URL, error) {
  61. urlStr, err := validateUrlString(originUrl)
  62. if err != nil {
  63. return nil, err
  64. }
  65. return url.Parse(urlStr)
  66. }
  67. func validateUrlString(originUrl string) (string, error) {
  68. if originUrl == "" {
  69. return "", fmt.Errorf("URL should not be empty")
  70. }
  71. if net.ParseIP(originUrl) != nil {
  72. return validateIP("", originUrl, "")
  73. } else if strings.HasPrefix(originUrl, "[") && strings.HasSuffix(originUrl, "]") {
  74. // ParseIP doesn't recoginze [::1]
  75. return validateIP("", originUrl[1:len(originUrl)-1], "")
  76. }
  77. host, port, err := net.SplitHostPort(originUrl)
  78. // user might pass in an ip address like 127.0.0.1
  79. if err == nil && net.ParseIP(host) != nil {
  80. return validateIP("", host, port)
  81. }
  82. unescapedUrl, err := url.PathUnescape(originUrl)
  83. if err != nil {
  84. return "", fmt.Errorf("URL %s has invalid escape characters %s", originUrl, unescapedUrl)
  85. }
  86. parsedUrl, err := url.Parse(unescapedUrl)
  87. if err != nil {
  88. return "", fmt.Errorf("URL %s has invalid format", originUrl)
  89. }
  90. // if the url is in the form of host:port, IsAbs() will think host is the schema
  91. var hostname string
  92. hasScheme := parsedUrl.IsAbs() && parsedUrl.Host != ""
  93. if hasScheme {
  94. err := validateScheme(parsedUrl.Scheme)
  95. if err != nil {
  96. return "", err
  97. }
  98. // The earlier check for ip address will miss the case http://[::1]
  99. // and http://[::1]:8080
  100. if net.ParseIP(parsedUrl.Hostname()) != nil {
  101. return validateIP(parsedUrl.Scheme, parsedUrl.Hostname(), parsedUrl.Port())
  102. }
  103. hostname, err = ValidateHostname(parsedUrl.Hostname())
  104. if err != nil {
  105. return "", fmt.Errorf("URL %s has invalid format", originUrl)
  106. }
  107. if parsedUrl.Port() != "" {
  108. return fmt.Sprintf("%s://%s", parsedUrl.Scheme, net.JoinHostPort(hostname, parsedUrl.Port())), nil
  109. }
  110. return fmt.Sprintf("%s://%s", parsedUrl.Scheme, hostname), nil
  111. } else {
  112. if host == "" {
  113. hostname, err = ValidateHostname(originUrl)
  114. if err != nil {
  115. return "", fmt.Errorf("URL no %s has invalid format", originUrl)
  116. }
  117. return fmt.Sprintf("%s://%s", defaultScheme, hostname), nil
  118. } else {
  119. hostname, err = ValidateHostname(host)
  120. if err != nil {
  121. return "", fmt.Errorf("URL %s has invalid format", originUrl)
  122. }
  123. // This is why the path is preserved when `originUrl` doesn't have a schema.
  124. // Using `parsedUrl.Port()` here, instead of `port`, would remove the path
  125. return fmt.Sprintf("%s://%s", defaultScheme, net.JoinHostPort(hostname, port)), nil
  126. }
  127. }
  128. }
  129. func validateScheme(scheme string) error {
  130. for _, protocol := range supportedProtocols {
  131. if scheme == protocol {
  132. return nil
  133. }
  134. }
  135. return fmt.Errorf("Currently Cloudflare Tunnel does not support %s protocol.", scheme)
  136. }
  137. func validateIP(scheme, host, port string) (string, error) {
  138. if scheme == "" {
  139. scheme = defaultScheme
  140. }
  141. if port != "" {
  142. return fmt.Sprintf("%s://%s", scheme, net.JoinHostPort(host, port)), nil
  143. } else if strings.Contains(host, ":") {
  144. // IPv6
  145. return fmt.Sprintf("%s://[%s]", scheme, host), nil
  146. }
  147. return fmt.Sprintf("%s://%s", scheme, host), nil
  148. }
  149. // originURL shouldn't be a pointer, because this function might change the scheme
  150. func ValidateHTTPService(originURL string, hostname string, transport http.RoundTripper) error {
  151. parsedURL, err := url.Parse(originURL)
  152. if err != nil {
  153. return err
  154. }
  155. client := &http.Client{
  156. Transport: transport,
  157. CheckRedirect: func(req *http.Request, via []*http.Request) error {
  158. return http.ErrUseLastResponse
  159. },
  160. Timeout: validationTimeout,
  161. }
  162. initialRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
  163. if err != nil {
  164. return err
  165. }
  166. initialRequest.Host = hostname
  167. resp, initialErr := client.Do(initialRequest)
  168. if initialErr == nil {
  169. resp.Body.Close()
  170. return nil
  171. }
  172. // Attempt the same endpoint via the other protocol (http/https); maybe we have better luck?
  173. oldScheme := parsedURL.Scheme
  174. parsedURL.Scheme = toggleProtocol(oldScheme)
  175. secondRequest, err := http.NewRequest("GET", parsedURL.String(), nil)
  176. if err != nil {
  177. return err
  178. }
  179. secondRequest.Host = hostname
  180. resp, secondErr := client.Do(secondRequest)
  181. if secondErr == nil { // Worked this time--advise the user to switch protocols
  182. _ = resp.Body.Close()
  183. return errors.Errorf(
  184. "%s doesn't seem to work over %s, but does seem to work over %s. Reason: %v. Consider changing the origin URL to %v",
  185. parsedURL.Host,
  186. oldScheme,
  187. parsedURL.Scheme,
  188. initialErr,
  189. originURL,
  190. )
  191. }
  192. return initialErr
  193. }
  194. func toggleProtocol(httpProtocol string) string {
  195. switch httpProtocol {
  196. case "http":
  197. return "https"
  198. case "https":
  199. return "http"
  200. default:
  201. return httpProtocol
  202. }
  203. }
  204. // Access checks if a JWT from Cloudflare Access is valid.
  205. type Access struct {
  206. verifier *oidc.IDTokenVerifier
  207. }
  208. func NewAccessValidator(ctx context.Context, domain, issuer, applicationAUD string) (*Access, error) {
  209. domainURL, err := validateUrlString(domain)
  210. if err != nil {
  211. return nil, err
  212. }
  213. issuerURL, err := validateUrlString(issuer)
  214. if err != nil {
  215. return nil, err
  216. }
  217. // An issuerURL from Cloudflare Access will always use HTTPS.
  218. issuerURL = strings.Replace(issuerURL, "http:", "https:", 1)
  219. keySet := oidc.NewRemoteKeySet(ctx, domainURL+accessCertPath)
  220. return &Access{oidc.NewVerifier(issuerURL, keySet, &oidc.Config{ClientID: applicationAUD})}, nil
  221. }
  222. func (a *Access) Validate(ctx context.Context, jwt string) error {
  223. token, err := a.verifier.Verify(ctx, jwt)
  224. if err != nil {
  225. return errors.Wrapf(err, "token is invalid: %s", jwt)
  226. }
  227. // Perform extra sanity checks, just to be safe.
  228. if token == nil {
  229. return fmt.Errorf("token is nil: %s", jwt)
  230. }
  231. if !strings.HasSuffix(token.Issuer, accessDomain) {
  232. return fmt.Errorf("token has non-cloudflare issuer of %s: %s", token.Issuer, jwt)
  233. }
  234. return nil
  235. }
  236. func (a *Access) ValidateRequest(ctx context.Context, r *http.Request) error {
  237. return a.Validate(ctx, r.Header.Get(accessJwtHeader))
  238. }