validate_arithmetics.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Performs validation of arithmetic instructions.
  15. #include "source/val/validate.h"
  16. #include <vector>
  17. #include "source/diagnostic.h"
  18. #include "source/opcode.h"
  19. #include "source/val/instruction.h"
  20. #include "source/val/validation_state.h"
  21. namespace spvtools {
  22. namespace val {
  23. // Validates correctness of arithmetic instructions.
  24. spv_result_t ArithmeticsPass(ValidationState_t& _, const Instruction* inst) {
  25. const SpvOp opcode = inst->opcode();
  26. const uint32_t result_type = inst->type_id();
  27. switch (opcode) {
  28. case SpvOpFAdd:
  29. case SpvOpFSub:
  30. case SpvOpFMul:
  31. case SpvOpFDiv:
  32. case SpvOpFRem:
  33. case SpvOpFMod:
  34. case SpvOpFNegate: {
  35. bool supportsCoopMat =
  36. (opcode != SpvOpFMul && opcode != SpvOpFRem && opcode != SpvOpFMod);
  37. if (!_.IsFloatScalarType(result_type) &&
  38. !_.IsFloatVectorType(result_type) &&
  39. !(supportsCoopMat && _.IsFloatCooperativeMatrixType(result_type)))
  40. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  41. << "Expected floating scalar or vector type as Result Type: "
  42. << spvOpcodeString(opcode);
  43. for (size_t operand_index = 2; operand_index < inst->operands().size();
  44. ++operand_index) {
  45. if (_.GetOperandTypeId(inst, operand_index) != result_type)
  46. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  47. << "Expected arithmetic operands to be of Result Type: "
  48. << spvOpcodeString(opcode) << " operand index "
  49. << operand_index;
  50. }
  51. break;
  52. }
  53. case SpvOpUDiv:
  54. case SpvOpUMod: {
  55. bool supportsCoopMat = (opcode == SpvOpUDiv);
  56. if (!_.IsUnsignedIntScalarType(result_type) &&
  57. !_.IsUnsignedIntVectorType(result_type) &&
  58. !(supportsCoopMat &&
  59. _.IsUnsignedIntCooperativeMatrixType(result_type)))
  60. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  61. << "Expected unsigned int scalar or vector type as Result Type: "
  62. << spvOpcodeString(opcode);
  63. for (size_t operand_index = 2; operand_index < inst->operands().size();
  64. ++operand_index) {
  65. if (_.GetOperandTypeId(inst, operand_index) != result_type)
  66. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  67. << "Expected arithmetic operands to be of Result Type: "
  68. << spvOpcodeString(opcode) << " operand index "
  69. << operand_index;
  70. }
  71. break;
  72. }
  73. case SpvOpISub:
  74. case SpvOpIAdd:
  75. case SpvOpIMul:
  76. case SpvOpSDiv:
  77. case SpvOpSMod:
  78. case SpvOpSRem:
  79. case SpvOpSNegate: {
  80. bool supportsCoopMat =
  81. (opcode != SpvOpIMul && opcode != SpvOpSRem && opcode != SpvOpSMod);
  82. if (!_.IsIntScalarType(result_type) && !_.IsIntVectorType(result_type) &&
  83. !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type)))
  84. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  85. << "Expected int scalar or vector type as Result Type: "
  86. << spvOpcodeString(opcode);
  87. const uint32_t dimension = _.GetDimension(result_type);
  88. const uint32_t bit_width = _.GetBitWidth(result_type);
  89. for (size_t operand_index = 2; operand_index < inst->operands().size();
  90. ++operand_index) {
  91. const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
  92. if (!type_id ||
  93. (!_.IsIntScalarType(type_id) && !_.IsIntVectorType(type_id) &&
  94. !(supportsCoopMat && _.IsIntCooperativeMatrixType(result_type))))
  95. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  96. << "Expected int scalar or vector type as operand: "
  97. << spvOpcodeString(opcode) << " operand index "
  98. << operand_index;
  99. if (_.GetDimension(type_id) != dimension)
  100. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  101. << "Expected arithmetic operands to have the same dimension "
  102. << "as Result Type: " << spvOpcodeString(opcode)
  103. << " operand index " << operand_index;
  104. if (_.GetBitWidth(type_id) != bit_width)
  105. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  106. << "Expected arithmetic operands to have the same bit width "
  107. << "as Result Type: " << spvOpcodeString(opcode)
  108. << " operand index " << operand_index;
  109. }
  110. break;
  111. }
  112. case SpvOpDot: {
  113. if (!_.IsFloatScalarType(result_type))
  114. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  115. << "Expected float scalar type as Result Type: "
  116. << spvOpcodeString(opcode);
  117. uint32_t first_vector_num_components = 0;
  118. for (size_t operand_index = 2; operand_index < inst->operands().size();
  119. ++operand_index) {
  120. const uint32_t type_id = _.GetOperandTypeId(inst, operand_index);
  121. if (!type_id || !_.IsFloatVectorType(type_id))
  122. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  123. << "Expected float vector as operand: "
  124. << spvOpcodeString(opcode) << " operand index "
  125. << operand_index;
  126. const uint32_t component_type = _.GetComponentType(type_id);
  127. if (component_type != result_type)
  128. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  129. << "Expected component type to be equal to Result Type: "
  130. << spvOpcodeString(opcode) << " operand index "
  131. << operand_index;
  132. const uint32_t num_components = _.GetDimension(type_id);
  133. if (operand_index == 2) {
  134. first_vector_num_components = num_components;
  135. } else if (num_components != first_vector_num_components) {
  136. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  137. << "Expected operands to have the same number of componenets: "
  138. << spvOpcodeString(opcode);
  139. }
  140. }
  141. break;
  142. }
  143. case SpvOpVectorTimesScalar: {
  144. if (!_.IsFloatVectorType(result_type))
  145. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  146. << "Expected float vector type as Result Type: "
  147. << spvOpcodeString(opcode);
  148. const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
  149. if (result_type != vector_type_id)
  150. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  151. << "Expected vector operand type to be equal to Result Type: "
  152. << spvOpcodeString(opcode);
  153. const uint32_t component_type = _.GetComponentType(vector_type_id);
  154. const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
  155. if (component_type != scalar_type_id)
  156. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  157. << "Expected scalar operand type to be equal to the component "
  158. << "type of the vector operand: " << spvOpcodeString(opcode);
  159. break;
  160. }
  161. case SpvOpMatrixTimesScalar: {
  162. if (!_.IsFloatMatrixType(result_type) &&
  163. !_.IsCooperativeMatrixType(result_type))
  164. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  165. << "Expected float matrix type as Result Type: "
  166. << spvOpcodeString(opcode);
  167. const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
  168. if (result_type != matrix_type_id)
  169. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  170. << "Expected matrix operand type to be equal to Result Type: "
  171. << spvOpcodeString(opcode);
  172. const uint32_t component_type = _.GetComponentType(matrix_type_id);
  173. const uint32_t scalar_type_id = _.GetOperandTypeId(inst, 3);
  174. if (component_type != scalar_type_id)
  175. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  176. << "Expected scalar operand type to be equal to the component "
  177. << "type of the matrix operand: " << spvOpcodeString(opcode);
  178. break;
  179. }
  180. case SpvOpVectorTimesMatrix: {
  181. const uint32_t vector_type_id = _.GetOperandTypeId(inst, 2);
  182. const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 3);
  183. if (!_.IsFloatVectorType(result_type))
  184. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  185. << "Expected float vector type as Result Type: "
  186. << spvOpcodeString(opcode);
  187. const uint32_t res_component_type = _.GetComponentType(result_type);
  188. if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
  189. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  190. << "Expected float vector type as left operand: "
  191. << spvOpcodeString(opcode);
  192. if (res_component_type != _.GetComponentType(vector_type_id))
  193. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  194. << "Expected component types of Result Type and vector to be "
  195. << "equal: " << spvOpcodeString(opcode);
  196. uint32_t matrix_num_rows = 0;
  197. uint32_t matrix_num_cols = 0;
  198. uint32_t matrix_col_type = 0;
  199. uint32_t matrix_component_type = 0;
  200. if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
  201. &matrix_num_cols, &matrix_col_type,
  202. &matrix_component_type))
  203. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  204. << "Expected float matrix type as right operand: "
  205. << spvOpcodeString(opcode);
  206. if (res_component_type != matrix_component_type)
  207. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  208. << "Expected component types of Result Type and matrix to be "
  209. << "equal: " << spvOpcodeString(opcode);
  210. if (matrix_num_cols != _.GetDimension(result_type))
  211. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  212. << "Expected number of columns of the matrix to be equal to "
  213. << "Result Type vector size: " << spvOpcodeString(opcode);
  214. if (matrix_num_rows != _.GetDimension(vector_type_id))
  215. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  216. << "Expected number of rows of the matrix to be equal to the "
  217. << "vector operand size: " << spvOpcodeString(opcode);
  218. break;
  219. }
  220. case SpvOpMatrixTimesVector: {
  221. const uint32_t matrix_type_id = _.GetOperandTypeId(inst, 2);
  222. const uint32_t vector_type_id = _.GetOperandTypeId(inst, 3);
  223. if (!_.IsFloatVectorType(result_type))
  224. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  225. << "Expected float vector type as Result Type: "
  226. << spvOpcodeString(opcode);
  227. uint32_t matrix_num_rows = 0;
  228. uint32_t matrix_num_cols = 0;
  229. uint32_t matrix_col_type = 0;
  230. uint32_t matrix_component_type = 0;
  231. if (!_.GetMatrixTypeInfo(matrix_type_id, &matrix_num_rows,
  232. &matrix_num_cols, &matrix_col_type,
  233. &matrix_component_type))
  234. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  235. << "Expected float matrix type as left operand: "
  236. << spvOpcodeString(opcode);
  237. if (result_type != matrix_col_type)
  238. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  239. << "Expected column type of the matrix to be equal to Result "
  240. "Type: "
  241. << spvOpcodeString(opcode);
  242. if (!vector_type_id || !_.IsFloatVectorType(vector_type_id))
  243. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  244. << "Expected float vector type as right operand: "
  245. << spvOpcodeString(opcode);
  246. if (matrix_component_type != _.GetComponentType(vector_type_id))
  247. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  248. << "Expected component types of the operands to be equal: "
  249. << spvOpcodeString(opcode);
  250. if (matrix_num_cols != _.GetDimension(vector_type_id))
  251. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  252. << "Expected number of columns of the matrix to be equal to the "
  253. << "vector size: " << spvOpcodeString(opcode);
  254. break;
  255. }
  256. case SpvOpMatrixTimesMatrix: {
  257. const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
  258. const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
  259. uint32_t res_num_rows = 0;
  260. uint32_t res_num_cols = 0;
  261. uint32_t res_col_type = 0;
  262. uint32_t res_component_type = 0;
  263. if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
  264. &res_col_type, &res_component_type))
  265. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  266. << "Expected float matrix type as Result Type: "
  267. << spvOpcodeString(opcode);
  268. uint32_t left_num_rows = 0;
  269. uint32_t left_num_cols = 0;
  270. uint32_t left_col_type = 0;
  271. uint32_t left_component_type = 0;
  272. if (!_.GetMatrixTypeInfo(left_type_id, &left_num_rows, &left_num_cols,
  273. &left_col_type, &left_component_type))
  274. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  275. << "Expected float matrix type as left operand: "
  276. << spvOpcodeString(opcode);
  277. uint32_t right_num_rows = 0;
  278. uint32_t right_num_cols = 0;
  279. uint32_t right_col_type = 0;
  280. uint32_t right_component_type = 0;
  281. if (!_.GetMatrixTypeInfo(right_type_id, &right_num_rows, &right_num_cols,
  282. &right_col_type, &right_component_type))
  283. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  284. << "Expected float matrix type as right operand: "
  285. << spvOpcodeString(opcode);
  286. if (!_.IsFloatScalarType(res_component_type))
  287. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  288. << "Expected float matrix type as Result Type: "
  289. << spvOpcodeString(opcode);
  290. if (res_col_type != left_col_type)
  291. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  292. << "Expected column types of Result Type and left matrix to be "
  293. << "equal: " << spvOpcodeString(opcode);
  294. if (res_component_type != right_component_type)
  295. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  296. << "Expected component types of Result Type and right matrix to "
  297. "be "
  298. << "equal: " << spvOpcodeString(opcode);
  299. if (res_num_cols != right_num_cols)
  300. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  301. << "Expected number of columns of Result Type and right matrix "
  302. "to "
  303. << "be equal: " << spvOpcodeString(opcode);
  304. if (left_num_cols != right_num_rows)
  305. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  306. << "Expected number of columns of left matrix and number of "
  307. "rows "
  308. << "of right matrix to be equal: " << spvOpcodeString(opcode);
  309. assert(left_num_rows == res_num_rows);
  310. break;
  311. }
  312. case SpvOpOuterProduct: {
  313. const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
  314. const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
  315. uint32_t res_num_rows = 0;
  316. uint32_t res_num_cols = 0;
  317. uint32_t res_col_type = 0;
  318. uint32_t res_component_type = 0;
  319. if (!_.GetMatrixTypeInfo(result_type, &res_num_rows, &res_num_cols,
  320. &res_col_type, &res_component_type))
  321. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  322. << "Expected float matrix type as Result Type: "
  323. << spvOpcodeString(opcode);
  324. if (left_type_id != res_col_type)
  325. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  326. << "Expected column type of Result Type to be equal to the type "
  327. << "of the left operand: " << spvOpcodeString(opcode);
  328. if (!right_type_id || !_.IsFloatVectorType(right_type_id))
  329. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  330. << "Expected float vector type as right operand: "
  331. << spvOpcodeString(opcode);
  332. if (res_component_type != _.GetComponentType(right_type_id))
  333. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  334. << "Expected component types of the operands to be equal: "
  335. << spvOpcodeString(opcode);
  336. if (res_num_cols != _.GetDimension(right_type_id))
  337. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  338. << "Expected number of columns of the matrix to be equal to the "
  339. << "vector size of the right operand: "
  340. << spvOpcodeString(opcode);
  341. break;
  342. }
  343. case SpvOpIAddCarry:
  344. case SpvOpISubBorrow:
  345. case SpvOpUMulExtended:
  346. case SpvOpSMulExtended: {
  347. std::vector<uint32_t> result_types;
  348. if (!_.GetStructMemberTypes(result_type, &result_types))
  349. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  350. << "Expected a struct as Result Type: "
  351. << spvOpcodeString(opcode);
  352. if (result_types.size() != 2)
  353. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  354. << "Expected Result Type struct to have two members: "
  355. << spvOpcodeString(opcode);
  356. if (opcode == SpvOpSMulExtended) {
  357. if (!_.IsIntScalarType(result_types[0]) &&
  358. !_.IsIntVectorType(result_types[0]))
  359. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  360. << "Expected Result Type struct member types to be integer "
  361. "scalar "
  362. << "or vector: " << spvOpcodeString(opcode);
  363. } else {
  364. if (!_.IsUnsignedIntScalarType(result_types[0]) &&
  365. !_.IsUnsignedIntVectorType(result_types[0]))
  366. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  367. << "Expected Result Type struct member types to be unsigned "
  368. << "integer scalar or vector: " << spvOpcodeString(opcode);
  369. }
  370. if (result_types[0] != result_types[1])
  371. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  372. << "Expected Result Type struct member types to be identical: "
  373. << spvOpcodeString(opcode);
  374. const uint32_t left_type_id = _.GetOperandTypeId(inst, 2);
  375. const uint32_t right_type_id = _.GetOperandTypeId(inst, 3);
  376. if (left_type_id != result_types[0] || right_type_id != result_types[0])
  377. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  378. << "Expected both operands to be of Result Type member type: "
  379. << spvOpcodeString(opcode);
  380. break;
  381. }
  382. case SpvOpCooperativeMatrixMulAddNV: {
  383. const uint32_t D_type_id = _.GetOperandTypeId(inst, 1);
  384. const uint32_t A_type_id = _.GetOperandTypeId(inst, 2);
  385. const uint32_t B_type_id = _.GetOperandTypeId(inst, 3);
  386. const uint32_t C_type_id = _.GetOperandTypeId(inst, 4);
  387. if (!_.IsCooperativeMatrixType(A_type_id)) {
  388. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  389. << "Expected cooperative matrix type as A Type: "
  390. << spvOpcodeString(opcode);
  391. }
  392. if (!_.IsCooperativeMatrixType(B_type_id)) {
  393. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  394. << "Expected cooperative matrix type as B Type: "
  395. << spvOpcodeString(opcode);
  396. }
  397. if (!_.IsCooperativeMatrixType(C_type_id)) {
  398. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  399. << "Expected cooperative matrix type as C Type: "
  400. << spvOpcodeString(opcode);
  401. }
  402. if (!_.IsCooperativeMatrixType(D_type_id)) {
  403. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  404. << "Expected cooperative matrix type as Result Type: "
  405. << spvOpcodeString(opcode);
  406. }
  407. const auto A = _.FindDef(A_type_id);
  408. const auto B = _.FindDef(B_type_id);
  409. const auto C = _.FindDef(C_type_id);
  410. const auto D = _.FindDef(D_type_id);
  411. std::tuple<bool, bool, uint32_t> A_scope, B_scope, C_scope, D_scope,
  412. A_rows, B_rows, C_rows, D_rows, A_cols, B_cols, C_cols, D_cols;
  413. A_scope = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(2));
  414. B_scope = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(2));
  415. C_scope = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(2));
  416. D_scope = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(2));
  417. A_rows = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(3));
  418. B_rows = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(3));
  419. C_rows = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(3));
  420. D_rows = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(3));
  421. A_cols = _.EvalInt32IfConst(A->GetOperandAs<uint32_t>(4));
  422. B_cols = _.EvalInt32IfConst(B->GetOperandAs<uint32_t>(4));
  423. C_cols = _.EvalInt32IfConst(C->GetOperandAs<uint32_t>(4));
  424. D_cols = _.EvalInt32IfConst(D->GetOperandAs<uint32_t>(4));
  425. const auto notEqual = [](std::tuple<bool, bool, uint32_t> X,
  426. std::tuple<bool, bool, uint32_t> Y) {
  427. return (std::get<1>(X) && std::get<1>(Y) &&
  428. std::get<2>(X) != std::get<2>(Y));
  429. };
  430. if (notEqual(A_scope, B_scope) || notEqual(A_scope, C_scope) ||
  431. notEqual(A_scope, D_scope) || notEqual(B_scope, C_scope) ||
  432. notEqual(B_scope, D_scope) || notEqual(C_scope, D_scope)) {
  433. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  434. << "Cooperative matrix scopes must match: "
  435. << spvOpcodeString(opcode);
  436. }
  437. if (notEqual(A_rows, C_rows) || notEqual(A_rows, D_rows) ||
  438. notEqual(C_rows, D_rows)) {
  439. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  440. << "Cooperative matrix 'M' mismatch: "
  441. << spvOpcodeString(opcode);
  442. }
  443. if (notEqual(B_cols, C_cols) || notEqual(B_cols, D_cols) ||
  444. notEqual(C_cols, D_cols)) {
  445. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  446. << "Cooperative matrix 'N' mismatch: "
  447. << spvOpcodeString(opcode);
  448. }
  449. if (notEqual(A_cols, B_rows)) {
  450. return _.diag(SPV_ERROR_INVALID_DATA, inst)
  451. << "Cooperative matrix 'K' mismatch: "
  452. << spvOpcodeString(opcode);
  453. }
  454. break;
  455. }
  456. default:
  457. break;
  458. }
  459. return SPV_SUCCESS;
  460. }
  461. } // namespace val
  462. } // namespace spvtools