constant_fold.cpp 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. /*
  2. * Copyright 2011-2013 Blender Foundation
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "render/constant_fold.h"
  17. #include "render/graph.h"
  18. #include "util/util_foreach.h"
  19. #include "util/util_logging.h"
  20. CCL_NAMESPACE_BEGIN
  21. ConstantFolder::ConstantFolder(ShaderGraph *graph,
  22. ShaderNode *node,
  23. ShaderOutput *output,
  24. Scene *scene)
  25. : graph(graph), node(node), output(output), scene(scene)
  26. {
  27. }
  28. bool ConstantFolder::all_inputs_constant() const
  29. {
  30. foreach (ShaderInput *input, node->inputs) {
  31. if (input->link) {
  32. return false;
  33. }
  34. }
  35. return true;
  36. }
  37. void ConstantFolder::make_constant(float value) const
  38. {
  39. VLOG(1) << "Folding " << node->name << "::" << output->name() << " to constant (" << value
  40. << ").";
  41. foreach (ShaderInput *sock, output->links) {
  42. sock->set(value);
  43. }
  44. graph->disconnect(output);
  45. }
  46. void ConstantFolder::make_constant(float3 value) const
  47. {
  48. VLOG(1) << "Folding " << node->name << "::" << output->name() << " to constant " << value << ".";
  49. foreach (ShaderInput *sock, output->links) {
  50. sock->set(value);
  51. }
  52. graph->disconnect(output);
  53. }
  54. void ConstantFolder::make_constant_clamp(float value, bool clamp) const
  55. {
  56. make_constant(clamp ? saturate(value) : value);
  57. }
  58. void ConstantFolder::make_constant_clamp(float3 value, bool clamp) const
  59. {
  60. if (clamp) {
  61. value.x = saturate(value.x);
  62. value.y = saturate(value.y);
  63. value.z = saturate(value.z);
  64. }
  65. make_constant(value);
  66. }
  67. void ConstantFolder::make_zero() const
  68. {
  69. if (output->type() == SocketType::FLOAT) {
  70. make_constant(0.0f);
  71. }
  72. else if (SocketType::is_float3(output->type())) {
  73. make_constant(make_float3(0.0f, 0.0f, 0.0f));
  74. }
  75. else {
  76. assert(0);
  77. }
  78. }
  79. void ConstantFolder::make_one() const
  80. {
  81. if (output->type() == SocketType::FLOAT) {
  82. make_constant(1.0f);
  83. }
  84. else if (SocketType::is_float3(output->type())) {
  85. make_constant(make_float3(1.0f, 1.0f, 1.0f));
  86. }
  87. else {
  88. assert(0);
  89. }
  90. }
  91. void ConstantFolder::bypass(ShaderOutput *new_output) const
  92. {
  93. assert(new_output);
  94. VLOG(1) << "Folding " << node->name << "::" << output->name() << " to socket "
  95. << new_output->parent->name << "::" << new_output->name() << ".";
  96. /* Remove all outgoing links from socket and connect them to new_output instead.
  97. * The graph->relink method affects node inputs, so it's not safe to use in constant
  98. * folding if the node has multiple outputs and will thus be folded multiple times. */
  99. vector<ShaderInput *> outputs = output->links;
  100. graph->disconnect(output);
  101. foreach (ShaderInput *sock, outputs) {
  102. graph->connect(new_output, sock);
  103. }
  104. }
  105. void ConstantFolder::discard() const
  106. {
  107. assert(output->type() == SocketType::CLOSURE);
  108. VLOG(1) << "Discarding closure " << node->name << ".";
  109. graph->disconnect(output);
  110. }
  111. void ConstantFolder::bypass_or_discard(ShaderInput *input) const
  112. {
  113. assert(input->type() == SocketType::CLOSURE);
  114. if (input->link) {
  115. bypass(input->link);
  116. }
  117. else {
  118. discard();
  119. }
  120. }
  121. bool ConstantFolder::try_bypass_or_make_constant(ShaderInput *input, bool clamp) const
  122. {
  123. if (input->type() != output->type()) {
  124. return false;
  125. }
  126. else if (!input->link) {
  127. if (input->type() == SocketType::FLOAT) {
  128. make_constant_clamp(node->get_float(input->socket_type), clamp);
  129. return true;
  130. }
  131. else if (SocketType::is_float3(input->type())) {
  132. make_constant_clamp(node->get_float3(input->socket_type), clamp);
  133. return true;
  134. }
  135. }
  136. else if (!clamp) {
  137. bypass(input->link);
  138. return true;
  139. }
  140. else {
  141. /* disconnect other inputs if we can't fully bypass due to clamp */
  142. foreach (ShaderInput *other, node->inputs) {
  143. if (other != input && other->link) {
  144. graph->disconnect(other);
  145. }
  146. }
  147. }
  148. return false;
  149. }
  150. bool ConstantFolder::is_zero(ShaderInput *input) const
  151. {
  152. if (!input->link) {
  153. if (input->type() == SocketType::FLOAT) {
  154. return node->get_float(input->socket_type) == 0.0f;
  155. }
  156. else if (SocketType::is_float3(input->type())) {
  157. return node->get_float3(input->socket_type) == make_float3(0.0f, 0.0f, 0.0f);
  158. }
  159. }
  160. return false;
  161. }
  162. bool ConstantFolder::is_one(ShaderInput *input) const
  163. {
  164. if (!input->link) {
  165. if (input->type() == SocketType::FLOAT) {
  166. return node->get_float(input->socket_type) == 1.0f;
  167. }
  168. else if (SocketType::is_float3(input->type())) {
  169. return node->get_float3(input->socket_type) == make_float3(1.0f, 1.0f, 1.0f);
  170. }
  171. }
  172. return false;
  173. }
  174. /* Specific nodes */
  175. void ConstantFolder::fold_mix(NodeMix type, bool clamp) const
  176. {
  177. ShaderInput *fac_in = node->input("Fac");
  178. ShaderInput *color1_in = node->input("Color1");
  179. ShaderInput *color2_in = node->input("Color2");
  180. float fac = saturate(node->get_float(fac_in->socket_type));
  181. bool fac_is_zero = !fac_in->link && fac == 0.0f;
  182. bool fac_is_one = !fac_in->link && fac == 1.0f;
  183. /* remove no-op node when factor is 0.0 */
  184. if (fac_is_zero) {
  185. /* note that some of the modes will clamp out of bounds values even without use_clamp */
  186. if (!(type == NODE_MIX_LIGHT || type == NODE_MIX_DODGE || type == NODE_MIX_BURN)) {
  187. if (try_bypass_or_make_constant(color1_in, clamp)) {
  188. return;
  189. }
  190. }
  191. }
  192. switch (type) {
  193. case NODE_MIX_BLEND:
  194. /* remove useless mix colors nodes */
  195. if (color1_in->link && color2_in->link) {
  196. if (color1_in->link == color2_in->link) {
  197. try_bypass_or_make_constant(color1_in, clamp);
  198. break;
  199. }
  200. }
  201. else if (!color1_in->link && !color2_in->link) {
  202. float3 color1 = node->get_float3(color1_in->socket_type);
  203. float3 color2 = node->get_float3(color2_in->socket_type);
  204. if (color1 == color2) {
  205. try_bypass_or_make_constant(color1_in, clamp);
  206. break;
  207. }
  208. }
  209. /* remove no-op mix color node when factor is 1.0 */
  210. if (fac_is_one) {
  211. try_bypass_or_make_constant(color2_in, clamp);
  212. break;
  213. }
  214. break;
  215. case NODE_MIX_ADD:
  216. /* 0 + X (fac 1) == X */
  217. if (is_zero(color1_in) && fac_is_one) {
  218. try_bypass_or_make_constant(color2_in, clamp);
  219. }
  220. /* X + 0 (fac ?) == X */
  221. else if (is_zero(color2_in)) {
  222. try_bypass_or_make_constant(color1_in, clamp);
  223. }
  224. break;
  225. case NODE_MIX_SUB:
  226. /* X - 0 (fac ?) == X */
  227. if (is_zero(color2_in)) {
  228. try_bypass_or_make_constant(color1_in, clamp);
  229. }
  230. /* X - X (fac 1) == 0 */
  231. else if (color1_in->link && color1_in->link == color2_in->link && fac_is_one) {
  232. make_zero();
  233. }
  234. break;
  235. case NODE_MIX_MUL:
  236. /* X * 1 (fac ?) == X, 1 * X (fac 1) == X */
  237. if (is_one(color1_in) && fac_is_one) {
  238. try_bypass_or_make_constant(color2_in, clamp);
  239. }
  240. else if (is_one(color2_in)) {
  241. try_bypass_or_make_constant(color1_in, clamp);
  242. }
  243. /* 0 * ? (fac ?) == 0, ? * 0 (fac 1) == 0 */
  244. else if (is_zero(color1_in)) {
  245. make_zero();
  246. }
  247. else if (is_zero(color2_in) && fac_is_one) {
  248. make_zero();
  249. }
  250. break;
  251. case NODE_MIX_DIV:
  252. /* X / 1 (fac ?) == X */
  253. if (is_one(color2_in)) {
  254. try_bypass_or_make_constant(color1_in, clamp);
  255. }
  256. /* 0 / ? (fac ?) == 0 */
  257. else if (is_zero(color1_in)) {
  258. make_zero();
  259. }
  260. break;
  261. default:
  262. break;
  263. }
  264. }
  265. void ConstantFolder::fold_math(NodeMath type, bool clamp) const
  266. {
  267. ShaderInput *value1_in = node->input("Value1");
  268. ShaderInput *value2_in = node->input("Value2");
  269. switch (type) {
  270. case NODE_MATH_ADD:
  271. /* X + 0 == 0 + X == X */
  272. if (is_zero(value1_in)) {
  273. try_bypass_or_make_constant(value2_in, clamp);
  274. }
  275. else if (is_zero(value2_in)) {
  276. try_bypass_or_make_constant(value1_in, clamp);
  277. }
  278. break;
  279. case NODE_MATH_SUBTRACT:
  280. /* X - 0 == X */
  281. if (is_zero(value2_in)) {
  282. try_bypass_or_make_constant(value1_in, clamp);
  283. }
  284. break;
  285. case NODE_MATH_MULTIPLY:
  286. /* X * 1 == 1 * X == X */
  287. if (is_one(value1_in)) {
  288. try_bypass_or_make_constant(value2_in, clamp);
  289. }
  290. else if (is_one(value2_in)) {
  291. try_bypass_or_make_constant(value1_in, clamp);
  292. }
  293. /* X * 0 == 0 * X == 0 */
  294. else if (is_zero(value1_in) || is_zero(value2_in)) {
  295. make_zero();
  296. }
  297. break;
  298. case NODE_MATH_DIVIDE:
  299. /* X / 1 == X */
  300. if (is_one(value2_in)) {
  301. try_bypass_or_make_constant(value1_in, clamp);
  302. }
  303. /* 0 / X == 0 */
  304. else if (is_zero(value1_in)) {
  305. make_zero();
  306. }
  307. break;
  308. case NODE_MATH_POWER:
  309. /* 1 ^ X == X ^ 0 == 1 */
  310. if (is_one(value1_in) || is_zero(value2_in)) {
  311. make_one();
  312. }
  313. /* X ^ 1 == X */
  314. else if (is_one(value2_in)) {
  315. try_bypass_or_make_constant(value1_in, clamp);
  316. }
  317. default:
  318. break;
  319. }
  320. }
  321. void ConstantFolder::fold_vector_math(NodeVectorMath type) const
  322. {
  323. ShaderInput *vector1_in = node->input("Vector1");
  324. ShaderInput *vector2_in = node->input("Vector2");
  325. switch (type) {
  326. case NODE_VECTOR_MATH_ADD:
  327. /* X + 0 == 0 + X == X */
  328. if (is_zero(vector1_in)) {
  329. try_bypass_or_make_constant(vector2_in);
  330. }
  331. else if (is_zero(vector2_in)) {
  332. try_bypass_or_make_constant(vector1_in);
  333. }
  334. break;
  335. case NODE_VECTOR_MATH_SUBTRACT:
  336. /* X - 0 == X */
  337. if (is_zero(vector2_in)) {
  338. try_bypass_or_make_constant(vector1_in);
  339. }
  340. break;
  341. case NODE_VECTOR_MATH_DOT_PRODUCT:
  342. case NODE_VECTOR_MATH_CROSS_PRODUCT:
  343. /* X * 0 == 0 * X == 0 */
  344. if (is_zero(vector1_in) || is_zero(vector2_in)) {
  345. make_zero();
  346. }
  347. break;
  348. default:
  349. break;
  350. }
  351. }
  352. CCL_NAMESPACE_END