validate.cpp 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. // Copyright (c) 2015-2016 The Khronos Group Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include "source/val/validate.h"
  15. #include <cassert>
  16. #include <cstdio>
  17. #include <algorithm>
  18. #include <functional>
  19. #include <iterator>
  20. #include <memory>
  21. #include <sstream>
  22. #include <string>
  23. #include <vector>
  24. #include "source/binary.h"
  25. #include "source/diagnostic.h"
  26. #include "source/enum_string_mapping.h"
  27. #include "source/extensions.h"
  28. #include "source/instruction.h"
  29. #include "source/opcode.h"
  30. #include "source/operand.h"
  31. #include "source/spirv_constant.h"
  32. #include "source/spirv_endian.h"
  33. #include "source/spirv_target_env.h"
  34. #include "source/spirv_validator_options.h"
  35. #include "source/val/construct.h"
  36. #include "source/val/function.h"
  37. #include "source/val/instruction.h"
  38. #include "source/val/validation_state.h"
  39. #include "spirv-tools/libspirv.h"
  40. namespace {
  41. // TODO(issue 1950): The validator only returns a single message anyway, so no
  42. // point in generating more than 1 warning.
  43. static uint32_t kDefaultMaxNumOfWarnings = 1;
  44. } // namespace
  45. namespace spvtools {
  46. namespace val {
  47. namespace {
  48. // TODO(umar): Validate header
  49. // TODO(umar): The binary parser validates the magic word, and the length of the
  50. // header, but nothing else.
  51. spv_result_t setHeader(void* user_data, spv_endianness_t, uint32_t,
  52. uint32_t version, uint32_t generator, uint32_t id_bound,
  53. uint32_t) {
  54. // Record the ID bound so that the validator can ensure no ID is out of bound.
  55. ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
  56. _.setIdBound(id_bound);
  57. _.setGenerator(generator);
  58. _.setVersion(version);
  59. return SPV_SUCCESS;
  60. }
  61. // Parses OpExtension instruction and registers extension.
  62. void RegisterExtension(ValidationState_t& _,
  63. const spv_parsed_instruction_t* inst) {
  64. const std::string extension_str = spvtools::GetExtensionString(inst);
  65. Extension extension;
  66. if (!GetExtensionFromString(extension_str.c_str(), &extension)) {
  67. // The error will be logged in the ProcessInstruction pass.
  68. return;
  69. }
  70. _.RegisterExtension(extension);
  71. }
  72. // Parses the beginning of the module searching for OpExtension instructions.
  73. // Registers extensions if recognized. Returns SPV_REQUESTED_TERMINATION
  74. // once an instruction which is not SpvOpCapability and SpvOpExtension is
  75. // encountered. According to the SPIR-V spec extensions are declared after
  76. // capabilities and before everything else.
  77. spv_result_t ProcessExtensions(void* user_data,
  78. const spv_parsed_instruction_t* inst) {
  79. const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
  80. if (opcode == SpvOpCapability) return SPV_SUCCESS;
  81. if (opcode == SpvOpExtension) {
  82. ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
  83. RegisterExtension(_, inst);
  84. return SPV_SUCCESS;
  85. }
  86. // OpExtension block is finished, requesting termination.
  87. return SPV_REQUESTED_TERMINATION;
  88. }
  89. spv_result_t ProcessInstruction(void* user_data,
  90. const spv_parsed_instruction_t* inst) {
  91. ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
  92. auto* instruction = _.AddOrderedInstruction(inst);
  93. _.RegisterDebugInstruction(instruction);
  94. return SPV_SUCCESS;
  95. }
  96. void printDot(const ValidationState_t& _, const BasicBlock& other) {
  97. std::string block_string;
  98. if (other.successors()->empty()) {
  99. block_string += "end ";
  100. } else {
  101. for (auto block : *other.successors()) {
  102. block_string += _.getIdName(block->id()) + " ";
  103. }
  104. }
  105. printf("%10s -> {%s\b}\n", _.getIdName(other.id()).c_str(),
  106. block_string.c_str());
  107. }
  108. void PrintBlocks(ValidationState_t& _, Function func) {
  109. assert(func.first_block());
  110. printf("%10s -> %s\n", _.getIdName(func.id()).c_str(),
  111. _.getIdName(func.first_block()->id()).c_str());
  112. for (const auto& block : func.ordered_blocks()) {
  113. printDot(_, *block);
  114. }
  115. }
  116. #ifdef __clang__
  117. #define UNUSED(func) [[gnu::unused]] func
  118. #elif defined(__GNUC__)
  119. #define UNUSED(func) \
  120. func __attribute__((unused)); \
  121. func
  122. #elif defined(_MSC_VER)
  123. #define UNUSED(func) func
  124. #endif
  125. UNUSED(void PrintDotGraph(ValidationState_t& _, Function func)) {
  126. if (func.first_block()) {
  127. std::string func_name(_.getIdName(func.id()));
  128. printf("digraph %s {\n", func_name.c_str());
  129. PrintBlocks(_, func);
  130. printf("}\n");
  131. }
  132. }
  133. spv_result_t ValidateForwardDecls(ValidationState_t& _) {
  134. if (_.unresolved_forward_id_count() == 0) return SPV_SUCCESS;
  135. std::stringstream ss;
  136. std::vector<uint32_t> ids = _.UnresolvedForwardIds();
  137. std::transform(
  138. std::begin(ids), std::end(ids),
  139. std::ostream_iterator<std::string>(ss, " "),
  140. bind(&ValidationState_t::getIdName, std::ref(_), std::placeholders::_1));
  141. auto id_str = ss.str();
  142. return _.diag(SPV_ERROR_INVALID_ID, nullptr)
  143. << "The following forward referenced IDs have not been defined:\n"
  144. << id_str.substr(0, id_str.size() - 1);
  145. }
  146. std::vector<std::string> CalculateNamesForEntryPoint(ValidationState_t& _,
  147. const uint32_t id) {
  148. auto id_descriptions = _.entry_point_descriptions(id);
  149. auto id_names = std::vector<std::string>();
  150. id_names.reserve((id_descriptions.size()));
  151. for (auto description : id_descriptions) id_names.push_back(description.name);
  152. return id_names;
  153. }
  154. spv_result_t ValidateEntryPointNameUnique(ValidationState_t& _,
  155. const uint32_t id) {
  156. auto id_names = CalculateNamesForEntryPoint(_, id);
  157. const auto names =
  158. std::unordered_set<std::string>(id_names.begin(), id_names.end());
  159. if (id_names.size() != names.size()) {
  160. std::sort(id_names.begin(), id_names.end());
  161. for (size_t i = 0; i < id_names.size() - 1; i++) {
  162. if (id_names[i] == id_names[i + 1]) {
  163. return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(id))
  164. << "Entry point name \"" << id_names[i]
  165. << "\" is not unique, which is not allow in WebGPU env.";
  166. }
  167. }
  168. }
  169. for (const auto other_id : _.entry_points()) {
  170. if (other_id == id) continue;
  171. const auto other_id_names = CalculateNamesForEntryPoint(_, other_id);
  172. for (const auto other_id_name : other_id_names) {
  173. if (names.find(other_id_name) != names.end()) {
  174. return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(id))
  175. << "Entry point name \"" << other_id_name
  176. << "\" is not unique, which is not allow in WebGPU env.";
  177. }
  178. }
  179. }
  180. return SPV_SUCCESS;
  181. }
  182. spv_result_t ValidateEntryPointNamesUnique(ValidationState_t& _) {
  183. for (const auto id : _.entry_points()) {
  184. auto result = ValidateEntryPointNameUnique(_, id);
  185. if (result != SPV_SUCCESS) return result;
  186. }
  187. return SPV_SUCCESS;
  188. }
  189. // Entry point validation. Based on 2.16.1 (Universal Validation Rules) of the
  190. // SPIRV spec:
  191. // * There is at least one OpEntryPoint instruction, unless the Linkage
  192. // capability is being used.
  193. // * No function can be targeted by both an OpEntryPoint instruction and an
  194. // OpFunctionCall instruction.
  195. //
  196. // Additionally enforces that entry points for Vulkan and WebGPU should not have
  197. // recursion. And that entry names should be unique for WebGPU.
  198. spv_result_t ValidateEntryPoints(ValidationState_t& _) {
  199. _.ComputeFunctionToEntryPointMapping();
  200. _.ComputeRecursiveEntryPoints();
  201. if (_.entry_points().empty() && !_.HasCapability(SpvCapabilityLinkage)) {
  202. return _.diag(SPV_ERROR_INVALID_BINARY, nullptr)
  203. << "No OpEntryPoint instruction was found. This is only allowed if "
  204. "the Linkage capability is being used.";
  205. }
  206. for (const auto& entry_point : _.entry_points()) {
  207. if (_.IsFunctionCallTarget(entry_point)) {
  208. return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
  209. << "A function (" << entry_point
  210. << ") may not be targeted by both an OpEntryPoint instruction and "
  211. "an OpFunctionCall instruction.";
  212. }
  213. // For Vulkan and WebGPU, the static function-call graph for an entry point
  214. // must not contain cycles.
  215. if (spvIsVulkanOrWebGPUEnv(_.context()->target_env)) {
  216. if (_.recursive_entry_points().find(entry_point) !=
  217. _.recursive_entry_points().end()) {
  218. return _.diag(SPV_ERROR_INVALID_BINARY, _.FindDef(entry_point))
  219. << "Entry points may not have a call graph with cycles.";
  220. }
  221. }
  222. // For WebGPU all entry point names must be unique.
  223. if (spvIsWebGPUEnv(_.context()->target_env)) {
  224. const auto result = ValidateEntryPointNamesUnique(_);
  225. if (result != SPV_SUCCESS) return result;
  226. }
  227. }
  228. return SPV_SUCCESS;
  229. }
  230. spv_result_t ValidateBinaryUsingContextAndValidationState(
  231. const spv_context_t& context, const uint32_t* words, const size_t num_words,
  232. spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
  233. auto binary = std::unique_ptr<spv_const_binary_t>(
  234. new spv_const_binary_t{words, num_words});
  235. spv_endianness_t endian;
  236. spv_position_t position = {};
  237. if (spvBinaryEndianness(binary.get(), &endian)) {
  238. return DiagnosticStream(position, context.consumer, "",
  239. SPV_ERROR_INVALID_BINARY)
  240. << "Invalid SPIR-V magic number.";
  241. }
  242. spv_header_t header;
  243. if (spvBinaryHeaderGet(binary.get(), endian, &header)) {
  244. return DiagnosticStream(position, context.consumer, "",
  245. SPV_ERROR_INVALID_BINARY)
  246. << "Invalid SPIR-V header.";
  247. }
  248. if (header.version > spvVersionForTargetEnv(context.target_env)) {
  249. return DiagnosticStream(position, context.consumer, "",
  250. SPV_ERROR_WRONG_VERSION)
  251. << "Invalid SPIR-V binary version "
  252. << SPV_SPIRV_VERSION_MAJOR_PART(header.version) << "."
  253. << SPV_SPIRV_VERSION_MINOR_PART(header.version)
  254. << " for target environment "
  255. << spvTargetEnvDescription(context.target_env) << ".";
  256. }
  257. if (header.bound > vstate->options()->universal_limits_.max_id_bound) {
  258. return DiagnosticStream(position, context.consumer, "",
  259. SPV_ERROR_INVALID_BINARY)
  260. << "Invalid SPIR-V. The id bound is larger than the max id bound "
  261. << vstate->options()->universal_limits_.max_id_bound << ".";
  262. }
  263. // Look for OpExtension instructions and register extensions.
  264. // This parse should not produce any error messages. Hijack the context and
  265. // replace the message consumer so that we do not pollute any state in input
  266. // consumer.
  267. spv_context_t hijacked_context = context;
  268. hijacked_context.consumer = [](spv_message_level_t, const char*,
  269. const spv_position_t&, const char*) {};
  270. spvBinaryParse(&hijacked_context, vstate, words, num_words,
  271. /* parsed_header = */ nullptr, ProcessExtensions,
  272. /* diagnostic = */ nullptr);
  273. // Parse the module and perform inline validation checks. These checks do
  274. // not require the the knowledge of the whole module.
  275. if (auto error = spvBinaryParse(&context, vstate, words, num_words, setHeader,
  276. ProcessInstruction, pDiagnostic)) {
  277. return error;
  278. }
  279. for (auto& instruction : vstate->ordered_instructions()) {
  280. {
  281. // In order to do this work outside of Process Instruction we need to be
  282. // able to, briefly, de-const the instruction.
  283. Instruction* inst = const_cast<Instruction*>(&instruction);
  284. if (inst->opcode() == SpvOpEntryPoint) {
  285. const auto entry_point = inst->GetOperandAs<uint32_t>(1);
  286. const auto execution_model = inst->GetOperandAs<SpvExecutionModel>(0);
  287. const char* str = reinterpret_cast<const char*>(
  288. inst->words().data() + inst->operand(2).offset);
  289. ValidationState_t::EntryPointDescription desc;
  290. desc.name = str;
  291. std::vector<uint32_t> interfaces;
  292. for (size_t j = 3; j < inst->operands().size(); ++j)
  293. desc.interfaces.push_back(inst->word(inst->operand(j).offset));
  294. vstate->RegisterEntryPoint(entry_point, execution_model,
  295. std::move(desc));
  296. }
  297. if (inst->opcode() == SpvOpFunctionCall) {
  298. if (!vstate->in_function_body()) {
  299. return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction)
  300. << "A FunctionCall must happen within a function body.";
  301. }
  302. const auto called_id = inst->GetOperandAs<uint32_t>(2);
  303. if (spvIsWebGPUEnv(context.target_env) &&
  304. !vstate->IsFunctionCallDefined(called_id)) {
  305. return vstate->diag(SPV_ERROR_INVALID_LAYOUT, &instruction)
  306. << "For WebGPU, functions need to be defined before being "
  307. "called.";
  308. }
  309. vstate->AddFunctionCallTarget(called_id);
  310. }
  311. if (vstate->in_function_body()) {
  312. inst->set_function(&(vstate->current_function()));
  313. inst->set_block(vstate->current_function().current_block());
  314. if (vstate->in_block() && spvOpcodeIsBlockTerminator(inst->opcode())) {
  315. vstate->current_function().current_block()->set_terminator(inst);
  316. }
  317. }
  318. if (auto error = IdPass(*vstate, inst)) return error;
  319. }
  320. if (auto error = CapabilityPass(*vstate, &instruction)) return error;
  321. if (auto error = DataRulesPass(*vstate, &instruction)) return error;
  322. if (auto error = ModuleLayoutPass(*vstate, &instruction)) return error;
  323. if (auto error = CfgPass(*vstate, &instruction)) return error;
  324. if (auto error = InstructionPass(*vstate, &instruction)) return error;
  325. // Now that all of the checks are done, update the state.
  326. {
  327. Instruction* inst = const_cast<Instruction*>(&instruction);
  328. vstate->RegisterInstruction(inst);
  329. }
  330. }
  331. if (!vstate->has_memory_model_specified())
  332. return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
  333. << "Missing required OpMemoryModel instruction.";
  334. if (vstate->in_function_body())
  335. return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
  336. << "Missing OpFunctionEnd at end of module.";
  337. // Catch undefined forward references before performing further checks.
  338. if (auto error = ValidateForwardDecls(*vstate)) return error;
  339. // ID usage needs be handled in its own iteration of the instructions,
  340. // between the two others. It depends on the first loop to have been
  341. // finished, so that all instructions have been registered. And the following
  342. // loop depends on all of the usage data being populated. Thus it cannot live
  343. // in either of those iterations.
  344. // It should also live after the forward declaration check, since it will
  345. // have problems with missing forward declarations, but give less useful error
  346. // messages.
  347. for (size_t i = 0; i < vstate->ordered_instructions().size(); ++i) {
  348. auto& instruction = vstate->ordered_instructions()[i];
  349. if (auto error = UpdateIdUse(*vstate, &instruction)) return error;
  350. }
  351. // Validate individual opcodes.
  352. for (size_t i = 0; i < vstate->ordered_instructions().size(); ++i) {
  353. auto& instruction = vstate->ordered_instructions()[i];
  354. // Keep these passes in the order they appear in the SPIR-V specification
  355. // sections to maintain test consistency.
  356. // Miscellaneous
  357. if (auto error = DebugPass(*vstate, &instruction)) return error;
  358. if (auto error = AnnotationPass(*vstate, &instruction)) return error;
  359. if (auto error = ExtensionPass(*vstate, &instruction)) return error;
  360. if (auto error = ModeSettingPass(*vstate, &instruction)) return error;
  361. if (auto error = TypePass(*vstate, &instruction)) return error;
  362. if (auto error = ConstantPass(*vstate, &instruction)) return error;
  363. if (auto error = MemoryPass(*vstate, &instruction)) return error;
  364. if (auto error = FunctionPass(*vstate, &instruction)) return error;
  365. if (auto error = ImagePass(*vstate, &instruction)) return error;
  366. if (auto error = ConversionPass(*vstate, &instruction)) return error;
  367. if (auto error = CompositesPass(*vstate, &instruction)) return error;
  368. if (auto error = ArithmeticsPass(*vstate, &instruction)) return error;
  369. if (auto error = BitwisePass(*vstate, &instruction)) return error;
  370. if (auto error = LogicalsPass(*vstate, &instruction)) return error;
  371. if (auto error = ControlFlowPass(*vstate, &instruction)) return error;
  372. if (auto error = DerivativesPass(*vstate, &instruction)) return error;
  373. if (auto error = AtomicsPass(*vstate, &instruction)) return error;
  374. if (auto error = PrimitivesPass(*vstate, &instruction)) return error;
  375. if (auto error = BarriersPass(*vstate, &instruction)) return error;
  376. // Group
  377. // Device-Side Enqueue
  378. // Pipe
  379. if (auto error = NonUniformPass(*vstate, &instruction)) return error;
  380. if (auto error = LiteralsPass(*vstate, &instruction)) return error;
  381. }
  382. // Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi
  383. // must only be preceeded by SpvOpLabel, SpvOpPhi, or SpvOpLine.
  384. if (auto error = ValidateAdjacency(*vstate)) return error;
  385. if (auto error = ValidateEntryPoints(*vstate)) return error;
  386. // CFG checks are performed after the binary has been parsed
  387. // and the CFGPass has collected information about the control flow
  388. if (auto error = PerformCfgChecks(*vstate)) return error;
  389. if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error;
  390. if (auto error = ValidateDecorations(*vstate)) return error;
  391. if (auto error = ValidateInterfaces(*vstate)) return error;
  392. // TODO(dsinclair): Restructure ValidateBuiltins so we can move into the
  393. // for() above as it loops over all ordered_instructions internally.
  394. if (auto error = ValidateBuiltIns(*vstate)) return error;
  395. // These checks must be performed after individual opcode checks because
  396. // those checks register the limitation checked here.
  397. for (const auto inst : vstate->ordered_instructions()) {
  398. if (auto error = ValidateExecutionLimitations(*vstate, &inst)) return error;
  399. }
  400. return SPV_SUCCESS;
  401. }
  402. } // namespace
  403. spv_result_t ValidateBinaryAndKeepValidationState(
  404. const spv_const_context context, spv_const_validator_options options,
  405. const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic,
  406. std::unique_ptr<ValidationState_t>* vstate) {
  407. spv_context_t hijack_context = *context;
  408. if (pDiagnostic) {
  409. *pDiagnostic = nullptr;
  410. UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
  411. }
  412. vstate->reset(new ValidationState_t(&hijack_context, options, words,
  413. num_words, kDefaultMaxNumOfWarnings));
  414. return ValidateBinaryUsingContextAndValidationState(
  415. hijack_context, words, num_words, pDiagnostic, vstate->get());
  416. }
  417. } // namespace val
  418. } // namespace spvtools
  419. spv_result_t spvValidate(const spv_const_context context,
  420. const spv_const_binary binary,
  421. spv_diagnostic* pDiagnostic) {
  422. return spvValidateBinary(context, binary->code, binary->wordCount,
  423. pDiagnostic);
  424. }
  425. spv_result_t spvValidateBinary(const spv_const_context context,
  426. const uint32_t* words, const size_t num_words,
  427. spv_diagnostic* pDiagnostic) {
  428. spv_context_t hijack_context = *context;
  429. if (pDiagnostic) {
  430. *pDiagnostic = nullptr;
  431. spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
  432. }
  433. // This interface is used for default command line options.
  434. spv_validator_options default_options = spvValidatorOptionsCreate();
  435. // Create the ValidationState using the context and default options.
  436. spvtools::val::ValidationState_t vstate(&hijack_context, default_options,
  437. words, num_words,
  438. kDefaultMaxNumOfWarnings);
  439. spv_result_t result =
  440. spvtools::val::ValidateBinaryUsingContextAndValidationState(
  441. hijack_context, words, num_words, pDiagnostic, &vstate);
  442. spvValidatorOptionsDestroy(default_options);
  443. return result;
  444. }
  445. spv_result_t spvValidateWithOptions(const spv_const_context context,
  446. spv_const_validator_options options,
  447. const spv_const_binary binary,
  448. spv_diagnostic* pDiagnostic) {
  449. spv_context_t hijack_context = *context;
  450. if (pDiagnostic) {
  451. *pDiagnostic = nullptr;
  452. spvtools::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
  453. }
  454. // Create the ValidationState using the context.
  455. spvtools::val::ValidationState_t vstate(&hijack_context, options,
  456. binary->code, binary->wordCount,
  457. kDefaultMaxNumOfWarnings);
  458. return spvtools::val::ValidateBinaryUsingContextAndValidationState(
  459. hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate);
  460. }