gemma.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package convert
  2. import (
  3. "fmt"
  4. "io"
  5. "log/slog"
  6. "strings"
  7. "github.com/pdevine/tensor"
  8. "github.com/pdevine/tensor/native"
  9. "github.com/ollama/ollama/llm"
  10. )
  11. type GemmaModel struct {
  12. ModelData
  13. }
  14. func addOnes(data []float32, vectorSize int) ([]float32, error) {
  15. n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
  16. ones := tensor.Ones(tensor.Float32, vectorSize)
  17. n, err := n.Add(ones)
  18. if err != nil {
  19. return nil, err
  20. }
  21. ts, err := native.SelectF32(n, 0)
  22. if err != nil {
  23. return nil, err
  24. }
  25. var f32s []float32
  26. for _, t := range ts {
  27. f32s = append(f32s, t...)
  28. }
  29. return f32s, nil
  30. }
  31. func (m *GemmaModel) GetTensors() error {
  32. t, err := m.Format.GetTensors(m.Path, m.Params)
  33. if err != nil {
  34. return err
  35. }
  36. slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
  37. for _, l := range t {
  38. if strings.HasSuffix(l.Name, "norm.weight") {
  39. wt := l.WriterTo.(safetensorWriterTo)
  40. wt.repacker = m.Repack
  41. l.WriterTo = wt
  42. }
  43. m.Tensors = append(m.Tensors, l)
  44. }
  45. return nil
  46. }
  47. func (m *GemmaModel) LoadVocab() error {
  48. v, err := LoadSentencePieceTokens(m.Path, m.Params)
  49. if err != nil {
  50. return err
  51. }
  52. m.Vocab = v
  53. return nil
  54. }
  55. func (m *GemmaModel) Repack(_ string, data []float32, shape []uint64) ([]float32, error) {
  56. return addOnes(data, int(shape[0]))
  57. }
  58. func (m *GemmaModel) WriteGGUF(ws io.WriteSeeker) error {
  59. kv := llm.KV{
  60. "general.architecture": "gemma",
  61. "general.name": m.Name,
  62. "gemma.context_length": uint32(m.Params.ContextSize),
  63. "gemma.embedding_length": uint32(m.Params.HiddenSize),
  64. "gemma.block_count": uint32(m.Params.HiddenLayers),
  65. "gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
  66. "gemma.attention.head_count": uint32(m.Params.AttentionHeads),
  67. "gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
  68. "gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
  69. "gemma.attention.key_length": uint32(m.Params.HeadDimension),
  70. "gemma.attention.value_length": uint32(m.Params.HeadDimension),
  71. "general.file_type": uint32(1),
  72. "tokenizer.ggml.model": "llama",
  73. "tokenizer.ggml.tokens": m.Vocab.Tokens,
  74. "tokenizer.ggml.scores": m.Vocab.Scores,
  75. "tokenizer.ggml.token_type": m.Vocab.Types,
  76. "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
  77. "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
  78. "tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
  79. "tokenizer.ggml.unknown_token_id": uint32(3),
  80. "tokenizer.ggml.add_bos_token": true,
  81. "tokenizer.ggml.add_eos_token": false,
  82. }
  83. return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
  84. }