convert_test.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. //go:build slow
  2. package convert
  3. import (
  4. "os"
  5. "path/filepath"
  6. "testing"
  7. "github.com/ollama/ollama/llm"
  8. )
  9. func convertFull(t *testing.T, p string) (llm.KV, llm.Tensors) {
  10. t.Helper()
  11. mf, err := GetModelFormat(p)
  12. if err != nil {
  13. t.Fatal(err)
  14. }
  15. params, err := mf.GetParams(p)
  16. if err != nil {
  17. t.Fatal(err)
  18. }
  19. arch, err := mf.GetModelArch("", p, params)
  20. if err != nil {
  21. t.Fatal(err)
  22. }
  23. if err := arch.LoadVocab(); err != nil {
  24. t.Fatal(err)
  25. }
  26. if err := arch.GetTensors(); err != nil {
  27. t.Fatal(err)
  28. }
  29. f, err := os.CreateTemp(t.TempDir(), "f16")
  30. if err != nil {
  31. t.Fatal(err)
  32. }
  33. defer f.Close()
  34. if err := arch.WriteGGUF(f); err != nil {
  35. t.Fatal(err)
  36. }
  37. r, err := os.Open(f.Name())
  38. if err != nil {
  39. t.Fatal(err)
  40. }
  41. defer r.Close()
  42. m, _, err := llm.DecodeGGML(r)
  43. if err != nil {
  44. t.Fatal(err)
  45. }
  46. return m.KV(), m.Tensors()
  47. }
  48. func TestConvertFull(t *testing.T) {
  49. cases := []struct {
  50. path string
  51. arch string
  52. tensors int
  53. layers int
  54. }{
  55. {"Meta-Llama-3-8B-Instruct", "llama", 291, 35},
  56. {"Mistral-7B-Instruct-v0.2", "llama", 291, 35},
  57. {"Mixtral-8x7B-Instruct-v0.1", "llama", 291, 35},
  58. {"gemma-2b-it", "gemma", 164, 20},
  59. }
  60. for _, tt := range cases {
  61. t.Run(tt.path, func(t *testing.T) {
  62. p := filepath.Join("testdata", tt.path)
  63. if _, err := os.Stat(p); err != nil {
  64. t.Skipf("%s not found", p)
  65. }
  66. kv, tensors := convertFull(t, p)
  67. if kv.Architecture() != tt.arch {
  68. t.Fatalf("expected llama, got %s", kv.Architecture())
  69. }
  70. if kv.FileType().String() != "F16" {
  71. t.Fatalf("expected F16, got %s", kv.FileType())
  72. }
  73. if len(tensors) != tt.tensors {
  74. t.Fatalf("expected %d tensors, got %d", tt.tensors, len(tensors))
  75. }
  76. layers := tensors.Layers()
  77. if len(layers) != tt.layers {
  78. t.Fatalf("expected %d layers, got %d", tt.layers, len(layers))
  79. }
  80. })
  81. }
  82. }