tokenizer.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. package convert
  2. import (
  3. "cmp"
  4. "crypto/sha256"
  5. "encoding/json"
  6. "fmt"
  7. "log/slog"
  8. "os"
  9. "slices"
  10. "golang.org/x/exp/maps"
  11. )
  12. type Tokenizer struct {
  13. Version string `json:"version"`
  14. AddedTokens []Token `json:"added_tokens"`
  15. Model TokenizerModel `json:"model"`
  16. PreTokenizer struct {
  17. PreTokenizers []struct {
  18. Type string `json:"type"`
  19. Pattern struct {
  20. Regex string `json:"Regex"`
  21. } `json:"pattern"`
  22. } `json:"pretokenizers"`
  23. } `json:"pre_tokenizer"`
  24. }
  25. type TokenizerModel struct {
  26. Type string `json:"type"`
  27. Vocab map[string]int `json:"vocab"`
  28. Merges []string `json:"merges"`
  29. Tokens []Token
  30. }
  31. type Token struct {
  32. ID int `json:"id"`
  33. Content string `json:"content"`
  34. Special bool `json:"special"`
  35. UserDefined bool
  36. }
  37. func (t *Token) Type() int32 {
  38. switch {
  39. case t.Special:
  40. return tokenTypeControl
  41. case t.UserDefined:
  42. return tokenTypeUserDefined
  43. default:
  44. return tokenTypeNormal
  45. }
  46. }
  47. func (t *Tokenizer) maxID() int {
  48. return max(
  49. slices.Max(maps.Values(t.Model.Vocab)),
  50. slices.MaxFunc(t.AddedTokens, func(a, b Token) int {
  51. return cmp.Compare(a.ID, b.ID)
  52. }).ID,
  53. )
  54. }
  55. func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, err error) {
  56. f, err := os.Open(dirpath)
  57. if err != nil {
  58. panic(err)
  59. }
  60. defer f.Close()
  61. var t Tokenizer
  62. if err := json.NewDecoder(f).Decode(&t); err != nil {
  63. return "", nil, nil, err
  64. }
  65. tokens = make([]Token, t.maxID()+1)
  66. for k, v := range t.Model.Vocab {
  67. tokens[v] = Token{ID: v, Content: k, Special: false, UserDefined: false}
  68. }
  69. for _, v := range t.AddedTokens {
  70. v.UserDefined = true
  71. tokens[v.ID] = v
  72. }
  73. sha256sum := sha256.New()
  74. for _, pt := range t.PreTokenizer.PreTokenizers {
  75. if pt.Type == "Split" && pt.Pattern.Regex != "" {
  76. sha256sum.Write([]byte(pt.Pattern.Regex))
  77. }
  78. }
  79. switch digest := fmt.Sprintf("%x", sha256sum.Sum(nil)); digest {
  80. case "d98f9631be1e9607a9848c26c1f9eac1aa9fc21ac6ba82a2fc0741af9780a48f":
  81. pre = "llama-bpe"
  82. case "03df5c5863ad70781dcfdef491ead25140f895fe8010964be0daefe27be32b02":
  83. pre = "deepseek-llm"
  84. case "21cde974d587f0d54dc8d56b183cc1e6239600172035c68fbd6d4b9f8da0576e":
  85. pre = "deepseek-coder"
  86. default:
  87. slog.Warn("unknown pretokenizer, using default", "digest", digest)
  88. pre = "default"
  89. }
  90. return pre, tokens, t.Model.Merges, nil
  91. }