origin_cert.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package credentials
  2. import (
  3. "bytes"
  4. "encoding/json"
  5. "encoding/pem"
  6. "fmt"
  7. "os"
  8. "path/filepath"
  9. "strings"
  10. "github.com/mitchellh/go-homedir"
  11. "github.com/rs/zerolog"
  12. "github.com/cloudflare/cloudflared/config"
  13. )
  14. const (
  15. DefaultCredentialFile = "cert.pem"
  16. )
  17. type OriginCert struct {
  18. ZoneID string `json:"zoneID"`
  19. AccountID string `json:"accountID"`
  20. APIToken string `json:"apiToken"`
  21. Endpoint string `json:"endpoint,omitempty"`
  22. }
  23. func (oc *OriginCert) UnmarshalJSON(data []byte) error {
  24. var aux struct {
  25. ZoneID string `json:"zoneID"`
  26. AccountID string `json:"accountID"`
  27. APIToken string `json:"apiToken"`
  28. Endpoint string `json:"endpoint,omitempty"`
  29. }
  30. if err := json.Unmarshal(data, &aux); err != nil {
  31. return fmt.Errorf("error parsing OriginCert: %v", err)
  32. }
  33. oc.ZoneID = aux.ZoneID
  34. oc.AccountID = aux.AccountID
  35. oc.APIToken = aux.APIToken
  36. oc.Endpoint = strings.ToLower(aux.Endpoint)
  37. return nil
  38. }
  39. // FindDefaultOriginCertPath returns the first path that contains a cert.pem file. If none of the
  40. // DefaultConfigSearchDirectories contains a cert.pem file, return empty string
  41. func FindDefaultOriginCertPath() string {
  42. for _, defaultConfigDir := range config.DefaultConfigSearchDirectories() {
  43. originCertPath, _ := homedir.Expand(filepath.Join(defaultConfigDir, DefaultCredentialFile))
  44. if ok := fileExists(originCertPath); ok {
  45. return originCertPath
  46. }
  47. }
  48. return ""
  49. }
  50. func DecodeOriginCert(blocks []byte) (*OriginCert, error) {
  51. return decodeOriginCert(blocks)
  52. }
  53. func (cert *OriginCert) EncodeOriginCert() ([]byte, error) {
  54. if cert == nil {
  55. return nil, fmt.Errorf("originCert cannot be nil")
  56. }
  57. buffer, err := json.Marshal(cert)
  58. if err != nil {
  59. return nil, fmt.Errorf("originCert marshal failed: %v", err)
  60. }
  61. block := pem.Block{
  62. Type: "ARGO TUNNEL TOKEN",
  63. Headers: map[string]string{},
  64. Bytes: buffer,
  65. }
  66. var out bytes.Buffer
  67. err = pem.Encode(&out, &block)
  68. if err != nil {
  69. return nil, fmt.Errorf("pem encoding failed: %v", err)
  70. }
  71. return out.Bytes(), nil
  72. }
  73. func decodeOriginCert(blocks []byte) (*OriginCert, error) {
  74. if len(blocks) == 0 {
  75. return nil, fmt.Errorf("cannot decode empty certificate")
  76. }
  77. originCert := OriginCert{}
  78. block, rest := pem.Decode(blocks)
  79. for block != nil {
  80. switch block.Type {
  81. case "PRIVATE KEY", "CERTIFICATE":
  82. // this is for legacy purposes.
  83. case "ARGO TUNNEL TOKEN":
  84. if originCert.ZoneID != "" || originCert.APIToken != "" {
  85. return nil, fmt.Errorf("found multiple tokens in the certificate")
  86. }
  87. // The token is a string,
  88. // Try the newer JSON format
  89. _ = json.Unmarshal(block.Bytes, &originCert)
  90. default:
  91. return nil, fmt.Errorf("unknown block %s in the certificate", block.Type)
  92. }
  93. block, rest = pem.Decode(rest)
  94. }
  95. if originCert.ZoneID == "" || originCert.APIToken == "" {
  96. return nil, fmt.Errorf("missing token in the certificate")
  97. }
  98. return &originCert, nil
  99. }
  100. func readOriginCert(originCertPath string) ([]byte, error) {
  101. originCert, err := os.ReadFile(originCertPath)
  102. if err != nil {
  103. return nil, fmt.Errorf("cannot read %s to load origin certificate", originCertPath)
  104. }
  105. return originCert, nil
  106. }
  107. // FindOriginCert will check to make sure that the certificate exists at the specified file path.
  108. func FindOriginCert(originCertPath string, log *zerolog.Logger) (string, error) {
  109. if originCertPath == "" {
  110. log.Error().Msgf("Cannot determine default origin certificate path. No file %s in %v. You need to specify the origin certificate path by specifying the origincert option in the configuration file, or set TUNNEL_ORIGIN_CERT environment variable", DefaultCredentialFile, config.DefaultConfigSearchDirectories())
  111. return "", fmt.Errorf("client didn't specify origincert path")
  112. }
  113. var err error
  114. originCertPath, err = homedir.Expand(originCertPath)
  115. if err != nil {
  116. log.Err(err).Msgf("Cannot resolve origin certificate path")
  117. return "", fmt.Errorf("cannot resolve path %s", originCertPath)
  118. }
  119. // Check that the user has acquired a certificate using the login command
  120. ok := fileExists(originCertPath)
  121. if !ok {
  122. log.Error().Msgf(`Cannot find a valid certificate for your origin at the path:
  123. %s
  124. If the path above is wrong, specify the path with the -origincert option.
  125. If you don't have a certificate signed by Cloudflare, run the command:
  126. cloudflared login
  127. `, originCertPath)
  128. return "", fmt.Errorf("cannot find a valid certificate at the path %s", originCertPath)
  129. }
  130. return originCertPath, nil
  131. }
  132. // FileExists checks to see if a file exist at the provided path.
  133. func fileExists(path string) bool {
  134. fileStat, err := os.Stat(path)
  135. if err != nil {
  136. return false
  137. }
  138. return !fileStat.IsDir()
  139. }