mistral.go 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. package convert
  2. import (
  3. "io"
  4. "regexp"
  5. "github.com/ollama/ollama/llm"
  6. )
  7. type MistralModel struct {
  8. ModelData
  9. }
  10. func (m *MistralModel) GetTensors() error {
  11. t, err := m.Format.GetTensors(m.Path, m.Params)
  12. if err != nil {
  13. return err
  14. }
  15. pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
  16. re, err := regexp.Compile(pattern)
  17. if err != nil {
  18. return err
  19. }
  20. for _, l := range t {
  21. matches := re.FindAllStringSubmatch(l.Name, -1)
  22. if len(matches) > 0 {
  23. wt := l.WriterTo.(safetensorWriterTo)
  24. wt.repacker = m.Repack
  25. l.WriterTo = wt
  26. }
  27. m.Tensors = append(m.Tensors, l)
  28. }
  29. return nil
  30. }
  31. func (m *MistralModel) LoadVocab() error {
  32. v, err := LoadSentencePieceTokens(m.Path, m.Params)
  33. if err != nil {
  34. return err
  35. }
  36. m.Vocab = v
  37. return nil
  38. }
  39. func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
  40. kv := llm.KV{
  41. "general.architecture": "llama",
  42. "general.name": m.Name,
  43. "llama.context_length": uint32(m.Params.ContextSize),
  44. "llama.embedding_length": uint32(m.Params.HiddenSize),
  45. "llama.block_count": uint32(m.Params.HiddenLayers),
  46. "llama.feed_forward_length": uint32(m.Params.IntermediateSize),
  47. "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
  48. "llama.attention.head_count": uint32(m.Params.AttentionHeads),
  49. "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
  50. "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
  51. "general.file_type": uint32(1),
  52. "tokenizer.ggml.model": "llama",
  53. "tokenizer.ggml.tokens": m.Vocab.Tokens,
  54. "tokenizer.ggml.scores": m.Vocab.Scores,
  55. "tokenizer.ggml.token_type": m.Vocab.Types,
  56. "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
  57. "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
  58. "tokenizer.ggml.add_bos_token": true,
  59. "tokenizer.ggml.add_eos_token": false,
  60. "tokenizer.ggml.unknown_token_id": uint32(0),
  61. }
  62. return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
  63. }
  64. func (m *MistralModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
  65. return llamaRepack(name, m.Params, data, shape)
  66. }