patterns.nim 10 KB


  1. #
  2. #
  3. # The Nim Compiler
  4. # (c) Copyright 2012 Andreas Rumpf
  5. #
  6. # See the file "copying.txt", included in this
  7. # distribution, for details about the copyright.
  8. #
  9. ## This module implements the pattern matching features for term rewriting
  10. ## macro support.
  11. import
  12. ast, types, semdata, sigmatch, idents, aliases, parampatterns, trees
  13. when defined(nimPreviewSlimSystem):
  14. import std/assertions
  15. type
  16. TPatternContext = object
  17. owner: PSym
  18. mapping: seq[PNode] # maps formal parameters to nodes
  19. formals: int
  20. c: PContext
  21. subMatch: bool # subnode matches are special
  22. mappingIsFull: bool
  23. PPatternContext = var TPatternContext
  24. proc getLazy(c: PPatternContext, sym: PSym): PNode =
  25. if c.mappingIsFull:
  26. result = c.mapping[sym.position]
  27. else:
  28. result = nil
  29. proc putLazy(c: PPatternContext, sym: PSym, n: PNode) =
  30. if not c.mappingIsFull:
  31. newSeq(c.mapping, c.formals)
  32. c.mappingIsFull = true
  33. c.mapping[sym.position] = n
  34. proc matches(c: PPatternContext, p, n: PNode): bool
  35. proc canonKind(n: PNode): TNodeKind =
  36. ## nodekind canonicalization for pattern matching
  37. result = n.kind
  38. case result
  39. of nkCallKinds: result = nkCall
  40. of nkStrLit..nkTripleStrLit: result = nkStrLit
  41. of nkFastAsgn, nkSinkAsgn: result = nkAsgn
  42. else: discard
  43. proc sameKinds(a, b: PNode): bool {.inline.} =
  44. result = a.kind == b.kind or a.canonKind == b.canonKind
  45. proc sameTrees*(a, b: PNode): bool =
  46. if sameKinds(a, b):
  47. case a.kind
  48. of nkSym: result = a.sym == b.sym
  49. of nkIdent: result = a.ident.id == b.ident.id
  50. of nkCharLit..nkInt64Lit: result = a.intVal == b.intVal
  51. of nkFloatLit..nkFloat64Lit: result = a.floatVal == b.floatVal
  52. of nkStrLit..nkTripleStrLit: result = a.strVal == b.strVal
  53. of nkEmpty, nkNilLit: result = true
  54. of nkType: result = sameTypeOrNil(a.typ, b.typ)
  55. else:
  56. if a.len == b.len:
  57. for i in 0..<a.len:
  58. if not sameTrees(a[i], b[i]): return
  59. result = true
  60. else:
  61. result = false
  62. else:
  63. result = false
  64. proc inSymChoice(sc, x: PNode): bool =
  65. if sc.kind == nkClosedSymChoice:
  66. result = false
  67. for i in 0..<sc.len:
  68. if sc[i].sym == x.sym: return true
  69. elif sc.kind == nkOpenSymChoice:
  70. # same name suffices for open sym choices!
  71. result = sc[0].sym.name.id == x.sym.name.id
  72. else:
  73. result = false
  74. proc checkTypes(c: PPatternContext, p: PSym, n: PNode): bool =
  75. # check param constraints first here as this is quite optimized:
  76. if p.constraint != nil:
  77. result = matchNodeKinds(p.constraint, n)
  78. if not result: return
  79. if isNil(n.typ):
  80. result = p.typ.kind in {tyVoid, tyTyped}
  81. else:
  82. result = sigmatch.argtypeMatches(c.c, p.typ, n.typ, fromHlo = true)
  83. proc isPatternParam(c: PPatternContext, p: PNode): bool {.inline.} =
  84. result = p.kind == nkSym and p.sym.kind == skParam and p.sym.owner == c.owner
  85. proc matchChoice(c: PPatternContext, p, n: PNode): bool =
  86. result = false
  87. for i in 1..<p.len:
  88. if matches(c, p[i], n): return true
  89. proc bindOrCheck(c: PPatternContext, param: PSym, n: PNode): bool =
  90. var pp = getLazy(c, param)
  91. if pp != nil:
  92. # check if we got the same pattern (already unified):
  93. result = sameTrees(pp, n) #matches(c, pp, n)
  94. elif n.kind == nkArgList or checkTypes(c, param, n):
  95. putLazy(c, param, n)
  96. result = true
  97. else:
  98. result = false
  99. proc gather(c: PPatternContext, param: PSym, n: PNode) =
  100. var pp = getLazy(c, param)
  101. if pp != nil and pp.kind == nkArgList:
  102. pp.add(n)
  103. else:
  104. pp = newNodeI(nkArgList, n.info, 1)
  105. pp[0] = n
  106. putLazy(c, param, pp)
  107. proc matchNested(c: PPatternContext, p, n: PNode, rpn: bool): bool =
  108. # match ``op * param`` or ``op *| param``
  109. proc matchStarAux(c: PPatternContext, op, n, arglist: PNode,
  110. rpn: bool): bool =
  111. result = true
  112. if n.kind in nkCallKinds and matches(c, op[1], n[0]):
  113. for i in 1..<n.len:
  114. if not matchStarAux(c, op, n[i], arglist, rpn): return false
  115. if rpn: arglist.add(n[0])
  116. elif n.kind == nkHiddenStdConv and n[1].kind == nkBracket:
  117. let n = n[1]
  118. for i in 0..<n.len:
  119. if not matchStarAux(c, op, n[i], arglist, rpn): return false
  120. elif checkTypes(c, p[2].sym, n):
  121. arglist.add(n)
  122. else:
  123. result = false
  124. if n.kind notin nkCallKinds: return false
  125. if matches(c, p[1], n[0]):
  126. var arglist = newNodeI(nkArgList, n.info)
  127. if matchStarAux(c, p, n, arglist, rpn):
  128. result = bindOrCheck(c, p[2].sym, arglist)
  129. else:
  130. result = false
  131. else:
  132. result = false
  133. proc matches(c: PPatternContext, p, n: PNode): bool =
  134. let n = skipHidden(n)
  135. if nfNoRewrite in n.flags:
  136. result = false
  137. elif isPatternParam(c, p):
  138. result = bindOrCheck(c, p.sym, n)
  139. elif n.kind == nkSym and p.kind == nkIdent:
  140. result = p.ident.id == n.sym.name.id
  141. elif n.kind == nkSym and inSymChoice(p, n):
  142. result = true
  143. elif n.kind == nkSym and n.sym.kind == skConst:
  144. # try both:
  145. if p.kind == nkSym: result = p.sym == n.sym
  146. elif matches(c, p, n.sym.astdef): result = true
  147. else: result = false
  148. elif p.kind == nkPattern:
  149. # pattern operators: | *
  150. let opr = p[0].ident.s
  151. case opr
  152. of "|": result = matchChoice(c, p, n)
  153. of "*": result = matchNested(c, p, n, rpn=false)
  154. of "**": result = matchNested(c, p, n, rpn=true)
  155. of "~": result = not matches(c, p[1], n)
  156. else:
  157. result = false
  158. doAssert(false, "invalid pattern")
  159. # template {add(a, `&` * b)}(a: string{noalias}, b: varargs[string]) =
  160. # a.add(b)
  161. elif p.kind == nkCurlyExpr:
  162. if p[1].kind == nkPrefix:
  163. if matches(c, p[0], n):
  164. gather(c, p[1][1].sym, n)
  165. result = true
  166. else:
  167. result = false
  168. else:
  169. assert isPatternParam(c, p[1])
  170. if matches(c, p[0], n):
  171. result = bindOrCheck(c, p[1].sym, n)
  172. else:
  173. result = false
  174. elif sameKinds(p, n):
  175. case p.kind
  176. of nkSym: result = p.sym == n.sym
  177. of nkIdent: result = p.ident.id == n.ident.id
  178. of nkCharLit..nkInt64Lit: result = p.intVal == n.intVal
  179. of nkFloatLit..nkFloat64Lit: result = p.floatVal == n.floatVal
  180. of nkStrLit..nkTripleStrLit: result = p.strVal == n.strVal
  181. of nkEmpty, nkNilLit, nkType:
  182. result = true
  183. else:
  184. # special rule for p(X) ~ f(...); this also works for stuff like
  185. # partial case statements, etc! - Not really ... :-/
  186. result = false
  187. let v = lastSon(p)
  188. if isPatternParam(c, v) and v.sym.typ.kind == tyVarargs:
  189. var arglist: PNode
  190. if p.len <= n.len:
  191. for i in 0..<p.len - 1:
  192. if not matches(c, p[i], n[i]): return
  193. if p.len == n.len and lastSon(n).kind == nkHiddenStdConv and
  194. lastSon(n)[1].kind == nkBracket:
  195. # unpack varargs:
  196. let n = lastSon(n)[1]
  197. arglist = newNodeI(nkArgList, n.info, n.len)
  198. for i in 0..<n.len: arglist[i] = n[i]
  199. else:
  200. arglist = newNodeI(nkArgList, n.info, n.len - p.len + 1)
  201. # f(1, 2, 3)
  202. # p(X)
  203. for i in 0..n.len - p.len:
  204. arglist[i] = n[i + p.len - 1]
  205. return bindOrCheck(c, v.sym, arglist)
  206. elif p.len-1 == n.len:
  207. for i in 0..<p.len - 1:
  208. if not matches(c, p[i], n[i]): return
  209. arglist = newNodeI(nkArgList, n.info)
  210. return bindOrCheck(c, v.sym, arglist)
  211. if p.len == n.len:
  212. for i in 0..<p.len:
  213. if not matches(c, p[i], n[i]): return
  214. result = true
  215. else:
  216. result = false
  217. proc matchStmtList(c: PPatternContext, p, n: PNode): PNode =
  218. proc matchRange(c: PPatternContext, p, n: PNode, i: int): bool =
  219. for j in 0..<p.len:
  220. if not matches(c, p[j], n[i+j]):
  221. # we need to undo any bindings:
  222. c.mapping = @[]
  223. c.mappingIsFull = false
  224. return false
  225. result = true
  226. if p.kind == nkStmtList and n.kind == p.kind and p.len < n.len:
  227. result = nil
  228. let n = flattenStmts(n)
  229. # no need to flatten 'p' here as that has already been done
  230. for i in 0..n.len - p.len:
  231. if matchRange(c, p, n, i):
  232. c.subMatch = true
  233. result = newNodeI(nkStmtList, n.info, 3)
  234. result[0] = extractRange(nkStmtList, n, 0, i-1)
  235. result[1] = extractRange(nkStmtList, n, i, i+p.len-1)
  236. result[2] = extractRange(nkStmtList, n, i+p.len, n.len-1)
  237. break
  238. elif matches(c, p, n):
  239. result = n
  240. else:
  241. result = nil
  242. proc aliasAnalysisRequested(params: PNode): bool =
  243. result = false
  244. if params.len >= 2:
  245. for i in 1..<params.len:
  246. let param = params[i].sym
  247. if whichAlias(param) != aqNone: return true
  248. proc addToArgList(result, n: PNode) =
  249. if n.typ != nil and n.typ.kind != tyTyped:
  250. if n.kind != nkArgList: result.add(n)
  251. else:
  252. for i in 0..<n.len: result.add(n[i])
  253. proc applyRule*(c: PContext, s: PSym, n: PNode): PNode =
  254. ## returns a tree to semcheck if the rule triggered; nil otherwise
  255. var ctx: TPatternContext
  256. ctx.owner = s
  257. ctx.c = c
  258. ctx.formals = s.typ.len-1
  259. var m = matchStmtList(ctx, s.ast[patternPos], n)
  260. if isNil(m): return nil
  261. # each parameter should have been bound; we simply setup a call and
  262. # let semantic checking deal with the rest :-)
  263. result = newNodeI(nkCall, n.info)
  264. result.add(newSymNode(s, n.info))
  265. let params = s.typ.n
  266. let requiresAA = aliasAnalysisRequested(params)
  267. var args: PNode =
  268. if requiresAA:
  269. newNodeI(nkArgList, n.info)
  270. else:
  271. nil
  272. for i in 1..<params.len:
  273. let param = params[i].sym
  274. let x = getLazy(ctx, param)
  275. # couldn't bind parameter:
  276. if isNil(x): return nil
  277. result.add(x)
  278. if requiresAA: addToArgList(args, x)
  279. # perform alias analysis here:
  280. if requiresAA:
  281. for i in 1..<params.len:
  282. var rs = result[i]
  283. let param = params[i].sym
  284. case whichAlias(param)
  285. of aqNone: discard
  286. of aqShouldAlias:
  287. # it suffices that it aliases for sure with *some* other param:
  288. var ok = false
  289. for arg in items(args):
  290. if arg != rs and aliases.isPartOf(rs, arg) == arYes:
  291. ok = true
  292. break
  293. # constraint not fulfilled:
  294. if not ok: return nil
  295. of aqNoAlias:
  296. # it MUST not alias with any other param:
  297. var ok = true
  298. for arg in items(args):
  299. if arg != rs and aliases.isPartOf(rs, arg) != arNo:
  300. ok = false
  301. break
  302. # constraint not fulfilled:
  303. if not ok: return nil
  304. markUsed(c, n.info, s)
  305. if ctx.subMatch:
  306. assert m.len == 3
  307. m[1] = result
  308. result = m