markv_decoder.h 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/comp/bit_stream.h"
  15. #include "source/comp/markv.h"
  16. #include "source/comp/markv_codec.h"
  17. #include "source/comp/markv_logger.h"
  18. #include "source/util/make_unique.h"
  19. #ifndef SOURCE_COMP_MARKV_DECODER_H_
  20. #define SOURCE_COMP_MARKV_DECODER_H_
  21. namespace spvtools {
  22. namespace comp {
  23. class MarkvLogger;
  24. // Decodes MARK-V buffers written by MarkvEncoder.
  25. class MarkvDecoder : public MarkvCodec {
  26. public:
  27. // |model| is owned by the caller, must be not null and valid during the
  28. // lifetime of MarkvEncoder.
  29. MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
  30. const MarkvCodecOptions& options, const MarkvModel* model)
  31. : MarkvCodec(context, GetValidatorOptions(options), model),
  32. options_(options),
  33. reader_(markv) {
  34. SetIdBound(1);
  35. parsed_operands_.reserve(25);
  36. inst_words_.reserve(25);
  37. }
  38. ~MarkvDecoder() = default;
  39. // Creates an internal logger which writes comments on the decoding process.
  40. void CreateLogger(MarkvLogConsumer log_consumer,
  41. MarkvDebugConsumer debug_consumer) {
  42. logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
  43. }
  44. // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
  45. // Can be called only once. Fails if data of wrong format or ends prematurely,
  46. // of if validation fails.
  47. spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
  48. // Creates and returns validator options. Returned value owned by the caller.
  49. static spv_validator_options GetValidatorOptions(
  50. const MarkvCodecOptions& options) {
  51. return options.validate_spirv_binary ? spvValidatorOptionsCreate()
  52. : nullptr;
  53. }
  54. private:
  55. // Describes the format of a typed literal number.
  56. struct NumberType {
  57. spv_number_kind_t type;
  58. uint32_t bit_width;
  59. };
  60. // Reads a single bit from reader_. The read bit is stored in |bit|.
  61. // Returns false iff reader_ fails.
  62. bool ReadBit(bool* bit) {
  63. uint64_t bits = 0;
  64. const bool result = reader_.ReadBits(&bits, 1);
  65. if (result) *bit = bits ? true : false;
  66. return result;
  67. };
  68. // Returns ReadBit bound to the class object.
  69. std::function<bool(bool*)> GetReadBitCallback() {
  70. return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
  71. }
  72. // Reads a single non-id word from bit stream. operand_.type determines if
  73. // the word needs to be decoded and how.
  74. spv_result_t DecodeNonIdWord(uint32_t* word);
  75. // Reads and decodes both opcode and num_operands as a single code.
  76. // Returns SPV_UNSUPPORTED iff no suitable codec was found.
  77. spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
  78. uint32_t* num_operands);
  79. // Reads mtf rank from bit stream. |mtf| is used to determine the codec
  80. // scheme. |fallback_method| is used if no codec defined for |mtf|.
  81. spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
  82. uint32_t* rank);
  83. // Reads id using coding based on mtf associated with the id descriptor.
  84. // Returns SPV_UNSUPPORTED iff fallback method needs to be used.
  85. spv_result_t DecodeIdWithDescriptor(uint32_t* id);
  86. // Reads id using coding based on the given |mtf|, which is expected to
  87. // contain the needed |id|.
  88. spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
  89. // Reads type id of the current instruction if can't be inferred.
  90. spv_result_t DecodeTypeId();
  91. // Reads result id of the current instruction if can't be inferred.
  92. spv_result_t DecodeResultId();
  93. // Reads id which is neither type nor result id.
  94. spv_result_t DecodeRefId(uint32_t* id);
  95. // Reads and discards bits until the beginning of the next byte if the
  96. // number of bits until the next byte is less than |byte_break_if_less_than|.
  97. bool ReadToByteBreak(size_t byte_break_if_less_than);
  98. // Returns instruction words decoded up to this point.
  99. const uint32_t* GetInstWords() const override { return inst_words_.data(); }
  100. // Reads a literal number as it is described in |operand| from the bit stream,
  101. // decodes and writes it to spirv_.
  102. spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
  103. // Reads instruction from bit stream, decodes and validates it.
  104. // Decoded instruction is valid until the next call of DecodeInstruction().
  105. spv_result_t DecodeInstruction();
  106. // Read operand from the stream decodes and validates it.
  107. spv_result_t DecodeOperand(size_t operand_offset,
  108. const spv_operand_type_t type,
  109. spv_operand_pattern_t* expected_operands);
  110. // Records the numeric type for an operand according to the type information
  111. // associated with the given non-zero type Id. This can fail if the type Id
  112. // is not a type Id, or if the type Id does not reference a scalar numeric
  113. // type. On success, return SPV_SUCCESS and populates the num_words,
  114. // number_kind, and number_bit_width fields of parsed_operand.
  115. spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
  116. uint32_t type_id);
  117. // Records the number type for the current instruction, if it generates a
  118. // type. For types that aren't scalar numbers, record something with number
  119. // kind SPV_NUMBER_NONE.
  120. void RecordNumberType();
  121. MarkvCodecOptions options_;
  122. // Temporary sink where decoded SPIR-V words are written. Once it contains the
  123. // entire module, the container is moved and returned.
  124. std::vector<uint32_t> spirv_;
  125. // Bit stream containing encoded data.
  126. BitReaderWord64 reader_;
  127. // Temporary storage for operands of the currently parsed instruction.
  128. // Valid until next DecodeInstruction call.
  129. std::vector<spv_parsed_operand_t> parsed_operands_;
  130. // Temporary storage for current instruction words.
  131. // Valid until next DecodeInstruction call.
  132. std::vector<uint32_t> inst_words_;
  133. // Maps a type ID to its number type description.
  134. std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
  135. // Maps an ExtInstImport id to the extended instruction type.
  136. std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
  137. };
  138. } // namespace comp
  139. } // namespace spvtools
  140. #endif // SOURCE_COMP_MARKV_DECODER_H_