wrapnils.nim 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. ## This module allows evaluating expressions safely against the following conditions:
  2. ## * nil dereferences
  3. ## * field accesses with incorrect discriminant in case objects
  4. ##
  5. ## `default(T)` is returned in those cases when evaluating an expression of type `T`.
  6. ## This simplifies code by reducing need for if-else branches.
  7. ##
  8. ## Note: experimental module, unstable API.
  9. #[
  10. TODO:
  11. consider handling indexing operations, eg:
  12. doAssert ?.default(seq[int])[3] == default(int)
  13. ]#
  14. import std/macros
  15. runnableExamples:
  16. type Foo = ref object
  17. x1: string
  18. x2: Foo
  19. x3: ref int
  20. var f: Foo
  21. assert ?.f.x2.x1 == "" # returns default value since `f` is nil
  22. var f2 = Foo(x1: "a")
  23. f2.x2 = f2
  24. assert ?.f2.x1 == "a" # same as f2.x1 (no nil LHS in this chain)
  25. assert ?.Foo(x1: "a").x1 == "a" # can use constructor inside
  26. # when you know a sub-expression doesn't involve a `nil` (e.g. `f2.x2.x2`),
  27. # you can scope it as follows:
  28. assert ?.(f2.x2.x2).x3[] == 0
  29. assert (?.f2.x2.x2).x3 == nil # this terminates ?. early
  30. runnableExamples:
  31. # ?. also allows case object
  32. type B = object
  33. b0: int
  34. case cond: bool
  35. of false: discard
  36. of true:
  37. b1: float
  38. var b = B(cond: false, b0: 3)
  39. doAssertRaises(FieldDefect): discard b.b1 # wrong discriminant
  40. doAssert ?.b.b1 == 0.0 # safe
  41. b = B(cond: true, b1: 4.5)
  42. doAssert ?.b.b1 == 4.5
  43. # lvalue semantics are preserved:
  44. if (let p = ?.b.b1.addr; p != nil): p[] = 4.7
  45. doAssert b.b1 == 4.7
  46. proc finalize(n: NimNode, lhs: NimNode, level: int): NimNode =
  47. if level == 0:
  48. result = quote: `lhs` = `n`
  49. else:
  50. result = quote: (var `lhs` = `n`)
  51. proc process(n: NimNode, lhs: NimNode, label: NimNode, level: int): NimNode =
  52. result = nil
  53. var n = n.copyNimTree
  54. var it = n
  55. let addr2 = bindSym"addr"
  56. var old: tuple[n: NimNode, index: int] = (nil, 0)
  57. while true:
  58. if it.len == 0:
  59. result = finalize(n, lhs, level)
  60. break
  61. elif it.kind == nnkCheckedFieldExpr:
  62. let dot = it[0]
  63. let obj = dot[0]
  64. let objRef = quote do: `addr2`(`obj`)
  65. # avoids a copy and preserves lvalue semantics, see tests
  66. let check = it[1]
  67. let okSet = check[1]
  68. let kind1 = check[2]
  69. let tmp = genSym(nskVar, "tmpCase")
  70. let body = process(objRef, tmp, label, level + 1)
  71. let tmp3 = nnkDerefExpr.newTree(tmp)
  72. it[0][0] = tmp3
  73. let dot2 = nnkDotExpr.newTree(@[tmp, dot[1]])
  74. if old.n != nil: old.n[old.index] = dot2
  75. else: n = dot2
  76. let assgn = finalize(n, lhs, level)
  77. result = quote do:
  78. `body`
  79. if `tmp3`.`kind1` notin `okSet`: break `label`
  80. `assgn`
  81. break
  82. elif it.kind in {nnkHiddenDeref, nnkDerefExpr}:
  83. let tmp = genSym(nskVar, "tmp")
  84. let body = process(it[0], tmp, label, level + 1)
  85. it[0] = tmp
  86. let assgn = finalize(n, lhs, level)
  87. result = quote do:
  88. `body`
  89. if `tmp` == nil: break `label`
  90. `assgn`
  91. break
  92. elif it.kind == nnkCall: # consider extending to `nnkCallKinds`
  93. # `copyNimTree` needed to avoid `typ = nil` issues
  94. old = (it, 1)
  95. it = it[1].copyNimTree
  96. else:
  97. old = (it, 0)
  98. it = it[0]
  99. macro `?.`*(a: typed): auto =
  100. ## Transforms `a` into an expression that can be safely evaluated even in
  101. ## presence of intermediate nil pointers/references, in which case a default
  102. ## value is produced.
  103. let lhs = genSym(nskVar, "lhs")
  104. let label = genSym(nskLabel, "label")
  105. let body = process(a, lhs, label, 0)
  106. result = quote do:
  107. var `lhs`: type(`a`) = default(type(`a`))
  108. block `label`:
  109. `body`
  110. `lhs`
  111. # the code below is not needed for `?.`
  112. from std/options import Option, isSome, get, option, unsafeGet, UnpackDefect
  113. macro `??.`*(a: typed): Option =
  114. ## Same as `?.` but returns an `Option`.
  115. runnableExamples:
  116. import std/options
  117. type Foo = ref object
  118. x1: ref int
  119. x2: int
  120. # `?.` can't distinguish between a valid vs invalid default value, but `??.` can:
  121. var f1 = Foo(x1: int.new, x2: 2)
  122. doAssert (??.f1.x1[]).get == 0 # not enough to tell when the chain was valid.
  123. doAssert (??.f1.x1[]).isSome # a nil didn't occur in the chain
  124. doAssert (??.f1.x2).get == 2
  125. var f2: Foo
  126. doAssert not (??.f2.x1[]).isSome # f2 was nil
  127. doAssertRaises(UnpackDefect): discard (??.f2.x1[]).get
  128. doAssert ?.f2.x1[] == 0 # in contrast, this returns default(int)
  129. let lhs = genSym(nskVar, "lhs")
  130. let lhs2 = genSym(nskVar, "lhs")
  131. let label = genSym(nskLabel, "label")
  132. let body = process(a, lhs2, label, 0)
  133. result = quote do:
  134. var `lhs`: Option[type(`a`)] = default(Option[type(`a`)])
  135. block `label`:
  136. var `lhs2`: type(`a`) = default(type(`a`))
  137. `body`
  138. `lhs` = option(`lhs2`)
  139. `lhs`
  140. template fakeDot*(a: Option, b): untyped =
  141. ## See top-level example.
  142. let a1 = a # to avoid double evaluations
  143. type T = Option[typeof(unsafeGet(a1).b)]
  144. if isSome(a1):
  145. let a2 = unsafeGet(a1)
  146. when typeof(a2) is ref|ptr:
  147. if a2 == nil:
  148. default(T)
  149. else:
  150. option(a2.b)
  151. else:
  152. option(a2.b)
  153. else:
  154. # nil is "sticky"; this is needed, see tests
  155. default(T)
  156. # xxx this should but doesn't work: func `[]`*[T, I](a: Option[T], i: I): Option {.inline.} =
  157. func `[]`*[T, I](a: Option[T], i: I): auto {.inline.} =
  158. ## See top-level example.
  159. if isSome(a):
  160. # correctly will raise IndexDefect if a is valid but wraps an empty container
  161. result = option(a.unsafeGet[i])
  162. func `[]`*[U](a: Option[U]): auto {.inline.} =
  163. ## See top-level example.
  164. if isSome(a):
  165. let a2 = a.unsafeGet
  166. if a2 != nil:
  167. result = option(a2[])
  168. when false:
  169. # xxx: expose a way to do this directly in std/options, e.g.: `getAsIs`
  170. proc safeGet[T](a: Option[T]): T {.inline.} =
  171. get(a, default(T))