validation_state.cpp 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114
  1. // Copyright (c) 2015-2016 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/val/validation_state.h"
  15. #include <cassert>
  16. #include <stack>
  17. #include <utility>
  18. #include "source/opcode.h"
  19. #include "source/spirv_target_env.h"
  20. #include "source/val/basic_block.h"
  21. #include "source/val/construct.h"
  22. #include "source/val/function.h"
  23. #include "spirv-tools/libspirv.h"
  24. namespace spvtools {
  25. namespace val {
  26. namespace {
  27. bool IsInstructionInLayoutSection(ModuleLayoutSection layout, SpvOp op) {
  28. // See Section 2.4
  29. bool out = false;
  30. // clang-format off
  31. switch (layout) {
  32. case kLayoutCapabilities: out = op == SpvOpCapability; break;
  33. case kLayoutExtensions: out = op == SpvOpExtension; break;
  34. case kLayoutExtInstImport: out = op == SpvOpExtInstImport; break;
  35. case kLayoutMemoryModel: out = op == SpvOpMemoryModel; break;
  36. case kLayoutEntryPoint: out = op == SpvOpEntryPoint; break;
  37. case kLayoutExecutionMode:
  38. out = op == SpvOpExecutionMode || op == SpvOpExecutionModeId;
  39. break;
  40. case kLayoutDebug1:
  41. switch (op) {
  42. case SpvOpSourceContinued:
  43. case SpvOpSource:
  44. case SpvOpSourceExtension:
  45. case SpvOpString:
  46. out = true;
  47. break;
  48. default: break;
  49. }
  50. break;
  51. case kLayoutDebug2:
  52. switch (op) {
  53. case SpvOpName:
  54. case SpvOpMemberName:
  55. out = true;
  56. break;
  57. default: break;
  58. }
  59. break;
  60. case kLayoutDebug3:
  61. // Only OpModuleProcessed is allowed here.
  62. out = (op == SpvOpModuleProcessed);
  63. break;
  64. case kLayoutAnnotations:
  65. switch (op) {
  66. case SpvOpDecorate:
  67. case SpvOpMemberDecorate:
  68. case SpvOpGroupDecorate:
  69. case SpvOpGroupMemberDecorate:
  70. case SpvOpDecorationGroup:
  71. case SpvOpDecorateId:
  72. case SpvOpDecorateStringGOOGLE:
  73. case SpvOpMemberDecorateStringGOOGLE:
  74. out = true;
  75. break;
  76. default: break;
  77. }
  78. break;
  79. case kLayoutTypes:
  80. if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
  81. out = true;
  82. break;
  83. }
  84. switch (op) {
  85. case SpvOpTypeForwardPointer:
  86. case SpvOpVariable:
  87. case SpvOpLine:
  88. case SpvOpNoLine:
  89. case SpvOpUndef:
  90. out = true;
  91. break;
  92. default: break;
  93. }
  94. break;
  95. case kLayoutFunctionDeclarations:
  96. case kLayoutFunctionDefinitions:
  97. // NOTE: These instructions should NOT be in these layout sections
  98. if (spvOpcodeGeneratesType(op) || spvOpcodeIsConstant(op)) {
  99. out = false;
  100. break;
  101. }
  102. switch (op) {
  103. case SpvOpCapability:
  104. case SpvOpExtension:
  105. case SpvOpExtInstImport:
  106. case SpvOpMemoryModel:
  107. case SpvOpEntryPoint:
  108. case SpvOpExecutionMode:
  109. case SpvOpExecutionModeId:
  110. case SpvOpSourceContinued:
  111. case SpvOpSource:
  112. case SpvOpSourceExtension:
  113. case SpvOpString:
  114. case SpvOpName:
  115. case SpvOpMemberName:
  116. case SpvOpModuleProcessed:
  117. case SpvOpDecorate:
  118. case SpvOpMemberDecorate:
  119. case SpvOpGroupDecorate:
  120. case SpvOpGroupMemberDecorate:
  121. case SpvOpDecorationGroup:
  122. case SpvOpTypeForwardPointer:
  123. out = false;
  124. break;
  125. default:
  126. out = true;
  127. break;
  128. }
  129. }
  130. // clang-format on
  131. return out;
  132. }
  133. // Counts the number of instructions and functions in the file.
  134. spv_result_t CountInstructions(void* user_data,
  135. const spv_parsed_instruction_t* inst) {
  136. ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
  137. if (inst->opcode == SpvOpFunction) _.increment_total_functions();
  138. _.increment_total_instructions();
  139. return SPV_SUCCESS;
  140. }
  141. } // namespace
  142. ValidationState_t::ValidationState_t(const spv_const_context ctx,
  143. const spv_const_validator_options opt,
  144. const uint32_t* words,
  145. const size_t num_words,
  146. const uint32_t max_warnings)
  147. : context_(ctx),
  148. options_(opt),
  149. words_(words),
  150. num_words_(num_words),
  151. unresolved_forward_ids_{},
  152. operand_names_{},
  153. current_layout_section_(kLayoutCapabilities),
  154. module_functions_(),
  155. module_capabilities_(),
  156. module_extensions_(),
  157. ordered_instructions_(),
  158. all_definitions_(),
  159. global_vars_(),
  160. local_vars_(),
  161. struct_nesting_depth_(),
  162. struct_has_nested_blockorbufferblock_struct_(),
  163. grammar_(ctx),
  164. addressing_model_(SpvAddressingModelMax),
  165. memory_model_(SpvMemoryModelMax),
  166. pointer_size_and_alignment_(0),
  167. in_function_(false),
  168. num_of_warnings_(0),
  169. max_num_of_warnings_(max_warnings) {
  170. assert(opt && "Validator options may not be Null.");
  171. const auto env = context_->target_env;
  172. if (spvIsVulkanEnv(env)) {
  173. // Vulkan 1.1 includes VK_KHR_relaxed_block_layout in core.
  174. if (env != SPV_ENV_VULKAN_1_0) {
  175. features_.env_relaxed_block_layout = true;
  176. }
  177. }
  178. switch (env) {
  179. case SPV_ENV_WEBGPU_0:
  180. features_.bans_op_undef = true;
  181. break;
  182. default:
  183. break;
  184. }
  185. // Only attempt to count if we have words, otherwise let the other validation
  186. // fail and generate an error.
  187. if (num_words > 0) {
  188. // Count the number of instructions in the binary.
  189. // This parse should not produce any error messages. Hijack the context and
  190. // replace the message consumer so that we do not pollute any state in input
  191. // consumer.
  192. spv_context_t hijacked_context = *ctx;
  193. hijacked_context.consumer = [](spv_message_level_t, const char*,
  194. const spv_position_t&, const char*) {};
  195. spvBinaryParse(&hijacked_context, this, words, num_words,
  196. /* parsed_header = */ nullptr, CountInstructions,
  197. /* diagnostic = */ nullptr);
  198. preallocateStorage();
  199. }
  200. friendly_mapper_ = spvtools::MakeUnique<spvtools::FriendlyNameMapper>(
  201. context_, words_, num_words_);
  202. name_mapper_ = friendly_mapper_->GetNameMapper();
  203. }
  204. void ValidationState_t::preallocateStorage() {
  205. ordered_instructions_.reserve(total_instructions_);
  206. module_functions_.reserve(total_functions_);
  207. }
  208. spv_result_t ValidationState_t::ForwardDeclareId(uint32_t id) {
  209. unresolved_forward_ids_.insert(id);
  210. return SPV_SUCCESS;
  211. }
  212. spv_result_t ValidationState_t::RemoveIfForwardDeclared(uint32_t id) {
  213. unresolved_forward_ids_.erase(id);
  214. return SPV_SUCCESS;
  215. }
  216. spv_result_t ValidationState_t::RegisterForwardPointer(uint32_t id) {
  217. forward_pointer_ids_.insert(id);
  218. return SPV_SUCCESS;
  219. }
  220. bool ValidationState_t::IsForwardPointer(uint32_t id) const {
  221. return (forward_pointer_ids_.find(id) != forward_pointer_ids_.end());
  222. }
  223. void ValidationState_t::AssignNameToId(uint32_t id, std::string name) {
  224. operand_names_[id] = name;
  225. }
  226. std::string ValidationState_t::getIdName(uint32_t id) const {
  227. const std::string id_name = name_mapper_(id);
  228. std::stringstream out;
  229. out << id << "[%" << id_name << "]";
  230. return out.str();
  231. }
  232. size_t ValidationState_t::unresolved_forward_id_count() const {
  233. return unresolved_forward_ids_.size();
  234. }
  235. std::vector<uint32_t> ValidationState_t::UnresolvedForwardIds() const {
  236. std::vector<uint32_t> out(std::begin(unresolved_forward_ids_),
  237. std::end(unresolved_forward_ids_));
  238. return out;
  239. }
  240. bool ValidationState_t::IsDefinedId(uint32_t id) const {
  241. return all_definitions_.find(id) != std::end(all_definitions_);
  242. }
  243. const Instruction* ValidationState_t::FindDef(uint32_t id) const {
  244. auto it = all_definitions_.find(id);
  245. if (it == all_definitions_.end()) return nullptr;
  246. return it->second;
  247. }
  248. Instruction* ValidationState_t::FindDef(uint32_t id) {
  249. auto it = all_definitions_.find(id);
  250. if (it == all_definitions_.end()) return nullptr;
  251. return it->second;
  252. }
  253. ModuleLayoutSection ValidationState_t::current_layout_section() const {
  254. return current_layout_section_;
  255. }
  256. void ValidationState_t::ProgressToNextLayoutSectionOrder() {
  257. // Guard against going past the last element(kLayoutFunctionDefinitions)
  258. if (current_layout_section_ <= kLayoutFunctionDefinitions) {
  259. current_layout_section_ =
  260. static_cast<ModuleLayoutSection>(current_layout_section_ + 1);
  261. }
  262. }
  263. bool ValidationState_t::IsOpcodeInCurrentLayoutSection(SpvOp op) {
  264. return IsInstructionInLayoutSection(current_layout_section_, op);
  265. }
  266. DiagnosticStream ValidationState_t::diag(spv_result_t error_code,
  267. const Instruction* inst) {
  268. if (error_code == SPV_WARNING) {
  269. if (num_of_warnings_ == max_num_of_warnings_) {
  270. DiagnosticStream({0, 0, 0}, context_->consumer, "", error_code)
  271. << "Other warnings have been suppressed.\n";
  272. }
  273. if (num_of_warnings_ >= max_num_of_warnings_) {
  274. return DiagnosticStream({0, 0, 0}, nullptr, "", error_code);
  275. }
  276. ++num_of_warnings_;
  277. }
  278. std::string disassembly;
  279. if (inst) disassembly = Disassemble(*inst);
  280. return DiagnosticStream({0, 0, inst ? inst->LineNum() : 0},
  281. context_->consumer, disassembly, error_code);
  282. }
  283. std::vector<Function>& ValidationState_t::functions() {
  284. return module_functions_;
  285. }
  286. Function& ValidationState_t::current_function() {
  287. assert(in_function_body());
  288. return module_functions_.back();
  289. }
  290. const Function& ValidationState_t::current_function() const {
  291. assert(in_function_body());
  292. return module_functions_.back();
  293. }
  294. const Function* ValidationState_t::function(uint32_t id) const {
  295. const auto it = id_to_function_.find(id);
  296. if (it == id_to_function_.end()) return nullptr;
  297. return it->second;
  298. }
  299. Function* ValidationState_t::function(uint32_t id) {
  300. auto it = id_to_function_.find(id);
  301. if (it == id_to_function_.end()) return nullptr;
  302. return it->second;
  303. }
  304. bool ValidationState_t::in_function_body() const { return in_function_; }
  305. bool ValidationState_t::in_block() const {
  306. return module_functions_.empty() == false &&
  307. module_functions_.back().current_block() != nullptr;
  308. }
  309. void ValidationState_t::RegisterCapability(SpvCapability cap) {
  310. // Avoid redundant work. Otherwise the recursion could induce work
  311. // quadrdatic in the capability dependency depth. (Ok, not much, but
  312. // it's something.)
  313. if (module_capabilities_.Contains(cap)) return;
  314. module_capabilities_.Add(cap);
  315. spv_operand_desc desc;
  316. if (SPV_SUCCESS ==
  317. grammar_.lookupOperand(SPV_OPERAND_TYPE_CAPABILITY, cap, &desc)) {
  318. CapabilitySet(desc->numCapabilities, desc->capabilities)
  319. .ForEach([this](SpvCapability c) { RegisterCapability(c); });
  320. }
  321. switch (cap) {
  322. case SpvCapabilityKernel:
  323. features_.group_ops_reduce_and_scans = true;
  324. break;
  325. case SpvCapabilityInt8:
  326. features_.use_int8_type = true;
  327. features_.declare_int8_type = true;
  328. break;
  329. case SpvCapabilityStorageBuffer8BitAccess:
  330. case SpvCapabilityUniformAndStorageBuffer8BitAccess:
  331. case SpvCapabilityStoragePushConstant8:
  332. features_.declare_int8_type = true;
  333. break;
  334. case SpvCapabilityInt16:
  335. features_.declare_int16_type = true;
  336. break;
  337. case SpvCapabilityFloat16:
  338. case SpvCapabilityFloat16Buffer:
  339. features_.declare_float16_type = true;
  340. break;
  341. case SpvCapabilityStorageUniformBufferBlock16:
  342. case SpvCapabilityStorageUniform16:
  343. case SpvCapabilityStoragePushConstant16:
  344. case SpvCapabilityStorageInputOutput16:
  345. features_.declare_int16_type = true;
  346. features_.declare_float16_type = true;
  347. features_.free_fp_rounding_mode = true;
  348. break;
  349. case SpvCapabilityVariablePointers:
  350. features_.variable_pointers = true;
  351. features_.variable_pointers_storage_buffer = true;
  352. break;
  353. case SpvCapabilityVariablePointersStorageBuffer:
  354. features_.variable_pointers_storage_buffer = true;
  355. break;
  356. default:
  357. break;
  358. }
  359. }
  360. void ValidationState_t::RegisterExtension(Extension ext) {
  361. if (module_extensions_.Contains(ext)) return;
  362. module_extensions_.Add(ext);
  363. switch (ext) {
  364. case kSPV_AMD_gpu_shader_half_float:
  365. case kSPV_AMD_gpu_shader_half_float_fetch:
  366. // SPV_AMD_gpu_shader_half_float enables float16 type.
  367. // https://github.com/KhronosGroup/SPIRV-Tools/issues/1375
  368. features_.declare_float16_type = true;
  369. break;
  370. case kSPV_AMD_gpu_shader_int16:
  371. // This is not yet in the extension, but it's recommended for it.
  372. // See https://github.com/KhronosGroup/glslang/issues/848
  373. features_.uconvert_spec_constant_op = true;
  374. break;
  375. case kSPV_AMD_shader_ballot:
  376. // The grammar doesn't encode the fact that SPV_AMD_shader_ballot
  377. // enables the use of group operations Reduce, InclusiveScan,
  378. // and ExclusiveScan. Enable it manually.
  379. // https://github.com/KhronosGroup/SPIRV-Tools/issues/991
  380. features_.group_ops_reduce_and_scans = true;
  381. break;
  382. default:
  383. break;
  384. }
  385. }
  386. bool ValidationState_t::HasAnyOfCapabilities(
  387. const CapabilitySet& capabilities) const {
  388. return module_capabilities_.HasAnyOf(capabilities);
  389. }
  390. bool ValidationState_t::HasAnyOfExtensions(
  391. const ExtensionSet& extensions) const {
  392. return module_extensions_.HasAnyOf(extensions);
  393. }
  394. void ValidationState_t::set_addressing_model(SpvAddressingModel am) {
  395. addressing_model_ = am;
  396. switch (am) {
  397. case SpvAddressingModelPhysical32:
  398. pointer_size_and_alignment_ = 4;
  399. break;
  400. default:
  401. // fall through
  402. case SpvAddressingModelPhysical64:
  403. case SpvAddressingModelPhysicalStorageBuffer64EXT:
  404. pointer_size_and_alignment_ = 8;
  405. break;
  406. }
  407. }
  408. SpvAddressingModel ValidationState_t::addressing_model() const {
  409. return addressing_model_;
  410. }
  411. void ValidationState_t::set_memory_model(SpvMemoryModel mm) {
  412. memory_model_ = mm;
  413. }
  414. SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; }
  415. spv_result_t ValidationState_t::RegisterFunction(
  416. uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
  417. uint32_t function_type_id) {
  418. assert(in_function_body() == false &&
  419. "RegisterFunction can only be called when parsing the binary outside "
  420. "of another function");
  421. in_function_ = true;
  422. module_functions_.emplace_back(id, ret_type_id, function_control,
  423. function_type_id);
  424. id_to_function_.emplace(id, &current_function());
  425. // TODO(umar): validate function type and type_id
  426. return SPV_SUCCESS;
  427. }
  428. spv_result_t ValidationState_t::RegisterFunctionEnd() {
  429. assert(in_function_body() == true &&
  430. "RegisterFunctionEnd can only be called when parsing the binary "
  431. "inside of another function");
  432. assert(in_block() == false &&
  433. "RegisterFunctionParameter can only be called when parsing the binary "
  434. "ouside of a block");
  435. current_function().RegisterFunctionEnd();
  436. in_function_ = false;
  437. return SPV_SUCCESS;
  438. }
  439. Instruction* ValidationState_t::AddOrderedInstruction(
  440. const spv_parsed_instruction_t* inst) {
  441. ordered_instructions_.emplace_back(inst);
  442. ordered_instructions_.back().SetLineNum(ordered_instructions_.size());
  443. return &ordered_instructions_.back();
  444. }
  445. // Improves diagnostic messages by collecting names of IDs
  446. void ValidationState_t::RegisterDebugInstruction(const Instruction* inst) {
  447. switch (inst->opcode()) {
  448. case SpvOpName: {
  449. const auto target = inst->GetOperandAs<uint32_t>(0);
  450. const auto* str = reinterpret_cast<const char*>(inst->words().data() +
  451. inst->operand(1).offset);
  452. AssignNameToId(target, str);
  453. break;
  454. }
  455. case SpvOpMemberName: {
  456. const auto target = inst->GetOperandAs<uint32_t>(0);
  457. const auto* str = reinterpret_cast<const char*>(inst->words().data() +
  458. inst->operand(2).offset);
  459. AssignNameToId(target, str);
  460. break;
  461. }
  462. case SpvOpSourceContinued:
  463. case SpvOpSource:
  464. case SpvOpSourceExtension:
  465. case SpvOpString:
  466. case SpvOpLine:
  467. case SpvOpNoLine:
  468. default:
  469. break;
  470. }
  471. }
  472. void ValidationState_t::RegisterInstruction(Instruction* inst) {
  473. if (inst->id()) all_definitions_.insert(std::make_pair(inst->id(), inst));
  474. // If the instruction is using an OpTypeSampledImage as an operand, it should
  475. // be recorded. The validator will ensure that all usages of an
  476. // OpTypeSampledImage and its definition are in the same basic block.
  477. for (uint16_t i = 0; i < inst->operands().size(); ++i) {
  478. const spv_parsed_operand_t& operand = inst->operand(i);
  479. if (SPV_OPERAND_TYPE_ID == operand.type) {
  480. const uint32_t operand_word = inst->word(operand.offset);
  481. Instruction* operand_inst = FindDef(operand_word);
  482. if (operand_inst && SpvOpSampledImage == operand_inst->opcode()) {
  483. RegisterSampledImageConsumer(operand_word, inst->id());
  484. }
  485. }
  486. }
  487. }
  488. std::vector<uint32_t> ValidationState_t::getSampledImageConsumers(
  489. uint32_t sampled_image_id) const {
  490. std::vector<uint32_t> result;
  491. auto iter = sampled_image_consumers_.find(sampled_image_id);
  492. if (iter != sampled_image_consumers_.end()) {
  493. result = iter->second;
  494. }
  495. return result;
  496. }
  497. void ValidationState_t::RegisterSampledImageConsumer(uint32_t sampled_image_id,
  498. uint32_t consumer_id) {
  499. sampled_image_consumers_[sampled_image_id].push_back(consumer_id);
  500. }
  501. uint32_t ValidationState_t::getIdBound() const { return id_bound_; }
  502. void ValidationState_t::setIdBound(const uint32_t bound) { id_bound_ = bound; }
  503. bool ValidationState_t::RegisterUniqueTypeDeclaration(const Instruction* inst) {
  504. std::vector<uint32_t> key;
  505. key.push_back(static_cast<uint32_t>(inst->opcode()));
  506. for (size_t index = 0; index < inst->operands().size(); ++index) {
  507. const spv_parsed_operand_t& operand = inst->operand(index);
  508. if (operand.type == SPV_OPERAND_TYPE_RESULT_ID) continue;
  509. const int words_begin = operand.offset;
  510. const int words_end = words_begin + operand.num_words;
  511. assert(words_end <= static_cast<int>(inst->words().size()));
  512. key.insert(key.end(), inst->words().begin() + words_begin,
  513. inst->words().begin() + words_end);
  514. }
  515. return unique_type_declarations_.insert(std::move(key)).second;
  516. }
  517. uint32_t ValidationState_t::GetTypeId(uint32_t id) const {
  518. const Instruction* inst = FindDef(id);
  519. return inst ? inst->type_id() : 0;
  520. }
  521. SpvOp ValidationState_t::GetIdOpcode(uint32_t id) const {
  522. const Instruction* inst = FindDef(id);
  523. return inst ? inst->opcode() : SpvOpNop;
  524. }
  525. uint32_t ValidationState_t::GetComponentType(uint32_t id) const {
  526. const Instruction* inst = FindDef(id);
  527. assert(inst);
  528. switch (inst->opcode()) {
  529. case SpvOpTypeFloat:
  530. case SpvOpTypeInt:
  531. case SpvOpTypeBool:
  532. return id;
  533. case SpvOpTypeVector:
  534. return inst->word(2);
  535. case SpvOpTypeMatrix:
  536. return GetComponentType(inst->word(2));
  537. case SpvOpTypeCooperativeMatrixNV:
  538. return inst->word(2);
  539. default:
  540. break;
  541. }
  542. if (inst->type_id()) return GetComponentType(inst->type_id());
  543. assert(0);
  544. return 0;
  545. }
  546. uint32_t ValidationState_t::GetDimension(uint32_t id) const {
  547. const Instruction* inst = FindDef(id);
  548. assert(inst);
  549. switch (inst->opcode()) {
  550. case SpvOpTypeFloat:
  551. case SpvOpTypeInt:
  552. case SpvOpTypeBool:
  553. return 1;
  554. case SpvOpTypeVector:
  555. case SpvOpTypeMatrix:
  556. return inst->word(3);
  557. case SpvOpTypeCooperativeMatrixNV:
  558. // Actual dimension isn't known, return 0
  559. return 0;
  560. default:
  561. break;
  562. }
  563. if (inst->type_id()) return GetDimension(inst->type_id());
  564. assert(0);
  565. return 0;
  566. }
  567. uint32_t ValidationState_t::GetBitWidth(uint32_t id) const {
  568. const uint32_t component_type_id = GetComponentType(id);
  569. const Instruction* inst = FindDef(component_type_id);
  570. assert(inst);
  571. if (inst->opcode() == SpvOpTypeFloat || inst->opcode() == SpvOpTypeInt)
  572. return inst->word(2);
  573. if (inst->opcode() == SpvOpTypeBool) return 1;
  574. assert(0);
  575. return 0;
  576. }
  577. bool ValidationState_t::IsFloatScalarType(uint32_t id) const {
  578. const Instruction* inst = FindDef(id);
  579. assert(inst);
  580. return inst->opcode() == SpvOpTypeFloat;
  581. }
  582. bool ValidationState_t::IsFloatVectorType(uint32_t id) const {
  583. const Instruction* inst = FindDef(id);
  584. assert(inst);
  585. if (inst->opcode() == SpvOpTypeVector) {
  586. return IsFloatScalarType(GetComponentType(id));
  587. }
  588. return false;
  589. }
  590. bool ValidationState_t::IsFloatScalarOrVectorType(uint32_t id) const {
  591. const Instruction* inst = FindDef(id);
  592. assert(inst);
  593. if (inst->opcode() == SpvOpTypeFloat) {
  594. return true;
  595. }
  596. if (inst->opcode() == SpvOpTypeVector) {
  597. return IsFloatScalarType(GetComponentType(id));
  598. }
  599. return false;
  600. }
  601. bool ValidationState_t::IsIntScalarType(uint32_t id) const {
  602. const Instruction* inst = FindDef(id);
  603. assert(inst);
  604. return inst->opcode() == SpvOpTypeInt;
  605. }
  606. bool ValidationState_t::IsIntVectorType(uint32_t id) const {
  607. const Instruction* inst = FindDef(id);
  608. assert(inst);
  609. if (inst->opcode() == SpvOpTypeVector) {
  610. return IsIntScalarType(GetComponentType(id));
  611. }
  612. return false;
  613. }
  614. bool ValidationState_t::IsIntScalarOrVectorType(uint32_t id) const {
  615. const Instruction* inst = FindDef(id);
  616. assert(inst);
  617. if (inst->opcode() == SpvOpTypeInt) {
  618. return true;
  619. }
  620. if (inst->opcode() == SpvOpTypeVector) {
  621. return IsIntScalarType(GetComponentType(id));
  622. }
  623. return false;
  624. }
  625. bool ValidationState_t::IsUnsignedIntScalarType(uint32_t id) const {
  626. const Instruction* inst = FindDef(id);
  627. assert(inst);
  628. return inst->opcode() == SpvOpTypeInt && inst->word(3) == 0;
  629. }
  630. bool ValidationState_t::IsUnsignedIntVectorType(uint32_t id) const {
  631. const Instruction* inst = FindDef(id);
  632. assert(inst);
  633. if (inst->opcode() == SpvOpTypeVector) {
  634. return IsUnsignedIntScalarType(GetComponentType(id));
  635. }
  636. return false;
  637. }
  638. bool ValidationState_t::IsSignedIntScalarType(uint32_t id) const {
  639. const Instruction* inst = FindDef(id);
  640. assert(inst);
  641. return inst->opcode() == SpvOpTypeInt && inst->word(3) == 1;
  642. }
  643. bool ValidationState_t::IsSignedIntVectorType(uint32_t id) const {
  644. const Instruction* inst = FindDef(id);
  645. assert(inst);
  646. if (inst->opcode() == SpvOpTypeVector) {
  647. return IsSignedIntScalarType(GetComponentType(id));
  648. }
  649. return false;
  650. }
  651. bool ValidationState_t::IsBoolScalarType(uint32_t id) const {
  652. const Instruction* inst = FindDef(id);
  653. assert(inst);
  654. return inst->opcode() == SpvOpTypeBool;
  655. }
  656. bool ValidationState_t::IsBoolVectorType(uint32_t id) const {
  657. const Instruction* inst = FindDef(id);
  658. assert(inst);
  659. if (inst->opcode() == SpvOpTypeVector) {
  660. return IsBoolScalarType(GetComponentType(id));
  661. }
  662. return false;
  663. }
  664. bool ValidationState_t::IsBoolScalarOrVectorType(uint32_t id) const {
  665. const Instruction* inst = FindDef(id);
  666. assert(inst);
  667. if (inst->opcode() == SpvOpTypeBool) {
  668. return true;
  669. }
  670. if (inst->opcode() == SpvOpTypeVector) {
  671. return IsBoolScalarType(GetComponentType(id));
  672. }
  673. return false;
  674. }
  675. bool ValidationState_t::IsFloatMatrixType(uint32_t id) const {
  676. const Instruction* inst = FindDef(id);
  677. assert(inst);
  678. if (inst->opcode() == SpvOpTypeMatrix) {
  679. return IsFloatScalarType(GetComponentType(id));
  680. }
  681. return false;
  682. }
  683. bool ValidationState_t::GetMatrixTypeInfo(uint32_t id, uint32_t* num_rows,
  684. uint32_t* num_cols,
  685. uint32_t* column_type,
  686. uint32_t* component_type) const {
  687. if (!id) return false;
  688. const Instruction* mat_inst = FindDef(id);
  689. assert(mat_inst);
  690. if (mat_inst->opcode() != SpvOpTypeMatrix) return false;
  691. const uint32_t vec_type = mat_inst->word(2);
  692. const Instruction* vec_inst = FindDef(vec_type);
  693. assert(vec_inst);
  694. if (vec_inst->opcode() != SpvOpTypeVector) {
  695. assert(0);
  696. return false;
  697. }
  698. *num_cols = mat_inst->word(3);
  699. *num_rows = vec_inst->word(3);
  700. *column_type = mat_inst->word(2);
  701. *component_type = vec_inst->word(2);
  702. return true;
  703. }
  704. bool ValidationState_t::GetStructMemberTypes(
  705. uint32_t struct_type_id, std::vector<uint32_t>* member_types) const {
  706. member_types->clear();
  707. if (!struct_type_id) return false;
  708. const Instruction* inst = FindDef(struct_type_id);
  709. assert(inst);
  710. if (inst->opcode() != SpvOpTypeStruct) return false;
  711. *member_types =
  712. std::vector<uint32_t>(inst->words().cbegin() + 2, inst->words().cend());
  713. if (member_types->empty()) return false;
  714. return true;
  715. }
  716. bool ValidationState_t::IsPointerType(uint32_t id) const {
  717. const Instruction* inst = FindDef(id);
  718. assert(inst);
  719. return inst->opcode() == SpvOpTypePointer;
  720. }
  721. bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
  722. uint32_t* storage_class) const {
  723. if (!id) return false;
  724. const Instruction* inst = FindDef(id);
  725. assert(inst);
  726. if (inst->opcode() != SpvOpTypePointer) return false;
  727. *storage_class = inst->word(2);
  728. *data_type = inst->word(3);
  729. return true;
  730. }
  731. bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
  732. const Instruction* inst = FindDef(id);
  733. assert(inst);
  734. return inst->opcode() == SpvOpTypeCooperativeMatrixNV;
  735. }
  736. bool ValidationState_t::IsFloatCooperativeMatrixType(uint32_t id) const {
  737. if (!IsCooperativeMatrixType(id)) return false;
  738. return IsFloatScalarType(FindDef(id)->word(2));
  739. }
  740. bool ValidationState_t::IsIntCooperativeMatrixType(uint32_t id) const {
  741. if (!IsCooperativeMatrixType(id)) return false;
  742. return IsIntScalarType(FindDef(id)->word(2));
  743. }
  744. bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
  745. if (!IsCooperativeMatrixType(id)) return false;
  746. return IsUnsignedIntScalarType(FindDef(id)->word(2));
  747. }
  748. spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
  749. const Instruction* inst, uint32_t m1, uint32_t m2) {
  750. const auto m1_type = FindDef(m1);
  751. const auto m2_type = FindDef(m2);
  752. if (m1_type->opcode() != SpvOpTypeCooperativeMatrixNV ||
  753. m2_type->opcode() != SpvOpTypeCooperativeMatrixNV) {
  754. return diag(SPV_ERROR_INVALID_DATA, inst)
  755. << "Expected cooperative matrix types";
  756. }
  757. uint32_t m1_scope_id = m1_type->GetOperandAs<uint32_t>(2);
  758. uint32_t m1_rows_id = m1_type->GetOperandAs<uint32_t>(3);
  759. uint32_t m1_cols_id = m1_type->GetOperandAs<uint32_t>(4);
  760. uint32_t m2_scope_id = m2_type->GetOperandAs<uint32_t>(2);
  761. uint32_t m2_rows_id = m2_type->GetOperandAs<uint32_t>(3);
  762. uint32_t m2_cols_id = m2_type->GetOperandAs<uint32_t>(4);
  763. bool m1_is_int32 = false, m1_is_const_int32 = false, m2_is_int32 = false,
  764. m2_is_const_int32 = false;
  765. uint32_t m1_value = 0, m2_value = 0;
  766. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  767. EvalInt32IfConst(m1_scope_id);
  768. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  769. EvalInt32IfConst(m2_scope_id);
  770. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  771. return diag(SPV_ERROR_INVALID_DATA, inst)
  772. << "Expected scopes of Matrix and Result Type to be "
  773. << "identical";
  774. }
  775. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  776. EvalInt32IfConst(m1_rows_id);
  777. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  778. EvalInt32IfConst(m2_rows_id);
  779. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  780. return diag(SPV_ERROR_INVALID_DATA, inst)
  781. << "Expected rows of Matrix type and Result Type to be "
  782. << "identical";
  783. }
  784. std::tie(m1_is_int32, m1_is_const_int32, m1_value) =
  785. EvalInt32IfConst(m1_cols_id);
  786. std::tie(m2_is_int32, m2_is_const_int32, m2_value) =
  787. EvalInt32IfConst(m2_cols_id);
  788. if (m1_is_const_int32 && m2_is_const_int32 && m1_value != m2_value) {
  789. return diag(SPV_ERROR_INVALID_DATA, inst)
  790. << "Expected columns of Matrix type and Result Type to be "
  791. << "identical";
  792. }
  793. return SPV_SUCCESS;
  794. }
  795. uint32_t ValidationState_t::GetOperandTypeId(const Instruction* inst,
  796. size_t operand_index) const {
  797. return GetTypeId(inst->GetOperandAs<uint32_t>(operand_index));
  798. }
  799. bool ValidationState_t::GetConstantValUint64(uint32_t id, uint64_t* val) const {
  800. const Instruction* inst = FindDef(id);
  801. if (!inst) {
  802. assert(0 && "Instruction not found");
  803. return false;
  804. }
  805. if (inst->opcode() != SpvOpConstant && inst->opcode() != SpvOpSpecConstant)
  806. return false;
  807. if (!IsIntScalarType(inst->type_id())) return false;
  808. if (inst->words().size() == 4) {
  809. *val = inst->word(3);
  810. } else {
  811. assert(inst->words().size() == 5);
  812. *val = inst->word(3);
  813. *val |= uint64_t(inst->word(4)) << 32;
  814. }
  815. return true;
  816. }
  817. std::tuple<bool, bool, uint32_t> ValidationState_t::EvalInt32IfConst(
  818. uint32_t id) const {
  819. const Instruction* const inst = FindDef(id);
  820. assert(inst);
  821. const uint32_t type = inst->type_id();
  822. if (type == 0 || !IsIntScalarType(type) || GetBitWidth(type) != 32) {
  823. return std::make_tuple(false, false, 0);
  824. }
  825. // Spec constant values cannot be evaluated so don't consider constant for
  826. // the purpose of this method.
  827. if (!spvOpcodeIsConstant(inst->opcode()) ||
  828. spvOpcodeIsSpecConstant(inst->opcode())) {
  829. return std::make_tuple(true, false, 0);
  830. }
  831. if (inst->opcode() == SpvOpConstantNull) {
  832. return std::make_tuple(true, true, 0);
  833. }
  834. assert(inst->words().size() == 4);
  835. return std::make_tuple(true, true, inst->word(3));
  836. }
  837. void ValidationState_t::ComputeFunctionToEntryPointMapping() {
  838. for (const uint32_t entry_point : entry_points()) {
  839. std::stack<uint32_t> call_stack;
  840. std::set<uint32_t> visited;
  841. call_stack.push(entry_point);
  842. while (!call_stack.empty()) {
  843. const uint32_t called_func_id = call_stack.top();
  844. call_stack.pop();
  845. if (!visited.insert(called_func_id).second) continue;
  846. function_to_entry_points_[called_func_id].push_back(entry_point);
  847. const Function* called_func = function(called_func_id);
  848. if (called_func) {
  849. // Other checks should error out on this invalid SPIR-V.
  850. for (const uint32_t new_call : called_func->function_call_targets()) {
  851. call_stack.push(new_call);
  852. }
  853. }
  854. }
  855. }
  856. }
  857. void ValidationState_t::ComputeRecursiveEntryPoints() {
  858. for (const Function func : functions()) {
  859. std::stack<uint32_t> call_stack;
  860. std::set<uint32_t> visited;
  861. for (const uint32_t new_call : func.function_call_targets()) {
  862. call_stack.push(new_call);
  863. }
  864. while (!call_stack.empty()) {
  865. const uint32_t called_func_id = call_stack.top();
  866. call_stack.pop();
  867. if (!visited.insert(called_func_id).second) continue;
  868. if (called_func_id == func.id()) {
  869. for (const uint32_t entry_point :
  870. function_to_entry_points_[called_func_id])
  871. recursive_entry_points_.insert(entry_point);
  872. break;
  873. }
  874. const Function* called_func = function(called_func_id);
  875. if (called_func) {
  876. // Other checks should error out on this invalid SPIR-V.
  877. for (const uint32_t new_call : called_func->function_call_targets()) {
  878. call_stack.push(new_call);
  879. }
  880. }
  881. }
  882. }
  883. }
  884. const std::vector<uint32_t>& ValidationState_t::FunctionEntryPoints(
  885. uint32_t func) const {
  886. auto iter = function_to_entry_points_.find(func);
  887. if (iter == function_to_entry_points_.end()) {
  888. return empty_ids_;
  889. } else {
  890. return iter->second;
  891. }
  892. }
  893. std::set<uint32_t> ValidationState_t::EntryPointReferences(uint32_t id) const {
  894. std::set<uint32_t> referenced_entry_points;
  895. const auto inst = FindDef(id);
  896. if (!inst) return referenced_entry_points;
  897. std::vector<const Instruction*> stack;
  898. stack.push_back(inst);
  899. while (!stack.empty()) {
  900. const auto current_inst = stack.back();
  901. stack.pop_back();
  902. if (const auto func = current_inst->function()) {
  903. // Instruction lives in a function, we can stop searching.
  904. const auto function_entry_points = FunctionEntryPoints(func->id());
  905. referenced_entry_points.insert(function_entry_points.begin(),
  906. function_entry_points.end());
  907. } else {
  908. // Instruction is in the global scope, keep searching its uses.
  909. for (auto pair : current_inst->uses()) {
  910. const auto next_inst = pair.first;
  911. stack.push_back(next_inst);
  912. }
  913. }
  914. }
  915. return referenced_entry_points;
  916. }
  917. std::string ValidationState_t::Disassemble(const Instruction& inst) const {
  918. const spv_parsed_instruction_t& c_inst(inst.c_inst());
  919. return Disassemble(c_inst.words, c_inst.num_words);
  920. }
  921. std::string ValidationState_t::Disassemble(const uint32_t* words,
  922. uint16_t num_words) const {
  923. uint32_t disassembly_options = SPV_BINARY_TO_TEXT_OPTION_NO_HEADER |
  924. SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES;
  925. return spvInstructionBinaryToText(context()->target_env, words, num_words,
  926. words_, num_words_, disassembly_options);
  927. }
  928. } // namespace val
  929. } // namespace spvtools