utils_test.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. //go:build integration
  2. package integration
  3. import (
  4. "bytes"
  5. "context"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "log/slog"
  10. "math/rand"
  11. "net"
  12. "net/http"
  13. "net/url"
  14. "os"
  15. "path/filepath"
  16. "runtime"
  17. "strconv"
  18. "strings"
  19. "sync"
  20. "testing"
  21. "time"
  22. "github.com/ollama/ollama/api"
  23. "github.com/ollama/ollama/app/lifecycle"
  24. "github.com/stretchr/testify/require"
  25. )
  26. func Init() {
  27. lifecycle.InitLogging()
  28. }
  29. func FindPort() string {
  30. port := 0
  31. if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
  32. var l *net.TCPListener
  33. if l, err = net.ListenTCP("tcp", a); err == nil {
  34. port = l.Addr().(*net.TCPAddr).Port
  35. l.Close()
  36. }
  37. }
  38. if port == 0 {
  39. port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
  40. }
  41. return strconv.Itoa(port)
  42. }
  43. func GetTestEndpoint() (*api.Client, string) {
  44. defaultPort := "11434"
  45. ollamaHost := os.Getenv("OLLAMA_HOST")
  46. scheme, hostport, ok := strings.Cut(ollamaHost, "://")
  47. if !ok {
  48. scheme, hostport = "http", ollamaHost
  49. }
  50. // trim trailing slashes
  51. hostport = strings.TrimRight(hostport, "/")
  52. host, port, err := net.SplitHostPort(hostport)
  53. if err != nil {
  54. host, port = "127.0.0.1", defaultPort
  55. if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
  56. host = ip.String()
  57. } else if hostport != "" {
  58. host = hostport
  59. }
  60. }
  61. if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
  62. port = FindPort()
  63. }
  64. slog.Info("server connection", "host", host, "port", port)
  65. return api.NewClient(
  66. &url.URL{
  67. Scheme: scheme,
  68. Host: net.JoinHostPort(host, port),
  69. },
  70. http.DefaultClient), fmt.Sprintf("%s:%s", host, port)
  71. }
  72. var serverMutex sync.Mutex
  73. var serverReady bool
  74. func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
  75. // Make sure the server has been built
  76. CLIName, err := filepath.Abs("../ollama")
  77. if err != nil {
  78. return err
  79. }
  80. if runtime.GOOS == "windows" {
  81. CLIName += ".exe"
  82. }
  83. _, err = os.Stat(CLIName)
  84. if err != nil {
  85. return fmt.Errorf("CLI missing, did you forget to build first? %w", err)
  86. }
  87. serverMutex.Lock()
  88. defer serverMutex.Unlock()
  89. if serverReady {
  90. return nil
  91. }
  92. if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
  93. slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
  94. t.Setenv("OLLAMA_HOST", ollamaHost)
  95. }
  96. slog.Info("starting server", "url", ollamaHost)
  97. done, err := lifecycle.SpawnServer(ctx, "../ollama")
  98. if err != nil {
  99. return fmt.Errorf("failed to start server: %w", err)
  100. }
  101. go func() {
  102. <-ctx.Done()
  103. serverMutex.Lock()
  104. defer serverMutex.Unlock()
  105. exitCode := <-done
  106. if exitCode > 0 {
  107. slog.Warn("server failure", "exit", exitCode)
  108. }
  109. serverReady = false
  110. }()
  111. // TODO wait only long enough for the server to be responsive...
  112. time.Sleep(500 * time.Millisecond)
  113. serverReady = true
  114. return nil
  115. }
  116. func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
  117. slog.Info("checking status of model", "model", modelName)
  118. showReq := &api.ShowRequest{Name: modelName}
  119. showCtx, cancel := context.WithDeadlineCause(
  120. ctx,
  121. time.Now().Add(10*time.Second),
  122. fmt.Errorf("show for existing model %s took too long", modelName),
  123. )
  124. defer cancel()
  125. _, err := client.Show(showCtx, showReq)
  126. var statusError api.StatusError
  127. switch {
  128. case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
  129. break
  130. case err != nil:
  131. return err
  132. default:
  133. slog.Info("model already present", "model", modelName)
  134. return nil
  135. }
  136. slog.Info("model missing", "model", modelName)
  137. stallDuration := 30 * time.Second // This includes checksum verification, which can take a while on larger models
  138. stallTimer := time.NewTimer(stallDuration)
  139. fn := func(resp api.ProgressResponse) error {
  140. // fmt.Print(".")
  141. if !stallTimer.Reset(stallDuration) {
  142. return fmt.Errorf("stall was detected, aborting status reporting")
  143. }
  144. return nil
  145. }
  146. stream := true
  147. pullReq := &api.PullRequest{Name: modelName, Stream: &stream}
  148. var pullError error
  149. done := make(chan int)
  150. go func() {
  151. pullError = client.Pull(ctx, pullReq, fn)
  152. done <- 0
  153. }()
  154. select {
  155. case <-stallTimer.C:
  156. return fmt.Errorf("download stalled")
  157. case <-done:
  158. return pullError
  159. }
  160. }
  161. var serverProcMutex sync.Mutex
  162. // Returns an Client, the testEndpoint, and a cleanup function, fails the test on errors
  163. // Starts the server if needed
  164. func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, string, func()) {
  165. client, testEndpoint := GetTestEndpoint()
  166. if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
  167. serverProcMutex.Lock()
  168. fp, err := os.CreateTemp("", "ollama-server-*.log")
  169. if err != nil {
  170. t.Fatalf("failed to generate log file: %s", err)
  171. }
  172. lifecycle.ServerLogFile = fp.Name()
  173. fp.Close()
  174. require.NoError(t, startServer(t, ctx, testEndpoint))
  175. }
  176. return client, testEndpoint, func() {
  177. if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
  178. defer serverProcMutex.Unlock()
  179. if t.Failed() {
  180. fp, err := os.Open(lifecycle.ServerLogFile)
  181. if err != nil {
  182. slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
  183. return
  184. }
  185. data, err := io.ReadAll(fp)
  186. if err != nil {
  187. slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
  188. return
  189. }
  190. slog.Warn("SERVER LOG FOLLOWS")
  191. os.Stderr.Write(data)
  192. slog.Warn("END OF SERVER")
  193. }
  194. err := os.Remove(lifecycle.ServerLogFile)
  195. if err != nil && !os.IsNotExist(err) {
  196. slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
  197. }
  198. }
  199. }
  200. }
  201. func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
  202. client, _, cleanup := InitServerConnection(ctx, t)
  203. defer cleanup()
  204. require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
  205. DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
  206. }
  207. func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
  208. stallTimer := time.NewTimer(initialTimeout)
  209. var buf bytes.Buffer
  210. fn := func(response api.GenerateResponse) error {
  211. // fmt.Print(".")
  212. buf.Write([]byte(response.Response))
  213. if !stallTimer.Reset(streamTimeout) {
  214. return fmt.Errorf("stall was detected while streaming response, aborting")
  215. }
  216. return nil
  217. }
  218. stream := true
  219. genReq.Stream = &stream
  220. done := make(chan int)
  221. var genErr error
  222. go func() {
  223. genErr = client.Generate(ctx, &genReq, fn)
  224. done <- 0
  225. }()
  226. select {
  227. case <-stallTimer.C:
  228. if buf.Len() == 0 {
  229. t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
  230. } else {
  231. t.Errorf("generate stalled. Response so far:%s", buf.String())
  232. }
  233. case <-done:
  234. require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
  235. // Verify the response contains the expected data
  236. response := buf.String()
  237. atLeastOne := false
  238. for _, resp := range anyResp {
  239. if strings.Contains(strings.ToLower(response), resp) {
  240. atLeastOne = true
  241. break
  242. }
  243. }
  244. require.True(t, atLeastOne, "none of %v found in %s", anyResp, response)
  245. slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
  246. case <-ctx.Done():
  247. t.Error("outer test context done while waiting for generate")
  248. }
  249. }
  250. // Generate a set of requests
  251. // By default each request uses orca-mini as the model
  252. func GenerateRequests() ([]api.GenerateRequest, [][]string) {
  253. return []api.GenerateRequest{
  254. {
  255. Model: "orca-mini",
  256. Prompt: "why is the ocean blue?",
  257. Stream: &stream,
  258. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  259. Options: map[string]interface{}{
  260. "seed": 42,
  261. "temperature": 0.0,
  262. },
  263. }, {
  264. Model: "orca-mini",
  265. Prompt: "why is the color of dirt brown?",
  266. Stream: &stream,
  267. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  268. Options: map[string]interface{}{
  269. "seed": 42,
  270. "temperature": 0.0,
  271. },
  272. }, {
  273. Model: "orca-mini",
  274. Prompt: "what is the origin of the us thanksgiving holiday?",
  275. Stream: &stream,
  276. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  277. Options: map[string]interface{}{
  278. "seed": 42,
  279. "temperature": 0.0,
  280. },
  281. }, {
  282. Model: "orca-mini",
  283. Prompt: "what is the origin of independence day?",
  284. Stream: &stream,
  285. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  286. Options: map[string]interface{}{
  287. "seed": 42,
  288. "temperature": 0.0,
  289. },
  290. }, {
  291. Model: "orca-mini",
  292. Prompt: "what is the composition of air?",
  293. Stream: &stream,
  294. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  295. Options: map[string]interface{}{
  296. "seed": 42,
  297. "temperature": 0.0,
  298. },
  299. },
  300. },
  301. [][]string{
  302. []string{"sunlight"},
  303. []string{"soil", "organic", "earth", "black", "tan"},
  304. []string{"england", "english", "massachusetts", "pilgrims", "british"},
  305. []string{"fourth", "july", "declaration", "independence"},
  306. []string{"nitrogen", "oxygen", "carbon", "dioxide"},
  307. }
  308. }