123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926 |
- // Copyright (c) 2018 Google LLC
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- #include "source/comp/markv_decoder.h"
- #include <cstring>
- #include <iterator>
- #include <numeric>
- #include "source/ext_inst.h"
- #include "source/opcode.h"
- #include "spirv-tools/libspirv.hpp"
- namespace spvtools {
- namespace comp {
- spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
- auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
- if (codec) {
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to decode non-id word with Huffman";
- if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
- // The word decoded successfully.
- *word = uint32_t(decoded_value);
- assert(*word == decoded_value);
- return SPV_SUCCESS;
- }
- // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
- }
- const size_t chunk_length =
- model_->GetOperandVariableWidthChunkLength(operand_.type);
- if (chunk_length) {
- if (!reader_.ReadVariableWidthU32(word, chunk_length))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to decode non-id word with varint";
- } else {
- if (!reader_.ReadUnencoded(word))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read unencoded non-id word";
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
- uint32_t* opcode, uint32_t* num_operands) {
- // First try to use the Markov chain codec.
- auto* codec =
- model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
- if (codec) {
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode opcode_and_num_operands, previous opcode is "
- << spvOpcodeString(GetPrevOpcode());
- if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
- // The word was successfully decoded.
- *opcode = uint32_t(decoded_value & 0xFFFF);
- *num_operands = uint32_t(decoded_value >> 16);
- return SPV_SUCCESS;
- }
- // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
- }
- // Fallback to base-rate codec.
- codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
- assert(codec);
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode opcode_and_num_operands with global codec";
- if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
- // Received kMarkvNoneOfTheAbove signal, fallback further.
- return SPV_UNSUPPORTED;
- }
- *opcode = uint32_t(decoded_value & 0xFFFF);
- *num_operands = uint32_t(decoded_value >> 16);
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
- uint32_t fallback_method,
- uint32_t* rank) {
- const auto* codec = GetMtfHuffmanCodec(mtf);
- if (!codec) {
- assert(fallback_method != kMtfNone);
- codec = GetMtfHuffmanCodec(fallback_method);
- }
- if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
- uint32_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
- if (decoded_value == kMtfRankEncodedByValueSignal) {
- // Decode by value.
- if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode MTF rank with varint";
- *rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
- } else {
- // Decode using Huffman coding.
- assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
- *rank = decoded_value;
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
- auto* codec =
- model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
- uint64_t mtf = kMtfNone;
- if (codec) {
- uint64_t decoded_value = 0;
- if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode descriptor with Huffman";
- if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
- const uint32_t long_descriptor = uint32_t(decoded_value);
- mtf = GetMtfLongIdDescriptor(long_descriptor);
- }
- }
- if (mtf == kMtfNone) {
- if (model_->id_fallback_strategy() !=
- MarkvModel::IdFallbackStrategy::kShortDescriptor) {
- return SPV_UNSUPPORTED;
- }
- uint64_t decoded_value = 0;
- if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
- return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
- const uint32_t short_descriptor = uint32_t(decoded_value);
- if (short_descriptor == 0) {
- // Forward declared id.
- return SPV_UNSUPPORTED;
- }
- mtf = GetMtfShortIdDescriptor(short_descriptor);
- }
- return DecodeExistingId(mtf, id);
- }
- spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
- assert(multi_mtf_.GetSize(mtf) > 0);
- *id = 0;
- uint32_t rank = 0;
- if (multi_mtf_.GetSize(mtf) == 1) {
- rank = 1;
- } else {
- const spv_result_t result =
- DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
- if (result != SPV_SUCCESS) return result;
- }
- assert(rank);
- if (!multi_mtf_.ValueFromRank(mtf, rank, id))
- return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
- {
- const spv_result_t result = DecodeIdWithDescriptor(id);
- if (result != SPV_UNSUPPORTED) return result;
- }
- const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
- SpvOp(inst_.opcode))(operand_index_);
- uint32_t rank = 0;
- *id = 0;
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- uint64_t mtf = GetRuleBasedMtf();
- if (mtf != kMtfNone && !can_forward_declare) {
- return DecodeExistingId(mtf, id);
- }
- if (mtf == kMtfNone) mtf = kMtfAll;
- {
- const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
- if (result != SPV_SUCCESS) return result;
- }
- if (rank == 0) {
- // This is the first occurrence of a forward declared id.
- *id = GetIdBound();
- SetIdBound(*id + 1);
- multi_mtf_.Insert(kMtfAll, *id);
- multi_mtf_.Insert(kMtfForwardDeclared, *id);
- if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
- } else {
- if (!multi_mtf_.ValueFromRank(mtf, rank, id))
- return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
- }
- } else {
- assert(can_forward_declare);
- if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to decode MTF rank with varint";
- if (rank == 0) {
- // This is the first occurrence of a forward declared id.
- *id = GetIdBound();
- SetIdBound(*id + 1);
- multi_mtf_.Insert(kMtfForwardDeclared, *id);
- } else {
- if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
- return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
- }
- }
- assert(*id);
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeTypeId() {
- if (inst_.opcode == SpvOpFunctionParameter) {
- assert(!remaining_function_parameter_types_.empty());
- inst_.type_id = remaining_function_parameter_types_.front();
- remaining_function_parameter_types_.pop_front();
- return SPV_SUCCESS;
- }
- {
- const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
- if (result != SPV_UNSUPPORTED) return result;
- }
- assert(model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased);
- uint64_t mtf = GetRuleBasedMtf();
- assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
- operand_index_));
- if (mtf == kMtfNone) {
- mtf = kMtfTypeNonFunction;
- // Function types should have been handled by GetRuleBasedMtf.
- assert(inst_.opcode != SpvOpFunction);
- }
- return DecodeExistingId(mtf, &inst_.type_id);
- }
- spv_result_t MarkvDecoder::DecodeResultId() {
- uint32_t rank = 0;
- const uint64_t num_still_forward_declared =
- multi_mtf_.GetSize(kMtfForwardDeclared);
- if (num_still_forward_declared) {
- // Some ids were forward declared. Check if this id is one of them.
- uint64_t id_was_forward_declared;
- if (!reader_.ReadBits(&id_was_forward_declared, 1))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read id_was_forward_declared flag";
- if (id_was_forward_declared) {
- if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read MTF rank of forward declared id";
- if (rank) {
- // The id was forward declared, recover it from kMtfForwardDeclared.
- if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
- &inst_.result_id))
- return Diag(SPV_ERROR_INTERNAL)
- << "Forward declared MTF rank is out of bounds";
- // We can now remove the id from kMtfForwardDeclared.
- if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
- return Diag(SPV_ERROR_INTERNAL)
- << "Failed to remove id from kMtfForwardDeclared";
- }
- }
- }
- if (inst_.result_id == 0) {
- // The id was not forward declared, issue a new id.
- inst_.result_id = GetIdBound();
- SetIdBound(inst_.result_id + 1);
- }
- if (model_->id_fallback_strategy() ==
- MarkvModel::IdFallbackStrategy::kRuleBased) {
- if (!rank) {
- multi_mtf_.Insert(kMtfAll, inst_.result_id);
- }
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeLiteralNumber(
- const spv_parsed_operand_t& operand) {
- if (operand.number_bit_width <= 32) {
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- } else {
- assert(operand.number_bit_width <= 64);
- uint64_t word = 0;
- if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
- if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
- } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
- int64_t val = 0;
- if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
- model_->s64_block_exponent()))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
- std::memcpy(&word, &val, 8);
- } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
- if (!reader_.ReadUnencoded(&word))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
- } else {
- return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
- }
- inst_words_.push_back(static_cast<uint32_t>(word));
- inst_words_.push_back(static_cast<uint32_t>(word >> 32));
- }
- return SPV_SUCCESS;
- }
- bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
- const size_t num_bits_to_next_byte =
- GetNumBitsToNextByte(reader_.GetNumReadBits());
- if (num_bits_to_next_byte == 0 ||
- num_bits_to_next_byte > byte_break_if_less_than)
- return true;
- uint64_t bits = 0;
- if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
- assert(bits == 0);
- if (bits != 0) return false;
- return true;
- }
- spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
- const bool header_read_success =
- reader_.ReadUnencoded(&header_.magic_number) &&
- reader_.ReadUnencoded(&header_.markv_version) &&
- reader_.ReadUnencoded(&header_.markv_model) &&
- reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
- reader_.ReadUnencoded(&header_.spirv_version) &&
- reader_.ReadUnencoded(&header_.spirv_generator);
- if (!header_read_success)
- return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
- if (header_.markv_length_in_bits == 0)
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Header markv_length_in_bits field is zero";
- if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary has incorrect magic number";
- // TODO(atgoo@github.com): Print version strings.
- if (header_.markv_version != MarkvCodec::GetMarkvVersion())
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary and the codec have different versions";
- const uint32_t model_type = header_.markv_model >> 16;
- const uint32_t model_version = header_.markv_model & 0xFFFF;
- if (model_type != model_->model_type())
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary and the codec use different MARK-V models";
- if (model_version != model_->model_version())
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary and the codec use different versions if the same "
- << "MARK-V model";
- spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic.
- spirv_.resize(5, 0);
- spirv_[0] = SpvMagicNumber;
- spirv_[1] = header_.spirv_version;
- spirv_[2] = header_.spirv_generator;
- if (logger_) {
- reader_.SetCallback(
- [this](const std::string& str) { logger_->AppendBitSequence(str); });
- }
- while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
- inst_ = {};
- const spv_result_t decode_result = DecodeInstruction();
- if (decode_result != SPV_SUCCESS) return decode_result;
- }
- if (validator_options_) {
- spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
- const spv_result_t result = spvValidateWithOptions(
- context_, validator_options_, &validation_binary, nullptr);
- if (result != SPV_SUCCESS) return result;
- }
- // Validate the decode binary
- if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
- !reader_.OnlyZeroesLeft()) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "MARK-V binary has wrong stated bit length "
- << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
- }
- // Decoding of the module is finished, validation state should have correct
- // id bound.
- spirv_[3] = GetIdBound();
- *spirv_binary = std::move(spirv_);
- return SPV_SUCCESS;
- }
- // TODO(atgoo@github.com): The implementation borrows heavily from
- // Parser::parseOperand.
- // Consider coupling them together in some way once MARK-V codec is more mature.
- // For now it's better to keep the code independent for experimentation
- // purposes.
- spv_result_t MarkvDecoder::DecodeOperand(
- size_t operand_offset, const spv_operand_type_t type,
- spv_operand_pattern_t* expected_operands) {
- const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
- memset(&operand_, 0, sizeof(operand_));
- assert((operand_offset >> 16) == 0);
- operand_.offset = static_cast<uint16_t>(operand_offset);
- operand_.type = type;
- // Set default values, may be updated later.
- operand_.number_kind = SPV_NUMBER_NONE;
- operand_.number_bit_width = 0;
- const size_t first_word_index = inst_words_.size();
- switch (type) {
- case SPV_OPERAND_TYPE_RESULT_ID: {
- const spv_result_t result = DecodeResultId();
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(inst_.result_id);
- SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
- PromoteIfNeeded(inst_.result_id);
- break;
- }
- case SPV_OPERAND_TYPE_TYPE_ID: {
- const spv_result_t result = DecodeTypeId();
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(inst_.type_id);
- SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
- PromoteIfNeeded(inst_.type_id);
- break;
- }
- case SPV_OPERAND_TYPE_ID:
- case SPV_OPERAND_TYPE_OPTIONAL_ID:
- case SPV_OPERAND_TYPE_SCOPE_ID:
- case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
- uint32_t id = 0;
- const spv_result_t result = DecodeRefId(&id);
- if (result != SPV_SUCCESS) return result;
- if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
- if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
- operand_.type = SPV_OPERAND_TYPE_ID;
- if (opcode == SpvOpExtInst && operand_.offset == 3) {
- // The current word is the extended instruction set id.
- // Set the extended instruction set type for the current
- // instruction.
- auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
- if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
- return Diag(SPV_ERROR_INVALID_ID)
- << "OpExtInst set id " << id
- << " does not reference an OpExtInstImport result Id";
- }
- inst_.ext_inst_type = ext_inst_type_iter->second;
- }
- }
- inst_words_.push_back(id);
- SetIdBound(std::max(GetIdBound(), id + 1));
- PromoteIfNeeded(id);
- break;
- }
- case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- assert(SpvOpExtInst == opcode);
- assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
- spv_ext_inst_desc ext_inst;
- if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid extended instruction number: " << word;
- spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
- break;
- }
- case SPV_OPERAND_TYPE_LITERAL_INTEGER:
- case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
- // These are regular single-word literal integer operands.
- // Post-parsing validation should check the range of the parsed value.
- operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
- // It turns out they are always unsigned integers!
- operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
- operand_.number_bit_width = 32;
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- break;
- }
- case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
- case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
- operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
- if (opcode == SpvOpSwitch) {
- // The literal operands have the same type as the value
- // referenced by the selector Id.
- const uint32_t selector_id = inst_words_.at(1);
- const auto type_id_iter = id_to_type_id_.find(selector_id);
- if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid OpSwitch: selector id " << selector_id
- << " has no type";
- }
- uint32_t type_id = type_id_iter->second;
- if (selector_id == type_id) {
- // Recall that by convention, a result ID that is a type definition
- // maps to itself.
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid OpSwitch: selector id " << selector_id
- << " is a type, not a value";
- }
- if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
- return error;
- if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
- operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid OpSwitch: selector id " << selector_id
- << " is not a scalar integer";
- }
- } else {
- assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
- // The literal number type is determined by the type Id for the
- // constant.
- assert(inst_.type_id);
- if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
- return error;
- }
- if (auto error = DecodeLiteralNumber(operand_)) return error;
- break;
- }
- case SPV_OPERAND_TYPE_LITERAL_STRING:
- case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
- operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
- std::vector<char> str;
- auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
- if (codec) {
- std::string decoded_string;
- const bool huffman_result =
- codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
- assert(huffman_result);
- if (!huffman_result)
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read literal string";
- if (decoded_string != "kMarkvNoneOfTheAbove") {
- std::copy(decoded_string.begin(), decoded_string.end(),
- std::back_inserter(str));
- str.push_back('\0');
- }
- }
- // The loop is expected to terminate once we encounter '\0' or exhaust
- // the bit stream.
- if (str.empty()) {
- while (true) {
- char ch = 0;
- if (!reader_.ReadUnencoded(&ch))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read literal string";
- str.push_back(ch);
- if (ch == '\0') break;
- }
- }
- while (str.size() % 4 != 0) str.push_back('\0');
- inst_words_.resize(inst_words_.size() + str.size() / 4);
- std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
- if (SpvOpExtInstImport == opcode) {
- // Record the extended instruction type for the ID for this import.
- // There is only one string literal argument to OpExtInstImport,
- // so it's sufficient to guard this just on the opcode.
- const spv_ext_inst_type_t ext_inst_type =
- spvExtInstImportTypeGet(str.data());
- if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid extended instruction import '" << str.data()
- << "'";
- }
- // We must have parsed a valid result ID. It's a condition
- // of the grammar, and we only accept non-zero result Ids.
- assert(inst_.result_id);
- const bool inserted =
- import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
- .second;
- (void)inserted;
- assert(inserted);
- }
- break;
- }
- case SPV_OPERAND_TYPE_CAPABILITY:
- case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
- case SPV_OPERAND_TYPE_EXECUTION_MODEL:
- case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
- case SPV_OPERAND_TYPE_MEMORY_MODEL:
- case SPV_OPERAND_TYPE_EXECUTION_MODE:
- case SPV_OPERAND_TYPE_STORAGE_CLASS:
- case SPV_OPERAND_TYPE_DIMENSIONALITY:
- case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
- case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
- case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
- case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
- case SPV_OPERAND_TYPE_LINKAGE_TYPE:
- case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
- case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
- case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
- case SPV_OPERAND_TYPE_DECORATION:
- case SPV_OPERAND_TYPE_BUILT_IN:
- case SPV_OPERAND_TYPE_GROUP_OPERATION:
- case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
- case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
- // A single word that is a plain enum value.
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- // Map an optional operand type to its corresponding concrete type.
- if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
- operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
- spv_operand_desc entry;
- if (grammar_.lookupOperand(type, word, &entry)) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid " << spvOperandTypeStr(operand_.type)
- << " operand: " << word;
- }
- // Prepare to accept operands to this operand, if needed.
- spvPushOperandTypes(entry->operandTypes, expected_operands);
- break;
- }
- case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
- case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
- case SPV_OPERAND_TYPE_LOOP_CONTROL:
- case SPV_OPERAND_TYPE_IMAGE:
- case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
- case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
- case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
- // This operand is a mask.
- uint32_t word = 0;
- const spv_result_t result = DecodeNonIdWord(&word);
- if (result != SPV_SUCCESS) return result;
- inst_words_.push_back(word);
- // Map an optional operand type to its corresponding concrete type.
- if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
- operand_.type = SPV_OPERAND_TYPE_IMAGE;
- else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
- operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
- // Check validity of set mask bits. Also prepare for operands for those
- // masks if they have any. To get operand order correct, scan from
- // MSB to LSB since we can only prepend operands to a pattern.
- // The only case in the grammar where you have more than one mask bit
- // having an operand is for image operands. See SPIR-V 3.14 Image
- // Operands.
- uint32_t remaining_word = word;
- for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
- if (remaining_word & mask) {
- spv_operand_desc entry;
- if (grammar_.lookupOperand(type, mask, &entry)) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Invalid " << spvOperandTypeStr(operand_.type)
- << " operand: " << word << " has invalid mask component "
- << mask;
- }
- remaining_word ^= mask;
- spvPushOperandTypes(entry->operandTypes, expected_operands);
- }
- }
- if (word == 0) {
- // An all-zeroes mask *might* also be valid.
- spv_operand_desc entry;
- if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
- // Prepare for its operands, if any.
- spvPushOperandTypes(entry->operandTypes, expected_operands);
- }
- }
- break;
- }
- default:
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Internal error: Unhandled operand type: " << type;
- }
- operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
- assert(spvOperandIsConcrete(operand_.type));
- parsed_operands_.push_back(operand_);
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::DecodeInstruction() {
- parsed_operands_.clear();
- inst_words_.clear();
- // Opcode/num_words placeholder, the word will be filled in later.
- inst_words_.push_back(0);
- bool num_operands_still_unknown = true;
- {
- uint32_t opcode = 0;
- uint32_t num_operands = 0;
- const spv_result_t opcode_decoding_result =
- DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
- if (opcode_decoding_result < 0) return opcode_decoding_result;
- if (opcode_decoding_result == SPV_SUCCESS) {
- inst_.num_operands = static_cast<uint16_t>(num_operands);
- num_operands_still_unknown = false;
- } else {
- if (!reader_.ReadVariableWidthU32(&opcode,
- model_->opcode_chunk_length())) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read opcode of instruction";
- }
- }
- inst_.opcode = static_cast<uint16_t>(opcode);
- }
- const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
- spv_opcode_desc opcode_desc;
- if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
- return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
- }
- spv_operand_pattern_t expected_operands;
- expected_operands.reserve(opcode_desc->numTypes);
- for (auto i = 0; i < opcode_desc->numTypes; i++) {
- expected_operands.push_back(
- opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
- }
- if (num_operands_still_unknown) {
- if (!OpcodeHasFixedNumberOfOperands(opcode)) {
- if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
- model_->num_operands_chunk_length()))
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Failed to read num_operands of instruction";
- } else {
- inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
- }
- }
- for (operand_index_ = 0;
- operand_index_ < static_cast<size_t>(inst_.num_operands);
- ++operand_index_) {
- assert(!expected_operands.empty());
- const spv_operand_type_t type =
- spvTakeFirstMatchableOperand(&expected_operands);
- const size_t operand_offset = inst_words_.size();
- const spv_result_t decode_result =
- DecodeOperand(operand_offset, type, &expected_operands);
- if (decode_result != SPV_SUCCESS) return decode_result;
- }
- assert(inst_.num_operands == parsed_operands_.size());
- // Only valid while inst_words_ and parsed_operands_ remain unchanged (until
- // next DecodeInstruction call).
- inst_.words = inst_words_.data();
- inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
- inst_.num_words = static_cast<uint16_t>(inst_words_.size());
- inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
- std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
- assert(inst_.num_words ==
- std::accumulate(
- parsed_operands_.begin(), parsed_operands_.end(), 1,
- [](int num_words, const spv_parsed_operand_t& operand) {
- return num_words += operand.num_words;
- }) &&
- "num_words in instruction doesn't correspond to the sum of num_words"
- "in the operands");
- RecordNumberType();
- ProcessCurInstruction();
- if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
- return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
- if (logger_) {
- logger_->NewLine();
- std::stringstream ss;
- ss << spvOpcodeString(opcode) << " ";
- for (size_t index = 1; index < inst_words_.size(); ++index)
- ss << inst_words_[index] << " ";
- logger_->AppendText(ss.str());
- logger_->NewLine();
- logger_->NewLine();
- if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
- }
- return SPV_SUCCESS;
- }
- spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
- spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
- assert(type_id != 0);
- auto type_info_iter = type_id_to_number_type_info_.find(type_id);
- if (type_info_iter == type_id_to_number_type_info_.end()) {
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Type Id " << type_id << " is not a type";
- }
- const NumberType& info = type_info_iter->second;
- if (info.type == SPV_NUMBER_NONE) {
- // This is a valid type, but for something other than a scalar number.
- return Diag(SPV_ERROR_INVALID_BINARY)
- << "Type Id " << type_id << " is not a scalar numeric type";
- }
- parsed_operand->number_kind = info.type;
- parsed_operand->number_bit_width = info.bit_width;
- // Round up the word count.
- parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
- return SPV_SUCCESS;
- }
- void MarkvDecoder::RecordNumberType() {
- const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
- if (spvOpcodeGeneratesType(opcode)) {
- NumberType info = {SPV_NUMBER_NONE, 0};
- if (SpvOpTypeInt == opcode) {
- info.bit_width = inst_.words[inst_.operands[1].offset];
- info.type = inst_.words[inst_.operands[2].offset]
- ? SPV_NUMBER_SIGNED_INT
- : SPV_NUMBER_UNSIGNED_INT;
- } else if (SpvOpTypeFloat == opcode) {
- info.bit_width = inst_.words[inst_.operands[1].offset];
- info.type = SPV_NUMBER_FLOATING;
- }
- // The *result* Id of a type generating instruction is the type Id.
- type_id_to_number_type_info_[inst_.result_id] = info;
- }
- }
- } // namespace comp
- } // namespace spvtools
|