gguf.go 14 KB


  1. package llm
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "strings"
  9. )
  10. type containerGGUF struct {
  11. ByteOrder binary.ByteOrder
  12. Version uint32
  13. V1 struct {
  14. NumTensor uint32
  15. NumKV uint32
  16. }
  17. V2 struct {
  18. NumTensor uint64
  19. NumKV uint64
  20. }
  21. V3 struct {
  22. NumTensor uint64
  23. NumKV uint64
  24. }
  25. maxArraySize int
  26. }
  27. func (c *containerGGUF) canCollectArray(size int) bool {
  28. return c.maxArraySize < 0 || size <= c.maxArraySize
  29. }
  30. func (c *containerGGUF) Name() string {
  31. return "gguf"
  32. }
  33. func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
  34. if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
  35. return nil, err
  36. }
  37. var err error
  38. switch c.Version {
  39. case 1:
  40. err = binary.Read(rs, c.ByteOrder, &c.V1)
  41. case 2:
  42. err = binary.Read(rs, c.ByteOrder, &c.V2)
  43. default:
  44. err = binary.Read(rs, c.ByteOrder, &c.V3)
  45. }
  46. if err != nil {
  47. return nil, err
  48. }
  49. model := newGGUF(c)
  50. if err := model.Decode(rs); err != nil {
  51. return nil, err
  52. }
  53. return model, nil
  54. }
  55. const (
  56. ggufTypeUint8 uint32 = iota
  57. ggufTypeInt8
  58. ggufTypeUint16
  59. ggufTypeInt16
  60. ggufTypeUint32
  61. ggufTypeInt32
  62. ggufTypeFloat32
  63. ggufTypeBool
  64. ggufTypeString
  65. ggufTypeArray
  66. ggufTypeUint64
  67. ggufTypeInt64
  68. ggufTypeFloat64
  69. )
  70. type gguf struct {
  71. *containerGGUF
  72. kv KV
  73. tensors []*Tensor
  74. parameters uint64
  75. scratch [16 << 10]byte
  76. }
  77. func newGGUF(container *containerGGUF) *gguf {
  78. return &gguf{
  79. containerGGUF: container,
  80. kv: make(KV),
  81. }
  82. }
  83. func NewGGUFV3(bo binary.ByteOrder) *gguf {
  84. return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
  85. }
  86. func (llm *gguf) KV() KV {
  87. return llm.kv
  88. }
  89. func (llm *gguf) Tensors() Tensors {
  90. return llm.tensors
  91. }
  92. func (llm *gguf) numTensor() uint64 {
  93. switch llm.Version {
  94. case 1:
  95. return uint64(llm.V1.NumTensor)
  96. case 2:
  97. return llm.V2.NumTensor
  98. default:
  99. return llm.V3.NumTensor
  100. }
  101. }
  102. func (llm *gguf) numKV() uint64 {
  103. switch llm.Version {
  104. case 1:
  105. return uint64(llm.V1.NumKV)
  106. case 2:
  107. return llm.V2.NumKV
  108. default:
  109. return llm.V3.NumKV
  110. }
  111. }
  112. func (llm *gguf) Decode(rs io.ReadSeeker) error {
  113. // decode key-values
  114. for i := 0; uint64(i) < llm.numKV(); i++ {
  115. k, err := readGGUFString(llm, rs)
  116. if err != nil {
  117. return err
  118. }
  119. t, err := readGGUF[uint32](llm, rs)
  120. if err != nil {
  121. return err
  122. }
  123. var v any
  124. switch t {
  125. case ggufTypeUint8:
  126. v, err = readGGUF[uint8](llm, rs)
  127. case ggufTypeInt8:
  128. v, err = readGGUF[int8](llm, rs)
  129. case ggufTypeUint16:
  130. v, err = readGGUF[uint16](llm, rs)
  131. case ggufTypeInt16:
  132. v, err = readGGUF[int16](llm, rs)
  133. case ggufTypeUint32:
  134. v, err = readGGUF[uint32](llm, rs)
  135. case ggufTypeInt32:
  136. v, err = readGGUF[int32](llm, rs)
  137. case ggufTypeUint64:
  138. v, err = readGGUF[uint64](llm, rs)
  139. case ggufTypeInt64:
  140. v, err = readGGUF[int64](llm, rs)
  141. case ggufTypeFloat32:
  142. v, err = readGGUF[float32](llm, rs)
  143. case ggufTypeFloat64:
  144. v, err = readGGUF[float64](llm, rs)
  145. case ggufTypeBool:
  146. v, err = readGGUF[bool](llm, rs)
  147. case ggufTypeString:
  148. v, err = readGGUFString(llm, rs)
  149. case ggufTypeArray:
  150. v, err = readGGUFArray(llm, rs)
  151. default:
  152. return fmt.Errorf("invalid type: %d", t)
  153. }
  154. if err != nil {
  155. return err
  156. }
  157. llm.kv[k] = v
  158. }
  159. // decode tensors
  160. for range llm.numTensor() {
  161. name, err := readGGUFString(llm, rs)
  162. if err != nil {
  163. return fmt.Errorf("failed to read tensor name: %w", err)
  164. }
  165. // dims is the number of dimensions in the tensor
  166. dims, err := readGGUF[uint32](llm, rs)
  167. if err != nil {
  168. return fmt.Errorf("failed to read tensor dimensions: %w", err)
  169. }
  170. shape := [4]uint64{1, 1, 1, 1}
  171. for i := 0; uint32(i) < dims; i++ {
  172. shape[i], err = readGGUF[uint64](llm, rs)
  173. if err != nil {
  174. return fmt.Errorf("failed to read tensor shape: %w", err)
  175. }
  176. }
  177. kind, err := readGGUF[uint32](llm, rs)
  178. if err != nil {
  179. return fmt.Errorf("failed to read tensor kind: %w", err)
  180. }
  181. offset, err := readGGUF[uint64](llm, rs)
  182. if err != nil {
  183. return fmt.Errorf("failed to read tensor offset: %w", err)
  184. }
  185. tensor := Tensor{
  186. Name: name,
  187. Kind: kind,
  188. Offset: offset,
  189. Shape: shape[:],
  190. }
  191. llm.tensors = append(llm.tensors, &tensor)
  192. llm.parameters += tensor.parameters()
  193. }
  194. // patch KV with parameter count
  195. llm.kv["general.parameter_count"] = llm.parameters
  196. alignment, ok := llm.kv["general.alignment"].(uint32)
  197. if !ok {
  198. alignment = 32
  199. }
  200. for _, tensor := range llm.tensors {
  201. offset, err := rs.Seek(0, io.SeekCurrent)
  202. if err != nil {
  203. return fmt.Errorf("failed to get current offset: %w", err)
  204. }
  205. padding := llm.padding(offset, int64(alignment))
  206. if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
  207. return fmt.Errorf("failed to seek to init padding: %w", err)
  208. }
  209. if _, err := rs.Seek(int64(tensor.Size()), io.SeekCurrent); err != nil {
  210. return fmt.Errorf("failed to seek to tensor: %w", err)
  211. }
  212. }
  213. return nil
  214. }
  215. func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
  216. var t T
  217. err := binary.Read(r, llm.ByteOrder, &t)
  218. return t, err
  219. }
  220. func writeGGUF[V any](llm *gguf, w io.Writer, t uint32, v V) error {
  221. if err := binary.Write(w, llm.ByteOrder, t); err != nil {
  222. return err
  223. }
  224. return binary.Write(w, llm.ByteOrder, v)
  225. }
  226. func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
  227. var length uint64
  228. if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
  229. return "", err
  230. }
  231. var b bytes.Buffer
  232. if _, err := io.CopyN(&b, r, int64(length)); err != nil {
  233. return "", err
  234. }
  235. // gguf v1 strings are null-terminated
  236. b.Truncate(b.Len() - 1)
  237. return b.String(), nil
  238. }
  239. func discardGGUFString(llm *gguf, r io.Reader) error {
  240. buf := llm.scratch[:8]
  241. _, err := io.ReadFull(r, buf)
  242. if err != nil {
  243. return err
  244. }
  245. size := int(llm.ByteOrder.Uint64(buf))
  246. for size > 0 {
  247. n, err := r.Read(llm.scratch[:min(size, cap(llm.scratch))])
  248. if err != nil {
  249. return err
  250. }
  251. size -= n
  252. }
  253. return nil
  254. }
  255. func readGGUFString(llm *gguf, r io.Reader) (string, error) {
  256. if llm.Version == 1 {
  257. return readGGUFV1String(llm, r)
  258. }
  259. buf := llm.scratch[:8]
  260. _, err := io.ReadFull(r, buf)
  261. if err != nil {
  262. return "", err
  263. }
  264. length := int(llm.ByteOrder.Uint64(buf))
  265. if length > len(llm.scratch) {
  266. buf = make([]byte, length)
  267. } else {
  268. buf = llm.scratch[:length]
  269. }
  270. clear(buf)
  271. _, err = io.ReadFull(r, buf)
  272. if err != nil {
  273. return "", err
  274. }
  275. return string(buf), nil
  276. }
  277. func writeGGUFString(llm *gguf, w io.Writer, s string) error {
  278. if err := binary.Write(w, llm.ByteOrder, ggufTypeString); err != nil {
  279. return err
  280. }
  281. if err := binary.Write(w, llm.ByteOrder, uint64(len(s))); err != nil {
  282. return err
  283. }
  284. _, err := io.Copy(w, strings.NewReader(s))
  285. return err
  286. }
  287. type array struct {
  288. size int
  289. values []any
  290. }
  291. func (a *array) MarshalJSON() ([]byte, error) {
  292. return json.Marshal(a.values)
  293. }
  294. func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
  295. t, err := readGGUF[uint32](llm, r)
  296. if err != nil {
  297. return nil, err
  298. }
  299. n, err := readGGUF[uint32](llm, r)
  300. if err != nil {
  301. return nil, err
  302. }
  303. a := &array{size: int(n)}
  304. if llm.canCollectArray(int(n)) {
  305. a.values = make([]any, 0, int(n))
  306. }
  307. for i := range n {
  308. var e any
  309. switch t {
  310. case ggufTypeUint8:
  311. e, err = readGGUF[uint8](llm, r)
  312. case ggufTypeInt8:
  313. e, err = readGGUF[int8](llm, r)
  314. case ggufTypeUint16:
  315. e, err = readGGUF[uint16](llm, r)
  316. case ggufTypeInt16:
  317. e, err = readGGUF[int16](llm, r)
  318. case ggufTypeUint32:
  319. e, err = readGGUF[uint32](llm, r)
  320. case ggufTypeInt32:
  321. e, err = readGGUF[int32](llm, r)
  322. case ggufTypeUint64:
  323. e, err = readGGUF[uint64](llm, r)
  324. case ggufTypeInt64:
  325. e, err = readGGUF[int64](llm, r)
  326. case ggufTypeFloat32:
  327. e, err = readGGUF[float32](llm, r)
  328. case ggufTypeFloat64:
  329. e, err = readGGUF[float64](llm, r)
  330. case ggufTypeBool:
  331. e, err = readGGUF[bool](llm, r)
  332. case ggufTypeString:
  333. e, err = readGGUFV1String(llm, r)
  334. default:
  335. return nil, fmt.Errorf("invalid array type: %d", t)
  336. }
  337. if err != nil {
  338. return nil, err
  339. }
  340. if a.values != nil {
  341. a.values[i] = e
  342. }
  343. }
  344. return a, nil
  345. }
  346. func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
  347. if llm.Version == 1 {
  348. return readGGUFV1Array(llm, r)
  349. }
  350. t, err := readGGUF[uint32](llm, r)
  351. if err != nil {
  352. return nil, err
  353. }
  354. n, err := readGGUF[uint64](llm, r)
  355. if err != nil {
  356. return nil, err
  357. }
  358. a := &array{size: int(n)}
  359. if llm.canCollectArray(int(n)) {
  360. a.values = make([]any, int(n))
  361. }
  362. for i := range n {
  363. var e any
  364. switch t {
  365. case ggufTypeUint8:
  366. e, err = readGGUF[uint8](llm, r)
  367. case ggufTypeInt8:
  368. e, err = readGGUF[int8](llm, r)
  369. case ggufTypeUint16:
  370. e, err = readGGUF[uint16](llm, r)
  371. case ggufTypeInt16:
  372. e, err = readGGUF[int16](llm, r)
  373. case ggufTypeUint32:
  374. e, err = readGGUF[uint32](llm, r)
  375. case ggufTypeInt32:
  376. e, err = readGGUF[int32](llm, r)
  377. case ggufTypeUint64:
  378. e, err = readGGUF[uint64](llm, r)
  379. case ggufTypeInt64:
  380. e, err = readGGUF[int64](llm, r)
  381. case ggufTypeFloat32:
  382. e, err = readGGUF[float32](llm, r)
  383. case ggufTypeFloat64:
  384. e, err = readGGUF[float64](llm, r)
  385. case ggufTypeBool:
  386. e, err = readGGUF[bool](llm, r)
  387. case ggufTypeString:
  388. if a.values != nil {
  389. e, err = readGGUFString(llm, r)
  390. } else {
  391. err = discardGGUFString(llm, r)
  392. }
  393. default:
  394. return nil, fmt.Errorf("invalid array type: %d", t)
  395. }
  396. if err != nil {
  397. return nil, err
  398. }
  399. if a.values != nil {
  400. a.values[i] = e
  401. }
  402. }
  403. return a, nil
  404. }
  405. func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {
  406. if err := binary.Write(w, llm.ByteOrder, ggufTypeArray); err != nil {
  407. return err
  408. }
  409. if err := binary.Write(w, llm.ByteOrder, t); err != nil {
  410. return err
  411. }
  412. if err := binary.Write(w, llm.ByteOrder, uint64(len(s))); err != nil {
  413. return err
  414. }
  415. for _, e := range s {
  416. if err := binary.Write(w, llm.ByteOrder, e); err != nil {
  417. return err
  418. }
  419. }
  420. return nil
  421. }
  422. var ggufKVOrder = map[string][]string{
  423. "llama": {
  424. "general.architecture",
  425. "general.name",
  426. "llama.vocab_size",
  427. "llama.context_length",
  428. "llama.embedding_length",
  429. "llama.block_count",
  430. "llama.feed_forward_length",
  431. "llama.attention.head_count",
  432. "llama.attention.head_count_kv",
  433. "llama.attention.layer_norm_rms_epsilon",
  434. "llama.rope.freq_base",
  435. "llama.rope.dimension_count",
  436. "llama.expert_count",
  437. "llama.expert_used_count",
  438. "gemma.context_length",
  439. "gemma.embedding_length",
  440. "gemma.block_count",
  441. "gemma.feed_forward_length",
  442. "gemma.attention.head_count",
  443. "gemma.attention.head_count_kv",
  444. "gemma.attention.layer_norm_rms_epsilon",
  445. "gemma.attention.key_length",
  446. "gemma.attention.value_length",
  447. "general.file_type",
  448. "tokenizer.ggml.pre",
  449. "tokenizer.ggml.model",
  450. "tokenizer.ggml.tokens",
  451. "tokenizer.ggml.scores",
  452. "tokenizer.ggml.merges",
  453. "tokenizer.ggml.token_type",
  454. "tokenizer.ggml.bos_token_id",
  455. "tokenizer.ggml.eos_token_id",
  456. "tokenizer.ggml.unknown_token_id",
  457. "tokenizer.ggml.padding_token_id",
  458. "tokenizer.ggml.add_bos_token",
  459. "tokenizer.ggml.add_eos_token",
  460. "tokenizer.chat_template",
  461. },
  462. }
  463. func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
  464. switch llm.Version {
  465. case 3:
  466. llm.V3.NumTensor = uint64(len(tensors))
  467. llm.V3.NumKV = uint64(len(kv))
  468. default:
  469. return fmt.Errorf("not implemented: ggufv%d", llm.Version)
  470. }
  471. if err := binary.Write(ws, llm.ByteOrder, []byte("GGUF")); err != nil {
  472. return err
  473. }
  474. if err := binary.Write(ws, llm.ByteOrder, llm.Version); err != nil {
  475. return err
  476. }
  477. if err := binary.Write(ws, llm.ByteOrder, llm.numTensor()); err != nil {
  478. return err
  479. }
  480. if err := binary.Write(ws, llm.ByteOrder, llm.numKV()); err != nil {
  481. return err
  482. }
  483. kvCheck := make(map[string]bool)
  484. for k := range kv {
  485. kvCheck[k] = false
  486. }
  487. for _, k := range ggufKVOrder["llama"] {
  488. v, ok := kv[k]
  489. if !ok {
  490. continue
  491. }
  492. kvCheck[k] = true
  493. if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
  494. return err
  495. }
  496. if err := binary.Write(ws, llm.ByteOrder, []byte(k)); err != nil {
  497. return err
  498. }
  499. var err error
  500. switch v := v.(type) {
  501. case uint32:
  502. err = writeGGUF(llm, ws, ggufTypeUint32, v)
  503. case float32:
  504. err = writeGGUF(llm, ws, ggufTypeFloat32, v)
  505. case bool:
  506. err = writeGGUF(llm, ws, ggufTypeBool, v)
  507. case string:
  508. err = writeGGUFString(llm, ws, v)
  509. case []int32:
  510. err = writeGGUFArray(llm, ws, ggufTypeInt32, v)
  511. case []uint32:
  512. err = writeGGUFArray(llm, ws, ggufTypeUint32, v)
  513. case []float32:
  514. err = writeGGUFArray(llm, ws, ggufTypeFloat32, v)
  515. case []string:
  516. if err := binary.Write(ws, llm.ByteOrder, ggufTypeArray); err != nil {
  517. return err
  518. }
  519. if err := binary.Write(ws, llm.ByteOrder, ggufTypeString); err != nil {
  520. return err
  521. }
  522. if err := binary.Write(ws, llm.ByteOrder, uint64(len(v))); err != nil {
  523. return err
  524. }
  525. for _, e := range v {
  526. if err := binary.Write(ws, llm.ByteOrder, uint64(len(e))); err != nil {
  527. return err
  528. }
  529. if err := binary.Write(ws, llm.ByteOrder, []byte(e)); err != nil {
  530. return err
  531. }
  532. }
  533. default:
  534. return fmt.Errorf("improper type for '%s'", k)
  535. }
  536. if err != nil {
  537. return err
  538. }
  539. }
  540. for k, v := range kvCheck {
  541. if !v {
  542. return fmt.Errorf("Didn't know how to write kv %s", k)
  543. }
  544. }
  545. for _, tensor := range tensors {
  546. if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
  547. return err
  548. }
  549. if err := binary.Write(ws, llm.ByteOrder, []byte(tensor.Name)); err != nil {
  550. return err
  551. }
  552. var dims int
  553. for cnt := range len(tensor.Shape) {
  554. if tensor.Shape[cnt] > 0 {
  555. dims++
  556. }
  557. }
  558. if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
  559. return err
  560. }
  561. for i := range dims {
  562. if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
  563. return err
  564. }
  565. }
  566. if err := binary.Write(ws, llm.ByteOrder, tensor.Kind); err != nil {
  567. return err
  568. }
  569. if err := binary.Write(ws, llm.ByteOrder, tensor.Offset); err != nil {
  570. return err
  571. }
  572. }
  573. var alignment int64 = 32
  574. for _, tensor := range tensors {
  575. offset, err := ws.Seek(0, io.SeekCurrent)
  576. if err != nil {
  577. return err
  578. }
  579. padding := llm.padding(offset, alignment)
  580. if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
  581. return err
  582. }
  583. if _, err := tensor.WriteTo(ws); err != nil {
  584. return err
  585. }
  586. }
  587. return nil
  588. }
  589. func (gguf) padding(offset, align int64) int64 {
  590. return (align - offset%align) % align
  591. }