convert.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. package convert
  2. import (
  3. "cmp"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "slices"
  12. "strings"
  13. "google.golang.org/protobuf/proto"
  14. "github.com/ollama/ollama/convert/sentencepiece"
  15. "github.com/ollama/ollama/llm"
  16. )
  17. const (
  18. _ int32 = iota
  19. tokenTypeNormal
  20. tokenTypeUnknown
  21. tokenTypeControl
  22. tokenTypeUserDefined
  23. tokenTypeUnused
  24. tokenTypeByte
  25. )
  26. type Params struct {
  27. Architectures []string `json:"architectures"`
  28. VocabSize int `json:"vocab_size"`
  29. HiddenSize int `json:"hidden_size"` // n_embd
  30. HiddenLayers int `json:"num_hidden_layers"` // n_layer
  31. ContextSize int `json:"max_position_embeddings"`
  32. IntermediateSize int `json:"intermediate_size"`
  33. AttentionHeads int `json:"num_attention_heads"` // n_head
  34. KeyValHeads int `json:"num_key_value_heads"`
  35. NormEPS float64 `json:"rms_norm_eps"`
  36. BoSTokenID int `json:"bos_token_id"`
  37. EoSTokenID int `json:"eos_token_id"`
  38. HeadDimension int `json:"head_dim"`
  39. PaddingTokenID int `json:"pad_token_id"`
  40. RopeFrequencyBase float64 `json:"rope_theta"`
  41. Experts int `json:"num_local_experts"`
  42. ExpertsUsed int `json:"num_experts_per_tok"`
  43. PreTokenizer string
  44. ByteOrder
  45. }
  46. type ByteOrder interface {
  47. binary.ByteOrder
  48. binary.AppendByteOrder
  49. }
  50. type ModelArch interface {
  51. GetTensors() error
  52. LoadVocab() error
  53. WriteGGUF(io.WriteSeeker) error
  54. }
  55. type ModelFormat interface {
  56. GetLayerName(string) (string, error)
  57. GetTensors(string, *Params) ([]llm.Tensor, error)
  58. GetParams(string) (*Params, error)
  59. GetModelArch(string, string, *Params) (ModelArch, error)
  60. }
  61. type ModelData struct {
  62. Path string
  63. Name string
  64. Params *Params
  65. Vocab *Vocab
  66. Tensors []llm.Tensor
  67. Format ModelFormat
  68. }
  69. func GetModelFormat(dirname string) (ModelFormat, error) {
  70. files, err := filepath.Glob(filepath.Join(dirname, "*"))
  71. if err != nil {
  72. return nil, err
  73. }
  74. for _, fn := range files {
  75. if strings.HasSuffix(fn, ".safetensors") {
  76. return &SafetensorFormat{}, nil
  77. } else if strings.HasSuffix(fn, ".bin") || strings.HasSuffix(fn, ".pth") {
  78. slog.Debug("model is torch")
  79. return &TorchFormat{}, nil
  80. }
  81. }
  82. return nil, fmt.Errorf("couldn't determine model format")
  83. }
  84. // Details on gguf's tokenizer can be found at:
  85. // https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#tokenizer
  86. type Vocab struct {
  87. Tokens []string
  88. Scores []float32
  89. Types []int32
  90. Merges []string
  91. }
  92. func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
  93. slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
  94. in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
  95. if err != nil {
  96. return nil, err
  97. }
  98. // To regenerate sentencepiece from the protobufs use:
  99. // protoc -I=./ --go_out=./ sentencepiece_model.proto
  100. modelProto := &sentencepiece.ModelProto{}
  101. if err := proto.Unmarshal(in, modelProto); err != nil {
  102. return nil, err
  103. }
  104. v := &Vocab{
  105. Tokens: make([]string, 0),
  106. Scores: make([]float32, 0),
  107. Types: make([]int32, 0),
  108. }
  109. pieces := modelProto.GetPieces()
  110. for _, p := range pieces {
  111. v.Tokens = append(v.Tokens, p.GetPiece())
  112. v.Scores = append(v.Scores, p.GetScore())
  113. t := p.GetType()
  114. switch t {
  115. case sentencepiece.ModelProto_SentencePiece_UNKNOWN:
  116. case sentencepiece.ModelProto_SentencePiece_CONTROL:
  117. case sentencepiece.ModelProto_SentencePiece_UNUSED:
  118. case sentencepiece.ModelProto_SentencePiece_BYTE:
  119. default:
  120. t = sentencepiece.ModelProto_SentencePiece_NORMAL
  121. }
  122. v.Types = append(v.Types, int32(t))
  123. }
  124. slog.Info(fmt.Sprintf("vocab size: %d", len(v.Tokens)))
  125. // add any additional tokens
  126. addIn, err := os.ReadFile(filepath.Join(dirpath, "added_tokens.json"))
  127. if os.IsNotExist(err) {
  128. return v, nil
  129. } else if err != nil {
  130. return nil, err
  131. }
  132. slog.Info("reading user defined tokens")
  133. var extraTokenData map[string]int
  134. if err := json.Unmarshal(addIn, &extraTokenData); err != nil {
  135. return nil, err
  136. }
  137. type token struct {
  138. key string
  139. pos int
  140. }
  141. extraTokens := make([]token, 0)
  142. for k, id := range extraTokenData {
  143. extraTokens = append(extraTokens, token{k, id})
  144. }
  145. slices.SortFunc(extraTokens, func(a, b token) int {
  146. return cmp.Compare(a.pos, b.pos)
  147. })
  148. numToks := len(v.Tokens)
  149. for cnt, t := range extraTokens {
  150. // the token id should match the specific index for the total number of tokens
  151. if t.pos != cnt+numToks {
  152. return nil, fmt.Errorf("token ID '%d' for '%s' doesn't match total token size", t.pos, t.key)
  153. }
  154. v.Tokens = append(v.Tokens, t.key)
  155. v.Scores = append(v.Scores, -1000.0)
  156. v.Types = append(v.Types, tokenTypeUserDefined)
  157. }
  158. slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
  159. if params.VocabSize > len(v.Tokens) {
  160. missingTokens := params.VocabSize - len(v.Tokens)
  161. slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
  162. for cnt := range missingTokens {
  163. v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
  164. v.Scores = append(v.Scores, -1)
  165. v.Types = append(v.Types, tokenTypeUserDefined)
  166. }
  167. }
  168. return v, nil
  169. }