safetensors.go 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. package convert
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "os"
  9. "path/filepath"
  10. "regexp"
  11. "slices"
  12. "strings"
  13. "github.com/d4l3k/go-bfloat16"
  14. "github.com/x448/float16"
  15. "github.com/ollama/ollama/llm"
  16. )
  17. type safetensorWriterTo struct {
  18. t *llm.Tensor
  19. params *Params
  20. bo ByteOrder
  21. filename string
  22. dtype string
  23. offset, size int64
  24. repacker func(string, []float32, []uint64) ([]float32, error)
  25. }
  26. type safetensorMetadata struct {
  27. Type string `json:"dtype"`
  28. Shape []uint64 `json:"shape"`
  29. Offsets []int64 `json:"data_offsets"`
  30. }
  31. type SafetensorFormat struct{}
  32. func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
  33. var tensors []llm.Tensor
  34. matches, err := filepath.Glob(filepath.Join(dirpath, "*.safetensors"))
  35. if err != nil {
  36. return nil, err
  37. }
  38. var offset uint64
  39. for _, f := range matches {
  40. var t []llm.Tensor
  41. var err error
  42. t, offset, err = m.readTensors(f, offset, params)
  43. if err != nil {
  44. return nil, err
  45. }
  46. tensors = append(tensors, t...)
  47. }
  48. return tensors, nil
  49. }
  50. func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
  51. f, err := os.Open(fn)
  52. if err != nil {
  53. return nil, 0, err
  54. }
  55. defer f.Close()
  56. var n int64
  57. if err := binary.Read(f, binary.LittleEndian, &n); err != nil {
  58. return nil, 0, err
  59. }
  60. b := bytes.NewBuffer(make([]byte, 0, n))
  61. if _, err = io.CopyN(b, f, n); err != nil {
  62. return nil, 0, err
  63. }
  64. var headers map[string]safetensorMetadata
  65. if err := json.NewDecoder(b).Decode(&headers); err != nil {
  66. return nil, 0, err
  67. }
  68. var keys []string
  69. for key := range headers {
  70. if !strings.HasSuffix(key, "self_attn.rotary_embd.inv_freq") {
  71. keys = append(keys, key)
  72. }
  73. }
  74. slices.Sort(keys)
  75. var tensors []llm.Tensor
  76. for _, key := range keys {
  77. value := headers[key]
  78. var kind uint32
  79. switch len(value.Shape) {
  80. case 0:
  81. // valuedata
  82. continue
  83. case 2:
  84. kind = 1
  85. }
  86. name, err := m.GetLayerName(key)
  87. if err != nil {
  88. return nil, 0, err
  89. }
  90. shape := make([]uint64, len(value.Shape))
  91. copy(shape, value.Shape)
  92. pad := func(s int64) int64 {
  93. return 8 + n + s
  94. }
  95. t := llm.Tensor{
  96. Name: name,
  97. Kind: kind,
  98. Offset: offset,
  99. Shape: shape,
  100. }
  101. t.WriterTo = safetensorWriterTo{
  102. t: &t,
  103. params: params,
  104. bo: params.ByteOrder,
  105. filename: fn,
  106. dtype: value.Type,
  107. offset: pad(value.Offsets[0]),
  108. size: pad(value.Offsets[1]) - pad(value.Offsets[0]),
  109. }
  110. offset += t.Size()
  111. tensors = append(tensors, t)
  112. }
  113. return tensors, offset, nil
  114. }
  115. func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
  116. f, err := os.Open(filepath.Join(dirpath, "config.json"))
  117. if err != nil {
  118. return nil, err
  119. }
  120. defer f.Close()
  121. var params Params
  122. if err := json.NewDecoder(f).Decode(&params); err != nil {
  123. return nil, err
  124. }
  125. params.ByteOrder = binary.LittleEndian
  126. return &params, nil
  127. }
  128. func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
  129. directMap := map[string]string{
  130. "model.embed_tokens.weight": "token_embd.weight",
  131. "lm_head.weight": "output.weight",
  132. "model.norm.weight": "output_norm.weight",
  133. }
  134. tMap := map[string]string{
  135. "model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
  136. "model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
  137. "model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
  138. "model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
  139. "model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
  140. "model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
  141. "model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
  142. "model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
  143. "model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
  144. "model.layers.(\\d+).block_sparse_moe.gate.weight": "blk.$1.ffn_gate_inp.weight",
  145. "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w1.weight": "blk.$1.ffn_gate.$2.weight",
  146. "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w2.weight": "blk.$1.ffn_down.$2.weight",
  147. "model.layers.(\\d+).block_sparse_moe.experts.(\\d+).w3.weight": "blk.$1.ffn_up.$2.weight",
  148. }
  149. v, ok := directMap[n]
  150. if ok {
  151. return v, nil
  152. }
  153. // quick hack to rename the layers to gguf format
  154. for k, v := range tMap {
  155. re := regexp.MustCompile(k)
  156. newName := re.ReplaceAllString(n, v)
  157. if newName != n {
  158. return newName, nil
  159. }
  160. }
  161. return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
  162. }
  163. func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
  164. f, err := os.Open(r.filename)
  165. if err != nil {
  166. return 0, err
  167. }
  168. defer f.Close()
  169. if _, err = f.Seek(r.offset, io.SeekStart); err != nil {
  170. return 0, err
  171. }
  172. var f32s []float32
  173. switch r.dtype {
  174. case "F32":
  175. f32s = make([]float32, r.size/4)
  176. if err = binary.Read(f, r.bo, f32s); err != nil {
  177. return 0, err
  178. }
  179. case "F16":
  180. u16s := make([]uint16, r.size/2)
  181. if err = binary.Read(f, r.bo, u16s); err != nil {
  182. return 0, err
  183. }
  184. for _, b := range u16s {
  185. f32s = append(f32s, float16.Frombits(b).Float32())
  186. }
  187. case "BF16":
  188. u8s := make([]uint8, r.size)
  189. if err = binary.Read(f, r.bo, u8s); err != nil {
  190. return 0, err
  191. }
  192. f32s = bfloat16.DecodeFloat32(u8s)
  193. default:
  194. return 0, fmt.Errorf("unknown data type: %s", r.dtype)
  195. }
  196. if r.repacker != nil {
  197. f32s, err = r.repacker(r.t.Name, f32s, r.t.Shape)
  198. if err != nil {
  199. return 0, err
  200. }
  201. }
  202. switch r.t.Kind {
  203. case 0:
  204. return 0, binary.Write(w, r.bo, f32s)
  205. case 1:
  206. f16s := make([]uint16, len(f32s))
  207. for i := range f32s {
  208. f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
  209. }
  210. return 0, binary.Write(w, r.bo, f16s)
  211. default:
  212. return 0, fmt.Errorf("unknown storage type: %d", r.t.Kind)
  213. }
  214. }
  215. func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
  216. switch len(params.Architectures) {
  217. case 0:
  218. return nil, fmt.Errorf("No architecture specified to convert")
  219. case 1:
  220. switch params.Architectures[0] {
  221. case "LlamaForCausalLM":
  222. return &LlamaModel{
  223. ModelData{
  224. Name: name,
  225. Path: dirPath,
  226. Params: params,
  227. Format: m,
  228. },
  229. }, nil
  230. case "MistralForCausalLM":
  231. return &MistralModel{
  232. ModelData{
  233. Name: name,
  234. Path: dirPath,
  235. Params: params,
  236. Format: m,
  237. },
  238. }, nil
  239. case "MixtralForCausalLM":
  240. return &MixtralModel{
  241. ModelData{
  242. Name: name,
  243. Path: dirPath,
  244. Params: params,
  245. Format: m,
  246. },
  247. }, nil
  248. case "GemmaForCausalLM":
  249. return &GemmaModel{
  250. ModelData{
  251. Name: name,
  252. Path: dirPath,
  253. Params: params,
  254. Format: m,
  255. },
  256. }, nil
  257. default:
  258. return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
  259. }
  260. }
  261. return nil, fmt.Errorf("Unknown error")
  262. }