auth.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package server
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "crypto/sha256"
  6. "encoding/base64"
  7. "encoding/hex"
  8. "encoding/json"
  9. "fmt"
  10. "io"
  11. "net/http"
  12. "net/url"
  13. "strconv"
  14. "strings"
  15. "time"
  16. "github.com/ollama/ollama/api"
  17. "github.com/ollama/ollama/auth"
  18. )
  19. type registryChallenge struct {
  20. Realm string
  21. Service string
  22. Scope string
  23. }
  24. func (r registryChallenge) URL() (*url.URL, error) {
  25. redirectURL, err := url.Parse(r.Realm)
  26. if err != nil {
  27. return nil, err
  28. }
  29. values := redirectURL.Query()
  30. values.Add("service", r.Service)
  31. for _, s := range strings.Split(r.Scope, " ") {
  32. values.Add("scope", s)
  33. }
  34. values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
  35. nonce, err := auth.NewNonce(rand.Reader, 16)
  36. if err != nil {
  37. return nil, err
  38. }
  39. values.Add("nonce", nonce)
  40. redirectURL.RawQuery = values.Encode()
  41. return redirectURL, nil
  42. }
  43. func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
  44. redirectURL, err := challenge.URL()
  45. if err != nil {
  46. return "", err
  47. }
  48. sha256sum := sha256.Sum256(nil)
  49. data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
  50. headers := make(http.Header)
  51. signature, err := auth.Sign(ctx, data)
  52. if err != nil {
  53. return "", err
  54. }
  55. headers.Add("Authorization", signature)
  56. response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
  57. if err != nil {
  58. return "", err
  59. }
  60. defer response.Body.Close()
  61. body, err := io.ReadAll(response.Body)
  62. if err != nil {
  63. return "", fmt.Errorf("%d: %v", response.StatusCode, err)
  64. }
  65. if response.StatusCode >= http.StatusBadRequest {
  66. if len(body) > 0 {
  67. return "", fmt.Errorf("%d: %s", response.StatusCode, body)
  68. } else {
  69. return "", fmt.Errorf("%d", response.StatusCode)
  70. }
  71. }
  72. var token api.TokenResponse
  73. if err := json.Unmarshal(body, &token); err != nil {
  74. return "", err
  75. }
  76. return token.Token, nil
  77. }