markv_decoder.cpp 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/comp/markv_decoder.h"
  15. #include <cstring>
  16. #include <iterator>
  17. #include <numeric>
  18. #include "source/ext_inst.h"
  19. #include "source/opcode.h"
  20. #include "spirv-tools/libspirv.hpp"
  21. namespace spvtools {
  22. namespace comp {
  23. spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
  24. auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
  25. if (codec) {
  26. uint64_t decoded_value = 0;
  27. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  28. return Diag(SPV_ERROR_INVALID_BINARY)
  29. << "Failed to decode non-id word with Huffman";
  30. if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
  31. // The word decoded successfully.
  32. *word = uint32_t(decoded_value);
  33. assert(*word == decoded_value);
  34. return SPV_SUCCESS;
  35. }
  36. // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
  37. }
  38. const size_t chunk_length =
  39. model_->GetOperandVariableWidthChunkLength(operand_.type);
  40. if (chunk_length) {
  41. if (!reader_.ReadVariableWidthU32(word, chunk_length))
  42. return Diag(SPV_ERROR_INVALID_BINARY)
  43. << "Failed to decode non-id word with varint";
  44. } else {
  45. if (!reader_.ReadUnencoded(word))
  46. return Diag(SPV_ERROR_INVALID_BINARY)
  47. << "Failed to read unencoded non-id word";
  48. }
  49. return SPV_SUCCESS;
  50. }
  51. spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
  52. uint32_t* opcode, uint32_t* num_operands) {
  53. // First try to use the Markov chain codec.
  54. auto* codec =
  55. model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
  56. if (codec) {
  57. uint64_t decoded_value = 0;
  58. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  59. return Diag(SPV_ERROR_INTERNAL)
  60. << "Failed to decode opcode_and_num_operands, previous opcode is "
  61. << spvOpcodeString(GetPrevOpcode());
  62. if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
  63. // The word was successfully decoded.
  64. *opcode = uint32_t(decoded_value & 0xFFFF);
  65. *num_operands = uint32_t(decoded_value >> 16);
  66. return SPV_SUCCESS;
  67. }
  68. // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
  69. }
  70. // Fallback to base-rate codec.
  71. codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
  72. assert(codec);
  73. uint64_t decoded_value = 0;
  74. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  75. return Diag(SPV_ERROR_INTERNAL)
  76. << "Failed to decode opcode_and_num_operands with global codec";
  77. if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
  78. // Received kMarkvNoneOfTheAbove signal, fallback further.
  79. return SPV_UNSUPPORTED;
  80. }
  81. *opcode = uint32_t(decoded_value & 0xFFFF);
  82. *num_operands = uint32_t(decoded_value >> 16);
  83. return SPV_SUCCESS;
  84. }
  85. spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
  86. uint32_t fallback_method,
  87. uint32_t* rank) {
  88. const auto* codec = GetMtfHuffmanCodec(mtf);
  89. if (!codec) {
  90. assert(fallback_method != kMtfNone);
  91. codec = GetMtfHuffmanCodec(fallback_method);
  92. }
  93. if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
  94. uint32_t decoded_value = 0;
  95. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  96. return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
  97. if (decoded_value == kMtfRankEncodedByValueSignal) {
  98. // Decode by value.
  99. if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
  100. return Diag(SPV_ERROR_INTERNAL)
  101. << "Failed to decode MTF rank with varint";
  102. *rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
  103. } else {
  104. // Decode using Huffman coding.
  105. assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
  106. *rank = decoded_value;
  107. }
  108. return SPV_SUCCESS;
  109. }
  110. spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
  111. auto* codec =
  112. model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
  113. uint64_t mtf = kMtfNone;
  114. if (codec) {
  115. uint64_t decoded_value = 0;
  116. if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
  117. return Diag(SPV_ERROR_INTERNAL)
  118. << "Failed to decode descriptor with Huffman";
  119. if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
  120. const uint32_t long_descriptor = uint32_t(decoded_value);
  121. mtf = GetMtfLongIdDescriptor(long_descriptor);
  122. }
  123. }
  124. if (mtf == kMtfNone) {
  125. if (model_->id_fallback_strategy() !=
  126. MarkvModel::IdFallbackStrategy::kShortDescriptor) {
  127. return SPV_UNSUPPORTED;
  128. }
  129. uint64_t decoded_value = 0;
  130. if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
  131. return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
  132. const uint32_t short_descriptor = uint32_t(decoded_value);
  133. if (short_descriptor == 0) {
  134. // Forward declared id.
  135. return SPV_UNSUPPORTED;
  136. }
  137. mtf = GetMtfShortIdDescriptor(short_descriptor);
  138. }
  139. return DecodeExistingId(mtf, id);
  140. }
  141. spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
  142. assert(multi_mtf_.GetSize(mtf) > 0);
  143. *id = 0;
  144. uint32_t rank = 0;
  145. if (multi_mtf_.GetSize(mtf) == 1) {
  146. rank = 1;
  147. } else {
  148. const spv_result_t result =
  149. DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
  150. if (result != SPV_SUCCESS) return result;
  151. }
  152. assert(rank);
  153. if (!multi_mtf_.ValueFromRank(mtf, rank, id))
  154. return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
  155. return SPV_SUCCESS;
  156. }
  157. spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
  158. {
  159. const spv_result_t result = DecodeIdWithDescriptor(id);
  160. if (result != SPV_UNSUPPORTED) return result;
  161. }
  162. const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
  163. SpvOp(inst_.opcode))(operand_index_);
  164. uint32_t rank = 0;
  165. *id = 0;
  166. if (model_->id_fallback_strategy() ==
  167. MarkvModel::IdFallbackStrategy::kRuleBased) {
  168. uint64_t mtf = GetRuleBasedMtf();
  169. if (mtf != kMtfNone && !can_forward_declare) {
  170. return DecodeExistingId(mtf, id);
  171. }
  172. if (mtf == kMtfNone) mtf = kMtfAll;
  173. {
  174. const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
  175. if (result != SPV_SUCCESS) return result;
  176. }
  177. if (rank == 0) {
  178. // This is the first occurrence of a forward declared id.
  179. *id = GetIdBound();
  180. SetIdBound(*id + 1);
  181. multi_mtf_.Insert(kMtfAll, *id);
  182. multi_mtf_.Insert(kMtfForwardDeclared, *id);
  183. if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
  184. } else {
  185. if (!multi_mtf_.ValueFromRank(mtf, rank, id))
  186. return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
  187. }
  188. } else {
  189. assert(can_forward_declare);
  190. if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
  191. return Diag(SPV_ERROR_INTERNAL)
  192. << "Failed to decode MTF rank with varint";
  193. if (rank == 0) {
  194. // This is the first occurrence of a forward declared id.
  195. *id = GetIdBound();
  196. SetIdBound(*id + 1);
  197. multi_mtf_.Insert(kMtfForwardDeclared, *id);
  198. } else {
  199. if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
  200. return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
  201. }
  202. }
  203. assert(*id);
  204. return SPV_SUCCESS;
  205. }
  206. spv_result_t MarkvDecoder::DecodeTypeId() {
  207. if (inst_.opcode == SpvOpFunctionParameter) {
  208. assert(!remaining_function_parameter_types_.empty());
  209. inst_.type_id = remaining_function_parameter_types_.front();
  210. remaining_function_parameter_types_.pop_front();
  211. return SPV_SUCCESS;
  212. }
  213. {
  214. const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
  215. if (result != SPV_UNSUPPORTED) return result;
  216. }
  217. assert(model_->id_fallback_strategy() ==
  218. MarkvModel::IdFallbackStrategy::kRuleBased);
  219. uint64_t mtf = GetRuleBasedMtf();
  220. assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
  221. operand_index_));
  222. if (mtf == kMtfNone) {
  223. mtf = kMtfTypeNonFunction;
  224. // Function types should have been handled by GetRuleBasedMtf.
  225. assert(inst_.opcode != SpvOpFunction);
  226. }
  227. return DecodeExistingId(mtf, &inst_.type_id);
  228. }
  229. spv_result_t MarkvDecoder::DecodeResultId() {
  230. uint32_t rank = 0;
  231. const uint64_t num_still_forward_declared =
  232. multi_mtf_.GetSize(kMtfForwardDeclared);
  233. if (num_still_forward_declared) {
  234. // Some ids were forward declared. Check if this id is one of them.
  235. uint64_t id_was_forward_declared;
  236. if (!reader_.ReadBits(&id_was_forward_declared, 1))
  237. return Diag(SPV_ERROR_INVALID_BINARY)
  238. << "Failed to read id_was_forward_declared flag";
  239. if (id_was_forward_declared) {
  240. if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
  241. return Diag(SPV_ERROR_INVALID_BINARY)
  242. << "Failed to read MTF rank of forward declared id";
  243. if (rank) {
  244. // The id was forward declared, recover it from kMtfForwardDeclared.
  245. if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
  246. &inst_.result_id))
  247. return Diag(SPV_ERROR_INTERNAL)
  248. << "Forward declared MTF rank is out of bounds";
  249. // We can now remove the id from kMtfForwardDeclared.
  250. if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
  251. return Diag(SPV_ERROR_INTERNAL)
  252. << "Failed to remove id from kMtfForwardDeclared";
  253. }
  254. }
  255. }
  256. if (inst_.result_id == 0) {
  257. // The id was not forward declared, issue a new id.
  258. inst_.result_id = GetIdBound();
  259. SetIdBound(inst_.result_id + 1);
  260. }
  261. if (model_->id_fallback_strategy() ==
  262. MarkvModel::IdFallbackStrategy::kRuleBased) {
  263. if (!rank) {
  264. multi_mtf_.Insert(kMtfAll, inst_.result_id);
  265. }
  266. }
  267. return SPV_SUCCESS;
  268. }
  269. spv_result_t MarkvDecoder::DecodeLiteralNumber(
  270. const spv_parsed_operand_t& operand) {
  271. if (operand.number_bit_width <= 32) {
  272. uint32_t word = 0;
  273. const spv_result_t result = DecodeNonIdWord(&word);
  274. if (result != SPV_SUCCESS) return result;
  275. inst_words_.push_back(word);
  276. } else {
  277. assert(operand.number_bit_width <= 64);
  278. uint64_t word = 0;
  279. if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
  280. if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
  281. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
  282. } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
  283. int64_t val = 0;
  284. if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
  285. model_->s64_block_exponent()))
  286. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
  287. std::memcpy(&word, &val, 8);
  288. } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
  289. if (!reader_.ReadUnencoded(&word))
  290. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
  291. } else {
  292. return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
  293. }
  294. inst_words_.push_back(static_cast<uint32_t>(word));
  295. inst_words_.push_back(static_cast<uint32_t>(word >> 32));
  296. }
  297. return SPV_SUCCESS;
  298. }
  299. bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
  300. const size_t num_bits_to_next_byte =
  301. GetNumBitsToNextByte(reader_.GetNumReadBits());
  302. if (num_bits_to_next_byte == 0 ||
  303. num_bits_to_next_byte > byte_break_if_less_than)
  304. return true;
  305. uint64_t bits = 0;
  306. if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
  307. assert(bits == 0);
  308. if (bits != 0) return false;
  309. return true;
  310. }
  311. spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
  312. const bool header_read_success =
  313. reader_.ReadUnencoded(&header_.magic_number) &&
  314. reader_.ReadUnencoded(&header_.markv_version) &&
  315. reader_.ReadUnencoded(&header_.markv_model) &&
  316. reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
  317. reader_.ReadUnencoded(&header_.spirv_version) &&
  318. reader_.ReadUnencoded(&header_.spirv_generator);
  319. if (!header_read_success)
  320. return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
  321. if (header_.markv_length_in_bits == 0)
  322. return Diag(SPV_ERROR_INVALID_BINARY)
  323. << "Header markv_length_in_bits field is zero";
  324. if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
  325. return Diag(SPV_ERROR_INVALID_BINARY)
  326. << "MARK-V binary has incorrect magic number";
  327. // TODO(atgoo@github.com): Print version strings.
  328. if (header_.markv_version != MarkvCodec::GetMarkvVersion())
  329. return Diag(SPV_ERROR_INVALID_BINARY)
  330. << "MARK-V binary and the codec have different versions";
  331. const uint32_t model_type = header_.markv_model >> 16;
  332. const uint32_t model_version = header_.markv_model & 0xFFFF;
  333. if (model_type != model_->model_type())
  334. return Diag(SPV_ERROR_INVALID_BINARY)
  335. << "MARK-V binary and the codec use different MARK-V models";
  336. if (model_version != model_->model_version())
  337. return Diag(SPV_ERROR_INVALID_BINARY)
  338. << "MARK-V binary and the codec use different versions if the same "
  339. << "MARK-V model";
  340. spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
  341. spirv_.resize(5, 0);
  342. spirv_[0] = SpvMagicNumber;
  343. spirv_[1] = header_.spirv_version;
  344. spirv_[2] = header_.spirv_generator;
  345. if (logger_) {
  346. reader_.SetCallback(
  347. [this](const std::string& str) { logger_->AppendBitSequence(str); });
  348. }
  349. while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
  350. inst_ = {};
  351. const spv_result_t decode_result = DecodeInstruction();
  352. if (decode_result != SPV_SUCCESS) return decode_result;
  353. }
  354. if (validator_options_) {
  355. spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
  356. const spv_result_t result = spvValidateWithOptions(
  357. context_, validator_options_, &validation_binary, nullptr);
  358. if (result != SPV_SUCCESS) return result;
  359. }
  360. // Validate the decode binary
  361. if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
  362. !reader_.OnlyZeroesLeft()) {
  363. return Diag(SPV_ERROR_INVALID_BINARY)
  364. << "MARK-V binary has wrong stated bit length "
  365. << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
  366. }
  367. // Decoding of the module is finished, validation state should have correct
  368. // id bound.
  369. spirv_[3] = GetIdBound();
  370. *spirv_binary = std::move(spirv_);
  371. return SPV_SUCCESS;
  372. }
  373. // TODO(atgoo@github.com): The implementation borrows heavily from
  374. // Parser::parseOperand.
  375. // Consider coupling them together in some way once MARK-V codec is more mature.
  376. // For now it's better to keep the code independent for experimentation
  377. // purposes.
  378. spv_result_t MarkvDecoder::DecodeOperand(
  379. size_t operand_offset, const spv_operand_type_t type,
  380. spv_operand_pattern_t* expected_operands) {
  381. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  382. memset(&operand_, 0, sizeof(operand_));
  383. assert((operand_offset >> 16) == 0);
  384. operand_.offset = static_cast<uint16_t>(operand_offset);
  385. operand_.type = type;
  386. // Set default values, may be updated later.
  387. operand_.number_kind = SPV_NUMBER_NONE;
  388. operand_.number_bit_width = 0;
  389. const size_t first_word_index = inst_words_.size();
  390. switch (type) {
  391. case SPV_OPERAND_TYPE_RESULT_ID: {
  392. const spv_result_t result = DecodeResultId();
  393. if (result != SPV_SUCCESS) return result;
  394. inst_words_.push_back(inst_.result_id);
  395. SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
  396. PromoteIfNeeded(inst_.result_id);
  397. break;
  398. }
  399. case SPV_OPERAND_TYPE_TYPE_ID: {
  400. const spv_result_t result = DecodeTypeId();
  401. if (result != SPV_SUCCESS) return result;
  402. inst_words_.push_back(inst_.type_id);
  403. SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
  404. PromoteIfNeeded(inst_.type_id);
  405. break;
  406. }
  407. case SPV_OPERAND_TYPE_ID:
  408. case SPV_OPERAND_TYPE_OPTIONAL_ID:
  409. case SPV_OPERAND_TYPE_SCOPE_ID:
  410. case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
  411. uint32_t id = 0;
  412. const spv_result_t result = DecodeRefId(&id);
  413. if (result != SPV_SUCCESS) return result;
  414. if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
  415. if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
  416. operand_.type = SPV_OPERAND_TYPE_ID;
  417. if (opcode == SpvOpExtInst && operand_.offset == 3) {
  418. // The current word is the extended instruction set id.
  419. // Set the extended instruction set type for the current
  420. // instruction.
  421. auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
  422. if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
  423. return Diag(SPV_ERROR_INVALID_ID)
  424. << "OpExtInst set id " << id
  425. << " does not reference an OpExtInstImport result Id";
  426. }
  427. inst_.ext_inst_type = ext_inst_type_iter->second;
  428. }
  429. }
  430. inst_words_.push_back(id);
  431. SetIdBound(std::max(GetIdBound(), id + 1));
  432. PromoteIfNeeded(id);
  433. break;
  434. }
  435. case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
  436. uint32_t word = 0;
  437. const spv_result_t result = DecodeNonIdWord(&word);
  438. if (result != SPV_SUCCESS) return result;
  439. inst_words_.push_back(word);
  440. assert(SpvOpExtInst == opcode);
  441. assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
  442. spv_ext_inst_desc ext_inst;
  443. if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
  444. return Diag(SPV_ERROR_INVALID_BINARY)
  445. << "Invalid extended instruction number: " << word;
  446. spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
  447. break;
  448. }
  449. case SPV_OPERAND_TYPE_LITERAL_INTEGER:
  450. case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
  451. // These are regular single-word literal integer operands.
  452. // Post-parsing validation should check the range of the parsed value.
  453. operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
  454. // It turns out they are always unsigned integers!
  455. operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
  456. operand_.number_bit_width = 32;
  457. uint32_t word = 0;
  458. const spv_result_t result = DecodeNonIdWord(&word);
  459. if (result != SPV_SUCCESS) return result;
  460. inst_words_.push_back(word);
  461. break;
  462. }
  463. case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
  464. case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
  465. operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
  466. if (opcode == SpvOpSwitch) {
  467. // The literal operands have the same type as the value
  468. // referenced by the selector Id.
  469. const uint32_t selector_id = inst_words_.at(1);
  470. const auto type_id_iter = id_to_type_id_.find(selector_id);
  471. if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
  472. return Diag(SPV_ERROR_INVALID_BINARY)
  473. << "Invalid OpSwitch: selector id " << selector_id
  474. << " has no type";
  475. }
  476. uint32_t type_id = type_id_iter->second;
  477. if (selector_id == type_id) {
  478. // Recall that by convention, a result ID that is a type definition
  479. // maps to itself.
  480. return Diag(SPV_ERROR_INVALID_BINARY)
  481. << "Invalid OpSwitch: selector id " << selector_id
  482. << " is a type, not a value";
  483. }
  484. if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
  485. return error;
  486. if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
  487. operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
  488. return Diag(SPV_ERROR_INVALID_BINARY)
  489. << "Invalid OpSwitch: selector id " << selector_id
  490. << " is not a scalar integer";
  491. }
  492. } else {
  493. assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
  494. // The literal number type is determined by the type Id for the
  495. // constant.
  496. assert(inst_.type_id);
  497. if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
  498. return error;
  499. }
  500. if (auto error = DecodeLiteralNumber(operand_)) return error;
  501. break;
  502. }
  503. case SPV_OPERAND_TYPE_LITERAL_STRING:
  504. case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
  505. operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
  506. std::vector<char> str;
  507. auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
  508. if (codec) {
  509. std::string decoded_string;
  510. const bool huffman_result =
  511. codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
  512. assert(huffman_result);
  513. if (!huffman_result)
  514. return Diag(SPV_ERROR_INVALID_BINARY)
  515. << "Failed to read literal string";
  516. if (decoded_string != "kMarkvNoneOfTheAbove") {
  517. std::copy(decoded_string.begin(), decoded_string.end(),
  518. std::back_inserter(str));
  519. str.push_back('\0');
  520. }
  521. }
  522. // The loop is expected to terminate once we encounter '\0' or exhaust
  523. // the bit stream.
  524. if (str.empty()) {
  525. while (true) {
  526. char ch = 0;
  527. if (!reader_.ReadUnencoded(&ch))
  528. return Diag(SPV_ERROR_INVALID_BINARY)
  529. << "Failed to read literal string";
  530. str.push_back(ch);
  531. if (ch == '\0') break;
  532. }
  533. }
  534. while (str.size() % 4 != 0) str.push_back('\0');
  535. inst_words_.resize(inst_words_.size() + str.size() / 4);
  536. std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
  537. if (SpvOpExtInstImport == opcode) {
  538. // Record the extended instruction type for the ID for this import.
  539. // There is only one string literal argument to OpExtInstImport,
  540. // so it's sufficient to guard this just on the opcode.
  541. const spv_ext_inst_type_t ext_inst_type =
  542. spvExtInstImportTypeGet(str.data());
  543. if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
  544. return Diag(SPV_ERROR_INVALID_BINARY)
  545. << "Invalid extended instruction import '" << str.data()
  546. << "'";
  547. }
  548. // We must have parsed a valid result ID. It's a condition
  549. // of the grammar, and we only accept non-zero result Ids.
  550. assert(inst_.result_id);
  551. const bool inserted =
  552. import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
  553. .second;
  554. (void)inserted;
  555. assert(inserted);
  556. }
  557. break;
  558. }
  559. case SPV_OPERAND_TYPE_CAPABILITY:
  560. case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
  561. case SPV_OPERAND_TYPE_EXECUTION_MODEL:
  562. case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
  563. case SPV_OPERAND_TYPE_MEMORY_MODEL:
  564. case SPV_OPERAND_TYPE_EXECUTION_MODE:
  565. case SPV_OPERAND_TYPE_STORAGE_CLASS:
  566. case SPV_OPERAND_TYPE_DIMENSIONALITY:
  567. case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
  568. case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
  569. case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
  570. case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
  571. case SPV_OPERAND_TYPE_LINKAGE_TYPE:
  572. case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
  573. case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
  574. case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
  575. case SPV_OPERAND_TYPE_DECORATION:
  576. case SPV_OPERAND_TYPE_BUILT_IN:
  577. case SPV_OPERAND_TYPE_GROUP_OPERATION:
  578. case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
  579. case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
  580. // A single word that is a plain enum value.
  581. uint32_t word = 0;
  582. const spv_result_t result = DecodeNonIdWord(&word);
  583. if (result != SPV_SUCCESS) return result;
  584. inst_words_.push_back(word);
  585. // Map an optional operand type to its corresponding concrete type.
  586. if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
  587. operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
  588. spv_operand_desc entry;
  589. if (grammar_.lookupOperand(type, word, &entry)) {
  590. return Diag(SPV_ERROR_INVALID_BINARY)
  591. << "Invalid " << spvOperandTypeStr(operand_.type)
  592. << " operand: " << word;
  593. }
  594. // Prepare to accept operands to this operand, if needed.
  595. spvPushOperandTypes(entry->operandTypes, expected_operands);
  596. break;
  597. }
  598. case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
  599. case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
  600. case SPV_OPERAND_TYPE_LOOP_CONTROL:
  601. case SPV_OPERAND_TYPE_IMAGE:
  602. case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
  603. case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
  604. case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
  605. // This operand is a mask.
  606. uint32_t word = 0;
  607. const spv_result_t result = DecodeNonIdWord(&word);
  608. if (result != SPV_SUCCESS) return result;
  609. inst_words_.push_back(word);
  610. // Map an optional operand type to its corresponding concrete type.
  611. if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
  612. operand_.type = SPV_OPERAND_TYPE_IMAGE;
  613. else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
  614. operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
  615. // Check validity of set mask bits. Also prepare for operands for those
  616. // masks if they have any. To get operand order correct, scan from
  617. // MSB to LSB since we can only prepend operands to a pattern.
  618. // The only case in the grammar where you have more than one mask bit
  619. // having an operand is for image operands. See SPIR-V 3.14 Image
  620. // Operands.
  621. uint32_t remaining_word = word;
  622. for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
  623. if (remaining_word & mask) {
  624. spv_operand_desc entry;
  625. if (grammar_.lookupOperand(type, mask, &entry)) {
  626. return Diag(SPV_ERROR_INVALID_BINARY)
  627. << "Invalid " << spvOperandTypeStr(operand_.type)
  628. << " operand: " << word << " has invalid mask component "
  629. << mask;
  630. }
  631. remaining_word ^= mask;
  632. spvPushOperandTypes(entry->operandTypes, expected_operands);
  633. }
  634. }
  635. if (word == 0) {
  636. // An all-zeroes mask *might* also be valid.
  637. spv_operand_desc entry;
  638. if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
  639. // Prepare for its operands, if any.
  640. spvPushOperandTypes(entry->operandTypes, expected_operands);
  641. }
  642. }
  643. break;
  644. }
  645. default:
  646. return Diag(SPV_ERROR_INVALID_BINARY)
  647. << "Internal error: Unhandled operand type: " << type;
  648. }
  649. operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
  650. assert(spvOperandIsConcrete(operand_.type));
  651. parsed_operands_.push_back(operand_);
  652. return SPV_SUCCESS;
  653. }
  654. spv_result_t MarkvDecoder::DecodeInstruction() {
  655. parsed_operands_.clear();
  656. inst_words_.clear();
  657. // Opcode/num_words placeholder, the word will be filled in later.
  658. inst_words_.push_back(0);
  659. bool num_operands_still_unknown = true;
  660. {
  661. uint32_t opcode = 0;
  662. uint32_t num_operands = 0;
  663. const spv_result_t opcode_decoding_result =
  664. DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
  665. if (opcode_decoding_result < 0) return opcode_decoding_result;
  666. if (opcode_decoding_result == SPV_SUCCESS) {
  667. inst_.num_operands = static_cast<uint16_t>(num_operands);
  668. num_operands_still_unknown = false;
  669. } else {
  670. if (!reader_.ReadVariableWidthU32(&opcode,
  671. model_->opcode_chunk_length())) {
  672. return Diag(SPV_ERROR_INVALID_BINARY)
  673. << "Failed to read opcode of instruction";
  674. }
  675. }
  676. inst_.opcode = static_cast<uint16_t>(opcode);
  677. }
  678. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  679. spv_opcode_desc opcode_desc;
  680. if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
  681. return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
  682. }
  683. spv_operand_pattern_t expected_operands;
  684. expected_operands.reserve(opcode_desc->numTypes);
  685. for (auto i = 0; i < opcode_desc->numTypes; i++) {
  686. expected_operands.push_back(
  687. opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
  688. }
  689. if (num_operands_still_unknown) {
  690. if (!OpcodeHasFixedNumberOfOperands(opcode)) {
  691. if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
  692. model_->num_operands_chunk_length()))
  693. return Diag(SPV_ERROR_INVALID_BINARY)
  694. << "Failed to read num_operands of instruction";
  695. } else {
  696. inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
  697. }
  698. }
  699. for (operand_index_ = 0;
  700. operand_index_ < static_cast<size_t>(inst_.num_operands);
  701. ++operand_index_) {
  702. assert(!expected_operands.empty());
  703. const spv_operand_type_t type =
  704. spvTakeFirstMatchableOperand(&expected_operands);
  705. const size_t operand_offset = inst_words_.size();
  706. const spv_result_t decode_result =
  707. DecodeOperand(operand_offset, type, &expected_operands);
  708. if (decode_result != SPV_SUCCESS) return decode_result;
  709. }
  710. assert(inst_.num_operands == parsed_operands_.size());
  711. // Only valid while inst_words_ and parsed_operands_ remain unchanged (until
  712. // next DecodeInstruction call).
  713. inst_.words = inst_words_.data();
  714. inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
  715. inst_.num_words = static_cast<uint16_t>(inst_words_.size());
  716. inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
  717. std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
  718. assert(inst_.num_words ==
  719. std::accumulate(
  720. parsed_operands_.begin(), parsed_operands_.end(), 1,
  721. [](int num_words, const spv_parsed_operand_t& operand) {
  722. return num_words += operand.num_words;
  723. }) &&
  724. "num_words in instruction doesn't correspond to the sum of num_words"
  725. "in the operands");
  726. RecordNumberType();
  727. ProcessCurInstruction();
  728. if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
  729. return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
  730. if (logger_) {
  731. logger_->NewLine();
  732. std::stringstream ss;
  733. ss << spvOpcodeString(opcode) << " ";
  734. for (size_t index = 1; index < inst_words_.size(); ++index)
  735. ss << inst_words_[index] << " ";
  736. logger_->AppendText(ss.str());
  737. logger_->NewLine();
  738. logger_->NewLine();
  739. if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
  740. }
  741. return SPV_SUCCESS;
  742. }
  743. spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
  744. spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
  745. assert(type_id != 0);
  746. auto type_info_iter = type_id_to_number_type_info_.find(type_id);
  747. if (type_info_iter == type_id_to_number_type_info_.end()) {
  748. return Diag(SPV_ERROR_INVALID_BINARY)
  749. << "Type Id " << type_id << " is not a type";
  750. }
  751. const NumberType& info = type_info_iter->second;
  752. if (info.type == SPV_NUMBER_NONE) {
  753. // This is a valid type, but for something other than a scalar number.
  754. return Diag(SPV_ERROR_INVALID_BINARY)
  755. << "Type Id " << type_id << " is not a scalar numeric type";
  756. }
  757. parsed_operand->number_kind = info.type;
  758. parsed_operand->number_bit_width = info.bit_width;
  759. // Round up the word count.
  760. parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
  761. return SPV_SUCCESS;
  762. }
  763. void MarkvDecoder::RecordNumberType() {
  764. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  765. if (spvOpcodeGeneratesType(opcode)) {
  766. NumberType info = {SPV_NUMBER_NONE, 0};
  767. if (SpvOpTypeInt == opcode) {
  768. info.bit_width = inst_.words[inst_.operands[1].offset];
  769. info.type = inst_.words[inst_.operands[2].offset]
  770. ? SPV_NUMBER_SIGNED_INT
  771. : SPV_NUMBER_UNSIGNED_INT;
  772. } else if (SpvOpTypeFloat == opcode) {
  773. info.bit_width = inst_.words[inst_.operands[1].offset];
  774. info.type = SPV_NUMBER_FLOATING;
  775. }
  776. // The *result* Id of a type generating instruction is the type Id.
  777. type_id_to_number_type_info_[inst_.result_id] = info;
  778. }
  779. }
  780. } // namespace comp
  781. } // namespace spvtools