markv_codec.cpp 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. // Copyright (c) 2018 Google LLC
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // MARK-V is a compression format for SPIR-V binaries. It strips away
  15. // non-essential information (such as result IDs which can be regenerated) and
  16. // uses various bit reduction techniques to reduce the size of the binary.
  17. #include "source/comp/markv_codec.h"
  18. #include "source/comp/markv_logger.h"
  19. #include "source/latest_version_glsl_std_450_header.h"
  20. #include "source/latest_version_opencl_std_header.h"
  21. #include "source/opcode.h"
  22. #include "source/util/make_unique.h"
  23. namespace spvtools {
  24. namespace comp {
  25. namespace {
  26. // Custom hash function used to produce short descriptors.
  27. uint32_t ShortHashU32Array(const std::vector<uint32_t>& words) {
  28. // The hash function is a sum of hashes of each word seeded by word index.
  29. // Knuth's multiplicative hash is used to hash the words.
  30. const uint32_t kKnuthMulHash = 2654435761;
  31. uint32_t val = 0;
  32. for (uint32_t i = 0; i < words.size(); ++i) {
  33. val += (words[i] + i + 123) * kKnuthMulHash;
  34. }
  35. return 1 + val % ((1 << MarkvCodec::kShortDescriptorNumBits) - 1);
  36. }
  37. // Returns a set of mtf rank codecs based on a plausible hand-coded
  38. // distribution.
  39. std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>>
  40. GetMtfHuffmanCodecs() {
  41. std::map<uint64_t, std::unique_ptr<HuffmanCodec<uint32_t>>> codecs;
  42. std::unique_ptr<HuffmanCodec<uint32_t>> codec;
  43. codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
  44. {0, 5},
  45. {1, 40},
  46. {2, 10},
  47. {3, 5},
  48. {4, 5},
  49. {5, 5},
  50. {6, 3},
  51. {7, 3},
  52. {8, 3},
  53. {9, 3},
  54. {MarkvCodec::kMtfRankEncodedByValueSignal, 10},
  55. }));
  56. codecs.emplace(kMtfAll, std::move(codec));
  57. codec = MakeUnique<HuffmanCodec<uint32_t>>(std::map<uint32_t, uint32_t>({
  58. {1, 50},
  59. {2, 20},
  60. {3, 5},
  61. {4, 5},
  62. {5, 2},
  63. {6, 1},
  64. {7, 1},
  65. {8, 1},
  66. {9, 1},
  67. {MarkvCodec::kMtfRankEncodedByValueSignal, 10},
  68. }));
  69. codecs.emplace(kMtfGenericNonZeroRank, std::move(codec));
  70. return codecs;
  71. }
  72. } // namespace
  73. const uint32_t MarkvCodec::kMarkvMagicNumber = 0x07230303;
  74. const uint32_t MarkvCodec::kMtfSmallestRankEncodedByValue = 10;
  75. const uint32_t MarkvCodec::kMtfRankEncodedByValueSignal =
  76. std::numeric_limits<uint32_t>::max();
  77. const uint32_t MarkvCodec::kShortDescriptorNumBits = 8;
  78. const size_t MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte = 8;
  79. MarkvCodec::MarkvCodec(spv_const_context context,
  80. spv_validator_options validator_options,
  81. const MarkvModel* model)
  82. : validator_options_(validator_options),
  83. grammar_(context),
  84. model_(model),
  85. short_id_descriptors_(ShortHashU32Array),
  86. mtf_huffman_codecs_(GetMtfHuffmanCodecs()),
  87. context_(context) {}
  88. MarkvCodec::~MarkvCodec() { spvValidatorOptionsDestroy(validator_options_); }
  89. MarkvCodec::MarkvHeader::MarkvHeader()
  90. : magic_number(MarkvCodec::kMarkvMagicNumber),
  91. markv_version(MarkvCodec::GetMarkvVersion()) {}
  92. // Defines and returns current MARK-V version.
  93. // static
  94. uint32_t MarkvCodec::GetMarkvVersion() {
  95. const uint32_t kVersionMajor = 1;
  96. const uint32_t kVersionMinor = 4;
  97. return kVersionMinor | (kVersionMajor << 16);
  98. }
  99. size_t MarkvCodec::GetNumBitsToNextByte(size_t bit_pos) const {
  100. return (8 - (bit_pos % 8)) % 8;
  101. }
  102. // Returns true if the opcode has a fixed number of operands. May return a
  103. // false negative.
  104. bool MarkvCodec::OpcodeHasFixedNumberOfOperands(SpvOp opcode) const {
  105. switch (opcode) {
  106. // TODO(atgoo@github.com) This is not a complete list.
  107. case SpvOpNop:
  108. case SpvOpName:
  109. case SpvOpUndef:
  110. case SpvOpSizeOf:
  111. case SpvOpLine:
  112. case SpvOpNoLine:
  113. case SpvOpDecorationGroup:
  114. case SpvOpExtension:
  115. case SpvOpExtInstImport:
  116. case SpvOpMemoryModel:
  117. case SpvOpCapability:
  118. case SpvOpTypeVoid:
  119. case SpvOpTypeBool:
  120. case SpvOpTypeInt:
  121. case SpvOpTypeFloat:
  122. case SpvOpTypeVector:
  123. case SpvOpTypeMatrix:
  124. case SpvOpTypeSampler:
  125. case SpvOpTypeSampledImage:
  126. case SpvOpTypeArray:
  127. case SpvOpTypePointer:
  128. case SpvOpConstantTrue:
  129. case SpvOpConstantFalse:
  130. case SpvOpLabel:
  131. case SpvOpBranch:
  132. case SpvOpFunction:
  133. case SpvOpFunctionParameter:
  134. case SpvOpFunctionEnd:
  135. case SpvOpBitcast:
  136. case SpvOpCopyObject:
  137. case SpvOpTranspose:
  138. case SpvOpSNegate:
  139. case SpvOpFNegate:
  140. case SpvOpIAdd:
  141. case SpvOpFAdd:
  142. case SpvOpISub:
  143. case SpvOpFSub:
  144. case SpvOpIMul:
  145. case SpvOpFMul:
  146. case SpvOpUDiv:
  147. case SpvOpSDiv:
  148. case SpvOpFDiv:
  149. case SpvOpUMod:
  150. case SpvOpSRem:
  151. case SpvOpSMod:
  152. case SpvOpFRem:
  153. case SpvOpFMod:
  154. case SpvOpVectorTimesScalar:
  155. case SpvOpMatrixTimesScalar:
  156. case SpvOpVectorTimesMatrix:
  157. case SpvOpMatrixTimesVector:
  158. case SpvOpMatrixTimesMatrix:
  159. case SpvOpOuterProduct:
  160. case SpvOpDot:
  161. return true;
  162. default:
  163. break;
  164. }
  165. return false;
  166. }
  167. void MarkvCodec::ProcessCurInstruction() {
  168. instructions_.emplace_back(new val::Instruction(&inst_));
  169. const SpvOp opcode = SpvOp(inst_.opcode);
  170. if (inst_.result_id) {
  171. id_to_def_instruction_.emplace(inst_.result_id, instructions_.back().get());
  172. // Collect ids local to the current function.
  173. if (cur_function_id_) {
  174. ids_local_to_cur_function_.push_back(inst_.result_id);
  175. }
  176. // Starting new function.
  177. if (opcode == SpvOpFunction) {
  178. cur_function_id_ = inst_.result_id;
  179. cur_function_return_type_ = inst_.type_id;
  180. if (model_->id_fallback_strategy() ==
  181. MarkvModel::IdFallbackStrategy::kRuleBased) {
  182. multi_mtf_.Insert(GetMtfFunctionWithReturnType(inst_.type_id),
  183. inst_.result_id);
  184. }
  185. // Store function parameter types in a queue, so that we know which types
  186. // to expect in the following OpFunctionParameter instructions.
  187. const val::Instruction* def_inst = FindDef(inst_.words[4]);
  188. assert(def_inst);
  189. assert(def_inst->opcode() == SpvOpTypeFunction);
  190. for (uint32_t i = 3; i < def_inst->words().size(); ++i) {
  191. remaining_function_parameter_types_.push_back(def_inst->word(i));
  192. }
  193. }
  194. }
  195. // Remove local ids from MTFs if function end.
  196. if (opcode == SpvOpFunctionEnd) {
  197. cur_function_id_ = 0;
  198. for (uint32_t id : ids_local_to_cur_function_) multi_mtf_.RemoveFromAll(id);
  199. ids_local_to_cur_function_.clear();
  200. assert(remaining_function_parameter_types_.empty());
  201. }
  202. if (!inst_.result_id) return;
  203. {
  204. // Save the result ID to type ID mapping.
  205. // In the grammar, type ID always appears before result ID.
  206. // A regular value maps to its type. Some instructions (e.g. OpLabel)
  207. // have no type Id, and will map to 0. The result Id for a
  208. // type-generating instruction (e.g. OpTypeInt) maps to itself.
  209. auto insertion_result = id_to_type_id_.emplace(
  210. inst_.result_id, spvOpcodeGeneratesType(SpvOp(inst_.opcode))
  211. ? inst_.result_id
  212. : inst_.type_id);
  213. (void)insertion_result;
  214. assert(insertion_result.second);
  215. }
  216. // Add result_id to MTFs.
  217. if (model_->id_fallback_strategy() ==
  218. MarkvModel::IdFallbackStrategy::kRuleBased) {
  219. switch (opcode) {
  220. case SpvOpTypeFloat:
  221. case SpvOpTypeInt:
  222. case SpvOpTypeBool:
  223. case SpvOpTypeVector:
  224. case SpvOpTypePointer:
  225. case SpvOpExtInstImport:
  226. case SpvOpTypeSampledImage:
  227. case SpvOpTypeImage:
  228. case SpvOpTypeSampler:
  229. multi_mtf_.Insert(GetMtfIdGeneratedByOpcode(opcode), inst_.result_id);
  230. break;
  231. default:
  232. break;
  233. }
  234. if (spvOpcodeIsComposite(opcode)) {
  235. multi_mtf_.Insert(kMtfTypeComposite, inst_.result_id);
  236. }
  237. if (opcode == SpvOpLabel) {
  238. multi_mtf_.InsertOrPromote(kMtfLabel, inst_.result_id);
  239. }
  240. if (opcode == SpvOpTypeInt) {
  241. multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
  242. multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
  243. }
  244. if (opcode == SpvOpTypeFloat) {
  245. multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
  246. multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
  247. }
  248. if (opcode == SpvOpTypeBool) {
  249. multi_mtf_.Insert(kMtfTypeScalar, inst_.result_id);
  250. multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
  251. }
  252. if (opcode == SpvOpTypeVector) {
  253. const uint32_t component_type_id = inst_.words[2];
  254. const uint32_t size = inst_.words[3];
  255. if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeFloat),
  256. component_type_id)) {
  257. multi_mtf_.Insert(kMtfTypeFloatScalarOrVector, inst_.result_id);
  258. } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeInt),
  259. component_type_id)) {
  260. multi_mtf_.Insert(kMtfTypeIntScalarOrVector, inst_.result_id);
  261. } else if (multi_mtf_.HasValue(GetMtfIdGeneratedByOpcode(SpvOpTypeBool),
  262. component_type_id)) {
  263. multi_mtf_.Insert(kMtfTypeBoolScalarOrVector, inst_.result_id);
  264. }
  265. multi_mtf_.Insert(GetMtfTypeVectorOfSize(size), inst_.result_id);
  266. }
  267. if (inst_.opcode == SpvOpTypeFunction) {
  268. const uint32_t return_type = inst_.words[2];
  269. multi_mtf_.Insert(kMtfTypeReturnedByFunction, return_type);
  270. multi_mtf_.Insert(GetMtfFunctionTypeWithReturnType(return_type),
  271. inst_.result_id);
  272. }
  273. if (inst_.type_id) {
  274. const val::Instruction* type_inst = FindDef(inst_.type_id);
  275. assert(type_inst);
  276. multi_mtf_.Insert(kMtfObject, inst_.result_id);
  277. multi_mtf_.Insert(GetMtfIdOfType(inst_.type_id), inst_.result_id);
  278. if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, inst_.type_id)) {
  279. multi_mtf_.Insert(kMtfFloatScalarOrVector, inst_.result_id);
  280. }
  281. if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, inst_.type_id))
  282. multi_mtf_.Insert(kMtfIntScalarOrVector, inst_.result_id);
  283. if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, inst_.type_id))
  284. multi_mtf_.Insert(kMtfBoolScalarOrVector, inst_.result_id);
  285. if (multi_mtf_.HasValue(kMtfTypeComposite, inst_.type_id))
  286. multi_mtf_.Insert(kMtfComposite, inst_.result_id);
  287. switch (type_inst->opcode()) {
  288. case SpvOpTypeInt:
  289. case SpvOpTypeBool:
  290. case SpvOpTypePointer:
  291. case SpvOpTypeVector:
  292. case SpvOpTypeImage:
  293. case SpvOpTypeSampledImage:
  294. case SpvOpTypeSampler:
  295. multi_mtf_.Insert(
  296. GetMtfIdWithTypeGeneratedByOpcode(type_inst->opcode()),
  297. inst_.result_id);
  298. break;
  299. default:
  300. break;
  301. }
  302. if (type_inst->opcode() == SpvOpTypeVector) {
  303. const uint32_t component_type = type_inst->word(2);
  304. multi_mtf_.Insert(GetMtfVectorOfComponentType(component_type),
  305. inst_.result_id);
  306. }
  307. if (type_inst->opcode() == SpvOpTypePointer) {
  308. assert(type_inst->operands().size() > 2);
  309. assert(type_inst->words().size() > type_inst->operands()[2].offset);
  310. const uint32_t data_type =
  311. type_inst->word(type_inst->operands()[2].offset);
  312. multi_mtf_.Insert(GetMtfPointerToType(data_type), inst_.result_id);
  313. if (multi_mtf_.HasValue(kMtfTypeComposite, data_type))
  314. multi_mtf_.Insert(kMtfTypePointerToComposite, inst_.result_id);
  315. }
  316. }
  317. if (spvOpcodeGeneratesType(opcode)) {
  318. if (opcode != SpvOpTypeFunction) {
  319. multi_mtf_.Insert(kMtfTypeNonFunction, inst_.result_id);
  320. }
  321. }
  322. }
  323. if (model_->AnyDescriptorHasCodingScheme()) {
  324. const uint32_t long_descriptor =
  325. long_id_descriptors_.ProcessInstruction(inst_);
  326. if (model_->DescriptorHasCodingScheme(long_descriptor))
  327. multi_mtf_.Insert(GetMtfLongIdDescriptor(long_descriptor),
  328. inst_.result_id);
  329. }
  330. if (model_->id_fallback_strategy() ==
  331. MarkvModel::IdFallbackStrategy::kShortDescriptor) {
  332. const uint32_t short_descriptor =
  333. short_id_descriptors_.ProcessInstruction(inst_);
  334. multi_mtf_.Insert(GetMtfShortIdDescriptor(short_descriptor),
  335. inst_.result_id);
  336. }
  337. }
  338. uint64_t MarkvCodec::GetRuleBasedMtf() {
  339. // This function is only called for id operands (but not result ids).
  340. assert(spvIsIdType(operand_.type) ||
  341. operand_.type == SPV_OPERAND_TYPE_OPTIONAL_ID);
  342. assert(operand_.type != SPV_OPERAND_TYPE_RESULT_ID);
  343. const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
  344. // All operand slots which expect label id.
  345. if ((inst_.opcode == SpvOpLoopMerge && operand_index_ <= 1) ||
  346. (inst_.opcode == SpvOpSelectionMerge && operand_index_ == 0) ||
  347. (inst_.opcode == SpvOpBranch && operand_index_ == 0) ||
  348. (inst_.opcode == SpvOpBranchConditional &&
  349. (operand_index_ == 1 || operand_index_ == 2)) ||
  350. (inst_.opcode == SpvOpPhi && operand_index_ >= 3 &&
  351. operand_index_ % 2 == 1) ||
  352. (inst_.opcode == SpvOpSwitch && operand_index_ > 0)) {
  353. return kMtfLabel;
  354. }
  355. switch (opcode) {
  356. case SpvOpFAdd:
  357. case SpvOpFSub:
  358. case SpvOpFMul:
  359. case SpvOpFDiv:
  360. case SpvOpFRem:
  361. case SpvOpFMod:
  362. case SpvOpFNegate: {
  363. if (operand_index_ == 0) return kMtfTypeFloatScalarOrVector;
  364. return GetMtfIdOfType(inst_.type_id);
  365. }
  366. case SpvOpISub:
  367. case SpvOpIAdd:
  368. case SpvOpIMul:
  369. case SpvOpSDiv:
  370. case SpvOpUDiv:
  371. case SpvOpSMod:
  372. case SpvOpUMod:
  373. case SpvOpSRem:
  374. case SpvOpSNegate: {
  375. if (operand_index_ == 0) return kMtfTypeIntScalarOrVector;
  376. return kMtfIntScalarOrVector;
  377. }
  378. // TODO(atgoo@github.com) Add OpConvertFToU and other opcodes.
  379. case SpvOpFOrdEqual:
  380. case SpvOpFUnordEqual:
  381. case SpvOpFOrdNotEqual:
  382. case SpvOpFUnordNotEqual:
  383. case SpvOpFOrdLessThan:
  384. case SpvOpFUnordLessThan:
  385. case SpvOpFOrdGreaterThan:
  386. case SpvOpFUnordGreaterThan:
  387. case SpvOpFOrdLessThanEqual:
  388. case SpvOpFUnordLessThanEqual:
  389. case SpvOpFOrdGreaterThanEqual:
  390. case SpvOpFUnordGreaterThanEqual: {
  391. if (operand_index_ == 0) return kMtfTypeBoolScalarOrVector;
  392. if (operand_index_ == 2) return kMtfFloatScalarOrVector;
  393. if (operand_index_ == 3) {
  394. const uint32_t first_operand_id = GetInstWords()[3];
  395. const uint32_t first_operand_type = id_to_type_id_.at(first_operand_id);
  396. return GetMtfIdOfType(first_operand_type);
  397. }
  398. break;
  399. }
  400. case SpvOpVectorShuffle: {
  401. if (operand_index_ == 0) {
  402. assert(inst_.num_operands > 4);
  403. return GetMtfTypeVectorOfSize(inst_.num_operands - 4);
  404. }
  405. assert(inst_.type_id);
  406. if (operand_index_ == 2 || operand_index_ == 3)
  407. return GetMtfVectorOfComponentType(
  408. GetVectorComponentType(inst_.type_id));
  409. break;
  410. }
  411. case SpvOpVectorTimesScalar: {
  412. if (operand_index_ == 0) {
  413. // TODO(atgoo@github.com) Could be narrowed to vector of floats.
  414. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
  415. }
  416. assert(inst_.type_id);
  417. if (operand_index_ == 2) return GetMtfIdOfType(inst_.type_id);
  418. if (operand_index_ == 3)
  419. return GetMtfIdOfType(GetVectorComponentType(inst_.type_id));
  420. break;
  421. }
  422. case SpvOpDot: {
  423. if (operand_index_ == 0) return GetMtfIdGeneratedByOpcode(SpvOpTypeFloat);
  424. assert(inst_.type_id);
  425. if (operand_index_ == 2)
  426. return GetMtfVectorOfComponentType(inst_.type_id);
  427. if (operand_index_ == 3) {
  428. const uint32_t vector_id = GetInstWords()[3];
  429. const uint32_t vector_type = id_to_type_id_.at(vector_id);
  430. return GetMtfIdOfType(vector_type);
  431. }
  432. break;
  433. }
  434. case SpvOpTypeVector: {
  435. if (operand_index_ == 1) {
  436. return kMtfTypeScalar;
  437. }
  438. break;
  439. }
  440. case SpvOpTypeMatrix: {
  441. if (operand_index_ == 1) {
  442. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
  443. }
  444. break;
  445. }
  446. case SpvOpTypePointer: {
  447. if (operand_index_ == 2) {
  448. return kMtfTypeNonFunction;
  449. }
  450. break;
  451. }
  452. case SpvOpTypeStruct: {
  453. if (operand_index_ >= 1) {
  454. return kMtfTypeNonFunction;
  455. }
  456. break;
  457. }
  458. case SpvOpTypeFunction: {
  459. if (operand_index_ == 1) {
  460. return kMtfTypeNonFunction;
  461. }
  462. if (operand_index_ >= 2) {
  463. return kMtfTypeNonFunction;
  464. }
  465. break;
  466. }
  467. case SpvOpLoad: {
  468. if (operand_index_ == 0) return kMtfTypeNonFunction;
  469. if (operand_index_ == 2) {
  470. assert(inst_.type_id);
  471. return GetMtfPointerToType(inst_.type_id);
  472. }
  473. break;
  474. }
  475. case SpvOpStore: {
  476. if (operand_index_ == 0)
  477. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypePointer);
  478. if (operand_index_ == 1) {
  479. const uint32_t pointer_id = GetInstWords()[1];
  480. const uint32_t pointer_type = id_to_type_id_.at(pointer_id);
  481. const val::Instruction* pointer_inst = FindDef(pointer_type);
  482. assert(pointer_inst);
  483. assert(pointer_inst->opcode() == SpvOpTypePointer);
  484. const uint32_t data_type =
  485. pointer_inst->word(pointer_inst->operands()[2].offset);
  486. return GetMtfIdOfType(data_type);
  487. }
  488. break;
  489. }
  490. case SpvOpVariable: {
  491. if (operand_index_ == 0)
  492. return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
  493. break;
  494. }
  495. case SpvOpAccessChain: {
  496. if (operand_index_ == 0)
  497. return GetMtfIdGeneratedByOpcode(SpvOpTypePointer);
  498. if (operand_index_ == 2) return kMtfTypePointerToComposite;
  499. if (operand_index_ >= 3)
  500. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeInt);
  501. break;
  502. }
  503. case SpvOpCompositeConstruct: {
  504. if (operand_index_ == 0) return kMtfTypeComposite;
  505. if (operand_index_ >= 2) {
  506. const uint32_t composite_type = GetInstWords()[1];
  507. if (multi_mtf_.HasValue(kMtfTypeFloatScalarOrVector, composite_type))
  508. return kMtfFloatScalarOrVector;
  509. if (multi_mtf_.HasValue(kMtfTypeIntScalarOrVector, composite_type))
  510. return kMtfIntScalarOrVector;
  511. if (multi_mtf_.HasValue(kMtfTypeBoolScalarOrVector, composite_type))
  512. return kMtfBoolScalarOrVector;
  513. }
  514. break;
  515. }
  516. case SpvOpCompositeExtract: {
  517. if (operand_index_ == 2) return kMtfComposite;
  518. break;
  519. }
  520. case SpvOpConstantComposite: {
  521. if (operand_index_ == 0) return kMtfTypeComposite;
  522. if (operand_index_ >= 2) {
  523. const val::Instruction* composite_type_inst = FindDef(inst_.type_id);
  524. assert(composite_type_inst);
  525. if (composite_type_inst->opcode() == SpvOpTypeVector) {
  526. return GetMtfIdOfType(composite_type_inst->word(2));
  527. }
  528. }
  529. break;
  530. }
  531. case SpvOpExtInst: {
  532. if (operand_index_ == 2)
  533. return GetMtfIdGeneratedByOpcode(SpvOpExtInstImport);
  534. if (operand_index_ >= 4) {
  535. const uint32_t return_type = GetInstWords()[1];
  536. const uint32_t ext_inst_type = inst_.ext_inst_type;
  537. const uint32_t ext_inst_index = GetInstWords()[4];
  538. // TODO(atgoo@github.com) The list of extended instructions is
  539. // incomplete. Only common instructions and low-hanging fruits listed.
  540. if (ext_inst_type == SPV_EXT_INST_TYPE_GLSL_STD_450) {
  541. switch (ext_inst_index) {
  542. case GLSLstd450FAbs:
  543. case GLSLstd450FClamp:
  544. case GLSLstd450FMax:
  545. case GLSLstd450FMin:
  546. case GLSLstd450FMix:
  547. case GLSLstd450Step:
  548. case GLSLstd450SmoothStep:
  549. case GLSLstd450Fma:
  550. case GLSLstd450Pow:
  551. case GLSLstd450Exp:
  552. case GLSLstd450Exp2:
  553. case GLSLstd450Log:
  554. case GLSLstd450Log2:
  555. case GLSLstd450Sqrt:
  556. case GLSLstd450InverseSqrt:
  557. case GLSLstd450Fract:
  558. case GLSLstd450Floor:
  559. case GLSLstd450Ceil:
  560. case GLSLstd450Radians:
  561. case GLSLstd450Degrees:
  562. case GLSLstd450Sin:
  563. case GLSLstd450Cos:
  564. case GLSLstd450Tan:
  565. case GLSLstd450Sinh:
  566. case GLSLstd450Cosh:
  567. case GLSLstd450Tanh:
  568. case GLSLstd450Asin:
  569. case GLSLstd450Acos:
  570. case GLSLstd450Atan:
  571. case GLSLstd450Atan2:
  572. case GLSLstd450Asinh:
  573. case GLSLstd450Acosh:
  574. case GLSLstd450Atanh:
  575. case GLSLstd450MatrixInverse:
  576. case GLSLstd450Cross:
  577. case GLSLstd450Normalize:
  578. case GLSLstd450Reflect:
  579. case GLSLstd450FaceForward:
  580. return GetMtfIdOfType(return_type);
  581. case GLSLstd450Length:
  582. case GLSLstd450Distance:
  583. case GLSLstd450Refract:
  584. return kMtfFloatScalarOrVector;
  585. default:
  586. break;
  587. }
  588. } else if (ext_inst_type == SPV_EXT_INST_TYPE_OPENCL_STD) {
  589. switch (ext_inst_index) {
  590. case OpenCLLIB::Fabs:
  591. case OpenCLLIB::FClamp:
  592. case OpenCLLIB::Fmax:
  593. case OpenCLLIB::Fmin:
  594. case OpenCLLIB::Step:
  595. case OpenCLLIB::Smoothstep:
  596. case OpenCLLIB::Fma:
  597. case OpenCLLIB::Pow:
  598. case OpenCLLIB::Exp:
  599. case OpenCLLIB::Exp2:
  600. case OpenCLLIB::Log:
  601. case OpenCLLIB::Log2:
  602. case OpenCLLIB::Sqrt:
  603. case OpenCLLIB::Rsqrt:
  604. case OpenCLLIB::Fract:
  605. case OpenCLLIB::Floor:
  606. case OpenCLLIB::Ceil:
  607. case OpenCLLIB::Radians:
  608. case OpenCLLIB::Degrees:
  609. case OpenCLLIB::Sin:
  610. case OpenCLLIB::Cos:
  611. case OpenCLLIB::Tan:
  612. case OpenCLLIB::Sinh:
  613. case OpenCLLIB::Cosh:
  614. case OpenCLLIB::Tanh:
  615. case OpenCLLIB::Asin:
  616. case OpenCLLIB::Acos:
  617. case OpenCLLIB::Atan:
  618. case OpenCLLIB::Atan2:
  619. case OpenCLLIB::Asinh:
  620. case OpenCLLIB::Acosh:
  621. case OpenCLLIB::Atanh:
  622. case OpenCLLIB::Cross:
  623. case OpenCLLIB::Normalize:
  624. return GetMtfIdOfType(return_type);
  625. case OpenCLLIB::Length:
  626. case OpenCLLIB::Distance:
  627. return kMtfFloatScalarOrVector;
  628. default:
  629. break;
  630. }
  631. }
  632. }
  633. break;
  634. }
  635. case SpvOpFunction: {
  636. if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
  637. if (operand_index_ == 3) {
  638. const uint32_t return_type = GetInstWords()[1];
  639. return GetMtfFunctionTypeWithReturnType(return_type);
  640. }
  641. break;
  642. }
  643. case SpvOpFunctionCall: {
  644. if (operand_index_ == 0) return kMtfTypeReturnedByFunction;
  645. if (operand_index_ == 2) {
  646. const uint32_t return_type = GetInstWords()[1];
  647. return GetMtfFunctionWithReturnType(return_type);
  648. }
  649. if (operand_index_ >= 3) {
  650. const uint32_t function_id = GetInstWords()[3];
  651. const val::Instruction* function_inst = FindDef(function_id);
  652. if (!function_inst) return kMtfObject;
  653. assert(function_inst->opcode() == SpvOpFunction);
  654. const uint32_t function_type_id = function_inst->word(4);
  655. const val::Instruction* function_type_inst = FindDef(function_type_id);
  656. assert(function_type_inst);
  657. assert(function_type_inst->opcode() == SpvOpTypeFunction);
  658. const uint32_t argument_type = function_type_inst->word(operand_index_);
  659. return GetMtfIdOfType(argument_type);
  660. }
  661. break;
  662. }
  663. case SpvOpReturnValue: {
  664. if (operand_index_ == 0) return GetMtfIdOfType(cur_function_return_type_);
  665. break;
  666. }
  667. case SpvOpBranchConditional: {
  668. if (operand_index_ == 0)
  669. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeBool);
  670. break;
  671. }
  672. case SpvOpSampledImage: {
  673. if (operand_index_ == 0)
  674. return GetMtfIdGeneratedByOpcode(SpvOpTypeSampledImage);
  675. if (operand_index_ == 2)
  676. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeImage);
  677. if (operand_index_ == 3)
  678. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampler);
  679. break;
  680. }
  681. case SpvOpImageSampleImplicitLod: {
  682. if (operand_index_ == 0)
  683. return GetMtfIdGeneratedByOpcode(SpvOpTypeVector);
  684. if (operand_index_ == 2)
  685. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeSampledImage);
  686. if (operand_index_ == 3)
  687. return GetMtfIdWithTypeGeneratedByOpcode(SpvOpTypeVector);
  688. break;
  689. }
  690. default:
  691. break;
  692. }
  693. return kMtfNone;
  694. }
  695. } // namespace comp
  696. } // namespace spvtools