scalar_analysis.cpp 41 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222
  1. // Copyright (c) 2018 Google LLC.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. #include <memory>
  15. #include <string>
  16. #include <unordered_set>
  17. #include <vector>
  18. #include "gmock/gmock.h"
  19. #include "source/opt/iterator.h"
  20. #include "source/opt/loop_descriptor.h"
  21. #include "source/opt/pass.h"
  22. #include "source/opt/scalar_analysis.h"
  23. #include "source/opt/tree_iterator.h"
  24. #include "test/opt/assembly_builder.h"
  25. #include "test/opt/function_utils.h"
  26. #include "test/opt/pass_fixture.h"
  27. #include "test/opt/pass_utils.h"
  28. namespace spvtools {
  29. namespace opt {
  30. namespace {
  31. using ::testing::UnorderedElementsAre;
  32. using ScalarAnalysisTest = PassTest<::testing::Test>;
  33. /*
  34. Generated from the following GLSL + --eliminate-local-multi-store
  35. #version 410 core
  36. layout (location = 1) out float array[10];
  37. void main() {
  38. for (int i = 0; i < 10; ++i) {
  39. array[i] = array[i+1];
  40. }
  41. }
  42. */
  43. TEST_F(ScalarAnalysisTest, BasicEvolutionTest) {
  44. const std::string text = R"(
  45. OpCapability Shader
  46. %1 = OpExtInstImport "GLSL.std.450"
  47. OpMemoryModel Logical GLSL450
  48. OpEntryPoint Fragment %4 "main" %24
  49. OpExecutionMode %4 OriginUpperLeft
  50. OpSource GLSL 410
  51. OpName %4 "main"
  52. OpName %24 "array"
  53. OpDecorate %24 Location 1
  54. %2 = OpTypeVoid
  55. %3 = OpTypeFunction %2
  56. %6 = OpTypeInt 32 1
  57. %7 = OpTypePointer Function %6
  58. %9 = OpConstant %6 0
  59. %16 = OpConstant %6 10
  60. %17 = OpTypeBool
  61. %19 = OpTypeFloat 32
  62. %20 = OpTypeInt 32 0
  63. %21 = OpConstant %20 10
  64. %22 = OpTypeArray %19 %21
  65. %23 = OpTypePointer Output %22
  66. %24 = OpVariable %23 Output
  67. %27 = OpConstant %6 1
  68. %29 = OpTypePointer Output %19
  69. %4 = OpFunction %2 None %3
  70. %5 = OpLabel
  71. OpBranch %10
  72. %10 = OpLabel
  73. %35 = OpPhi %6 %9 %5 %34 %13
  74. OpLoopMerge %12 %13 None
  75. OpBranch %14
  76. %14 = OpLabel
  77. %18 = OpSLessThan %17 %35 %16
  78. OpBranchConditional %18 %11 %12
  79. %11 = OpLabel
  80. %28 = OpIAdd %6 %35 %27
  81. %30 = OpAccessChain %29 %24 %28
  82. %31 = OpLoad %19 %30
  83. %32 = OpAccessChain %29 %24 %35
  84. OpStore %32 %31
  85. OpBranch %13
  86. %13 = OpLabel
  87. %34 = OpIAdd %6 %35 %27
  88. OpBranch %10
  89. %12 = OpLabel
  90. OpReturn
  91. OpFunctionEnd
  92. )";
  93. // clang-format on
  94. std::unique_ptr<IRContext> context =
  95. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  96. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  97. Module* module = context->module();
  98. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  99. << text << std::endl;
  100. const Function* f = spvtest::GetFunction(module, 4);
  101. ScalarEvolutionAnalysis analysis{context.get()};
  102. const Instruction* store = nullptr;
  103. const Instruction* load = nullptr;
  104. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 11)) {
  105. if (inst.opcode() == SpvOp::SpvOpStore) {
  106. store = &inst;
  107. }
  108. if (inst.opcode() == SpvOp::SpvOpLoad) {
  109. load = &inst;
  110. }
  111. }
  112. EXPECT_NE(load, nullptr);
  113. EXPECT_NE(store, nullptr);
  114. Instruction* access_chain =
  115. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  116. Instruction* child = context->get_def_use_mgr()->GetDef(
  117. access_chain->GetSingleWordInOperand(1));
  118. const SENode* node = analysis.AnalyzeInstruction(child);
  119. EXPECT_NE(node, nullptr);
  120. // Unsimplified node should have the form of ADD(REC(0,1), 1)
  121. EXPECT_EQ(node->GetType(), SENode::Add);
  122. const SENode* child_1 = node->GetChild(0);
  123. EXPECT_TRUE(child_1->GetType() == SENode::Constant ||
  124. child_1->GetType() == SENode::RecurrentAddExpr);
  125. const SENode* child_2 = node->GetChild(1);
  126. EXPECT_TRUE(child_2->GetType() == SENode::Constant ||
  127. child_2->GetType() == SENode::RecurrentAddExpr);
  128. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  129. // Simplified should be in the form of REC(1,1)
  130. EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
  131. EXPECT_EQ(simplified->GetChild(0)->GetType(), SENode::Constant);
  132. EXPECT_EQ(simplified->GetChild(0)->AsSEConstantNode()->FoldToSingleValue(),
  133. 1);
  134. EXPECT_EQ(simplified->GetChild(1)->GetType(), SENode::Constant);
  135. EXPECT_EQ(simplified->GetChild(1)->AsSEConstantNode()->FoldToSingleValue(),
  136. 1);
  137. EXPECT_EQ(simplified->GetChild(0), simplified->GetChild(1));
  138. }
  139. /*
  140. Generated from the following GLSL + --eliminate-local-multi-store
  141. #version 410 core
  142. layout (location = 1) out float array[10];
  143. layout (location = 2) flat in int loop_invariant;
  144. void main() {
  145. for (int i = 0; i < 10; ++i) {
  146. array[i] = array[i+loop_invariant];
  147. }
  148. }
  149. */
  150. TEST_F(ScalarAnalysisTest, LoadTest) {
  151. const std::string text = R"(
  152. OpCapability Shader
  153. %1 = OpExtInstImport "GLSL.std.450"
  154. OpMemoryModel Logical GLSL450
  155. OpEntryPoint Fragment %2 "main" %3 %4
  156. OpExecutionMode %2 OriginUpperLeft
  157. OpSource GLSL 430
  158. OpName %2 "main"
  159. OpName %3 "array"
  160. OpName %4 "loop_invariant"
  161. OpDecorate %3 Location 1
  162. OpDecorate %4 Flat
  163. OpDecorate %4 Location 2
  164. %5 = OpTypeVoid
  165. %6 = OpTypeFunction %5
  166. %7 = OpTypeInt 32 1
  167. %8 = OpTypePointer Function %7
  168. %9 = OpConstant %7 0
  169. %10 = OpConstant %7 10
  170. %11 = OpTypeBool
  171. %12 = OpTypeFloat 32
  172. %13 = OpTypeInt 32 0
  173. %14 = OpConstant %13 10
  174. %15 = OpTypeArray %12 %14
  175. %16 = OpTypePointer Output %15
  176. %3 = OpVariable %16 Output
  177. %17 = OpTypePointer Input %7
  178. %4 = OpVariable %17 Input
  179. %18 = OpTypePointer Output %12
  180. %19 = OpConstant %7 1
  181. %2 = OpFunction %5 None %6
  182. %20 = OpLabel
  183. OpBranch %21
  184. %21 = OpLabel
  185. %22 = OpPhi %7 %9 %20 %23 %24
  186. OpLoopMerge %25 %24 None
  187. OpBranch %26
  188. %26 = OpLabel
  189. %27 = OpSLessThan %11 %22 %10
  190. OpBranchConditional %27 %28 %25
  191. %28 = OpLabel
  192. %29 = OpLoad %7 %4
  193. %30 = OpIAdd %7 %22 %29
  194. %31 = OpAccessChain %18 %3 %30
  195. %32 = OpLoad %12 %31
  196. %33 = OpAccessChain %18 %3 %22
  197. OpStore %33 %32
  198. OpBranch %24
  199. %24 = OpLabel
  200. %23 = OpIAdd %7 %22 %19
  201. OpBranch %21
  202. %25 = OpLabel
  203. OpReturn
  204. OpFunctionEnd
  205. )";
  206. // clang-format on
  207. std::unique_ptr<IRContext> context =
  208. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  209. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  210. Module* module = context->module();
  211. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  212. << text << std::endl;
  213. const Function* f = spvtest::GetFunction(module, 2);
  214. ScalarEvolutionAnalysis analysis{context.get()};
  215. const Instruction* load = nullptr;
  216. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 28)) {
  217. if (inst.opcode() == SpvOp::SpvOpLoad) {
  218. load = &inst;
  219. }
  220. }
  221. EXPECT_NE(load, nullptr);
  222. Instruction* access_chain =
  223. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  224. Instruction* child = context->get_def_use_mgr()->GetDef(
  225. access_chain->GetSingleWordInOperand(1));
  226. // const SENode* node =
  227. // analysis.GetNodeFromInstruction(child->unique_id());
  228. const SENode* node = analysis.AnalyzeInstruction(child);
  229. EXPECT_NE(node, nullptr);
  230. // Unsimplified node should have the form of ADD(REC(0,1), X)
  231. EXPECT_EQ(node->GetType(), SENode::Add);
  232. const SENode* child_1 = node->GetChild(0);
  233. EXPECT_TRUE(child_1->GetType() == SENode::ValueUnknown ||
  234. child_1->GetType() == SENode::RecurrentAddExpr);
  235. const SENode* child_2 = node->GetChild(1);
  236. EXPECT_TRUE(child_2->GetType() == SENode::ValueUnknown ||
  237. child_2->GetType() == SENode::RecurrentAddExpr);
  238. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  239. EXPECT_EQ(simplified->GetType(), SENode::RecurrentAddExpr);
  240. const SERecurrentNode* rec = simplified->AsSERecurrentNode();
  241. EXPECT_NE(rec->GetChild(0), rec->GetChild(1));
  242. EXPECT_EQ(rec->GetOffset()->GetType(), SENode::ValueUnknown);
  243. EXPECT_EQ(rec->GetCoefficient()->GetType(), SENode::Constant);
  244. EXPECT_EQ(rec->GetCoefficient()->AsSEConstantNode()->FoldToSingleValue(), 1u);
  245. }
  246. /*
  247. Generated from the following GLSL + --eliminate-local-multi-store
  248. #version 410 core
  249. layout (location = 1) out float array[10];
  250. layout (location = 2) flat in int loop_invariant;
  251. void main() {
  252. array[0] = array[loop_invariant * 2 + 4 + 5 - 24 - loop_invariant -
  253. loop_invariant+ 16 * 3];
  254. }
  255. */
  256. TEST_F(ScalarAnalysisTest, SimplifySimple) {
  257. const std::string text = R"(
  258. OpCapability Shader
  259. %1 = OpExtInstImport "GLSL.std.450"
  260. OpMemoryModel Logical GLSL450
  261. OpEntryPoint Fragment %2 "main" %3 %4
  262. OpExecutionMode %2 OriginUpperLeft
  263. OpSource GLSL 430
  264. OpName %2 "main"
  265. OpName %3 "array"
  266. OpName %4 "loop_invariant"
  267. OpDecorate %3 Location 1
  268. OpDecorate %4 Flat
  269. OpDecorate %4 Location 2
  270. %5 = OpTypeVoid
  271. %6 = OpTypeFunction %5
  272. %7 = OpTypeFloat 32
  273. %8 = OpTypeInt 32 0
  274. %9 = OpConstant %8 10
  275. %10 = OpTypeArray %7 %9
  276. %11 = OpTypePointer Output %10
  277. %3 = OpVariable %11 Output
  278. %12 = OpTypeInt 32 1
  279. %13 = OpConstant %12 0
  280. %14 = OpTypePointer Input %12
  281. %4 = OpVariable %14 Input
  282. %15 = OpConstant %12 2
  283. %16 = OpConstant %12 4
  284. %17 = OpConstant %12 5
  285. %18 = OpConstant %12 24
  286. %19 = OpConstant %12 48
  287. %20 = OpTypePointer Output %7
  288. %2 = OpFunction %5 None %6
  289. %21 = OpLabel
  290. %22 = OpLoad %12 %4
  291. %23 = OpIMul %12 %22 %15
  292. %24 = OpIAdd %12 %23 %16
  293. %25 = OpIAdd %12 %24 %17
  294. %26 = OpISub %12 %25 %18
  295. %28 = OpISub %12 %26 %22
  296. %30 = OpISub %12 %28 %22
  297. %31 = OpIAdd %12 %30 %19
  298. %32 = OpAccessChain %20 %3 %31
  299. %33 = OpLoad %7 %32
  300. %34 = OpAccessChain %20 %3 %13
  301. OpStore %34 %33
  302. OpReturn
  303. OpFunctionEnd
  304. )";
  305. // clang-format on
  306. std::unique_ptr<IRContext> context =
  307. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  308. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  309. Module* module = context->module();
  310. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  311. << text << std::endl;
  312. const Function* f = spvtest::GetFunction(module, 2);
  313. ScalarEvolutionAnalysis analysis{context.get()};
  314. const Instruction* load = nullptr;
  315. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
  316. if (inst.opcode() == SpvOp::SpvOpLoad && inst.result_id() == 33) {
  317. load = &inst;
  318. }
  319. }
  320. EXPECT_NE(load, nullptr);
  321. Instruction* access_chain =
  322. context->get_def_use_mgr()->GetDef(load->GetSingleWordInOperand(0));
  323. Instruction* child = context->get_def_use_mgr()->GetDef(
  324. access_chain->GetSingleWordInOperand(1));
  325. const SENode* node = analysis.AnalyzeInstruction(child);
  326. // Unsimplified is a very large graph with an add at the top.
  327. EXPECT_NE(node, nullptr);
  328. EXPECT_EQ(node->GetType(), SENode::Add);
  329. // Simplified node should resolve down to a constant expression as the loads
  330. // will eliminate themselves.
  331. SENode* simplified = analysis.SimplifyExpression(const_cast<SENode*>(node));
  332. EXPECT_EQ(simplified->GetType(), SENode::Constant);
  333. EXPECT_EQ(simplified->AsSEConstantNode()->FoldToSingleValue(), 33u);
  334. }
  335. /*
  336. Generated from the following GLSL + --eliminate-local-multi-store
  337. #version 410 core
  338. layout(location = 0) in vec4 c;
  339. layout (location = 1) out float array[10];
  340. void main() {
  341. int N = int(c.x);
  342. for (int i = 0; i < 10; ++i) {
  343. array[i] = array[i];
  344. array[i] = array[i-1];
  345. array[i] = array[i+1];
  346. array[i+1] = array[i+1];
  347. array[i+N] = array[i+N];
  348. array[i] = array[i+N];
  349. }
  350. }
  351. */
  352. TEST_F(ScalarAnalysisTest, Simplify) {
  353. const std::string text = R"( OpCapability Shader
  354. %1 = OpExtInstImport "GLSL.std.450"
  355. OpMemoryModel Logical GLSL450
  356. OpEntryPoint Fragment %4 "main" %12 %33
  357. OpExecutionMode %4 OriginUpperLeft
  358. OpSource GLSL 410
  359. OpName %4 "main"
  360. OpName %8 "N"
  361. OpName %12 "c"
  362. OpName %19 "i"
  363. OpName %33 "array"
  364. OpDecorate %12 Location 0
  365. OpDecorate %33 Location 1
  366. %2 = OpTypeVoid
  367. %3 = OpTypeFunction %2
  368. %6 = OpTypeInt 32 1
  369. %7 = OpTypePointer Function %6
  370. %9 = OpTypeFloat 32
  371. %10 = OpTypeVector %9 4
  372. %11 = OpTypePointer Input %10
  373. %12 = OpVariable %11 Input
  374. %13 = OpTypeInt 32 0
  375. %14 = OpConstant %13 0
  376. %15 = OpTypePointer Input %9
  377. %20 = OpConstant %6 0
  378. %27 = OpConstant %6 10
  379. %28 = OpTypeBool
  380. %30 = OpConstant %13 10
  381. %31 = OpTypeArray %9 %30
  382. %32 = OpTypePointer Output %31
  383. %33 = OpVariable %32 Output
  384. %36 = OpTypePointer Output %9
  385. %42 = OpConstant %6 1
  386. %4 = OpFunction %2 None %3
  387. %5 = OpLabel
  388. %8 = OpVariable %7 Function
  389. %19 = OpVariable %7 Function
  390. %16 = OpAccessChain %15 %12 %14
  391. %17 = OpLoad %9 %16
  392. %18 = OpConvertFToS %6 %17
  393. OpStore %8 %18
  394. OpStore %19 %20
  395. OpBranch %21
  396. %21 = OpLabel
  397. %78 = OpPhi %6 %20 %5 %77 %24
  398. OpLoopMerge %23 %24 None
  399. OpBranch %25
  400. %25 = OpLabel
  401. %29 = OpSLessThan %28 %78 %27
  402. OpBranchConditional %29 %22 %23
  403. %22 = OpLabel
  404. %37 = OpAccessChain %36 %33 %78
  405. %38 = OpLoad %9 %37
  406. %39 = OpAccessChain %36 %33 %78
  407. OpStore %39 %38
  408. %43 = OpISub %6 %78 %42
  409. %44 = OpAccessChain %36 %33 %43
  410. %45 = OpLoad %9 %44
  411. %46 = OpAccessChain %36 %33 %78
  412. OpStore %46 %45
  413. %49 = OpIAdd %6 %78 %42
  414. %50 = OpAccessChain %36 %33 %49
  415. %51 = OpLoad %9 %50
  416. %52 = OpAccessChain %36 %33 %78
  417. OpStore %52 %51
  418. %54 = OpIAdd %6 %78 %42
  419. %56 = OpIAdd %6 %78 %42
  420. %57 = OpAccessChain %36 %33 %56
  421. %58 = OpLoad %9 %57
  422. %59 = OpAccessChain %36 %33 %54
  423. OpStore %59 %58
  424. %62 = OpIAdd %6 %78 %18
  425. %65 = OpIAdd %6 %78 %18
  426. %66 = OpAccessChain %36 %33 %65
  427. %67 = OpLoad %9 %66
  428. %68 = OpAccessChain %36 %33 %62
  429. OpStore %68 %67
  430. %72 = OpIAdd %6 %78 %18
  431. %73 = OpAccessChain %36 %33 %72
  432. %74 = OpLoad %9 %73
  433. %75 = OpAccessChain %36 %33 %78
  434. OpStore %75 %74
  435. OpBranch %24
  436. %24 = OpLabel
  437. %77 = OpIAdd %6 %78 %42
  438. OpStore %19 %77
  439. OpBranch %21
  440. %23 = OpLabel
  441. OpReturn
  442. OpFunctionEnd
  443. )";
  444. // clang-format on
  445. std::unique_ptr<IRContext> context =
  446. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  447. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  448. Module* module = context->module();
  449. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  450. << text << std::endl;
  451. const Function* f = spvtest::GetFunction(module, 4);
  452. ScalarEvolutionAnalysis analysis{context.get()};
  453. const Instruction* loads[6];
  454. const Instruction* stores[6];
  455. int load_count = 0;
  456. int store_count = 0;
  457. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 22)) {
  458. if (inst.opcode() == SpvOp::SpvOpLoad) {
  459. loads[load_count] = &inst;
  460. ++load_count;
  461. }
  462. if (inst.opcode() == SpvOp::SpvOpStore) {
  463. stores[store_count] = &inst;
  464. ++store_count;
  465. }
  466. }
  467. EXPECT_EQ(load_count, 6);
  468. EXPECT_EQ(store_count, 6);
  469. Instruction* load_access_chain;
  470. Instruction* store_access_chain;
  471. Instruction* load_child;
  472. Instruction* store_child;
  473. SENode* load_node;
  474. SENode* store_node;
  475. SENode* subtract_node;
  476. SENode* simplified_node;
  477. // Testing [i] - [i] == 0
  478. load_access_chain =
  479. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  480. store_access_chain =
  481. context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
  482. load_child = context->get_def_use_mgr()->GetDef(
  483. load_access_chain->GetSingleWordInOperand(1));
  484. store_child = context->get_def_use_mgr()->GetDef(
  485. store_access_chain->GetSingleWordInOperand(1));
  486. load_node = analysis.AnalyzeInstruction(load_child);
  487. store_node = analysis.AnalyzeInstruction(store_child);
  488. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  489. simplified_node = analysis.SimplifyExpression(subtract_node);
  490. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  491. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  492. // Testing [i] - [i-1] == 1
  493. load_access_chain =
  494. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  495. store_access_chain =
  496. context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
  497. load_child = context->get_def_use_mgr()->GetDef(
  498. load_access_chain->GetSingleWordInOperand(1));
  499. store_child = context->get_def_use_mgr()->GetDef(
  500. store_access_chain->GetSingleWordInOperand(1));
  501. load_node = analysis.AnalyzeInstruction(load_child);
  502. store_node = analysis.AnalyzeInstruction(store_child);
  503. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  504. simplified_node = analysis.SimplifyExpression(subtract_node);
  505. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  506. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 1u);
  507. // Testing [i] - [i+1] == -1
  508. load_access_chain =
  509. context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
  510. store_access_chain =
  511. context->get_def_use_mgr()->GetDef(stores[2]->GetSingleWordInOperand(0));
  512. load_child = context->get_def_use_mgr()->GetDef(
  513. load_access_chain->GetSingleWordInOperand(1));
  514. store_child = context->get_def_use_mgr()->GetDef(
  515. store_access_chain->GetSingleWordInOperand(1));
  516. load_node = analysis.AnalyzeInstruction(load_child);
  517. store_node = analysis.AnalyzeInstruction(store_child);
  518. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  519. simplified_node = analysis.SimplifyExpression(subtract_node);
  520. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  521. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), -1);
  522. // Testing [i+1] - [i+1] == 0
  523. load_access_chain =
  524. context->get_def_use_mgr()->GetDef(loads[3]->GetSingleWordInOperand(0));
  525. store_access_chain =
  526. context->get_def_use_mgr()->GetDef(stores[3]->GetSingleWordInOperand(0));
  527. load_child = context->get_def_use_mgr()->GetDef(
  528. load_access_chain->GetSingleWordInOperand(1));
  529. store_child = context->get_def_use_mgr()->GetDef(
  530. store_access_chain->GetSingleWordInOperand(1));
  531. load_node = analysis.AnalyzeInstruction(load_child);
  532. store_node = analysis.AnalyzeInstruction(store_child);
  533. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  534. simplified_node = analysis.SimplifyExpression(subtract_node);
  535. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  536. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  537. // Testing [i+N] - [i+N] == 0
  538. load_access_chain =
  539. context->get_def_use_mgr()->GetDef(loads[4]->GetSingleWordInOperand(0));
  540. store_access_chain =
  541. context->get_def_use_mgr()->GetDef(stores[4]->GetSingleWordInOperand(0));
  542. load_child = context->get_def_use_mgr()->GetDef(
  543. load_access_chain->GetSingleWordInOperand(1));
  544. store_child = context->get_def_use_mgr()->GetDef(
  545. store_access_chain->GetSingleWordInOperand(1));
  546. load_node = analysis.AnalyzeInstruction(load_child);
  547. store_node = analysis.AnalyzeInstruction(store_child);
  548. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  549. simplified_node = analysis.SimplifyExpression(subtract_node);
  550. EXPECT_EQ(simplified_node->GetType(), SENode::Constant);
  551. EXPECT_EQ(simplified_node->AsSEConstantNode()->FoldToSingleValue(), 0u);
  552. // Testing [i] - [i+N] == -N
  553. load_access_chain =
  554. context->get_def_use_mgr()->GetDef(loads[5]->GetSingleWordInOperand(0));
  555. store_access_chain =
  556. context->get_def_use_mgr()->GetDef(stores[5]->GetSingleWordInOperand(0));
  557. load_child = context->get_def_use_mgr()->GetDef(
  558. load_access_chain->GetSingleWordInOperand(1));
  559. store_child = context->get_def_use_mgr()->GetDef(
  560. store_access_chain->GetSingleWordInOperand(1));
  561. load_node = analysis.AnalyzeInstruction(load_child);
  562. store_node = analysis.AnalyzeInstruction(store_child);
  563. subtract_node = analysis.CreateSubtraction(store_node, load_node);
  564. simplified_node = analysis.SimplifyExpression(subtract_node);
  565. EXPECT_EQ(simplified_node->GetType(), SENode::Negative);
  566. }
  567. /*
  568. Generated from the following GLSL + --eliminate-local-multi-store
  569. #version 430
  570. layout(location = 1) out float array[10];
  571. layout(location = 2) flat in int loop_invariant;
  572. void main(void) {
  573. for (int i = 0; i < 10; ++i) {
  574. array[i * 2 + i * 5] = array[i * i * 2];
  575. array[i * 2] = array[i * 5];
  576. }
  577. }
  578. */
  579. TEST_F(ScalarAnalysisTest, SimplifyMultiplyInductions) {
  580. const std::string text = R"(
  581. OpCapability Shader
  582. %1 = OpExtInstImport "GLSL.std.450"
  583. OpMemoryModel Logical GLSL450
  584. OpEntryPoint Fragment %2 "main" %3 %4
  585. OpExecutionMode %2 OriginUpperLeft
  586. OpSource GLSL 430
  587. OpName %2 "main"
  588. OpName %5 "i"
  589. OpName %3 "array"
  590. OpName %4 "loop_invariant"
  591. OpDecorate %3 Location 1
  592. OpDecorate %4 Flat
  593. OpDecorate %4 Location 2
  594. %6 = OpTypeVoid
  595. %7 = OpTypeFunction %6
  596. %8 = OpTypeInt 32 1
  597. %9 = OpTypePointer Function %8
  598. %10 = OpConstant %8 0
  599. %11 = OpConstant %8 10
  600. %12 = OpTypeBool
  601. %13 = OpTypeFloat 32
  602. %14 = OpTypeInt 32 0
  603. %15 = OpConstant %14 10
  604. %16 = OpTypeArray %13 %15
  605. %17 = OpTypePointer Output %16
  606. %3 = OpVariable %17 Output
  607. %18 = OpConstant %8 2
  608. %19 = OpConstant %8 5
  609. %20 = OpTypePointer Output %13
  610. %21 = OpConstant %8 1
  611. %22 = OpTypePointer Input %8
  612. %4 = OpVariable %22 Input
  613. %2 = OpFunction %6 None %7
  614. %23 = OpLabel
  615. %5 = OpVariable %9 Function
  616. OpStore %5 %10
  617. OpBranch %24
  618. %24 = OpLabel
  619. %25 = OpPhi %8 %10 %23 %26 %27
  620. OpLoopMerge %28 %27 None
  621. OpBranch %29
  622. %29 = OpLabel
  623. %30 = OpSLessThan %12 %25 %11
  624. OpBranchConditional %30 %31 %28
  625. %31 = OpLabel
  626. %32 = OpIMul %8 %25 %18
  627. %33 = OpIMul %8 %25 %19
  628. %34 = OpIAdd %8 %32 %33
  629. %35 = OpIMul %8 %25 %25
  630. %36 = OpIMul %8 %35 %18
  631. %37 = OpAccessChain %20 %3 %36
  632. %38 = OpLoad %13 %37
  633. %39 = OpAccessChain %20 %3 %34
  634. OpStore %39 %38
  635. %40 = OpIMul %8 %25 %18
  636. %41 = OpIMul %8 %25 %19
  637. %42 = OpAccessChain %20 %3 %41
  638. %43 = OpLoad %13 %42
  639. %44 = OpAccessChain %20 %3 %40
  640. OpStore %44 %43
  641. OpBranch %27
  642. %27 = OpLabel
  643. %26 = OpIAdd %8 %25 %21
  644. OpStore %5 %26
  645. OpBranch %24
  646. %28 = OpLabel
  647. OpReturn
  648. OpFunctionEnd
  649. )";
  650. std::unique_ptr<IRContext> context =
  651. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  652. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  653. Module* module = context->module();
  654. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  655. << text << std::endl;
  656. const Function* f = spvtest::GetFunction(module, 2);
  657. ScalarEvolutionAnalysis analysis{context.get()};
  658. const Instruction* loads[2] = {nullptr, nullptr};
  659. const Instruction* stores[2] = {nullptr, nullptr};
  660. int load_count = 0;
  661. int store_count = 0;
  662. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 31)) {
  663. if (inst.opcode() == SpvOp::SpvOpLoad) {
  664. loads[load_count] = &inst;
  665. ++load_count;
  666. }
  667. if (inst.opcode() == SpvOp::SpvOpStore) {
  668. stores[store_count] = &inst;
  669. ++store_count;
  670. }
  671. }
  672. EXPECT_EQ(load_count, 2);
  673. EXPECT_EQ(store_count, 2);
  674. Instruction* load_access_chain =
  675. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  676. Instruction* store_access_chain =
  677. context->get_def_use_mgr()->GetDef(stores[0]->GetSingleWordInOperand(0));
  678. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  679. load_access_chain->GetSingleWordInOperand(1));
  680. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  681. store_access_chain->GetSingleWordInOperand(1));
  682. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  683. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  684. load_access_chain =
  685. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  686. store_access_chain =
  687. context->get_def_use_mgr()->GetDef(stores[1]->GetSingleWordInOperand(0));
  688. load_child = context->get_def_use_mgr()->GetDef(
  689. load_access_chain->GetSingleWordInOperand(1));
  690. store_child = context->get_def_use_mgr()->GetDef(
  691. store_access_chain->GetSingleWordInOperand(1));
  692. SENode* second_store =
  693. analysis.SimplifyExpression(analysis.AnalyzeInstruction(store_child));
  694. SENode* second_load =
  695. analysis.SimplifyExpression(analysis.AnalyzeInstruction(load_child));
  696. SENode* combined_add = analysis.SimplifyExpression(
  697. analysis.CreateAddNode(second_load, second_store));
  698. // We're checking that the two recurrent expression have been correctly
  699. // folded. In store_simplified they will have been folded as the entire
  700. // expression was simplified as one. In combined_add the two expressions have
  701. // been simplified one after the other which means the recurrent expressions
  702. // aren't exactly the same but should still be folded as they are with respect
  703. // to the same loop.
  704. EXPECT_EQ(combined_add, store_simplified);
  705. }
  706. /*
  707. Generated from the following GLSL + --eliminate-local-multi-store
  708. #version 430
  709. void main(void) {
  710. for (int i = 0; i < 10; --i) {
  711. array[i] = array[i];
  712. }
  713. }
  714. */
  715. TEST_F(ScalarAnalysisTest, SimplifyNegativeSteps) {
  716. const std::string text = R"(
  717. OpCapability Shader
  718. %1 = OpExtInstImport "GLSL.std.450"
  719. OpMemoryModel Logical GLSL450
  720. OpEntryPoint Fragment %2 "main" %3 %4
  721. OpExecutionMode %2 OriginUpperLeft
  722. OpSource GLSL 430
  723. OpName %2 "main"
  724. OpName %5 "i"
  725. OpName %3 "array"
  726. OpName %4 "loop_invariant"
  727. OpDecorate %3 Location 1
  728. OpDecorate %4 Flat
  729. OpDecorate %4 Location 2
  730. %6 = OpTypeVoid
  731. %7 = OpTypeFunction %6
  732. %8 = OpTypeInt 32 1
  733. %9 = OpTypePointer Function %8
  734. %10 = OpConstant %8 0
  735. %11 = OpConstant %8 10
  736. %12 = OpTypeBool
  737. %13 = OpTypeFloat 32
  738. %14 = OpTypeInt 32 0
  739. %15 = OpConstant %14 10
  740. %16 = OpTypeArray %13 %15
  741. %17 = OpTypePointer Output %16
  742. %3 = OpVariable %17 Output
  743. %18 = OpTypePointer Output %13
  744. %19 = OpConstant %8 1
  745. %20 = OpTypePointer Input %8
  746. %4 = OpVariable %20 Input
  747. %2 = OpFunction %6 None %7
  748. %21 = OpLabel
  749. %5 = OpVariable %9 Function
  750. OpStore %5 %10
  751. OpBranch %22
  752. %22 = OpLabel
  753. %23 = OpPhi %8 %10 %21 %24 %25
  754. OpLoopMerge %26 %25 None
  755. OpBranch %27
  756. %27 = OpLabel
  757. %28 = OpSLessThan %12 %23 %11
  758. OpBranchConditional %28 %29 %26
  759. %29 = OpLabel
  760. %30 = OpAccessChain %18 %3 %23
  761. %31 = OpLoad %13 %30
  762. %32 = OpAccessChain %18 %3 %23
  763. OpStore %32 %31
  764. OpBranch %25
  765. %25 = OpLabel
  766. %24 = OpISub %8 %23 %19
  767. OpStore %5 %24
  768. OpBranch %22
  769. %26 = OpLabel
  770. OpReturn
  771. OpFunctionEnd
  772. )";
  773. std::unique_ptr<IRContext> context =
  774. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  775. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  776. Module* module = context->module();
  777. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  778. << text << std::endl;
  779. const Function* f = spvtest::GetFunction(module, 2);
  780. ScalarEvolutionAnalysis analysis{context.get()};
  781. const Instruction* loads[1] = {nullptr};
  782. int load_count = 0;
  783. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 29)) {
  784. if (inst.opcode() == SpvOp::SpvOpLoad) {
  785. loads[load_count] = &inst;
  786. ++load_count;
  787. }
  788. }
  789. EXPECT_EQ(load_count, 1);
  790. Instruction* load_access_chain =
  791. context->get_def_use_mgr()->GetDef(loads[0]->GetSingleWordInOperand(0));
  792. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  793. load_access_chain->GetSingleWordInOperand(1));
  794. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  795. EXPECT_TRUE(load_node);
  796. EXPECT_EQ(load_node->GetType(), SENode::RecurrentAddExpr);
  797. EXPECT_TRUE(load_node->AsSERecurrentNode());
  798. SENode* child_1 = load_node->AsSERecurrentNode()->GetCoefficient();
  799. SENode* child_2 = load_node->AsSERecurrentNode()->GetOffset();
  800. EXPECT_EQ(child_1->GetType(), SENode::Constant);
  801. EXPECT_EQ(child_2->GetType(), SENode::Constant);
  802. EXPECT_EQ(child_1->AsSEConstantNode()->FoldToSingleValue(), -1);
  803. EXPECT_EQ(child_2->AsSEConstantNode()->FoldToSingleValue(), 0u);
  804. SERecurrentNode* load_simplified =
  805. analysis.SimplifyExpression(load_node)->AsSERecurrentNode();
  806. EXPECT_TRUE(load_simplified);
  807. EXPECT_EQ(load_node, load_simplified);
  808. EXPECT_EQ(load_simplified->GetType(), SENode::RecurrentAddExpr);
  809. EXPECT_TRUE(load_simplified->AsSERecurrentNode());
  810. SENode* simplified_child_1 =
  811. load_simplified->AsSERecurrentNode()->GetCoefficient();
  812. SENode* simplified_child_2 =
  813. load_simplified->AsSERecurrentNode()->GetOffset();
  814. EXPECT_EQ(child_1, simplified_child_1);
  815. EXPECT_EQ(child_2, simplified_child_2);
  816. }
  817. /*
  818. Generated from the following GLSL + --eliminate-local-multi-store
  819. #version 430
  820. void main(void) {
  821. for (int i = 0; i < 10; --i) {
  822. array[i] = array[i];
  823. }
  824. }
  825. */
  826. TEST_F(ScalarAnalysisTest, SimplifyInductionsAndLoads) {
  827. const std::string text = R"(
  828. OpCapability Shader
  829. %1 = OpExtInstImport "GLSL.std.450"
  830. OpMemoryModel Logical GLSL450
  831. OpEntryPoint Fragment %2 "main" %3 %4
  832. OpExecutionMode %2 OriginUpperLeft
  833. OpSource GLSL 430
  834. OpName %2 "main"
  835. OpName %5 "i"
  836. OpName %3 "array"
  837. OpName %4 "N"
  838. OpDecorate %3 Location 1
  839. OpDecorate %4 Flat
  840. OpDecorate %4 Location 2
  841. %6 = OpTypeVoid
  842. %7 = OpTypeFunction %6
  843. %8 = OpTypeInt 32 1
  844. %9 = OpTypePointer Function %8
  845. %10 = OpConstant %8 0
  846. %11 = OpConstant %8 10
  847. %12 = OpTypeBool
  848. %13 = OpTypeFloat 32
  849. %14 = OpTypeInt 32 0
  850. %15 = OpConstant %14 10
  851. %16 = OpTypeArray %13 %15
  852. %17 = OpTypePointer Output %16
  853. %3 = OpVariable %17 Output
  854. %18 = OpConstant %8 2
  855. %19 = OpTypePointer Input %8
  856. %4 = OpVariable %19 Input
  857. %20 = OpTypePointer Output %13
  858. %21 = OpConstant %8 1
  859. %2 = OpFunction %6 None %7
  860. %22 = OpLabel
  861. %5 = OpVariable %9 Function
  862. OpStore %5 %10
  863. OpBranch %23
  864. %23 = OpLabel
  865. %24 = OpPhi %8 %10 %22 %25 %26
  866. OpLoopMerge %27 %26 None
  867. OpBranch %28
  868. %28 = OpLabel
  869. %29 = OpSLessThan %12 %24 %11
  870. OpBranchConditional %29 %30 %27
  871. %30 = OpLabel
  872. %31 = OpLoad %8 %4
  873. %32 = OpIMul %8 %18 %31
  874. %33 = OpIAdd %8 %24 %32
  875. %35 = OpIAdd %8 %24 %31
  876. %36 = OpAccessChain %20 %3 %35
  877. %37 = OpLoad %13 %36
  878. %38 = OpAccessChain %20 %3 %33
  879. OpStore %38 %37
  880. %39 = OpIMul %8 %18 %24
  881. %41 = OpIMul %8 %18 %31
  882. %42 = OpIAdd %8 %39 %41
  883. %43 = OpIAdd %8 %42 %21
  884. %44 = OpIMul %8 %18 %24
  885. %46 = OpIAdd %8 %44 %31
  886. %47 = OpIAdd %8 %46 %21
  887. %48 = OpAccessChain %20 %3 %47
  888. %49 = OpLoad %13 %48
  889. %50 = OpAccessChain %20 %3 %43
  890. OpStore %50 %49
  891. OpBranch %26
  892. %26 = OpLabel
  893. %25 = OpISub %8 %24 %21
  894. OpStore %5 %25
  895. OpBranch %23
  896. %27 = OpLabel
  897. OpReturn
  898. OpFunctionEnd
  899. )";
  900. std::unique_ptr<IRContext> context =
  901. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  902. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  903. Module* module = context->module();
  904. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  905. << text << std::endl;
  906. const Function* f = spvtest::GetFunction(module, 2);
  907. ScalarEvolutionAnalysis analysis{context.get()};
  908. std::vector<const Instruction*> loads{};
  909. std::vector<const Instruction*> stores{};
  910. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 30)) {
  911. if (inst.opcode() == SpvOp::SpvOpLoad) {
  912. loads.push_back(&inst);
  913. }
  914. if (inst.opcode() == SpvOp::SpvOpStore) {
  915. stores.push_back(&inst);
  916. }
  917. }
  918. EXPECT_EQ(loads.size(), 3u);
  919. EXPECT_EQ(stores.size(), 2u);
  920. {
  921. Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
  922. stores[0]->GetSingleWordInOperand(0));
  923. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  924. store_access_chain->GetSingleWordInOperand(1));
  925. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  926. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  927. Instruction* load_access_chain =
  928. context->get_def_use_mgr()->GetDef(loads[1]->GetSingleWordInOperand(0));
  929. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  930. load_access_chain->GetSingleWordInOperand(1));
  931. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  932. SENode* load_simplified = analysis.SimplifyExpression(load_node);
  933. SENode* difference =
  934. analysis.CreateSubtraction(store_simplified, load_simplified);
  935. SENode* difference_simplified = analysis.SimplifyExpression(difference);
  936. // Check that i+2*N - i*N, turns into just N when both sides have already
  937. // been simplified into a single recurrent expression.
  938. EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
  939. // Check that the inverse, i*N - i+2*N turns into -N.
  940. SENode* difference_inverse = analysis.SimplifyExpression(
  941. analysis.CreateSubtraction(load_simplified, store_simplified));
  942. EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
  943. EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
  944. EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  945. }
  946. {
  947. Instruction* store_access_chain = context->get_def_use_mgr()->GetDef(
  948. stores[1]->GetSingleWordInOperand(0));
  949. Instruction* store_child = context->get_def_use_mgr()->GetDef(
  950. store_access_chain->GetSingleWordInOperand(1));
  951. SENode* store_node = analysis.AnalyzeInstruction(store_child);
  952. SENode* store_simplified = analysis.SimplifyExpression(store_node);
  953. Instruction* load_access_chain =
  954. context->get_def_use_mgr()->GetDef(loads[2]->GetSingleWordInOperand(0));
  955. Instruction* load_child = context->get_def_use_mgr()->GetDef(
  956. load_access_chain->GetSingleWordInOperand(1));
  957. SENode* load_node = analysis.AnalyzeInstruction(load_child);
  958. SENode* load_simplified = analysis.SimplifyExpression(load_node);
  959. SENode* difference =
  960. analysis.CreateSubtraction(store_simplified, load_simplified);
  961. SENode* difference_simplified = analysis.SimplifyExpression(difference);
  962. // Check that 2*i + 2*N + 1 - 2*i + N + 1, turns into just N when both
  963. // sides have already been simplified into a single recurrent expression.
  964. EXPECT_EQ(difference_simplified->GetType(), SENode::ValueUnknown);
  965. // Check that the inverse, (2*i + N + 1) - (2*i + 2*N + 1) turns into -N.
  966. SENode* difference_inverse = analysis.SimplifyExpression(
  967. analysis.CreateSubtraction(load_simplified, store_simplified));
  968. EXPECT_EQ(difference_inverse->GetType(), SENode::Negative);
  969. EXPECT_EQ(difference_inverse->GetChild(0)->GetType(), SENode::ValueUnknown);
  970. EXPECT_EQ(difference_inverse->GetChild(0), difference_simplified);
  971. }
  972. }
  973. /* Generated from the following GLSL + --eliminate-local-multi-store
  974. #version 430
  975. layout(location = 1) out float array[10];
  976. layout(location = 2) flat in int N;
  977. void main(void) {
  978. int step = 0;
  979. for (int i = 0; i < N; i += step) {
  980. step++;
  981. }
  982. }
  983. */
  984. TEST_F(ScalarAnalysisTest, InductionWithVariantStep) {
  985. const std::string text = R"(
  986. OpCapability Shader
  987. %1 = OpExtInstImport "GLSL.std.450"
  988. OpMemoryModel Logical GLSL450
  989. OpEntryPoint Fragment %2 "main" %3 %4
  990. OpExecutionMode %2 OriginUpperLeft
  991. OpSource GLSL 430
  992. OpName %2 "main"
  993. OpName %5 "step"
  994. OpName %6 "i"
  995. OpName %3 "N"
  996. OpName %4 "array"
  997. OpDecorate %3 Flat
  998. OpDecorate %3 Location 2
  999. OpDecorate %4 Location 1
  1000. %7 = OpTypeVoid
  1001. %8 = OpTypeFunction %7
  1002. %9 = OpTypeInt 32 1
  1003. %10 = OpTypePointer Function %9
  1004. %11 = OpConstant %9 0
  1005. %12 = OpTypePointer Input %9
  1006. %3 = OpVariable %12 Input
  1007. %13 = OpTypeBool
  1008. %14 = OpConstant %9 1
  1009. %15 = OpTypeFloat 32
  1010. %16 = OpTypeInt 32 0
  1011. %17 = OpConstant %16 10
  1012. %18 = OpTypeArray %15 %17
  1013. %19 = OpTypePointer Output %18
  1014. %4 = OpVariable %19 Output
  1015. %2 = OpFunction %7 None %8
  1016. %20 = OpLabel
  1017. %5 = OpVariable %10 Function
  1018. %6 = OpVariable %10 Function
  1019. OpStore %5 %11
  1020. OpStore %6 %11
  1021. OpBranch %21
  1022. %21 = OpLabel
  1023. %22 = OpPhi %9 %11 %20 %23 %24
  1024. %25 = OpPhi %9 %11 %20 %26 %24
  1025. OpLoopMerge %27 %24 None
  1026. OpBranch %28
  1027. %28 = OpLabel
  1028. %29 = OpLoad %9 %3
  1029. %30 = OpSLessThan %13 %25 %29
  1030. OpBranchConditional %30 %31 %27
  1031. %31 = OpLabel
  1032. %23 = OpIAdd %9 %22 %14
  1033. OpStore %5 %23
  1034. OpBranch %24
  1035. %24 = OpLabel
  1036. %26 = OpIAdd %9 %25 %23
  1037. OpStore %6 %26
  1038. OpBranch %21
  1039. %27 = OpLabel
  1040. OpReturn
  1041. OpFunctionEnd
  1042. )";
  1043. std::unique_ptr<IRContext> context =
  1044. BuildModule(SPV_ENV_UNIVERSAL_1_1, nullptr, text,
  1045. SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
  1046. Module* module = context->module();
  1047. EXPECT_NE(nullptr, module) << "Assembling failed for shader:\n"
  1048. << text << std::endl;
  1049. const Function* f = spvtest::GetFunction(module, 2);
  1050. ScalarEvolutionAnalysis analysis{context.get()};
  1051. std::vector<const Instruction*> phis{};
  1052. for (const Instruction& inst : *spvtest::GetBasicBlock(f, 21)) {
  1053. if (inst.opcode() == SpvOp::SpvOpPhi) {
  1054. phis.push_back(&inst);
  1055. }
  1056. }
  1057. EXPECT_EQ(phis.size(), 2u);
  1058. SENode* phi_node_1 = analysis.AnalyzeInstruction(phis[0]);
  1059. SENode* phi_node_2 = analysis.AnalyzeInstruction(phis[1]);
  1060. phi_node_1->DumpDot(std::cout, true);
  1061. EXPECT_NE(phi_node_1, nullptr);
  1062. EXPECT_NE(phi_node_2, nullptr);
  1063. EXPECT_EQ(phi_node_1->GetType(), SENode::RecurrentAddExpr);
  1064. EXPECT_EQ(phi_node_2->GetType(), SENode::CanNotCompute);
  1065. SENode* simplified_1 = analysis.SimplifyExpression(phi_node_1);
  1066. SENode* simplified_2 = analysis.SimplifyExpression(phi_node_2);
  1067. EXPECT_EQ(simplified_1->GetType(), SENode::RecurrentAddExpr);
  1068. EXPECT_EQ(simplified_2->GetType(), SENode::CanNotCompute);
  1069. }
  1070. } // namespace
  1071. } // namespace opt
  1072. } // namespace spvtools