ast_pattern_matching.nim 18 KB


  1. # this is a copy paste implementation of github.com/krux02/ast_pattern_matching
  2. # Please provide bugfixes upstream first before adding them here.
  3. import macros, strutils, tables
  4. export macros
  5. when isMainModule:
  6. template debug(args: varargs[untyped]): untyped =
  7. echo args
  8. else:
  9. template debug(args: varargs[untyped]): untyped =
  10. discard
  11. const
  12. nnkIntLiterals* = nnkCharLit..nnkUInt64Lit
  13. nnkStringLiterals* = nnkStrLit..nnkTripleStrLit
  14. nnkFloatLiterals* = nnkFloatLit..nnkFloat64Lit
  15. proc newLit[T: enum](arg: T): NimNode =
  16. newIdentNode($arg)
  17. proc newLit[T](arg: set[T]): NimNode =
  18. ## does not work for the empty sets
  19. result = nnkCurly.newTree
  20. for x in arg:
  21. result.add newLit(x)
  22. type SomeFloat = float | float32 | float64
  23. proc len[T](arg: set[T]): int = card(arg)
  24. type
  25. MatchingErrorKind* = enum
  26. NoError
  27. WrongKindLength
  28. WrongKindValue
  29. WrongIdent
  30. WrongCustomCondition
  31. MatchingError = object
  32. node*: NimNode
  33. expectedKind*: set[NimNodeKind]
  34. case kind*: MatchingErrorKind
  35. of NoError:
  36. discard
  37. of WrongKindLength:
  38. expectedLength*: int
  39. of WrongKindValue:
  40. expectedValue*: NimNode
  41. of WrongIdent, WrongCustomCondition:
  42. strVal*: string
  43. proc `$`*(arg: MatchingError): string =
  44. let n = arg.node
  45. case arg.kind
  46. of NoError:
  47. "no error"
  48. of WrongKindLength:
  49. let k = arg.expectedKind
  50. let l = arg.expectedLength
  51. var msg = "expected "
  52. if k.len == 0:
  53. msg.add "any node"
  54. elif k.len == 1:
  55. for el in k: # only one element but there is no index op for sets
  56. msg.add $el
  57. else:
  58. msg.add "a node in" & $k
  59. if l >= 0:
  60. msg.add " with " & $l & " child(ren)"
  61. msg.add ", but got " & $n.kind
  62. if l >= 0:
  63. msg.add " with " & $n.len & " child(ren)"
  64. msg
  65. of WrongKindValue:
  66. let k = $arg.expectedKind
  67. let v = arg.expectedValue.repr
  68. var msg = "expected " & k & " with value " & v & " but got " & n.lispRepr
  69. if n.kind in {nnkOpenSymChoice, nnkClosedSymChoice}:
  70. msg = msg & " (a sym-choice does not have a strVal member, maybe you should match with `ident`)"
  71. msg
  72. of WrongIdent:
  73. let prefix = "expected ident `" & arg.strVal & "` but got "
  74. if n.kind in {nnkIdent, nnkSym, nnkOpenSymChoice, nnkClosedSymChoice}:
  75. prefix & "`" & n.strVal & "`"
  76. else:
  77. prefix & $n.kind & " with " & $n.len & " child(ren)"
  78. of WrongCustomCondition:
  79. "custom condition check failed: " & arg.strVal
  80. proc failWithMatchingError*(arg: MatchingError): void {.compileTime, noReturn.} =
  81. error($arg, arg.node)
  82. proc expectValue(arg: NimNode; value: SomeInteger): void {.compileTime.} =
  83. arg.expectKind nnkLiterals
  84. if arg.intVal != int(value):
  85. error("expected value " & $value & " but got " & arg.repr, arg)
  86. proc expectValue(arg: NimNode; value: SomeFloat): void {.compileTime.} =
  87. arg.expectKind nnkLiterals
  88. if arg.floatVal != float(value):
  89. error("expected value " & $value & " but got " & arg.repr, arg)
  90. proc expectValue(arg: NimNode; value: string): void {.compileTime.} =
  91. arg.expectKind nnkLiterals
  92. if arg.strVal != value:
  93. error("expected value " & value & " but got " & arg.repr, arg)
  94. proc expectValue[T](arg: NimNode; value: pointer): void {.compileTime.} =
  95. arg.expectKind nnkLiterals
  96. if value != nil:
  97. error("Expect Value for pointers works only on `nil` when the argument is a pointer.")
  98. arg.expectKind nnkNilLit
  99. proc expectIdent(arg: NimNode; strVal: string): void {.compileTime.} =
  100. if not arg.eqIdent(strVal):
  101. error("Expect ident `" & strVal & "` but got " & arg.repr)
  102. proc matchLengthKind*(arg: NimNode; kind: set[NimNodeKind]; length: int): MatchingError {.compileTime.} =
  103. let kindFail = not(kind.card == 0 or arg.kind in kind)
  104. let lengthFail = not(length < 0 or length == arg.len)
  105. if kindFail or lengthFail:
  106. result.node = arg
  107. result.kind = WrongKindLength
  108. result.expectedLength = length
  109. result.expectedKind = kind
  110. proc matchLengthKind*(arg: NimNode; kind: NimNodeKind; length: int): MatchingError {.compileTime.} =
  111. matchLengthKind(arg, {kind}, length)
  112. proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeInteger): MatchingError {.compileTime.} =
  113. template kindFail: bool = not(kind.card == 0 or arg.kind in kind)
  114. template valueFail: bool = arg.intVal != int(value)
  115. if kindFail or valueFail:
  116. result.node = arg
  117. result.kind = WrongKindValue
  118. result.expectedKind = kind
  119. result.expectedValue = newLit(value)
  120. proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeInteger): MatchingError {.compileTime.} =
  121. matchValue(arg, {kind}, value)
  122. proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: SomeFloat): MatchingError {.compileTime.} =
  123. let kindFail = not(kind.card == 0 or arg.kind in kind)
  124. let valueFail = arg.floatVal != float(value)
  125. if kindFail or valueFail:
  126. result.node = arg
  127. result.kind = WrongKindValue
  128. result.expectedKind = kind
  129. result.expectedValue = newLit(value)
  130. proc matchValue(arg: NimNode; kind: NimNodeKind; value: SomeFloat): MatchingError {.compileTime.} =
  131. matchValue(arg, {kind}, value)
  132. const nnkStrValKinds = {nnkStrLit, nnkRStrLit, nnkTripleStrLit, nnkIdent, nnkSym}
  133. proc matchValue(arg: NimNode; kind: set[NimNodeKind]; value: string): MatchingError {.compileTime.} =
  134. # if kind * nnkStringLiterals TODO do something that ensures that here is only checked for string literals
  135. let kindFail = not(kind.card == 0 or arg.kind in kind)
  136. let valueFail =
  137. if kind.card == 0:
  138. false
  139. else:
  140. arg.kind notin (kind * nnkStrValKinds) or arg.strVal != value
  141. if kindFail or valueFail:
  142. result.node = arg
  143. result.kind = WrongKindValue
  144. result.expectedKind = kind
  145. result.expectedValue = newLit(value)
  146. proc matchValue(arg: NimNode; kind: NimNodeKind; value: string): MatchingError {.compileTime.} =
  147. matchValue(arg, {kind}, value)
  148. proc matchValue[T](arg: NimNode; value: pointer): MatchingError {.compileTime.} =
  149. if value != nil:
  150. error("Expect Value for pointers works only on `nil` when the argument is a pointer.")
  151. arg.matchLengthKind(nnkNilLit, -1)
  152. proc matchIdent*(arg:NimNode; value: string): MatchingError =
  153. if not arg.eqIdent(value):
  154. result.node = arg
  155. result.kind = Wrongident
  156. result.strVal = value
  157. proc checkCustomExpr*(arg: NimNode; cond: bool, exprstr: string): MatchingError =
  158. if not cond:
  159. result.node = arg
  160. result.kind = WrongCustomCondition
  161. result.strVal = exprstr
  162. static:
  163. var literals: array[19, NimNode]
  164. var i = 0
  165. for litKind in nnkLiterals:
  166. literals[i] = ident($litKind)
  167. i += 1
  168. var nameToKind = newTable[string, NimNodeKind]()
  169. for kind in NimNodeKind:
  170. nameToKind[ ($kind)[3..^1] ] = kind
  171. let identifierKinds = newLit({nnkSym, nnkIdent, nnkOpenSymChoice, nnkClosedSymChoice})
  172. proc generateMatchingCode(astSym: NimNode, pattern: NimNode, depth: int, blockLabel, errorSym, localsArraySym: NimNode; dest: NimNode): int =
  173. ## return the number of indices used in the array for local variables.
  174. var currentLocalIndex = 0
  175. proc nodeVisiting(astSym: NimNode, pattern: NimNode, depth: int): void =
  176. let ind = " ".repeat(depth) # indentation
  177. proc genMatchLogic(matchProc, argSym1, argSym2: NimNode): void =
  178. dest.add quote do:
  179. `errorSym` = `astSym`.`matchProc`(`argSym1`, `argSym2`)
  180. if `errorSym`.kind != NoError:
  181. break `blockLabel`
  182. proc genIdentMatchLogic(identValueLit: NimNode): void =
  183. dest.add quote do:
  184. `errorSym` = `astSym`.matchIdent(`identValueLit`)
  185. if `errorSym`.kind != NoError:
  186. break `blockLabel`
  187. proc genCustomMatchLogic(conditionExpr: NimNode): void =
  188. let exprStr = newLit(conditionExpr.repr)
  189. dest.add quote do:
  190. `errorSym` = `astSym`.checkCustomExpr(`conditionExpr`, `exprStr`)
  191. if `errorSym`.kind != NoError:
  192. break `blockLabel`
  193. # proc handleKindMatching(kindExpr: NimNode): void =
  194. # if kindExpr.eqIdent("_"):
  195. # # this is the wildcand that matches any kind
  196. # return
  197. # else:
  198. # genMatchLogic(bindSym"matchKind", kindExpr)
  199. # generate recursively a matching expression
  200. if pattern.kind == nnkCall:
  201. pattern.expectMinLen(1)
  202. debug ind, pattern[0].repr, "("
  203. let kindSet = if pattern[0].eqIdent("_"): nnkCurly.newTree else: pattern[0]
  204. # handleKindMatching(pattern[0])
  205. if pattern.len == 2 and pattern[1].kind == nnkExprEqExpr:
  206. if pattern[1][1].kind in nnkStringLiterals:
  207. pattern[1][0].expectIdent("strVal")
  208. elif pattern[1][1].kind in nnkIntLiterals:
  209. pattern[1][0].expectIdent("intVal")
  210. elif pattern[1][1].kind in nnkFloatLiterals:
  211. pattern[1][0].expectIdent("floatVal")
  212. genMatchLogic(bindSym"matchValue", kindSet, pattern[1][1])
  213. else:
  214. let lengthLit = newLit(pattern.len - 1)
  215. genMatchLogic(bindSym"matchLengthKind", kindSet, lengthLit)
  216. for i in 1 ..< pattern.len:
  217. let childSym = nnkBracketExpr.newTree(localsArraySym, newLit(currentLocalIndex))
  218. currentLocalIndex += 1
  219. let indexLit = newLit(i - 1)
  220. dest.add quote do:
  221. `childSym` = `astSym`[`indexLit`]
  222. nodeVisiting(childSym, pattern[i], depth + 1)
  223. debug ind, ")"
  224. elif pattern.kind == nnkCallStrLit and pattern[0].eqIdent("ident"):
  225. genIdentMatchLogic(pattern[1])
  226. elif pattern.kind == nnkPar and pattern.len == 1:
  227. nodeVisiting(astSym, pattern[0], depth)
  228. elif pattern.kind == nnkPrefix:
  229. error("prefix patterns not implemented", pattern)
  230. elif pattern.kind == nnkAccQuoted:
  231. debug ind, pattern.repr
  232. let matchedExpr = pattern[0]
  233. matchedExpr.expectKind nnkIdent
  234. dest.add quote do:
  235. let `matchedExpr` = `astSym`
  236. elif pattern.kind == nnkInfix and pattern[0].eqIdent("@"):
  237. pattern[1].expectKind nnkAccQuoted
  238. let matchedExpr = pattern[1][0]
  239. matchedExpr.expectKind nnkIdent
  240. dest.add quote do:
  241. let `matchedExpr` = `astSym`
  242. debug ind, pattern[1].repr, " = "
  243. nodeVisiting(matchedExpr, pattern[2], depth + 1)
  244. elif pattern.kind == nnkInfix and pattern[0].eqIdent("|="):
  245. nodeVisiting(astSym, pattern[1], depth + 1)
  246. genCustomMatchLogic(pattern[2])
  247. elif pattern.kind in nnkCallKinds:
  248. error("only boring call syntax allowed, this is " & $pattern.kind & ".", pattern)
  249. elif pattern.kind in nnkLiterals:
  250. genMatchLogic(bindSym"matchValue", nnkCurly.newTree, pattern)
  251. elif not pattern.eqIdent("_"):
  252. # When it is not one of the other branches, it is simply treated
  253. # as an expression for the node kind, without checking child
  254. # nodes.
  255. debug ind, pattern.repr
  256. genMatchLogic(bindSym"matchLengthKind", pattern, newLit(-1))
  257. nodeVisiting(astSym, pattern, depth)
  258. return currentLocalIndex
  259. macro matchAst*(astExpr: NimNode; args: varargs[untyped]): untyped =
  260. let astSym = genSym(nskLet, "ast")
  261. let beginBranches = if args[0].kind == nnkIdent: 1 else: 0
  262. let endBranches = if args[^1].kind == nnkElse: args.len - 1 else: args.len
  263. for i in beginBranches ..< endBranches:
  264. args[i].expectKind nnkOfBranch
  265. let outerErrorSym: NimNode =
  266. if beginBranches == 1:
  267. args[0].expectKind nnkIdent
  268. args[0]
  269. else:
  270. nil
  271. let elseBranch: NimNode =
  272. if endBranches == args.len - 1:
  273. args[^1].expectKind(nnkElse)
  274. args[^1][0]
  275. else:
  276. nil
  277. let outerBlockLabel = genSym(nskLabel, "matchingSection")
  278. let outerStmtList = newStmtList()
  279. let errorSymbols = nnkBracket.newTree
  280. ## the vm only allows 255 local variables. This sucks a lot and I
  281. ## have to work around it. So instead of creating a lot of local
  282. ## variables, I just create one array of local variables. This is
  283. ## just annoying.
  284. let localsArraySym = genSym(nskVar, "locals")
  285. var localsArrayLen: int = 0
  286. for i in beginBranches ..< endBranches:
  287. let ofBranch = args[i]
  288. ofBranch.expectKind(nnkOfBranch)
  289. ofBranch.expectLen(2)
  290. let pattern = ofBranch[0]
  291. let code = ofBranch[1]
  292. code.expectKind nnkStmtList
  293. let stmtList = newStmtList()
  294. let blockLabel = genSym(nskLabel, "matchingBranch")
  295. let errorSym = genSym(nskVar, "branchError")
  296. errorSymbols.add errorSym
  297. let numLocalsUsed = generateMatchingCode(astSym, pattern, 0, blockLabel, errorSym, localsArraySym, stmtList)
  298. localsArrayLen = max(localsArrayLen, numLocalsUsed)
  299. stmtList.add code
  300. # maybe there is a better mechanism disable errors for statement after return
  301. if code[^1].kind != nnkReturnStmt:
  302. stmtList.add nnkBreakStmt.newTree(outerBlockLabel)
  303. outerStmtList.add quote do:
  304. var `errorSym`: MatchingError
  305. block `blockLabel`:
  306. `stmtList`
  307. if elseBranch != nil:
  308. if outerErrorSym != nil:
  309. outerStmtList.add quote do:
  310. let `outerErrorSym` = @`errorSymbols`
  311. `elseBranch`
  312. else:
  313. outerStmtList.add elseBranch
  314. else:
  315. if errorSymbols.len == 1:
  316. # there is only one of branch and no else branch
  317. # the error message can be very precise here.
  318. let errorSym = errorSymbols[0]
  319. outerStmtList.add quote do:
  320. failWithMatchingError(`errorSym`)
  321. else:
  322. var patterns: string = ""
  323. for i in beginBranches ..< endBranches:
  324. let ofBranch = args[i]
  325. let pattern = ofBranch[0]
  326. patterns.add pattern.repr
  327. patterns.add "\n"
  328. let patternsLit = newLit(patterns)
  329. outerStmtList.add quote do:
  330. error("Ast pattern mismatch: got " & `astSym`.lispRepr & "\nbut expected one of:\n" & `patternsLit`, `astSym`)
  331. let lengthLit = newLit(localsArrayLen)
  332. result = quote do:
  333. block `outerBlockLabel`:
  334. let `astSym` = `astExpr`
  335. var `localsArraySym`: array[`lengthLit`, NimNode]
  336. `outerStmtList`
  337. debug result.repr
  338. proc recursiveNodeVisiting*(arg: NimNode, callback: proc(arg: NimNode): bool) =
  339. ## if `callback` returns true, visitor continues to visit the
  340. ## children of `arg` otherwise it stops.
  341. if callback(arg):
  342. for child in arg:
  343. recursiveNodeVisiting(child, callback)
  344. macro matchAstRecursive*(ast: NimNode; args: varargs[untyped]): untyped =
  345. # Does not recurse further on matched nodes.
  346. if args[^1].kind == nnkElse:
  347. error("Recursive matching with an else branch is pointless.", args[^1])
  348. let visitor = genSym(nskProc, "visitor")
  349. let visitorArg = genSym(nskParam, "arg")
  350. let visitorStmtList = newStmtList()
  351. let matchingSection = genSym(nskLabel, "matchingSection")
  352. let localsArraySym = genSym(nskVar, "locals")
  353. let branchError = genSym(nskVar, "branchError")
  354. var localsArrayLen = 0
  355. for ofBranch in args:
  356. ofBranch.expectKind(nnkOfBranch)
  357. ofBranch.expectLen(2)
  358. let pattern = ofBranch[0]
  359. let code = ofBranch[1]
  360. code.expectkind(nnkStmtList)
  361. let stmtList = newStmtList()
  362. let matchingBranch = genSym(nskLabel, "matchingBranch")
  363. let numLocalsUsed = generateMatchingCode(visitorArg, pattern, 0, matchingBranch, branchError, localsArraySym, stmtList)
  364. localsArrayLen = max(localsArrayLen, numLocalsUsed)
  365. stmtList.add code
  366. stmtList.add nnkBreakStmt.newTree(matchingSection)
  367. visitorStmtList.add quote do:
  368. `branchError`.kind = NoError
  369. block `matchingBranch`:
  370. `stmtList`
  371. let resultIdent = ident"result"
  372. let visitingProc = bindSym"recursiveNodeVisiting"
  373. let lengthLit = newLit(localsArrayLen)
  374. result = quote do:
  375. proc `visitor`(`visitorArg`: NimNode): bool =
  376. block `matchingSection`:
  377. var `localsArraySym`: array[`lengthLit`, NimNode]
  378. var `branchError`: MatchingError
  379. `visitorStmtList`
  380. `resultIdent` = true
  381. `visitingProc`(`ast`, `visitor`)
  382. debug result.repr
  383. ################################################################################
  384. ################################# Example Code #################################
  385. ################################################################################
  386. when isMainModule:
  387. static:
  388. let mykinds = {nnkIdent, nnkCall}
  389. macro foo(arg: untyped): untyped =
  390. matchAst(arg, matchError):
  391. of nnkStmtList(nnkIdent, nnkIdent, nnkIdent):
  392. echo(88*88+33*33)
  393. of nnkStmtList(
  394. _(
  395. nnkIdentDefs(
  396. ident"a",
  397. nnkEmpty, nnkIntLit(intVal = 123)
  398. )
  399. ),
  400. _,
  401. nnkForStmt(
  402. nnkIdent(strVal = "i"),
  403. nnkInfix,
  404. `mysym` @ nnkStmtList
  405. )
  406. ):
  407. echo "The AST did match!!!"
  408. echo "The matched sub tree is the following:"
  409. echo mysym.lispRepr
  410. #else:
  411. # echo "sadly the AST did not match :("
  412. # echo arg.treeRepr
  413. # failWithMatchingError(matchError[1])
  414. foo:
  415. let a = 123
  416. let b = 342
  417. for i in a ..< b:
  418. echo "Hallo", i
  419. static:
  420. var ast = quote do:
  421. type
  422. A[T: static[int]] = object
  423. ast = ast[0]
  424. ast.matchAst(err): # this is a sub ast for this a findAst or something like that is useful
  425. of nnkTypeDef(_, nnkGenericParams( nnkIdentDefs( nnkIdent(strVal = "T"), `staticTy`, nnkEmpty )), _):
  426. echo "`", staticTy.repr, "` used to be of nnkStaticTy, now it is ", staticTy.kind, " with ", staticTy[0].repr
  427. ast = quote do:
  428. if cond1: expr1 elif cond2: expr2 else: expr3
  429. ast.matchAst:
  430. of {nnkIfExpr, nnkIfStmt}(
  431. {nnkElifExpr, nnkElifBranch}(`cond1`, `expr1`),
  432. {nnkElifExpr, nnkElifBranch}(`cond2`, `expr2`),
  433. {nnkElseExpr, nnkElse}(`expr3`)
  434. ):
  435. echo "ok"
  436. let ast2 = nnkStmtList.newTree( newLit(1) )
  437. ast2.matchAst:
  438. of nnkIntLit( 1 ):
  439. echo "fail"
  440. of nnkStmtList( 1 ):
  441. echo "ok"
  442. ast = bindSym"[]"
  443. ast.matchAst(errors):
  444. of nnkClosedSymChoice(strVal = "[]"):
  445. echo "fail, this is the wrong syntax, a sym choice does not have a `strVal` member."
  446. of ident"[]":
  447. echo "ok"
  448. const myConst = 123
  449. ast = newLit(123)
  450. ast.matchAst:
  451. of _(intVal = myConst):
  452. echo "ok"
  453. macro testRecCase(ast: untyped): untyped =
  454. ast.matchAstRecursive:
  455. of nnkIdentDefs(`a`,`b`,`c`):
  456. echo "got ident defs a: ", a.repr, " b: ", b.repr, " c: ", c.repr
  457. of ident"m":
  458. echo "got the ident m"
  459. testRecCase:
  460. type Obj[T] {.inheritable.} = object
  461. name: string
  462. case isFat: bool
  463. of true:
  464. m: array[100_000, T]
  465. of false:
  466. m: array[10, T]
  467. macro testIfCondition(ast: untyped): untyped =
  468. let literals = nnkBracket.newTree
  469. ast.matchAstRecursive:
  470. of `intLit` @ nnkIntLit |= intLit.intVal > 5:
  471. literals.add intLit
  472. let literals2 = quote do:
  473. [6,7,8,9]
  474. doAssert literals2 == literals
  475. testIfCondition([1,6,2,7,3,8,4,9,5,0,"123"])