inflate.hpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. #pragma once
  2. //a bad implementation of inflate from zlib/minizip
  3. //todo: replace with Talarubi's version
  4. #include <setjmp.h>
  5. namespace nall::Decode {
  6. namespace puff {
  7. inline auto puff(
  8. unsigned char* dest, unsigned long* destlen,
  9. unsigned char* source, unsigned long* sourcelen
  10. ) -> int;
  11. }
  12. inline auto inflate(
  13. uint8_t* target, uint targetLength,
  14. const uint8_t* source, uint sourceLength
  15. ) -> bool {
  16. unsigned long tl = targetLength, sl = sourceLength;
  17. int result = puff::puff((unsigned char*)target, &tl, (unsigned char*)source, &sl);
  18. return result == 0;
  19. }
  20. namespace puff {
  21. enum : uint {
  22. MAXBITS = 15,
  23. MAXLCODES = 286,
  24. MAXDCODES = 30,
  25. FIXLCODES = 288,
  26. MAXCODES = MAXLCODES + MAXDCODES,
  27. };
  28. struct state {
  29. unsigned char* out;
  30. unsigned long outlen;
  31. unsigned long outcnt;
  32. unsigned char* in;
  33. unsigned long inlen;
  34. unsigned long incnt;
  35. int bitbuf;
  36. int bitcnt;
  37. jmp_buf env;
  38. };
  39. struct huffman {
  40. short* count;
  41. short* symbol;
  42. };
  43. inline auto bits(state* s, int need) -> int {
  44. long val;
  45. val = s->bitbuf;
  46. while(s->bitcnt < need) {
  47. if(s->incnt == s->inlen) longjmp(s->env, 1);
  48. val |= (long)(s->in[s->incnt++]) << s->bitcnt;
  49. s->bitcnt += 8;
  50. }
  51. s->bitbuf = (int)(val >> need);
  52. s->bitcnt -= need;
  53. return (int)(val & ((1L << need) - 1));
  54. }
  55. inline auto stored(state* s) -> int {
  56. uint len;
  57. s->bitbuf = 0;
  58. s->bitcnt = 0;
  59. if(s->incnt + 4 > s->inlen) return 2;
  60. len = s->in[s->incnt++];
  61. len |= s->in[s->incnt++] << 8;
  62. if(s->in[s->incnt++] != (~len & 0xff) ||
  63. s->in[s->incnt++] != ((~len >> 8) & 0xff)
  64. ) return 2;
  65. if(s->incnt + len > s->inlen) return 2;
  66. if(s->out != nullptr) {
  67. if(s->outcnt + len > s->outlen) return 1;
  68. while(len--) s->out[s->outcnt++] = s->in[s->incnt++];
  69. } else {
  70. s->outcnt += len;
  71. s->incnt += len;
  72. }
  73. return 0;
  74. }
  75. inline auto decode(state* s, huffman* h) -> int {
  76. int len, code, first, count, index, bitbuf, left;
  77. short* next;
  78. bitbuf = s->bitbuf;
  79. left = s->bitcnt;
  80. code = first = index = 0;
  81. len = 1;
  82. next = h->count + 1;
  83. while(true) {
  84. while(left--) {
  85. code |= bitbuf & 1;
  86. bitbuf >>= 1;
  87. count = *next++;
  88. if(code - count < first) {
  89. s->bitbuf = bitbuf;
  90. s->bitcnt = (s->bitcnt - len) & 7;
  91. return h->symbol[index + (code - first)];
  92. }
  93. index += count;
  94. first += count;
  95. first <<= 1;
  96. code <<= 1;
  97. len++;
  98. }
  99. left = (MAXBITS + 1) - len;
  100. if(left == 0) break;
  101. if(s->incnt == s->inlen) longjmp(s->env, 1);
  102. bitbuf = s->in[s->incnt++];
  103. if(left > 8) left = 8;
  104. }
  105. return -10;
  106. }
  107. inline auto construct(huffman* h, short* length, int n) -> int {
  108. int symbol, len, left;
  109. short offs[MAXBITS + 1];
  110. for(len = 0; len <= MAXBITS; len++) h->count[len] = 0;
  111. for(symbol = 0; symbol < n; symbol++) h->count[length[symbol]]++;
  112. if(h->count[0] == n) return 0;
  113. left = 1;
  114. for(len = 1; len <= MAXBITS; len++) {
  115. left <<= 1;
  116. left -= h->count[len];
  117. if(left < 0) return left;
  118. }
  119. offs[1] = 0;
  120. for(len = 1; len < MAXBITS; len++) offs[len + 1] = offs[len] + h->count[len];
  121. for(symbol = 0; symbol < n; symbol++) {
  122. if(length[symbol] != 0) h->symbol[offs[length[symbol]]++] = symbol;
  123. }
  124. return left;
  125. }
  126. inline auto codes(state* s, huffman* lencode, huffman* distcode) -> int {
  127. int symbol, len;
  128. uint dist;
  129. static const short lens[29] = {
  130. 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31,
  131. 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258
  132. };
  133. static const short lext[29] = {
  134. 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2,
  135. 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0
  136. };
  137. static const short dists[30] = {
  138. 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193,
  139. 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145,
  140. 8193, 12289, 16385, 24577
  141. };
  142. static const short dext[30] = {
  143. 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6,
  144. 7, 7, 8, 8, 9, 9, 10, 10, 11, 11,
  145. 12, 12, 13, 13
  146. };
  147. do {
  148. symbol = decode(s, lencode);
  149. if(symbol < 0) return symbol;
  150. if(symbol < 256) {
  151. if(s->out != nullptr) {
  152. if(s->outcnt == s->outlen) return 1;
  153. s->out[s->outcnt] = symbol;
  154. }
  155. s->outcnt++;
  156. } else if(symbol > 256) {
  157. symbol -= 257;
  158. if(symbol >= 29) return -10;
  159. len = lens[symbol] + bits(s, lext[symbol]);
  160. symbol = decode(s, distcode);
  161. if(symbol < 0) return symbol;
  162. dist = dists[symbol] + bits(s, dext[symbol]);
  163. #ifndef INFLATE_ALLOW_INVALID_DISTANCE_TOO_FAR
  164. if(dist > s->outcnt) return -11;
  165. #endif
  166. if(s->out != nullptr) {
  167. if(s->outcnt + len > s->outlen) return 1;
  168. while(len--) {
  169. s->out[s->outcnt] =
  170. #ifdef INFLATE_ALLOW_INVALID_DISTANCE_TOO_FAR
  171. dist > s->outcnt ? 0 :
  172. #endif
  173. s->out[s->outcnt - dist];
  174. s->outcnt++;
  175. }
  176. } else {
  177. s->outcnt += len;
  178. }
  179. }
  180. } while(symbol != 256);
  181. return 0;
  182. }
  183. inline auto fixed(state* s) -> int {
  184. static int virgin = 1;
  185. static short lencnt[MAXBITS + 1], lensym[FIXLCODES];
  186. static short distcnt[MAXBITS + 1], distsym[MAXDCODES];
  187. static huffman lencode, distcode;
  188. if(virgin) {
  189. int symbol = 0;
  190. short lengths[FIXLCODES];
  191. lencode.count = lencnt;
  192. lencode.symbol = lensym;
  193. distcode.count = distcnt;
  194. distcode.symbol = distsym;
  195. for(; symbol < 144; symbol++) lengths[symbol] = 8;
  196. for(; symbol < 256; symbol++) lengths[symbol] = 9;
  197. for(; symbol < 280; symbol++) lengths[symbol] = 7;
  198. for(; symbol < FIXLCODES; symbol++) lengths[symbol] = 8;
  199. construct(&lencode, lengths, FIXLCODES);
  200. for(symbol = 0; symbol < MAXDCODES; symbol++) lengths[symbol] = 5;
  201. construct(&distcode, lengths, MAXDCODES);
  202. virgin = 0;
  203. }
  204. return codes(s, &lencode, &distcode);
  205. }
  206. inline auto dynamic(state* s) -> int {
  207. int nlen, ndist, ncode, index, err;
  208. short lengths[MAXCODES];
  209. short lencnt[MAXBITS + 1], lensym[MAXLCODES];
  210. short distcnt[MAXBITS + 1], distsym[MAXDCODES];
  211. huffman lencode, distcode;
  212. static const short order[19] = {
  213. 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
  214. };
  215. lencode.count = lencnt;
  216. lencode.symbol = lensym;
  217. distcode.count = distcnt;
  218. distcode.symbol = distsym;
  219. nlen = bits(s, 5) + 257;
  220. ndist = bits(s, 5) + 1;
  221. ncode = bits(s, 4) + 4;
  222. if(nlen > MAXLCODES || ndist > MAXDCODES) return -3;
  223. for(index = 0; index < ncode; index++) lengths[order[index]] = bits(s, 3);
  224. for(; index < 19; index++) lengths[order[index]] = 0;
  225. err = construct(&lencode, lengths, 19);
  226. if(err != 0) return -4;
  227. index = 0;
  228. while(index < nlen + ndist) {
  229. int symbol, len;
  230. symbol = decode(s, &lencode);
  231. if(symbol < 16) {
  232. lengths[index++] = symbol;
  233. } else {
  234. len = 0;
  235. if(symbol == 16) {
  236. if(index == 0) return -5;
  237. len = lengths[index - 1];
  238. symbol = 3 + bits(s, 2);
  239. } else if(symbol == 17) {
  240. symbol = 3 + bits(s, 3);
  241. } else {
  242. symbol = 11 + bits(s, 7);
  243. }
  244. if(index + symbol > nlen + ndist) return -6;
  245. while(symbol--) lengths[index++] = len;
  246. }
  247. }
  248. if(lengths[256] == 0) return -9;
  249. err = construct(&lencode, lengths, nlen);
  250. if(err < 0 || (err > 0 && nlen - lencode.count[0] != 1)) return -7;
  251. err = construct(&distcode, lengths + nlen, ndist);
  252. if(err < 0 || (err > 0 && ndist - distcode.count[0] != 1)) return -8;
  253. return codes(s, &lencode, &distcode);
  254. }
  255. inline auto puff(
  256. unsigned char* dest, unsigned long* destlen,
  257. unsigned char* source, unsigned long* sourcelen
  258. ) -> int {
  259. state s;
  260. int last, type, err;
  261. s.out = dest;
  262. s.outlen = *destlen;
  263. s.outcnt = 0;
  264. s.in = source;
  265. s.inlen = *sourcelen;
  266. s.incnt = 0;
  267. s.bitbuf = 0;
  268. s.bitcnt = 0;
  269. if(setjmp(s.env) != 0) {
  270. err = 2;
  271. } else {
  272. do {
  273. last = bits(&s, 1);
  274. type = bits(&s, 2);
  275. err = type == 0 ? stored(&s)
  276. : type == 1 ? fixed(&s)
  277. : type == 2 ? dynamic(&s)
  278. : -1;
  279. if(err != 0) break;
  280. } while(!last);
  281. }
  282. if(err <= 0) {
  283. *destlen = s.outcnt;
  284. *sourcelen = s.incnt;
  285. }
  286. return err;
  287. }
  288. }
  289. }