123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328 |
- package tunnelstore
- import (
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "net/url"
- "path"
- "strings"
- "time"
- "github.com/google/uuid"
- "github.com/pkg/errors"
- "github.com/cloudflare/cloudflared/logger"
- )
- const (
- defaultTimeout = 15 * time.Second
- jsonContentType = "application/json"
- )
- var (
- ErrTunnelNameConflict = errors.New("tunnel with name already exists")
- ErrUnauthorized = errors.New("unauthorized")
- ErrBadRequest = errors.New("incorrect request parameters")
- ErrNotFound = errors.New("not found")
- )
- type Tunnel struct {
- ID uuid.UUID `json:"id"`
- Name string `json:"name"`
- CreatedAt time.Time `json:"created_at"`
- DeletedAt time.Time `json:"deleted_at"`
- Connections []Connection `json:"connections"`
- }
- type Connection struct {
- ColoName string `json:"colo_name"`
- ID uuid.UUID `json:"uuid"`
- IsPendingReconnect bool `json:"is_pending_reconnect"`
- }
- // Route represents a record type that can route to a tunnel
- type Route interface {
- json.Marshaler
- RecordType() string
- // SuccessSummary explains what will route to this tunnel when it's provisioned successfully
- SuccessSummary() string
- }
- type DNSRoute struct {
- userHostname string
- }
- func NewDNSRoute(userHostname string) Route {
- return &DNSRoute{
- userHostname: userHostname,
- }
- }
- func (dr *DNSRoute) MarshalJSON() ([]byte, error) {
- s := struct {
- Type string `json:"type"`
- UserHostname string `json:"user_hostname"`
- }{
- Type: dr.RecordType(),
- UserHostname: dr.userHostname,
- }
- return json.Marshal(&s)
- }
- func (dr *DNSRoute) RecordType() string {
- return "dns"
- }
- func (dr *DNSRoute) SuccessSummary() string {
- return fmt.Sprintf("%s will route to your tunnel", dr.userHostname)
- }
- type LBRoute struct {
- lbName string
- lbPool string
- }
- func NewLBRoute(lbName, lbPool string) Route {
- return &LBRoute{
- lbName: lbName,
- lbPool: lbPool,
- }
- }
- func (lr *LBRoute) MarshalJSON() ([]byte, error) {
- s := struct {
- Type string `json:"type"`
- LBName string `json:"lb_name"`
- LBPool string `json:"lb_pool"`
- }{
- Type: lr.RecordType(),
- LBName: lr.lbName,
- LBPool: lr.lbPool,
- }
- return json.Marshal(&s)
- }
- func (lr *LBRoute) RecordType() string {
- return "lb"
- }
- func (lr *LBRoute) SuccessSummary() string {
- return fmt.Sprintf("Load balancer %s will route to this tunnel through pool %s", lr.lbName, lr.lbPool)
- }
- type Client interface {
- CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error)
- GetTunnel(tunnelID uuid.UUID) (*Tunnel, error)
- DeleteTunnel(tunnelID uuid.UUID) error
- ListTunnels(filter *Filter) ([]*Tunnel, error)
- CleanupConnections(tunnelID uuid.UUID) error
- RouteTunnel(tunnelID uuid.UUID, route Route) error
- }
- type RESTClient struct {
- baseEndpoints *baseEndpoints
- authToken string
- userAgent string
- client http.Client
- logger logger.Service
- }
- type baseEndpoints struct {
- accountLevel url.URL
- zoneLevel url.URL
- }
- var _ Client = (*RESTClient)(nil)
- func NewRESTClient(baseURL, accountTag, zoneTag, authToken, userAgent string, logger logger.Service) (*RESTClient, error) {
- if strings.HasSuffix(baseURL, "/") {
- baseURL = baseURL[:len(baseURL)-1]
- }
- accountLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/accounts/%s/tunnels", baseURL, accountTag))
- if err != nil {
- return nil, errors.Wrap(err, "failed to create account level endpoint")
- }
- zoneLevelEndpoint, err := url.Parse(fmt.Sprintf("%s/zones/%s/tunnels", baseURL, zoneTag))
- if err != nil {
- return nil, errors.Wrap(err, "failed to create account level endpoint")
- }
- return &RESTClient{
- baseEndpoints: &baseEndpoints{
- accountLevel: *accountLevelEndpoint,
- zoneLevel: *zoneLevelEndpoint,
- },
- authToken: authToken,
- userAgent: userAgent,
- client: http.Client{
- Transport: &http.Transport{
- TLSHandshakeTimeout: defaultTimeout,
- ResponseHeaderTimeout: defaultTimeout,
- },
- Timeout: defaultTimeout,
- },
- logger: logger,
- }, nil
- }
- type newTunnel struct {
- Name string `json:"name"`
- TunnelSecret []byte `json:"tunnel_secret"`
- }
- func (r *RESTClient) CreateTunnel(name string, tunnelSecret []byte) (*Tunnel, error) {
- if name == "" {
- return nil, errors.New("tunnel name required")
- }
- if _, err := uuid.Parse(name); err == nil {
- return nil, errors.New("you cannot use UUIDs as tunnel names")
- }
- body := &newTunnel{
- Name: name,
- TunnelSecret: tunnelSecret,
- }
- resp, err := r.sendRequest("POST", r.baseEndpoints.accountLevel, body)
- if err != nil {
- return nil, errors.Wrap(err, "REST request failed")
- }
- defer resp.Body.Close()
- switch resp.StatusCode {
- case http.StatusOK:
- return unmarshalTunnel(resp.Body)
- case http.StatusConflict:
- return nil, ErrTunnelNameConflict
- }
- return nil, r.statusCodeToError("create tunnel", resp)
- }
- func (r *RESTClient) GetTunnel(tunnelID uuid.UUID) (*Tunnel, error) {
- endpoint := r.baseEndpoints.accountLevel
- endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
- resp, err := r.sendRequest("GET", endpoint, nil)
- if err != nil {
- return nil, errors.Wrap(err, "REST request failed")
- }
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusOK {
- return unmarshalTunnel(resp.Body)
- }
- return nil, r.statusCodeToError("get tunnel", resp)
- }
- func (r *RESTClient) DeleteTunnel(tunnelID uuid.UUID) error {
- endpoint := r.baseEndpoints.accountLevel
- endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v", tunnelID))
- resp, err := r.sendRequest("DELETE", endpoint, nil)
- if err != nil {
- return errors.Wrap(err, "REST request failed")
- }
- defer resp.Body.Close()
- return r.statusCodeToError("delete tunnel", resp)
- }
- func (r *RESTClient) ListTunnels(filter *Filter) ([]*Tunnel, error) {
- endpoint := r.baseEndpoints.accountLevel
- endpoint.RawQuery = filter.encode()
- resp, err := r.sendRequest("GET", endpoint, nil)
- if err != nil {
- return nil, errors.Wrap(err, "REST request failed")
- }
- defer resp.Body.Close()
- if resp.StatusCode == http.StatusOK {
- var tunnels []*Tunnel
- if err := json.NewDecoder(resp.Body).Decode(&tunnels); err != nil {
- return nil, errors.Wrap(err, "failed to decode response")
- }
- return tunnels, nil
- }
- return nil, r.statusCodeToError("list tunnels", resp)
- }
- func (r *RESTClient) CleanupConnections(tunnelID uuid.UUID) error {
- endpoint := r.baseEndpoints.accountLevel
- endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/connections", tunnelID))
- resp, err := r.sendRequest("DELETE", endpoint, nil)
- if err != nil {
- return errors.Wrap(err, "REST request failed")
- }
- defer resp.Body.Close()
- return r.statusCodeToError("cleanup connections", resp)
- }
- func (r *RESTClient) RouteTunnel(tunnelID uuid.UUID, route Route) error {
- endpoint := r.baseEndpoints.zoneLevel
- endpoint.Path = path.Join(endpoint.Path, fmt.Sprintf("%v/routes", tunnelID))
- resp, err := r.sendRequest("PUT", endpoint, route)
- if err != nil {
- return errors.Wrap(err, "REST request failed")
- }
- defer resp.Body.Close()
- return r.statusCodeToError("add route", resp)
- }
- func (r *RESTClient) sendRequest(method string, url url.URL, body interface{}) (*http.Response, error) {
- var bodyReader io.Reader
- if body != nil {
- if bodyBytes, err := json.Marshal(body); err != nil {
- return nil, errors.Wrap(err, "failed to serialize json body")
- } else {
- bodyReader = bytes.NewBuffer(bodyBytes)
- }
- }
- req, err := http.NewRequest(method, url.String(), bodyReader)
- if err != nil {
- return nil, errors.Wrapf(err, "can't create %s request", method)
- }
- req.Header.Set("User-Agent", r.userAgent)
- if bodyReader != nil {
- req.Header.Set("Content-Type", jsonContentType)
- }
- req.Header.Add("X-Auth-User-Service-Key", r.authToken)
- return r.client.Do(req)
- }
- func unmarshalTunnel(reader io.Reader) (*Tunnel, error) {
- var tunnel Tunnel
- if err := json.NewDecoder(reader).Decode(&tunnel); err != nil {
- return nil, errors.Wrap(err, "failed to decode response")
- }
- return &tunnel, nil
- }
- func (r *RESTClient) statusCodeToError(op string, resp *http.Response) error {
- if resp.Header.Get("Content-Type") == "application/json" {
- var errorsResp struct{
- Error string `json:"error"`
- }
- if json.NewDecoder(resp.Body).Decode(&errorsResp) == nil && errorsResp.Error != ""{
- return errors.Errorf("Failed to %s: %s", op, errorsResp.Error)
- }
- }
- switch resp.StatusCode {
- case http.StatusOK:
- return nil
- case http.StatusBadRequest:
- return ErrBadRequest
- case http.StatusUnauthorized, http.StatusForbidden:
- return ErrUnauthorized
- case http.StatusNotFound:
- return ErrNotFound
- }
- return errors.Errorf("API call to %s failed with status %d: %s", op,
- resp.StatusCode, http.StatusText(resp.StatusCode))
- }
|