client.go 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. package tunnelstore
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "net/http"
  8. "net/url"
  9. "path"
  10. "strings"
  11. "time"
  12. "github.com/google/uuid"
  13. "github.com/pkg/errors"
  14. "github.com/cloudflare/cloudflared/logger"
  15. )
  16. const (
  17. defaultTimeout = 15 * time.Second
  18. jsonContentType = "application/json"
  19. )
  20. var (
  21. ErrTunnelNameConflict = errors.New("tunnel with name already exists")
  22. ErrUnauthorized = errors.New("unauthorized")
  23. ErrBadRequest = errors.New("incorrect request parameters")
  24. ErrNotFound = errors.New("not found")
  25. )
  26. type Tunnel struct {
  27. ID uuid.UUID `json:"id"`
  28. Name string `json:"name"`
  29. CreatedAt time.Time `json:"created_at"`
  30. DeletedAt time.Time `json:"deleted_at"`
  31. Connections []Connection `json:"connections"`
  32. }
  33. type Connection struct {
  34. ColoName string `json:"colo_name"`
  35. ID uuid.UUID `json:"uuid"`
  36. IsPendingReconnect bool `json:"is_pending_reconnect"`
  37. }
  38. // Route represents a record type that can route to a tunnel
  39. type Route interface {
  40. json.Marshaler
  41. RecordType() string
  42. // SuccessSummary explains what will route to this tunnel when it's provisioned successfully
  43. SuccessSummary() string
  44. }
  45. type DNSRoute struct {
  46. userHostname string
  47. }
  48. func NewDNSRoute(userHostname string) Route {
  49. return &DNSRoute{
  50. userHostname: userHostname,
  51. }
  52. }
  53. func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
  54. s := struct {
  55. Type string `json:"type"`
  56. UserHostname string `json:"user_hostname"`
  57. }{
  58. Type: dr.RecordType(),
  59. UserHostname: dr.userHostname,
  60. }
  61. return json.Marshal(&s)
  62. }
  63. func (dr *DNSRoute) RecordType() string {
  64. return "dns"
  65. }
  66. func (dr *DNSRoute) SuccessSummary() string {
  67. return fmt.Sprintf("%s will route to your tunnel", dr.userHostname)
  68. }
  69. type LBRoute struct {
  70. lbName string
  71. lbPool string
  72. }
  73. func NewLBRoute(lbName, lbPool string) Route {
  74. return &LBRoute{
  75. lbName: lbName,
  76. lbPool: lbPool,
  77. }
  78. }
  79. func (lr *LBRoute) MarshalJSON() ([]byte, error) {
  80. s := struct {
  81. Type string `json:"type"`
  82. LBName string `json:"lb_name"`
  83. LBPool string `json:"lb_pool"`
  84. }{
  85. Type: lr.RecordType(),
  86. LBName: lr.lbName,
  87. LBPool: lr.lbPool,
  88. }
  89. return json.Marshal(&s)
  90. }
  91. func (lr *LBRoute) RecordType() string {
  92. return "lb"
  93. }
  94. func (lr *LBRoute) SuccessSummary() string {
  95. return fmt.Sprintf("Load balancer %s will route to this tunnel through pool %s", lr.lbName, lr.lbPool)
  96. }
  97. type Client interface {
  98. CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
  99. GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
  100. DeleteTunnel(tunnelID uuid.UUID) error
  101. ListTunnels(filter *Filter) ([]*Tunnel, error)
  102. CleanupConnections(tunnelID uuid.UUID) error
  103. RouteTunnel(tunnelID uuid.UUID, route Route) error
  104. }
  105. type RESTClient struct {
  106. baseEndpoints *baseEndpoints
  107. authToken string
  108. userAgent string
  109. client http.Client
  110. logger logger.Service
  111. }
  112. type baseEndpoints struct {
  113. accountLevel url.URL
  114. zoneLevel url.URL
  115. }
  116. var _ Client = (*RESTClient)(nil)
  117. func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, logger logger.Service) (*RESTClient, error) {
  118. if strings.HasSuffix(baseURL, "/") {
  119. baseURL = baseURL[:len(baseURL)-1]
  120. }
  121. accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag))
  122. if err != nil {
  123. return nil, errors.Wrap(err, "failed to create account level endpoint")
  124. }
  125. zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
  126. if err != nil {
  127. return nil, errors.Wrap(err, "failed to create account level endpoint")
  128. }
  129. return &RESTClient{
  130. baseEndpoints: &baseEndpoints{
  131. accountLevel: *accountLevelEndpoint,
  132. zoneLevel: *zoneLevelEndpoint,
  133. },
  134. authToken: authToken,
  135. userAgent: userAgent,
  136. client: http.Client{
  137. Transport: &http.Transport{
  138. TLSHandshakeTimeout: defaultTimeout,
  139. ResponseHeaderTimeout: defaultTimeout,
  140. },
  141. Timeout: defaultTimeout,
  142. },
  143. logger: logger,
  144. }, nil
  145. }
  146. type newTunnel struct {
  147. Name string `json:"name"`
  148. TunnelSecret []byte `json:"tunnel_secret"`
  149. }
  150. func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
  151. if name == "" {
  152. return nil, errors.New("tunnel name required")
  153. }
  154. if _, err := uuid.Parse(name); err == nil {
  155. return nil, errors.New("you cannot use UUIDs as tunnel names")
  156. }
  157. body := &newTunnel{
  158. Name: name,
  159. TunnelSecret: tunnelSecret,
  160. }
  161. resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, body)
  162. if err != nil {
  163. return nil, errors.Wrap(err, "REST request failed")
  164. }
  165. defer resp.Body.Close()
  166. switch resp.StatusCode {
  167. case http.StatusOK:
  168. return unmarshalTunnel(resp.Body)
  169. case http.StatusConflict:
  170. return nil, ErrTunnelNameConflict
  171. }
  172. return nil, r.statusCodeToError("create tunnel", resp)
  173. }
  174. func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
  175. endpoint := r.baseEndpoints.accountLevel
  176. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
  177. resp, err := r.sendRequest("GET", endpoint, nil)
  178. if err != nil {
  179. return nil, errors.Wrap(err, "REST request failed")
  180. }
  181. defer resp.Body.Close()
  182. if resp.StatusCode == http.StatusOK {
  183. return unmarshalTunnel(resp.Body)
  184. }
  185. return nil, r.statusCodeToError("get tunnel", resp)
  186. }
  187. func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
  188. endpoint := r.baseEndpoints.accountLevel
  189. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
  190. resp, err := r.sendRequest("DELETE", endpoint, nil)
  191. if err != nil {
  192. return errors.Wrap(err, "REST request failed")
  193. }
  194. defer resp.Body.Close()
  195. return r.statusCodeToError("delete tunnel", resp)
  196. }
  197. func (r *RESTClient) ListTunnels(filter *Filter) ([]*Tunnel, error) {
  198. endpoint := r.baseEndpoints.accountLevel
  199. endpoint.RawQuery = filter.encode()
  200. resp, err := r.sendRequest("GET", endpoint, nil)
  201. if err != nil {
  202. return nil, errors.Wrap(err, "REST request failed")
  203. }
  204. defer resp.Body.Close()
  205. if resp.StatusCode == http.StatusOK {
  206. var tunnels []*Tunnel
  207. if err := json.NewDecoder(resp.Body).Decode(&tunnels); err != nil {
  208. return nil, errors.Wrap(err, "failed to decode response")
  209. }
  210. return tunnels, nil
  211. }
  212. return nil, r.statusCodeToError("list tunnels", resp)
  213. }
  214. func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error {
  215. endpoint := r.baseEndpoints.accountLevel
  216. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID))
  217. resp, err := r.sendRequest("DELETE", endpoint, nil)
  218. if err != nil {
  219. return errors.Wrap(err, "REST request failed")
  220. }
  221. defer resp.Body.Close()
  222. return r.statusCodeToError("cleanup connections", resp)
  223. }
  224. func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) error {
  225. endpoint := r.baseEndpoints.zoneLevel
  226. endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
  227. resp, err := r.sendRequest("PUT", endpoint, route)
  228. if err != nil {
  229. return errors.Wrap(err, "REST request failed")
  230. }
  231. defer resp.Body.Close()
  232. return r.statusCodeToError("add route", resp)
  233. }
  234. func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) {
  235. var bodyReader io.Reader
  236. if body != nil {
  237. if bodyBytes, err := json.Marshal(body); err != nil {
  238. return nil, errors.Wrap(err, "failed to serialize json body")
  239. } else {
  240. bodyReader = bytes.NewBuffer(bodyBytes)
  241. }
  242. }
  243. req, err := http.NewRequest(method, url.String(), bodyReader)
  244. if err != nil {
  245. return nil, errors.Wrapf(err, "can't create %s request", method)
  246. }
  247. req.Header.Set("User-Agent", r.userAgent)
  248. if bodyReader != nil {
  249. req.Header.Set("Content-Type", jsonContentType)
  250. }
  251. req.Header.Add("X-Auth-User-Service-Key", r.authToken)
  252. return r.client.Do(req)
  253. }
  254. func unmarshalTunnel(reader io.Reader) (*Tunnel, error) {
  255. var tunnel Tunnel
  256. if err := json.NewDecoder(reader).Decode(&tunnel); err != nil {
  257. return nil, errors.Wrap(err, "failed to decode response")
  258. }
  259. return &tunnel, nil
  260. }
  261. func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
  262. if resp.Header.Get("Content-Type") == "application/json" {
  263. var errorsResp struct{
  264. Error string `json:"error"`
  265. }
  266. if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil && errorsResp.Error != ""{
  267. return errors.Errorf("Failed to %s: %s", op, errorsResp.Error)
  268. }
  269. }
  270. switch resp.StatusCode {
  271. case http.StatusOK:
  272. return nil
  273. case http.StatusBadRequest:
  274. return ErrBadRequest
  275. case http.StatusUnauthorized, http.StatusForbidden:
  276. return ErrUnauthorized
  277. case http.StatusNotFound:
  278. return ErrNotFound
  279. }
  280. return errors.Errorf("API call to %s failed with status %d: %s", op,
  281. resp.StatusCode, http.StatusText(resp.StatusCode))
  282. }