bit_stream.h 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Contains utils for reading, writing and debug printing bit streams.
  15. #ifndef SOURCE_COMP_BIT_STREAM_H_
  16. #define SOURCE_COMP_BIT_STREAM_H_
  17. #include <algorithm>
  18. #include <bitset>
  19. #include <cassert>
  20. #include <cstdint>
  21. #include <cstring>
  22. #include <functional>
  23. #include <sstream>
  24. #include <string>
  25. #include <utility>
  26. #include <vector>
  27. namespace spvtools {
  28. namespace comp {
  29. // Terminology:
  30. // Bits - usually used for a uint64 word, first bit is the lowest.
  31. // Stream - std::string of '0' and '1', read left-to-right,
  32. // i.e. first bit is at the front and not at the end as in
  33. // std::bitset::to_string().
  34. // Bitset - std::bitset corresponding to uint64 bits and to reverse(stream).
  35. // Converts number of bits to a respective number of chunks of size N.
  36. // For example NumBitsToNumWords<8> returns how many bytes are needed to store
  37. // |num_bits|.
  38. template <size_t N>
  39. inline size_t NumBitsToNumWords(size_t num_bits) {
  40. return (num_bits + (N - 1)) / N;
  41. }
  42. // Returns value of the same type as |in|, where all but the first |num_bits|
  43. // are set to zero.
  44. template <typename T>
  45. inline T GetLowerBits(T in, size_t num_bits) {
  46. return sizeof(T) * 8 == num_bits ? in : in & T((T(1) << num_bits) - T(1));
  47. }
  48. // Encodes signed integer as unsigned. This is a generalized version of
  49. // EncodeZigZag, designed to favor small positive numbers.
  50. // Values are transformed in blocks of 2^|block_exponent|.
  51. // If |block_exponent| is zero, then this degenerates into normal EncodeZigZag.
  52. // Example when |block_exponent| is 1 (return value is the index):
  53. // 0, 1, -1, -2, 2, 3, -3, -4, 4, 5, -5, -6, 6, 7, -7, -8
  54. // Example when |block_exponent| is 2:
  55. // 0, 1, 2, 3, -1, -2, -3, -4, 4, 5, 6, 7, -5, -6, -7, -8
  56. inline uint64_t EncodeZigZag(int64_t val, size_t block_exponent) {
  57. assert(block_exponent < 64);
  58. const uint64_t uval = static_cast<uint64_t>(val >= 0 ? val : -val - 1);
  59. const uint64_t block_num =
  60. ((uval >> block_exponent) << 1) + (val >= 0 ? 0 : 1);
  61. const uint64_t pos = GetLowerBits(uval, block_exponent);
  62. return (block_num << block_exponent) + pos;
  63. }
  64. // Decodes signed integer encoded with EncodeZigZag. |block_exponent| must be
  65. // the same.
  66. inline int64_t DecodeZigZag(uint64_t val, size_t block_exponent) {
  67. assert(block_exponent < 64);
  68. const uint64_t block_num = val >> block_exponent;
  69. const uint64_t pos = GetLowerBits(val, block_exponent);
  70. if (block_num & 1) {
  71. // Negative.
  72. return -1LL - ((block_num >> 1) << block_exponent) - pos;
  73. } else {
  74. // Positive.
  75. return ((block_num >> 1) << block_exponent) + pos;
  76. }
  77. }
  78. // Converts first |num_bits| stored in uint64 to a left-to-right stream of bits.
  79. inline std::string BitsToStream(uint64_t bits, size_t num_bits = 64) {
  80. std::bitset<64> bitset(bits);
  81. std::string str = bitset.to_string().substr(64 - num_bits);
  82. std::reverse(str.begin(), str.end());
  83. return str;
  84. }
  85. // Base class for writing sequences of bits.
  86. class BitWriterInterface {
  87. public:
  88. BitWriterInterface() = default;
  89. virtual ~BitWriterInterface() = default;
  90. // Writes lower |num_bits| in |bits| to the stream.
  91. // |num_bits| must be no greater than 64.
  92. virtual void WriteBits(uint64_t bits, size_t num_bits) = 0;
  93. // Writes bits from value of type |T| to the stream. No encoding is done.
  94. // Always writes 8 * sizeof(T) bits.
  95. template <typename T>
  96. void WriteUnencoded(T val) {
  97. static_assert(sizeof(T) <= 64, "Type size too large");
  98. uint64_t bits = 0;
  99. memcpy(&bits, &val, sizeof(T));
  100. WriteBits(bits, sizeof(T) * 8);
  101. }
  102. // Writes |val| in chunks of size |chunk_length| followed by a signal bit:
  103. // 0 - no more chunks to follow
  104. // 1 - more chunks to follow
  105. // for example 255 is encoded into 1111 1 1111 0 for chunk length 4.
  106. // The last chunk can be truncated and signal bit omitted, if the entire
  107. // payload (for example 16 bit for uint16_t has already been written).
  108. void WriteVariableWidthU64(uint64_t val, size_t chunk_length);
  109. void WriteVariableWidthU32(uint32_t val, size_t chunk_length);
  110. void WriteVariableWidthU16(uint16_t val, size_t chunk_length);
  111. void WriteVariableWidthS64(int64_t val, size_t chunk_length,
  112. size_t zigzag_exponent);
  113. // Returns number of bits written.
  114. virtual size_t GetNumBits() const = 0;
  115. // Provides direct access to the buffer data if implemented.
  116. virtual const uint8_t* GetData() const { return nullptr; }
  117. // Returns buffer size in bytes.
  118. size_t GetDataSizeBytes() const { return NumBitsToNumWords<8>(GetNumBits()); }
  119. // Generates and returns byte array containing written bits.
  120. virtual std::vector<uint8_t> GetDataCopy() const = 0;
  121. BitWriterInterface(const BitWriterInterface&) = delete;
  122. BitWriterInterface& operator=(const BitWriterInterface&) = delete;
  123. };
  124. // This class is an implementation of BitWriterInterface, using
  125. // std::vector<uint64_t> to store written bits.
  126. class BitWriterWord64 : public BitWriterInterface {
  127. public:
  128. explicit BitWriterWord64(size_t reserve_bits = 64);
  129. void WriteBits(uint64_t bits, size_t num_bits) override;
  130. size_t GetNumBits() const override { return end_; }
  131. const uint8_t* GetData() const override {
  132. return reinterpret_cast<const uint8_t*>(buffer_.data());
  133. }
  134. std::vector<uint8_t> GetDataCopy() const override {
  135. return std::vector<uint8_t>(GetData(), GetData() + GetDataSizeBytes());
  136. }
  137. // Sets callback to emit bit sequences after every write.
  138. void SetCallback(std::function<void(const std::string&)> callback) {
  139. callback_ = callback;
  140. }
  141. protected:
  142. // Sends string generated from arguments to callback_ if defined.
  143. void EmitSequence(uint64_t bits, size_t num_bits) const {
  144. if (callback_) callback_(BitsToStream(bits, num_bits));
  145. }
  146. private:
  147. std::vector<uint64_t> buffer_;
  148. // Total number of bits written so far. Named 'end' as analogy to std::end().
  149. size_t end_;
  150. // If not null, the writer will use the callback to emit the written bit
  151. // sequence as a string of '0' and '1'.
  152. std::function<void(const std::string&)> callback_;
  153. };
  154. // Base class for reading sequences of bits.
  155. class BitReaderInterface {
  156. public:
  157. BitReaderInterface() {}
  158. virtual ~BitReaderInterface() {}
  159. // Reads |num_bits| from the stream, stores them in |bits|.
  160. // Returns number of read bits. |num_bits| must be no greater than 64.
  161. virtual size_t ReadBits(uint64_t* bits, size_t num_bits) = 0;
  162. // Reads 8 * sizeof(T) bits and stores them in |val|.
  163. template <typename T>
  164. bool ReadUnencoded(T* val) {
  165. static_assert(sizeof(T) <= 64, "Type size too large");
  166. uint64_t bits = 0;
  167. const size_t num_read = ReadBits(&bits, sizeof(T) * 8);
  168. if (num_read != sizeof(T) * 8) return false;
  169. memcpy(val, &bits, sizeof(T));
  170. return true;
  171. }
  172. // Returns number of bits already read.
  173. virtual size_t GetNumReadBits() const = 0;
  174. // These two functions define 'hard' and 'soft' EOF.
  175. //
  176. // Returns true if the end of the buffer was reached.
  177. virtual bool ReachedEnd() const = 0;
  178. // Returns true if we reached the end of the buffer or are nearing it and only
  179. // zero bits are left to read. Implementations of this function are allowed to
  180. // commit a "false negative" error if the end of the buffer was not reached,
  181. // i.e. it can return false even if indeed only zeroes are left.
  182. // It is assumed that the consumer expects that
  183. // the buffer stream ends with padding zeroes, and would accept this as a
  184. // 'soft' EOF. Implementations of this class do not necessarily need to
  185. // implement this, default behavior can simply delegate to ReachedEnd().
  186. virtual bool OnlyZeroesLeft() const { return ReachedEnd(); }
  187. // Reads value encoded with WriteVariableWidthXXX (see BitWriterInterface).
  188. // Reader and writer must use the same |chunk_length| and variable type.
  189. // Returns true on success, false if the bit stream ends prematurely.
  190. bool ReadVariableWidthU64(uint64_t* val, size_t chunk_length);
  191. bool ReadVariableWidthU32(uint32_t* val, size_t chunk_length);
  192. bool ReadVariableWidthU16(uint16_t* val, size_t chunk_length);
  193. bool ReadVariableWidthS64(int64_t* val, size_t chunk_length,
  194. size_t zigzag_exponent);
  195. BitReaderInterface(const BitReaderInterface&) = delete;
  196. BitReaderInterface& operator=(const BitReaderInterface&) = delete;
  197. };
  198. // This class is an implementation of BitReaderInterface which accepts both
  199. // uint8_t and uint64_t buffers as input. uint64_t buffers are consumed and
  200. // owned. uint8_t buffers are copied.
  201. class BitReaderWord64 : public BitReaderInterface {
  202. public:
  203. // Consumes and owns the buffer.
  204. explicit BitReaderWord64(std::vector<uint64_t>&& buffer);
  205. // Copies the buffer and casts it to uint64.
  206. // Consuming the original buffer and casting it to uint64 is difficult,
  207. // as it would potentially cause data misalignment and poor performance.
  208. explicit BitReaderWord64(const std::vector<uint8_t>& buffer);
  209. BitReaderWord64(const void* buffer, size_t num_bytes);
  210. size_t ReadBits(uint64_t* bits, size_t num_bits) override;
  211. size_t GetNumReadBits() const override { return pos_; }
  212. bool ReachedEnd() const override;
  213. bool OnlyZeroesLeft() const override;
  214. BitReaderWord64() = delete;
  215. // Sets callback to emit bit sequences after every read.
  216. void SetCallback(std::function<void(const std::string&)> callback) {
  217. callback_ = callback;
  218. }
  219. protected:
  220. // Sends string generated from arguments to callback_ if defined.
  221. void EmitSequence(uint64_t bits, size_t num_bits) const {
  222. if (callback_) callback_(BitsToStream(bits, num_bits));
  223. }
  224. private:
  225. const std::vector<uint64_t> buffer_;
  226. size_t pos_;
  227. // If not null, the reader will use the callback to emit the read bit
  228. // sequence as a string of '0' and '1'.
  229. std::function<void(const std::string&)> callback_;
  230. };
  231. } // namespace comp
  232. } // namespace spvtools
  233. #endif // SOURCE_COMP_BIT_STREAM_H_