roundTripCrash.c 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. /*
  2. * Copyright (c) 2016-2021, Yann Collet, Facebook, Inc.
  3. * All rights reserved.
  4. *
  5. * This source code is licensed under both the BSD-style license (found in the
  6. * LICENSE file in the root directory of this source tree) and the GPLv2 (found
  7. * in the COPYING file in the root directory of this source tree).
  8. * You may select, at your option, one of the above-listed licenses.
  9. */
  10. /*
  11. This program takes a file in input,
  12. performs a zstd round-trip test (compression - decompress)
  13. compares the result with original
  14. and generates a crash (double free) on corruption detection.
  15. */
  16. /*===========================================
  17. * Dependencies
  18. *==========================================*/
  19. #include <stddef.h> /* size_t */
  20. #include <stdlib.h> /* malloc, free, exit */
  21. #include <stdio.h> /* fprintf */
  22. #include <string.h> /* strcmp */
  23. #include <sys/types.h> /* stat */
  24. #include <sys/stat.h> /* stat */
  25. #include "xxhash.h"
  26. #define ZSTD_STATIC_LINKING_ONLY
  27. #include "zstd.h"
  28. /*===========================================
  29. * Macros
  30. *==========================================*/
  31. #define MIN(a,b) ( (a) < (b) ? (a) : (b) )
  32. static void crash(int errorCode){
  33. /* abort if AFL/libfuzzer, exit otherwise */
  34. #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION /* could also use __AFL_COMPILER */
  35. abort();
  36. #else
  37. exit(errorCode);
  38. #endif
  39. }
  40. #define CHECK_Z(f) { \
  41. size_t const err = f; \
  42. if (ZSTD_isError(err)) { \
  43. fprintf(stderr, \
  44. "Error=> %s: %s", \
  45. #f, ZSTD_getErrorName(err)); \
  46. crash(1); \
  47. } }
  48. /** roundTripTest() :
  49. * Compresses `srcBuff` into `compressedBuff`,
  50. * then decompresses `compressedBuff` into `resultBuff`.
  51. * Compression level used is derived from first content byte.
  52. * @return : result of decompression, which should be == `srcSize`
  53. * or an error code if either compression or decompression fails.
  54. * Note : `compressedBuffCapacity` should be `>= ZSTD_compressBound(srcSize)`
  55. * for compression to be guaranteed to work */
  56. static size_t roundTripTest(void* resultBuff, size_t resultBuffCapacity,
  57. void* compressedBuff, size_t compressedBuffCapacity,
  58. const void* srcBuff, size_t srcBuffSize)
  59. {
  60. static const int maxClevel = 19;
  61. size_t const hashLength = MIN(128, srcBuffSize);
  62. unsigned const h32 = XXH32(srcBuff, hashLength, 0);
  63. int const cLevel = h32 % maxClevel;
  64. size_t const cSize = ZSTD_compress(compressedBuff, compressedBuffCapacity, srcBuff, srcBuffSize, cLevel);
  65. if (ZSTD_isError(cSize)) {
  66. fprintf(stderr, "Compression error : %s \n", ZSTD_getErrorName(cSize));
  67. return cSize;
  68. }
  69. return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, cSize);
  70. }
  71. /** cctxParamRoundTripTest() :
  72. * Same as roundTripTest() except allows experimenting with ZSTD_CCtx_params. */
  73. static size_t cctxParamRoundTripTest(void* resultBuff, size_t resultBuffCapacity,
  74. void* compressedBuff, size_t compressedBuffCapacity,
  75. const void* srcBuff, size_t srcBuffSize)
  76. {
  77. ZSTD_CCtx* const cctx = ZSTD_createCCtx();
  78. ZSTD_CCtx_params* const cctxParams = ZSTD_createCCtxParams();
  79. ZSTD_inBuffer inBuffer = { srcBuff, srcBuffSize, 0 };
  80. ZSTD_outBuffer outBuffer = { compressedBuff, compressedBuffCapacity, 0 };
  81. static const int maxClevel = 19;
  82. size_t const hashLength = MIN(128, srcBuffSize);
  83. unsigned const h32 = XXH32(srcBuff, hashLength, 0);
  84. int const cLevel = h32 % maxClevel;
  85. /* Set parameters */
  86. CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_compressionLevel, cLevel) );
  87. CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_nbWorkers, 2) );
  88. CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_overlapLog, 5) );
  89. /* Apply parameters */
  90. CHECK_Z( ZSTD_CCtx_setParametersUsingCCtxParams(cctx, cctxParams) );
  91. CHECK_Z (ZSTD_compressStream2(cctx, &outBuffer, &inBuffer, ZSTD_e_end) );
  92. ZSTD_freeCCtxParams(cctxParams);
  93. ZSTD_freeCCtx(cctx);
  94. return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, outBuffer.pos);
  95. }
  96. static size_t checkBuffers(const void* buff1, const void* buff2, size_t buffSize)
  97. {
  98. const char* ip1 = (const char*)buff1;
  99. const char* ip2 = (const char*)buff2;
  100. size_t pos;
  101. for (pos=0; pos<buffSize; pos++)
  102. if (ip1[pos]!=ip2[pos])
  103. break;
  104. return pos;
  105. }
  106. static void roundTripCheck(const void* srcBuff, size_t srcBuffSize, int testCCtxParams)
  107. {
  108. size_t const cBuffSize = ZSTD_compressBound(srcBuffSize);
  109. void* cBuff = malloc(cBuffSize);
  110. void* rBuff = malloc(cBuffSize);
  111. if (!cBuff || !rBuff) {
  112. fprintf(stderr, "not enough memory ! \n");
  113. exit (1);
  114. }
  115. { size_t const result = testCCtxParams ?
  116. cctxParamRoundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize)
  117. : roundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize);
  118. if (ZSTD_isError(result)) {
  119. fprintf(stderr, "roundTripTest error : %s \n", ZSTD_getErrorName(result));
  120. crash(1);
  121. }
  122. if (result != srcBuffSize) {
  123. fprintf(stderr, "Incorrect regenerated size : %u != %u\n", (unsigned)result, (unsigned)srcBuffSize);
  124. crash(1);
  125. }
  126. if (checkBuffers(srcBuff, rBuff, srcBuffSize) != srcBuffSize) {
  127. fprintf(stderr, "Silent decoding corruption !!!");
  128. crash(1);
  129. }
  130. }
  131. free(cBuff);
  132. free(rBuff);
  133. }
  134. static size_t getFileSize(const char* infilename)
  135. {
  136. int r;
  137. #if defined(_MSC_VER)
  138. struct _stat64 statbuf;
  139. r = _stat64(infilename, &statbuf);
  140. if (r || !(statbuf.st_mode & S_IFREG)) return 0; /* No good... */
  141. #else
  142. struct stat statbuf;
  143. r = stat(infilename, &statbuf);
  144. if (r || !S_ISREG(statbuf.st_mode)) return 0; /* No good... */
  145. #endif
  146. return (size_t)statbuf.st_size;
  147. }
  148. static int isDirectory(const char* infilename)
  149. {
  150. int r;
  151. #if defined(_MSC_VER)
  152. struct _stat64 statbuf;
  153. r = _stat64(infilename, &statbuf);
  154. if (!r && (statbuf.st_mode & _S_IFDIR)) return 1;
  155. #else
  156. struct stat statbuf;
  157. r = stat(infilename, &statbuf);
  158. if (!r && S_ISDIR(statbuf.st_mode)) return 1;
  159. #endif
  160. return 0;
  161. }
  162. /** loadFile() :
  163. * requirement : `buffer` size >= `fileSize` */
  164. static void loadFile(void* buffer, const char* fileName, size_t fileSize)
  165. {
  166. FILE* const f = fopen(fileName, "rb");
  167. if (isDirectory(fileName)) {
  168. fprintf(stderr, "Ignoring %s directory \n", fileName);
  169. exit(2);
  170. }
  171. if (f==NULL) {
  172. fprintf(stderr, "Impossible to open %s \n", fileName);
  173. exit(3);
  174. }
  175. { size_t const readSize = fread(buffer, 1, fileSize, f);
  176. if (readSize != fileSize) {
  177. fprintf(stderr, "Error reading %s \n", fileName);
  178. exit(5);
  179. } }
  180. fclose(f);
  181. }
  182. static void fileCheck(const char* fileName, int testCCtxParams)
  183. {
  184. size_t const fileSize = getFileSize(fileName);
  185. void* const buffer = malloc(fileSize + !fileSize /* avoid 0 */);
  186. if (!buffer) {
  187. fprintf(stderr, "not enough memory \n");
  188. exit(4);
  189. }
  190. loadFile(buffer, fileName, fileSize);
  191. roundTripCheck(buffer, fileSize, testCCtxParams);
  192. free (buffer);
  193. }
  194. int main(int argCount, const char** argv) {
  195. int argNb = 1;
  196. int testCCtxParams = 0;
  197. if (argCount < 2) {
  198. fprintf(stderr, "Error : no argument : need input file \n");
  199. exit(9);
  200. }
  201. if (!strcmp(argv[argNb], "--cctxParams")) {
  202. testCCtxParams = 1;
  203. argNb++;
  204. }
  205. fileCheck(argv[argNb], testCCtxParams);
  206. fprintf(stderr, "no pb detected\n");
  207. return 0;
  208. }