llama.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package convert
  2. import (
  3. "cmp"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "os"
  8. "path/filepath"
  9. "regexp"
  10. "strings"
  11. "github.com/pdevine/tensor"
  12. "github.com/pdevine/tensor/native"
  13. "github.com/ollama/ollama/llm"
  14. )
  15. type LlamaModel struct {
  16. ModelData
  17. }
  18. func (m *LlamaModel) GetTensors() error {
  19. t, err := m.Format.GetTensors(m.Path, m.Params)
  20. if err != nil {
  21. return err
  22. }
  23. pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
  24. re, err := regexp.Compile(pattern)
  25. if err != nil {
  26. return err
  27. }
  28. for _, l := range t {
  29. matches := re.FindAllStringSubmatch(l.Name, -1)
  30. if len(matches) > 0 {
  31. switch m.Format.(type) {
  32. case *TorchFormat:
  33. wt := l.WriterTo.(torchWriterTo)
  34. wt.repacker = m.Repack
  35. l.WriterTo = wt
  36. case *SafetensorFormat:
  37. wt := l.WriterTo.(safetensorWriterTo)
  38. wt.repacker = m.Repack
  39. l.WriterTo = wt
  40. }
  41. }
  42. m.Tensors = append(m.Tensors, l)
  43. }
  44. return nil
  45. }
  46. func (m *LlamaModel) LoadVocab() (err error) {
  47. pre, ts, merges, err := parseTokens(filepath.Join(m.Path, "tokenizer.json"))
  48. if errors.Is(err, os.ErrNotExist) {
  49. return nil
  50. } else if err != nil {
  51. return err
  52. }
  53. m.Vocab = &Vocab{}
  54. for _, t := range ts {
  55. m.Vocab.Tokens = append(m.Vocab.Tokens, t.Content)
  56. m.Vocab.Types = append(m.Vocab.Types, t.Type())
  57. }
  58. m.Vocab.Merges = merges
  59. m.Params.PreTokenizer = pre
  60. return nil
  61. }
  62. func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
  63. kv := llm.KV{
  64. "general.architecture": "llama",
  65. "general.name": m.Name,
  66. "llama.vocab_size": uint32(len(m.Vocab.Tokens)),
  67. "llama.context_length": uint32(m.Params.ContextSize),
  68. "llama.embedding_length": uint32(m.Params.HiddenSize),
  69. "llama.block_count": uint32(m.Params.HiddenLayers),
  70. "llama.feed_forward_length": uint32(m.Params.IntermediateSize),
  71. "llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
  72. "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
  73. "llama.attention.head_count": uint32(m.Params.AttentionHeads),
  74. "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
  75. "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
  76. "general.file_type": uint32(1),
  77. "tokenizer.ggml.model": "gpt2",
  78. "tokenizer.ggml.pre": m.Params.PreTokenizer,
  79. "tokenizer.ggml.tokens": m.Vocab.Tokens,
  80. "tokenizer.ggml.token_type": m.Vocab.Types,
  81. "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
  82. "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
  83. "tokenizer.ggml.unknown_token_id": uint32(0),
  84. }
  85. if len(m.Vocab.Merges) > 0 {
  86. kv["tokenizer.ggml.merges"] = m.Vocab.Merges
  87. } else {
  88. kv["tokenizer.ggml.scores"] = m.Vocab.Scores
  89. }
  90. return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
  91. }
  92. func (m *LlamaModel) Repack(name string, data []float32, shape []uint64) ([]float32, error) {
  93. return llamaRepack(name, m.Params, data, shape)
  94. }
  95. func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([]float32, error) {
  96. var dims []int
  97. for _, dim := range shape {
  98. if dim != 0 {
  99. dims = append(dims, int(dim))
  100. }
  101. }
  102. var heads int
  103. switch {
  104. case strings.HasSuffix(name, "attn_q.weight"):
  105. heads = params.AttentionHeads
  106. case strings.HasSuffix(name, "attn_k.weight"):
  107. heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
  108. default:
  109. return nil, fmt.Errorf("unknown tensor name: %s", name)
  110. }
  111. n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
  112. if err := n.Reshape(append([]int{heads, 2, dims[0] / heads / 2}, dims[1:]...)...); err != nil {
  113. return nil, err
  114. }
  115. if err := n.T(0, 2, 1, 3); err != nil {
  116. return nil, err
  117. }
  118. if err := n.Reshape(dims...); err != nil {
  119. return nil, err
  120. }
  121. if err := n.Transpose(); err != nil {
  122. return nil, err
  123. }
  124. ts, err := native.SelectF32(n, 1)
  125. if err != nil {
  126. return nil, err
  127. }
  128. var f32s []float32
  129. for _, t := range ts {
  130. f32s = append(f32s, t...)
  131. }
  132. return f32s, nil
  133. }