writer.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package lzw
  5. import (
  6. "bufio"
  7. "errors"
  8. "fmt"
  9. "io"
  10. )
  11. // A writer is a buffered, flushable writer.
  12. type writer interface {
  13. io.ByteWriter
  14. Flush() error
  15. }
  16. // An errWriteCloser is an io.WriteCloser that always returns a given error.
  17. type errWriteCloser struct {
  18. err error
  19. }
  20. func (e *errWriteCloser) Write([]byte) (int, error) {
  21. return 0, e.err
  22. }
  23. func (e *errWriteCloser) Close() error {
  24. return e.err
  25. }
  26. const (
  27. // A code is a 12 bit value, stored as a uint32 when encoding to avoid
  28. // type conversions when shifting bits.
  29. maxCode = 1<<12 - 1
  30. invalidCode = 1<<32 - 1
  31. // There are 1<<12 possible codes, which is an upper bound on the number of
  32. // valid hash table entries at any given point in time. tableSize is 4x that.
  33. tableSize = 4 * 1 << 12
  34. tableMask = tableSize - 1
  35. // A hash table entry is a uint32. Zero is an invalid entry since the
  36. // lower 12 bits of a valid entry must be a non-literal code.
  37. invalidEntry = 0
  38. )
  39. // encoder is LZW compressor.
  40. type encoder struct {
  41. // w is the writer that compressed bytes are written to.
  42. w writer
  43. // order, write, bits, nBits and width are the state for
  44. // converting a code stream into a byte stream.
  45. order Order
  46. write func(*encoder, uint32) error
  47. bits uint32
  48. nBits uint
  49. width uint
  50. // litWidth is the width in bits of literal codes.
  51. litWidth uint
  52. // hi is the code implied by the next code emission.
  53. // overflow is the code at which hi overflows the code width.
  54. hi, overflow uint32
  55. // savedCode is the accumulated code at the end of the most recent Write
  56. // call. It is equal to invalidCode if there was no such call.
  57. savedCode uint32
  58. // err is the first error encountered during writing. Closing the encoder
  59. // will make any future Write calls return errClosed
  60. err error
  61. // table is the hash table from 20-bit keys to 12-bit values. Each table
  62. // entry contains key<<12|val and collisions resolve by linear probing.
  63. // The keys consist of a 12-bit code prefix and an 8-bit byte suffix.
  64. // The values are a 12-bit code.
  65. table [tableSize]uint32
  66. }
  67. // writeLSB writes the code c for "Least Significant Bits first" data.
  68. func (e *encoder) writeLSB(c uint32) error {
  69. e.bits |= c << e.nBits
  70. e.nBits += e.width
  71. for e.nBits >= 8 {
  72. if err := e.w.WriteByte(uint8(e.bits)); err != nil {
  73. return err
  74. }
  75. e.bits >>= 8
  76. e.nBits -= 8
  77. }
  78. return nil
  79. }
  80. // writeMSB writes the code c for "Most Significant Bits first" data.
  81. func (e *encoder) writeMSB(c uint32) error {
  82. e.bits |= c << (32 - e.width - e.nBits)
  83. e.nBits += e.width
  84. for e.nBits >= 8 {
  85. if err := e.w.WriteByte(uint8(e.bits >> 24)); err != nil {
  86. return err
  87. }
  88. e.bits <<= 8
  89. e.nBits -= 8
  90. }
  91. return nil
  92. }
  93. // errOutOfCodes is an internal error that means that the encoder has run out
  94. // of unused codes and a clear code needs to be sent next.
  95. var errOutOfCodes = errors.New("lzw: out of codes")
  96. // incHi increments e.hi and checks for both overflow and running out of
  97. // unused codes. In the latter case, incHi sends a clear code, resets the
  98. // encoder state and returns errOutOfCodes.
  99. func (e *encoder) incHi() error {
  100. e.hi++
  101. if e.hi == e.overflow {
  102. e.width++
  103. e.overflow <<= 1
  104. }
  105. if e.hi == maxCode {
  106. clear := uint32(1) << e.litWidth
  107. if err := e.write(e, clear); err != nil {
  108. return err
  109. }
  110. e.width = uint(e.litWidth) + 1
  111. e.hi = clear + 1
  112. e.overflow = clear << 1
  113. for i := range e.table {
  114. e.table[i] = invalidEntry
  115. }
  116. return errOutOfCodes
  117. }
  118. return nil
  119. }
  120. // Write writes a compressed representation of p to e's underlying writer.
  121. func (e *encoder) Write(p []byte) (n int, err error) {
  122. if e.err != nil {
  123. return 0, e.err
  124. }
  125. if len(p) == 0 {
  126. return 0, nil
  127. }
  128. n = len(p)
  129. litMask := uint32(1<<e.litWidth - 1)
  130. code := e.savedCode
  131. if code == invalidCode {
  132. // The first code sent is always a literal code.
  133. code, p = uint32(p[0])&litMask, p[1:]
  134. }
  135. loop:
  136. for _, x := range p {
  137. literal := uint32(x) & litMask
  138. key := code<<8 | literal
  139. // If there is a hash table hit for this key then we continue the loop
  140. // and do not emit a code yet.
  141. hash := (key>>12 ^ key) & tableMask
  142. for h, t := hash, e.table[hash]; t != invalidEntry; {
  143. if key == t>>12 {
  144. code = t & maxCode
  145. continue loop
  146. }
  147. h = (h + 1) & tableMask
  148. t = e.table[h]
  149. }
  150. // Otherwise, write the current code, and literal becomes the start of
  151. // the next emitted code.
  152. if e.err = e.write(e, code); e.err != nil {
  153. return 0, e.err
  154. }
  155. code = literal
  156. // Increment e.hi, the next implied code. If we run out of codes, reset
  157. // the encoder state (including clearing the hash table) and continue.
  158. if err1 := e.incHi(); err1 != nil {
  159. if err1 == errOutOfCodes {
  160. continue
  161. }
  162. e.err = err1
  163. return 0, e.err
  164. }
  165. // Otherwise, insert key -> e.hi into the map that e.table represents.
  166. for {
  167. if e.table[hash] == invalidEntry {
  168. e.table[hash] = (key << 12) | e.hi
  169. break
  170. }
  171. hash = (hash + 1) & tableMask
  172. }
  173. }
  174. e.savedCode = code
  175. return n, nil
  176. }
  177. // Close closes the encoder, flushing any pending output. It does not close or
  178. // flush e's underlying writer.
  179. func (e *encoder) Close() error {
  180. if e.err != nil {
  181. if e.err == errClosed {
  182. return nil
  183. }
  184. return e.err
  185. }
  186. // Make any future calls to Write return errClosed.
  187. e.err = errClosed
  188. // Write the savedCode if valid.
  189. if e.savedCode != invalidCode {
  190. if err := e.write(e, e.savedCode); err != nil {
  191. return err
  192. }
  193. if err := e.incHi(); err != nil && err != errOutOfCodes {
  194. return err
  195. }
  196. }
  197. // Write the eof code.
  198. eof := uint32(1)<<e.litWidth + 1
  199. if err := e.write(e, eof); err != nil {
  200. return err
  201. }
  202. // Write the final bits.
  203. if e.nBits > 0 {
  204. if e.order == MSB {
  205. e.bits >>= 24
  206. }
  207. if err := e.w.WriteByte(uint8(e.bits)); err != nil {
  208. return err
  209. }
  210. }
  211. return e.w.Flush()
  212. }
  213. // NewWriter creates a new io.WriteCloser.
  214. // Writes to the returned io.WriteCloser are compressed and written to w.
  215. // It is the caller's responsibility to call Close on the WriteCloser when
  216. // finished writing.
  217. // The number of bits to use for literal codes, litWidth, must be in the
  218. // range [2,8] and is typically 8.
  219. func NewWriter(w io.Writer, order Order, litWidth int) io.WriteCloser {
  220. var write func(*encoder, uint32) error
  221. switch order {
  222. case LSB:
  223. write = (*encoder).writeLSB
  224. case MSB:
  225. write = (*encoder).writeMSB
  226. default:
  227. return &errWriteCloser{errors.New("lzw: unknown order")}
  228. }
  229. if litWidth < 2 || 8 < litWidth {
  230. return &errWriteCloser{fmt.Errorf("lzw: litWidth %d out of range", litWidth)}
  231. }
  232. bw, ok := w.(writer)
  233. if !ok {
  234. bw = bufio.NewWriter(w)
  235. }
  236. lw := uint(litWidth)
  237. return &encoder{
  238. w: bw,
  239. order: order,
  240. write: write,
  241. width: 1 + lw,
  242. litWidth: lw,
  243. hi: 1<<lw + 1,
  244. overflow: 1 << (lw + 1),
  245. savedCode: invalidCode,
  246. }
  247. }