concurrency_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. //go:build integration
  2. package integration
  3. import (
  4. "context"
  5. "log/slog"
  6. "os"
  7. "strconv"
  8. "sync"
  9. "testing"
  10. "time"
  11. "github.com/ollama/ollama/api"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestMultiModelConcurrency(t *testing.T) {
  15. var (
  16. req = [2]api.GenerateRequest{
  17. {
  18. Model: "orca-mini",
  19. Prompt: "why is the ocean blue?",
  20. Stream: &stream,
  21. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  22. Options: map[string]interface{}{
  23. "seed": 42,
  24. "temperature": 0.0,
  25. },
  26. }, {
  27. Model: "tinydolphin",
  28. Prompt: "what is the origin of the us thanksgiving holiday?",
  29. Stream: &stream,
  30. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  31. Options: map[string]interface{}{
  32. "seed": 42,
  33. "temperature": 0.0,
  34. },
  35. },
  36. }
  37. resp = [2][]string{
  38. []string{"sunlight"},
  39. []string{"england", "english", "massachusetts", "pilgrims", "british"},
  40. }
  41. )
  42. var wg sync.WaitGroup
  43. wg.Add(len(req))
  44. ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
  45. defer cancel()
  46. client, _, cleanup := InitServerConnection(ctx, t)
  47. defer cleanup()
  48. for i := 0; i < len(req); i++ {
  49. require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
  50. }
  51. for i := 0; i < len(req); i++ {
  52. go func(i int) {
  53. defer wg.Done()
  54. DoGenerate(ctx, t, client, req[i], resp[i], 60*time.Second, 10*time.Second)
  55. }(i)
  56. }
  57. wg.Wait()
  58. }
  59. func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
  60. req, resp := GenerateRequests()
  61. reqLimit := len(req)
  62. iterLimit := 5
  63. vram := os.Getenv("OLLAMA_MAX_VRAM")
  64. if vram != "" {
  65. max, err := strconv.ParseUint(vram, 10, 64)
  66. require.NoError(t, err)
  67. // Don't hammer on small VRAM cards...
  68. if max < 4*1024*1024*1024 {
  69. reqLimit = min(reqLimit, 2)
  70. iterLimit = 2
  71. }
  72. }
  73. ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
  74. defer cancel()
  75. client, _, cleanup := InitServerConnection(ctx, t)
  76. defer cleanup()
  77. // Get the server running (if applicable) warm the model up with a single initial request
  78. DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
  79. var wg sync.WaitGroup
  80. wg.Add(reqLimit)
  81. for i := 0; i < reqLimit; i++ {
  82. go func(i int) {
  83. defer wg.Done()
  84. for j := 0; j < iterLimit; j++ {
  85. slog.Info("Starting", "req", i, "iter", j)
  86. // On slower GPUs it can take a while to process the concurrent requests
  87. // so we allow a much longer initial timeout
  88. DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
  89. }
  90. }(i)
  91. }
  92. wg.Wait()
  93. }
  94. // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
  95. func TestMultiModelStress(t *testing.T) {
  96. vram := os.Getenv("OLLAMA_MAX_VRAM")
  97. if vram == "" {
  98. t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
  99. }
  100. max, err := strconv.ParseUint(vram, 10, 64)
  101. require.NoError(t, err)
  102. const MB = uint64(1024 * 1024)
  103. type model struct {
  104. name string
  105. size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
  106. }
  107. smallModels := []model{
  108. {
  109. name: "orca-mini",
  110. size: 2992 * MB,
  111. },
  112. {
  113. name: "phi",
  114. size: 2616 * MB,
  115. },
  116. {
  117. name: "gemma:2b",
  118. size: 2364 * MB,
  119. },
  120. {
  121. name: "stable-code:3b",
  122. size: 2608 * MB,
  123. },
  124. {
  125. name: "starcoder2:3b",
  126. size: 2166 * MB,
  127. },
  128. }
  129. mediumModels := []model{
  130. {
  131. name: "llama2",
  132. size: 5118 * MB,
  133. },
  134. {
  135. name: "mistral",
  136. size: 4620 * MB,
  137. },
  138. {
  139. name: "orca-mini:7b",
  140. size: 5118 * MB,
  141. },
  142. {
  143. name: "dolphin-mistral",
  144. size: 4620 * MB,
  145. },
  146. {
  147. name: "gemma:7b",
  148. size: 5000 * MB,
  149. },
  150. // TODO - uncomment this once #3565 is merged and this is rebased on it
  151. // {
  152. // name: "codellama:7b",
  153. // size: 5118 * MB,
  154. // },
  155. }
  156. // These seem to be too slow to be useful...
  157. // largeModels := []model{
  158. // {
  159. // name: "llama2:13b",
  160. // size: 7400 * MB,
  161. // },
  162. // {
  163. // name: "codellama:13b",
  164. // size: 7400 * MB,
  165. // },
  166. // {
  167. // name: "orca-mini:13b",
  168. // size: 7400 * MB,
  169. // },
  170. // {
  171. // name: "gemma:7b",
  172. // size: 5000 * MB,
  173. // },
  174. // {
  175. // name: "starcoder2:15b",
  176. // size: 9100 * MB,
  177. // },
  178. // }
  179. var chosenModels []model
  180. switch {
  181. case max < 10000*MB:
  182. slog.Info("selecting small models")
  183. chosenModels = smallModels
  184. // case max < 30000*MB:
  185. default:
  186. slog.Info("selecting medium models")
  187. chosenModels = mediumModels
  188. // default:
  189. // slog.Info("selecting large models")
  190. // chosenModels = largModels
  191. }
  192. req, resp := GenerateRequests()
  193. for i := range req {
  194. if i > len(chosenModels) {
  195. break
  196. }
  197. req[i].Model = chosenModels[i].name
  198. }
  199. ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
  200. defer cancel()
  201. client, _, cleanup := InitServerConnection(ctx, t)
  202. defer cleanup()
  203. // Make sure all the models are pulled before we get started
  204. for _, r := range req {
  205. require.NoError(t, PullIfMissing(ctx, client, r.Model))
  206. }
  207. var wg sync.WaitGroup
  208. consumed := uint64(256 * MB) // Assume some baseline usage
  209. for i := 0; i < len(req); i++ {
  210. // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
  211. if i > 1 && consumed > max {
  212. slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
  213. break
  214. }
  215. consumed += chosenModels[i].size
  216. slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
  217. wg.Add(1)
  218. go func(i int) {
  219. defer wg.Done()
  220. for j := 0; j < 3; j++ {
  221. slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
  222. DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
  223. }
  224. }(i)
  225. }
  226. go func() {
  227. for {
  228. time.Sleep(2 * time.Second)
  229. select {
  230. case <-ctx.Done():
  231. return
  232. default:
  233. models, err := client.ListRunning(ctx)
  234. if err != nil {
  235. slog.Warn("failed to list running models", "error", err)
  236. continue
  237. }
  238. for _, m := range models.Models {
  239. slog.Info("loaded model snapshot", "model", m)
  240. }
  241. }
  242. }
  243. }()
  244. wg.Wait()
  245. }