tincremental.nim 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. discard """
  2. output: '''heavy_calc_impl is called
  3. sub_calc1_impl is called
  4. sub_calc2_impl is called
  5. ** no changes recompute effectively
  6. ** change one input and recompute effectively
  7. heavy_calc_impl is called
  8. sub_calc2_impl is called'''
  9. """
  10. # sample incremental
  11. import tables
  12. import macros
  13. var inputs = initTable[string, float]()
  14. var cache = initTable[string, float]()
  15. var dep_tree {.compileTime.} = initTable[string, string]()
  16. macro symHash(s: typed{nkSym}): string =
  17. result = newStrLitNode(symBodyHash(s))
  18. #######################################################################################
  19. template graph_node(key: string) {.pragma.}
  20. proc tag(n: NimNode): NimNode =
  21. ## returns graph node unique name of a function or nil if it is not a graph node
  22. expectKind(n, {nnkProcDef, nnkFuncDef})
  23. for p in n.pragma:
  24. if p.len > 0 and p[0] == bindSym"graph_node":
  25. return p[1]
  26. return nil
  27. macro graph_node_key(n: typed{nkSym}): untyped =
  28. result = newStrLitNode(n.symBodyHash)
  29. macro graph_discovery(n: typed{nkSym}): untyped =
  30. # discovers graph dependency tree and updated dep_tree global var
  31. let mytag = newStrLitNode(n.symBodyHash)
  32. var visited: seq[NimNode]
  33. proc discover(n: NimNode) =
  34. case n.kind:
  35. of nnkNone..pred(nnkSym), succ(nnkSym)..nnkNilLit: discard
  36. of nnkSym:
  37. if n.symKind in {nskFunc, nskProc}:
  38. if n notin visited:
  39. visited.add n
  40. let tag = n.getImpl.tag
  41. if tag != nil:
  42. dep_tree[tag.strVal] = mytag.strVal
  43. else:
  44. discover(n.getImpl.body)
  45. else:
  46. for child in n:
  47. discover(child)
  48. discover(n.getImpl.body)
  49. result = newEmptyNode()
  50. #######################################################################################
  51. macro incremental_input(key: static[string], n: untyped{nkFuncDef}): untyped =
  52. # mark leaf nodes of the graph
  53. template getInput(key) {.dirty.} =
  54. {.noSideEffect.}:
  55. inputs[key]
  56. result = n
  57. result.pragma = nnkPragma.newTree(nnkCall.newTree(bindSym"graph_node", newStrLitNode(key)))
  58. result.body = getAst(getInput(key))
  59. macro incremental(n: untyped{nkFuncDef}): untyped =
  60. ## incrementalize side effect free computation
  61. ## wraps function into caching layer, mark caching function as a graph_node
  62. ## injects dependency discovery between graph nodes
  63. template cache_func_body(func_name, func_name_str, func_call) {.dirty.} =
  64. {.noSideEffect.}:
  65. graph_discovery(func_name)
  66. let key = graph_node_key(func_name)
  67. if key in cache:
  68. result = cache[key]
  69. else:
  70. echo func_name_str & " is called"
  71. result = func_call
  72. cache[key] = result
  73. let func_name = n.name.strVal & "_impl"
  74. let func_call = nnkCall.newTree(ident func_name)
  75. for i in 1..<n.params.len:
  76. func_call.add n.params[i][0]
  77. let cache_func = n.copyNimTree
  78. cache_func.body = getAst(cache_func_body(ident func_name, func_name, func_call))
  79. cache_func.pragma = nnkPragma.newTree(newCall(bindSym"graph_node",
  80. newCall(bindSym"symHash", ident func_name)))
  81. n.name = ident(func_name)
  82. result = nnkStmtList.newTree(n, cache_func)
  83. ###########################################################################
  84. ### Example
  85. ###########################################################################
  86. func input1(): float {.incremental_input("a1").}
  87. func input2(): float {.incremental_input("a2").}
  88. func sub_calc1(a: float): float {.incremental.} =
  89. a + input1()
  90. func sub_calc2(b: float): float {.incremental.} =
  91. b + input2()
  92. func heavy_calc(a: float, b: float): float {.incremental.} =
  93. sub_calc1(a) + sub_calc2(b)
  94. ###########################################################################
  95. ## graph finalize and inputs
  96. ###########################################################################
  97. macro finalize_dep_tree(): untyped =
  98. result = nnkTableConstr.newNimNode
  99. for key, val in dep_tree:
  100. result.add nnkExprColonExpr.newTree(newStrLitNode key, newStrLitNode val)
  101. result = nnkCall.newTree(bindSym"toTable", result)
  102. const dep_tree_final = finalize_dep_tree()
  103. proc set_input(key: string, val: float) =
  104. ## set input value
  105. ## all affected nodes of graph are invalidated
  106. inputs[key] = val
  107. var k = key
  108. while k != "":
  109. k = dep_tree_final.getOrDefault(k , "")
  110. cache.del(k)
  111. ###########################################################################
  112. ## demo
  113. ###########################################################################
  114. set_input("a1", 5)
  115. set_input("a2", 2)
  116. discard heavy_calc(5.0, 10.0)
  117. echo "** no changes recompute effectively"
  118. discard heavy_calc(5.0, 10.0)
  119. echo "** change one input and recompute effectively"
  120. set_input("a2", 10)
  121. discard heavy_calc(5.0, 10.0)