strength_reduction_test.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. // Copyright (c) 2017 Google Inc.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <algorithm>
  15. #include <cstdarg>
  16. #include <iostream>
  17. #include <sstream>
  18. #include <string>
  19. #include <unordered_set>
  20. #include <vector>
  21. #include "gmock/gmock.h"
  22. #include "test/opt/assembly_builder.h"
  23. #include "test/opt/pass_fixture.h"
  24. #include "test/opt/pass_utils.h"
  25. namespace spvtools {
  26. namespace opt {
  27. namespace {
  28. using ::testing::HasSubstr;
  29. using ::testing::MatchesRegex;
  30. using StrengthReductionBasicTest = PassTest<::testing::Test>;
  31. // Test to make sure we replace 5*8.
  32. TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy8) {
  33. const std::vector<const char*> text = {
  34. // clang-format off
  35. "OpCapability Shader",
  36. "%1 = OpExtInstImport \"GLSL.std.450\"",
  37. "OpMemoryModel Logical GLSL450",
  38. "OpEntryPoint Vertex %main \"main\"",
  39. "OpName %main \"main\"",
  40. "%void = OpTypeVoid",
  41. "%4 = OpTypeFunction %void",
  42. "%uint = OpTypeInt 32 0",
  43. "%uint_5 = OpConstant %uint 5",
  44. "%uint_8 = OpConstant %uint 8",
  45. "%main = OpFunction %void None %4",
  46. "%8 = OpLabel",
  47. "%9 = OpIMul %uint %uint_5 %uint_8",
  48. "OpReturn",
  49. "OpFunctionEnd"
  50. // clang-format on
  51. };
  52. auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
  53. JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
  54. EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
  55. const std::string& output = std::get<0>(result);
  56. EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
  57. EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_3"));
  58. }
  59. // TODO(dneto): Add Effcee as required dependency, and make this unconditional.
  60. // Test to make sure we replace 16*5
  61. // Also demonstrate use of Effcee matching.
  62. TEST_F(StrengthReductionBasicTest, BasicReplaceMulBy16) {
  63. const std::string text = R"(
  64. OpCapability Shader
  65. %1 = OpExtInstImport "GLSL.std.450"
  66. OpMemoryModel Logical GLSL450
  67. OpEntryPoint Vertex %main "main"
  68. OpName %main "main"
  69. %void = OpTypeVoid
  70. %4 = OpTypeFunction %void
  71. ; We know disassembly will produce %uint here, but
  72. ; CHECK: %uint = OpTypeInt 32 0
  73. ; CHECK-DAG: [[five:%[a-zA-Z_\d]+]] = OpConstant %uint 5
  74. ; We have RE2 regular expressions, so \w matches [_a-zA-Z0-9].
  75. ; This shows the preferred pattern for matching SPIR-V identifiers.
  76. ; (We could have cheated in this case since we know the disassembler will
  77. ; generate the 'nice' name of "%uint_4".
  78. ; CHECK-DAG: [[four:%\w+]] = OpConstant %uint 4
  79. %uint = OpTypeInt 32 0
  80. %uint_5 = OpConstant %uint 5
  81. %uint_16 = OpConstant %uint 16
  82. %main = OpFunction %void None %4
  83. ; CHECK: OpLabel
  84. %8 = OpLabel
  85. ; CHECK-NEXT: OpShiftLeftLogical %uint [[five]] [[four]]
  86. ; The multiplication disappears.
  87. ; CHECK-NOT: OpIMul
  88. %9 = OpIMul %uint %uint_16 %uint_5
  89. OpReturn
  90. ; CHECK: OpFunctionEnd
  91. OpFunctionEnd)";
  92. SinglePassRunAndMatch<StrengthReductionPass>(text, false);
  93. }
  94. // Test to make sure we replace a multiple of 32 and 4.
  95. TEST_F(StrengthReductionBasicTest, BasicTwoPowersOf2) {
  96. // In this case, we have two powers of 2. Need to make sure we replace only
  97. // one of them for the bit shift.
  98. // clang-format off
  99. const std::string text = R"(
  100. OpCapability Shader
  101. %1 = OpExtInstImport "GLSL.std.450"
  102. OpMemoryModel Logical GLSL450
  103. OpEntryPoint Vertex %main "main"
  104. OpName %main "main"
  105. %void = OpTypeVoid
  106. %4 = OpTypeFunction %void
  107. %int = OpTypeInt 32 1
  108. %int_32 = OpConstant %int 32
  109. %int_4 = OpConstant %int 4
  110. %main = OpFunction %void None %4
  111. %8 = OpLabel
  112. %9 = OpIMul %int %int_32 %int_4
  113. OpReturn
  114. OpFunctionEnd
  115. )";
  116. // clang-format on
  117. auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
  118. text, /* skip_nop = */ true, /* do_validation = */ false);
  119. EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
  120. const std::string& output = std::get<0>(result);
  121. EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
  122. EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %int %int_4 %uint_5"));
  123. }
  124. // Test to make sure we don't replace 0*5.
  125. TEST_F(StrengthReductionBasicTest, BasicDontReplace0) {
  126. const std::vector<const char*> text = {
  127. // clang-format off
  128. "OpCapability Shader",
  129. "%1 = OpExtInstImport \"GLSL.std.450\"",
  130. "OpMemoryModel Logical GLSL450",
  131. "OpEntryPoint Vertex %main \"main\"",
  132. "OpName %main \"main\"",
  133. "%void = OpTypeVoid",
  134. "%4 = OpTypeFunction %void",
  135. "%int = OpTypeInt 32 1",
  136. "%int_0 = OpConstant %int 0",
  137. "%int_5 = OpConstant %int 5",
  138. "%main = OpFunction %void None %4",
  139. "%8 = OpLabel",
  140. "%9 = OpIMul %int %int_0 %int_5",
  141. "OpReturn",
  142. "OpFunctionEnd"
  143. // clang-format on
  144. };
  145. auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
  146. JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
  147. EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
  148. }
  149. // Test to make sure we do not replace a multiple of 5 and 7.
  150. TEST_F(StrengthReductionBasicTest, BasicNoChange) {
  151. const std::vector<const char*> text = {
  152. // clang-format off
  153. "OpCapability Shader",
  154. "%1 = OpExtInstImport \"GLSL.std.450\"",
  155. "OpMemoryModel Logical GLSL450",
  156. "OpEntryPoint Vertex %2 \"main\"",
  157. "OpName %2 \"main\"",
  158. "%3 = OpTypeVoid",
  159. "%4 = OpTypeFunction %3",
  160. "%5 = OpTypeInt 32 1",
  161. "%6 = OpTypeInt 32 0",
  162. "%7 = OpConstant %5 5",
  163. "%8 = OpConstant %5 7",
  164. "%2 = OpFunction %3 None %4",
  165. "%9 = OpLabel",
  166. "%10 = OpIMul %5 %7 %8",
  167. "OpReturn",
  168. "OpFunctionEnd",
  169. // clang-format on
  170. };
  171. auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
  172. JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
  173. EXPECT_EQ(Pass::Status::SuccessWithoutChange, std::get<1>(result));
  174. }
  175. // Test to make sure constants and types are reused and not duplicated.
  176. TEST_F(StrengthReductionBasicTest, NoDuplicateConstantsAndTypes) {
  177. const std::vector<const char*> text = {
  178. // clang-format off
  179. "OpCapability Shader",
  180. "%1 = OpExtInstImport \"GLSL.std.450\"",
  181. "OpMemoryModel Logical GLSL450",
  182. "OpEntryPoint Vertex %main \"main\"",
  183. "OpName %main \"main\"",
  184. "%void = OpTypeVoid",
  185. "%4 = OpTypeFunction %void",
  186. "%uint = OpTypeInt 32 0",
  187. "%uint_8 = OpConstant %uint 8",
  188. "%uint_3 = OpConstant %uint 3",
  189. "%main = OpFunction %void None %4",
  190. "%8 = OpLabel",
  191. "%9 = OpIMul %uint %uint_8 %uint_3",
  192. "OpReturn",
  193. "OpFunctionEnd",
  194. // clang-format on
  195. };
  196. auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
  197. JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
  198. EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
  199. const std::string& output = std::get<0>(result);
  200. EXPECT_THAT(output,
  201. Not(MatchesRegex(".*OpConstant %uint 3.*OpConstant %uint 3.*")));
  202. EXPECT_THAT(output, Not(MatchesRegex(".*OpTypeInt 32 0.*OpTypeInt 32 0.*")));
  203. }
  204. // Test to make sure we generate the constants only once
  205. TEST_F(StrengthReductionBasicTest, BasicCreateOneConst) {
  206. const std::vector<const char*> text = {
  207. // clang-format off
  208. "OpCapability Shader",
  209. "%1 = OpExtInstImport \"GLSL.std.450\"",
  210. "OpMemoryModel Logical GLSL450",
  211. "OpEntryPoint Vertex %main \"main\"",
  212. "OpName %main \"main\"",
  213. "%void = OpTypeVoid",
  214. "%4 = OpTypeFunction %void",
  215. "%uint = OpTypeInt 32 0",
  216. "%uint_5 = OpConstant %uint 5",
  217. "%uint_9 = OpConstant %uint 9",
  218. "%uint_128 = OpConstant %uint 128",
  219. "%main = OpFunction %void None %4",
  220. "%8 = OpLabel",
  221. "%9 = OpIMul %uint %uint_5 %uint_128",
  222. "%10 = OpIMul %uint %uint_9 %uint_128",
  223. "OpReturn",
  224. "OpFunctionEnd"
  225. // clang-format on
  226. };
  227. auto result = SinglePassRunAndDisassemble<StrengthReductionPass>(
  228. JoinAllInsts(text), /* skip_nop = */ true, /* do_validation = */ false);
  229. EXPECT_EQ(Pass::Status::SuccessWithChange, std::get<1>(result));
  230. const std::string& output = std::get<0>(result);
  231. EXPECT_THAT(output, Not(HasSubstr("OpIMul")));
  232. EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_5 %uint_7"));
  233. EXPECT_THAT(output, HasSubstr("OpShiftLeftLogical %uint %uint_9 %uint_7"));
  234. }
  235. // Test to make sure we generate the instructions in the correct position and
  236. // that the uses get replaced as well. Here we check that the use in the return
  237. // is replaced, we also check that we can replace two OpIMuls when one feeds the
  238. // other.
  239. TEST_F(StrengthReductionBasicTest, BasicCheckPositionAndReplacement) {
  240. // This is just the preamble to set up the test.
  241. const std::vector<const char*> common_text = {
  242. // clang-format off
  243. "OpCapability Shader",
  244. "%1 = OpExtInstImport \"GLSL.std.450\"",
  245. "OpMemoryModel Logical GLSL450",
  246. "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
  247. "OpExecutionMode %main OriginUpperLeft",
  248. "OpName %main \"main\"",
  249. "OpName %foo_i1_ \"foo(i1;\"",
  250. "OpName %n \"n\"",
  251. "OpName %gl_FragColor \"gl_FragColor\"",
  252. "OpName %param \"param\"",
  253. "OpDecorate %gl_FragColor Location 0",
  254. "%void = OpTypeVoid",
  255. "%3 = OpTypeFunction %void",
  256. "%int = OpTypeInt 32 1",
  257. "%_ptr_Function_int = OpTypePointer Function %int",
  258. "%8 = OpTypeFunction %int %_ptr_Function_int",
  259. "%int_256 = OpConstant %int 256",
  260. "%int_2 = OpConstant %int 2",
  261. "%float = OpTypeFloat 32",
  262. "%v4float = OpTypeVector %float 4",
  263. "%_ptr_Output_v4float = OpTypePointer Output %v4float",
  264. "%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
  265. "%float_1 = OpConstant %float 1",
  266. "%int_10 = OpConstant %int 10",
  267. "%float_0_375 = OpConstant %float 0.375",
  268. "%float_0_75 = OpConstant %float 0.75",
  269. "%uint = OpTypeInt 32 0",
  270. "%uint_8 = OpConstant %uint 8",
  271. "%uint_1 = OpConstant %uint 1",
  272. "%main = OpFunction %void None %3",
  273. "%5 = OpLabel",
  274. "%param = OpVariable %_ptr_Function_int Function",
  275. "OpStore %param %int_10",
  276. "%26 = OpFunctionCall %int %foo_i1_ %param",
  277. "%27 = OpConvertSToF %float %26",
  278. "%28 = OpFDiv %float %float_1 %27",
  279. "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
  280. "OpStore %gl_FragColor %31",
  281. "OpReturn",
  282. "OpFunctionEnd"
  283. // clang-format on
  284. };
  285. // This is the real test. The two OpIMul should be replaced. The expected
  286. // output is in |foo_after|.
  287. const std::vector<const char*> foo_before = {
  288. // clang-format off
  289. "%foo_i1_ = OpFunction %int None %8",
  290. "%n = OpFunctionParameter %_ptr_Function_int",
  291. "%11 = OpLabel",
  292. "%12 = OpLoad %int %n",
  293. "%14 = OpIMul %int %12 %int_256",
  294. "%16 = OpIMul %int %14 %int_2",
  295. "OpReturnValue %16",
  296. "OpFunctionEnd",
  297. // clang-format on
  298. };
  299. const std::vector<const char*> foo_after = {
  300. // clang-format off
  301. "%foo_i1_ = OpFunction %int None %8",
  302. "%n = OpFunctionParameter %_ptr_Function_int",
  303. "%11 = OpLabel",
  304. "%12 = OpLoad %int %n",
  305. "%33 = OpShiftLeftLogical %int %12 %uint_8",
  306. "%34 = OpShiftLeftLogical %int %33 %uint_1",
  307. "OpReturnValue %34",
  308. "OpFunctionEnd",
  309. // clang-format on
  310. };
  311. SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  312. SinglePassRunAndCheck<StrengthReductionPass>(
  313. JoinAllInsts(Concat(common_text, foo_before)),
  314. JoinAllInsts(Concat(common_text, foo_after)),
  315. /* skip_nop = */ true, /* do_validate = */ true);
  316. }
  317. // Test that, when the result of an OpIMul instruction has more than 1 use, and
  318. // the instruction is replaced, all of the uses of the results are replace with
  319. // the new result.
  320. TEST_F(StrengthReductionBasicTest, BasicTestMultipleReplacements) {
  321. // This is just the preamble to set up the test.
  322. const std::vector<const char*> common_text = {
  323. // clang-format off
  324. "OpCapability Shader",
  325. "%1 = OpExtInstImport \"GLSL.std.450\"",
  326. "OpMemoryModel Logical GLSL450",
  327. "OpEntryPoint Fragment %main \"main\" %gl_FragColor",
  328. "OpExecutionMode %main OriginUpperLeft",
  329. "OpName %main \"main\"",
  330. "OpName %foo_i1_ \"foo(i1;\"",
  331. "OpName %n \"n\"",
  332. "OpName %gl_FragColor \"gl_FragColor\"",
  333. "OpName %param \"param\"",
  334. "OpDecorate %gl_FragColor Location 0",
  335. "%void = OpTypeVoid",
  336. "%3 = OpTypeFunction %void",
  337. "%int = OpTypeInt 32 1",
  338. "%_ptr_Function_int = OpTypePointer Function %int",
  339. "%8 = OpTypeFunction %int %_ptr_Function_int",
  340. "%int_256 = OpConstant %int 256",
  341. "%int_2 = OpConstant %int 2",
  342. "%float = OpTypeFloat 32",
  343. "%v4float = OpTypeVector %float 4",
  344. "%_ptr_Output_v4float = OpTypePointer Output %v4float",
  345. "%gl_FragColor = OpVariable %_ptr_Output_v4float Output",
  346. "%float_1 = OpConstant %float 1",
  347. "%int_10 = OpConstant %int 10",
  348. "%float_0_375 = OpConstant %float 0.375",
  349. "%float_0_75 = OpConstant %float 0.75",
  350. "%uint = OpTypeInt 32 0",
  351. "%uint_8 = OpConstant %uint 8",
  352. "%uint_1 = OpConstant %uint 1",
  353. "%main = OpFunction %void None %3",
  354. "%5 = OpLabel",
  355. "%param = OpVariable %_ptr_Function_int Function",
  356. "OpStore %param %int_10",
  357. "%26 = OpFunctionCall %int %foo_i1_ %param",
  358. "%27 = OpConvertSToF %float %26",
  359. "%28 = OpFDiv %float %float_1 %27",
  360. "%31 = OpCompositeConstruct %v4float %28 %float_0_375 %float_0_75 %float_1",
  361. "OpStore %gl_FragColor %31",
  362. "OpReturn",
  363. "OpFunctionEnd"
  364. // clang-format on
  365. };
  366. // This is the real test. The two OpIMul instructions should be replaced. In
  367. // particular, we want to be sure that both uses of %16 are changed to use the
  368. // new result.
  369. const std::vector<const char*> foo_before = {
  370. // clang-format off
  371. "%foo_i1_ = OpFunction %int None %8",
  372. "%n = OpFunctionParameter %_ptr_Function_int",
  373. "%11 = OpLabel",
  374. "%12 = OpLoad %int %n",
  375. "%14 = OpIMul %int %12 %int_256",
  376. "%16 = OpIMul %int %14 %int_2",
  377. "%17 = OpIAdd %int %14 %16",
  378. "OpReturnValue %17",
  379. "OpFunctionEnd",
  380. // clang-format on
  381. };
  382. const std::vector<const char*> foo_after = {
  383. // clang-format off
  384. "%foo_i1_ = OpFunction %int None %8",
  385. "%n = OpFunctionParameter %_ptr_Function_int",
  386. "%11 = OpLabel",
  387. "%12 = OpLoad %int %n",
  388. "%34 = OpShiftLeftLogical %int %12 %uint_8",
  389. "%35 = OpShiftLeftLogical %int %34 %uint_1",
  390. "%17 = OpIAdd %int %34 %35",
  391. "OpReturnValue %17",
  392. "OpFunctionEnd",
  393. // clang-format on
  394. };
  395. SetAssembleOptions(SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  396. SinglePassRunAndCheck<StrengthReductionPass>(
  397. JoinAllInsts(Concat(common_text, foo_before)),
  398. JoinAllInsts(Concat(common_text, foo_after)),
  399. /* skip_nop = */ true, /* do_validate = */ true);
  400. }
  401. } // namespace
  402. } // namespace opt
  403. } // namespace spvtools