payload.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. package llm
  2. import (
  3. "compress/gzip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "runtime"
  12. "slices"
  13. "strings"
  14. "golang.org/x/sync/errgroup"
  15. "github.com/ollama/ollama/gpu"
  16. )
  17. var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
  18. func Init() error {
  19. payloadsDir, err := gpu.PayloadsDir()
  20. if err != nil {
  21. return err
  22. }
  23. if runtime.GOOS != "windows" {
  24. slog.Info("extracting embedded files", "dir", payloadsDir)
  25. binGlob := "build/*/*/*/bin/*"
  26. // extract server libraries
  27. err = extractFiles(payloadsDir, binGlob)
  28. if err != nil {
  29. return fmt.Errorf("extract binaries: %v", err)
  30. }
  31. }
  32. var variants []string
  33. for v := range getAvailableServers() {
  34. variants = append(variants, v)
  35. }
  36. slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
  37. slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
  38. return nil
  39. }
  40. // binary names may contain an optional variant separated by '_'
  41. // For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
  42. // Any library without a variant is the lowest common denominator
  43. func getAvailableServers() map[string]string {
  44. payloadsDir, err := gpu.PayloadsDir()
  45. if err != nil {
  46. slog.Error("payload lookup error", "error", err)
  47. return nil
  48. }
  49. // glob payloadsDir for files that start with ollama_
  50. pattern := filepath.Join(payloadsDir, "*", "ollama_*")
  51. files, err := filepath.Glob(pattern)
  52. if err != nil {
  53. slog.Debug("could not glob", "pattern", pattern, "error", err)
  54. return nil
  55. }
  56. servers := make(map[string]string)
  57. for _, file := range files {
  58. slog.Debug("availableServers : found", "file", file)
  59. servers[filepath.Base(filepath.Dir(file))] = filepath.Dir(file)
  60. }
  61. return servers
  62. }
  63. // serversForGpu returns a list of compatible servers give the provided GPU
  64. // info, ordered by performance. assumes Init() has been called
  65. // TODO - switch to metadata based mapping
  66. func serversForGpu(info gpu.GpuInfo) []string {
  67. // glob workDir for files that start with ollama_
  68. availableServers := getAvailableServers()
  69. requested := info.Library
  70. if info.Variant != gpu.CPUCapabilityNone {
  71. requested += "_" + info.Variant.String()
  72. }
  73. servers := []string{}
  74. // exact match first
  75. for a := range availableServers {
  76. if a == requested {
  77. servers = []string{a}
  78. if a == "metal" {
  79. return servers
  80. }
  81. break
  82. }
  83. }
  84. alt := []string{}
  85. // Then for GPUs load alternates and sort the list for consistent load ordering
  86. if info.Library != "cpu" {
  87. for a := range availableServers {
  88. if info.Library == strings.Split(a, "_")[0] && a != requested {
  89. alt = append(alt, a)
  90. }
  91. }
  92. slices.Sort(alt)
  93. servers = append(servers, alt...)
  94. }
  95. if !(runtime.GOOS == "darwin" && runtime.GOARCH == "arm64") {
  96. // Load up the best CPU variant if not primary requested
  97. if info.Library != "cpu" {
  98. variant := gpu.GetCPUCapability()
  99. // If no variant, then we fall back to default
  100. // If we have a variant, try that if we find an exact match
  101. // Attempting to run the wrong CPU instructions will panic the
  102. // process
  103. if variant != gpu.CPUCapabilityNone {
  104. for cmp := range availableServers {
  105. if cmp == "cpu_"+variant.String() {
  106. servers = append(servers, cmp)
  107. break
  108. }
  109. }
  110. } else {
  111. servers = append(servers, "cpu")
  112. }
  113. }
  114. if len(servers) == 0 {
  115. servers = []string{"cpu"}
  116. }
  117. }
  118. return servers
  119. }
  120. // Return the optimal server for this CPU architecture
  121. func serverForCpu() string {
  122. if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
  123. return "metal"
  124. }
  125. variant := gpu.GetCPUCapability()
  126. availableServers := getAvailableServers()
  127. if variant != gpu.CPUCapabilityNone {
  128. for cmp := range availableServers {
  129. if cmp == "cpu_"+variant.String() {
  130. return cmp
  131. }
  132. }
  133. }
  134. return "cpu"
  135. }
  136. // extract extracts the embedded files to the target directory
  137. func extractFiles(targetDir string, glob string) error {
  138. files, err := fs.Glob(libEmbed, glob)
  139. if err != nil || len(files) == 0 {
  140. return errPayloadMissing
  141. }
  142. if err := os.MkdirAll(targetDir, 0o755); err != nil {
  143. return fmt.Errorf("extractFiles could not mkdir %s: %v", targetDir, err)
  144. }
  145. g := new(errgroup.Group)
  146. // build/$OS/$GOARCH/$VARIANT/{bin,lib}/$FILE
  147. for _, file := range files {
  148. filename := file
  149. variant := filepath.Base(filepath.Dir(filepath.Dir(filename)))
  150. slog.Debug("extracting", "variant", variant, "file", filename)
  151. g.Go(func() error {
  152. srcf, err := libEmbed.Open(filename)
  153. if err != nil {
  154. return err
  155. }
  156. defer srcf.Close()
  157. src := io.Reader(srcf)
  158. if strings.HasSuffix(filename, ".gz") {
  159. src, err = gzip.NewReader(src)
  160. if err != nil {
  161. return fmt.Errorf("decompress payload %s: %v", filename, err)
  162. }
  163. filename = strings.TrimSuffix(filename, ".gz")
  164. }
  165. variantDir := filepath.Join(targetDir, variant)
  166. if err := os.MkdirAll(variantDir, 0o755); err != nil {
  167. return fmt.Errorf("extractFiles could not mkdir %s: %v", variantDir, err)
  168. }
  169. base := filepath.Base(filename)
  170. destFilename := filepath.Join(variantDir, base)
  171. _, err = os.Stat(destFilename)
  172. switch {
  173. case errors.Is(err, os.ErrNotExist):
  174. destFile, err := os.OpenFile(destFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  175. if err != nil {
  176. return fmt.Errorf("write payload %s: %v", filename, err)
  177. }
  178. defer destFile.Close()
  179. if _, err := io.Copy(destFile, src); err != nil {
  180. return fmt.Errorf("copy payload %s: %v", filename, err)
  181. }
  182. case err != nil:
  183. return fmt.Errorf("stat payload %s: %v", filename, err)
  184. }
  185. return nil
  186. })
  187. }
  188. err = g.Wait()
  189. if err != nil {
  190. // If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
  191. gpu.Cleanup()
  192. return err
  193. }
  194. return nil
  195. }