torch.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. package convert
  2. import (
  3. "encoding/binary"
  4. "encoding/json"
  5. "fmt"
  6. "io"
  7. "log/slog"
  8. "os"
  9. "path/filepath"
  10. "regexp"
  11. "strings"
  12. "github.com/nlpodyssey/gopickle/pytorch"
  13. "github.com/nlpodyssey/gopickle/types"
  14. "github.com/x448/float16"
  15. "github.com/ollama/ollama/llm"
  16. )
  17. type torchWriterTo struct {
  18. t *llm.Tensor
  19. params *Params
  20. bo ByteOrder
  21. storage pytorch.StorageInterface
  22. repacker func(string, []float32, []uint64) ([]float32, error)
  23. }
  24. type TorchFormat struct{}
  25. func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
  26. slog.Debug("getting torch tensors")
  27. var files []string
  28. if pt, _ := filepath.Glob(filepath.Join(dirpath, "consolidated*.pth")); len(pt) > 0 {
  29. files = append(files, pt...)
  30. } else if pt, _ := filepath.Glob(filepath.Join(dirpath, "pytorch_model*.pth")); len(pt) > 0 {
  31. files = append(files, pt...)
  32. }
  33. var offset uint64
  34. var tensors []llm.Tensor
  35. for _, fn := range files {
  36. m, err := pytorch.Load(fn)
  37. if err != nil {
  38. slog.Error(fmt.Sprintf("error unpickling: %q", err))
  39. return []llm.Tensor{}, err
  40. }
  41. for _, k := range m.(*types.Dict).Keys() {
  42. if strings.HasSuffix(k.(string), "self_attn.rotary_emb.inv_freq") {
  43. continue
  44. }
  45. t, _ := m.(*types.Dict).Get(k)
  46. tshape := t.(*pytorch.Tensor).Size
  47. var size uint64
  48. var kind uint32
  49. switch len(tshape) {
  50. case 0:
  51. continue
  52. case 1:
  53. // convert to float32
  54. kind = 0
  55. size = uint64(tshape[0] * 4)
  56. case 2:
  57. // convert to float16
  58. kind = 1
  59. size = uint64(tshape[0] * tshape[1] * 2)
  60. }
  61. ggufName, err := tf.GetLayerName(k.(string))
  62. if err != nil {
  63. slog.Error(err.Error())
  64. return nil, err
  65. }
  66. slog.Debug(fmt.Sprintf("'%35s': '%30s' %10d [%#v]", k.(string), ggufName, size, tshape))
  67. shape := []uint64{0, 0, 0, 0}
  68. for i := range tshape {
  69. shape[i] = uint64(tshape[i])
  70. }
  71. tensor := llm.Tensor{
  72. Name: ggufName,
  73. Kind: kind,
  74. Offset: offset, // calculate the offset
  75. Shape: shape,
  76. }
  77. tensor.WriterTo = torchWriterTo{
  78. t: &tensor,
  79. params: params,
  80. bo: params.ByteOrder,
  81. storage: t.(*pytorch.Tensor).Source,
  82. }
  83. tensors = append(tensors, tensor)
  84. offset += size
  85. }
  86. }
  87. return tensors, nil
  88. }
  89. func getAltParams(dirpath string) (*Params, error) {
  90. f, err := os.Open(filepath.Join(dirpath, "params.json"))
  91. if err != nil {
  92. slog.Error("no params.json")
  93. return nil, err
  94. }
  95. defer f.Close()
  96. type TorchParams struct {
  97. HiddenSize int `json:"dim"`
  98. AttentionHeads int `json:"n_heads"`
  99. KeyValHeads int `json:"n_kv_heads"`
  100. HiddenLayers int `json:"n_layers"`
  101. RopeTheta float64 `json:"rope_theta"`
  102. NormEPS float64 `json:"norm_eps"`
  103. }
  104. var tparams TorchParams
  105. d := json.NewDecoder(f)
  106. err = d.Decode(&tparams)
  107. if err != nil {
  108. return nil, err
  109. }
  110. params := &Params{
  111. Architectures: []string{"LlamaForCausalLM"},
  112. HiddenSize: tparams.HiddenSize,
  113. AttentionHeads: tparams.AttentionHeads,
  114. KeyValHeads: tparams.KeyValHeads,
  115. HiddenLayers: tparams.HiddenLayers,
  116. NormEPS: tparams.NormEPS,
  117. }
  118. switch {
  119. case tparams.RopeTheta == 1000000:
  120. // Codellama
  121. params.ContextSize = 16384
  122. case tparams.NormEPS == 1e-06:
  123. // llama2
  124. slog.Debug("Found llama2 - setting context size to 4096")
  125. params.ContextSize = 4096
  126. default:
  127. params.ContextSize = 2048
  128. }
  129. params.ByteOrder = binary.LittleEndian
  130. return params, nil
  131. }
  132. func (m *TorchFormat) GetParams(dirpath string) (*Params, error) {
  133. f, err := os.Open(filepath.Join(dirpath, "config.json"))
  134. if err != nil {
  135. if os.IsNotExist(err) {
  136. // try params.json instead
  137. return getAltParams(dirpath)
  138. } else {
  139. return nil, err
  140. }
  141. }
  142. var params Params
  143. d := json.NewDecoder(f)
  144. err = d.Decode(&params)
  145. if err != nil {
  146. return nil, err
  147. }
  148. params.ByteOrder = binary.LittleEndian
  149. return &params, nil
  150. }
  151. func (m *TorchFormat) GetLayerName(n string) (string, error) {
  152. directMap := map[string]string{
  153. "tok_embeddings.weight": "token_embd.weight",
  154. "output.weight": "output.weight",
  155. "norm.weight": "output_norm.weight",
  156. "rope.freqs": "rope_freqs.weight",
  157. "model.embed_tokens.weight": "token_embd.weight",
  158. "lm_head.weight": "output.weight",
  159. "model.norm.weight": "output_norm.weight",
  160. }
  161. lMap := map[string]string{
  162. "layers.(\\d+).attention_norm.weight": "blk.$1.attn_norm.weight",
  163. "layers.(\\d+).attention_output_norm.weight": "blk.$1.attn_norm.weight",
  164. "layers.(\\d+).feed_forward.w2.weight": "blk.$1.ffn_down.weight",
  165. "layers.(\\d+).feed_forward.w1.weight": "blk.$1.ffn_gate.weight",
  166. "layers.(\\d+).feed_forward.w3.weight": "blk.$1.ffn_up.weight",
  167. "layers.(\\d+).ffn_norm.weight": "blk.$1.ffn_norm.weight",
  168. "layers.(\\d+).attention.wk.weight": "blk.$1.attn_k.weight",
  169. "layers.(\\d+).attention.wo.weight": "blk.$1.attn_output.weight",
  170. "layers.(\\d+).attention.wq.weight": "blk.$1.attn_q.weight",
  171. "layers.(\\d+).attention.wv.weight": "blk.$1.attn_v.weight",
  172. "model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
  173. "model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
  174. "model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
  175. "model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
  176. "model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
  177. "model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
  178. "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
  179. "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
  180. "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
  181. }
  182. v, ok := directMap[n]
  183. if ok {
  184. return v, nil
  185. }
  186. // quick hack to rename the layers to gguf format
  187. for k, v := range lMap {
  188. re := regexp.MustCompile(k)
  189. newName := re.ReplaceAllString(n, v)
  190. if newName != n {
  191. return newName, nil
  192. }
  193. }
  194. return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
  195. }
  196. func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) {
  197. var f32s []float32
  198. switch s := r.storage.(type) {
  199. case *pytorch.FloatStorage:
  200. f32s = s.Data
  201. case *pytorch.HalfStorage:
  202. f32s = s.Data
  203. case *pytorch.BFloat16Storage:
  204. f32s = s.Data
  205. default:
  206. return 0, fmt.Errorf("unknown data type: %T", s)
  207. }
  208. if r.repacker != nil {
  209. f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape)
  210. if err != nil {
  211. return 0, err
  212. }
  213. }
  214. switch r.t.Kind {
  215. case 0:
  216. return 0, binary.Write(w, r.bo, f32s)
  217. case 1:
  218. f16s := make([]uint16, len(f32s))
  219. for i := range f32s {
  220. f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
  221. }
  222. return 0, binary.Write(w, r.bo, f16s)
  223. default:
  224. return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind)
  225. }
  226. }
  227. func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
  228. switch len(params.Architectures) {
  229. case 0:
  230. return nil, fmt.Errorf("No architecture specified to convert")
  231. case 1:
  232. switch params.Architectures[0] {
  233. case "LlamaForCausalLM":
  234. return &LlamaModel{
  235. ModelData{
  236. Name: name,
  237. Path: dirPath,
  238. Params: params,
  239. Format: m,
  240. },
  241. }, nil
  242. default:
  243. return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
  244. }
  245. }
  246. return nil, fmt.Errorf("Unknown error")
  247. }