rewrite.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886
  1. """Rewrite assertion AST to produce nice error messages"""
  2. import ast
  3. import errno
  4. import itertools
  5. import imp
  6. import marshal
  7. import os
  8. import re
  9. import struct
  10. import sys
  11. import types
  12. import py
  13. from _pytest.assertion import util
  14. # pytest caches rewritten pycs in __pycache__.
  15. if hasattr(imp, "get_tag"):
  16. PYTEST_TAG = imp.get_tag() + "-PYTEST"
  17. else:
  18. if hasattr(sys, "pypy_version_info"):
  19. impl = "pypy"
  20. elif sys.platform == "java":
  21. impl = "jython"
  22. else:
  23. impl = "cpython"
  24. ver = sys.version_info
  25. PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
  26. del ver, impl
  27. PYC_EXT = ".py" + (__debug__ and "c" or "o")
  28. PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
  29. REWRITE_NEWLINES = sys.version_info[:2] != (2, 7) and sys.version_info < (3, 2)
  30. ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3
  31. if sys.version_info >= (3,5):
  32. ast_Call = ast.Call
  33. else:
  34. ast_Call = lambda a,b,c: ast.Call(a, b, c, None, None)
  35. class AssertionRewritingHook(object):
  36. """PEP302 Import hook which rewrites asserts."""
  37. def __init__(self):
  38. self.session = None
  39. self.modules = {}
  40. self._register_with_pkg_resources()
  41. def set_session(self, session):
  42. self.fnpats = session.config.getini("python_files")
  43. self.session = session
  44. def find_module(self, name, path=None):
  45. if self.session is None:
  46. return None
  47. sess = self.session
  48. state = sess.config._assertstate
  49. state.trace("find_module called for: %s" % name)
  50. names = name.rsplit(".", 1)
  51. lastname = names[-1]
  52. pth = None
  53. if path is not None:
  54. # Starting with Python 3.3, path is a _NamespacePath(), which
  55. # causes problems if not converted to list.
  56. path = list(path)
  57. if len(path) == 1:
  58. pth = path[0]
  59. if pth is None:
  60. try:
  61. fd, fn, desc = imp.find_module(lastname, path)
  62. except ImportError:
  63. return None
  64. if fd is not None:
  65. fd.close()
  66. tp = desc[2]
  67. if tp == imp.PY_COMPILED:
  68. if hasattr(imp, "source_from_cache"):
  69. fn = imp.source_from_cache(fn)
  70. else:
  71. fn = fn[:-1]
  72. elif tp != imp.PY_SOURCE:
  73. # Don't know what this is.
  74. return None
  75. else:
  76. fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
  77. fn_pypath = py.path.local(fn)
  78. # Is this a test file?
  79. if not sess.isinitpath(fn):
  80. # We have to be very careful here because imports in this code can
  81. # trigger a cycle.
  82. self.session = None
  83. try:
  84. for pat in self.fnpats:
  85. if fn_pypath.fnmatch(pat):
  86. state.trace("matched test file %r" % (fn,))
  87. break
  88. else:
  89. return None
  90. finally:
  91. self.session = sess
  92. else:
  93. state.trace("matched test file (was specified on cmdline): %r" %
  94. (fn,))
  95. # The requested module looks like a test file, so rewrite it. This is
  96. # the most magical part of the process: load the source, rewrite the
  97. # asserts, and load the rewritten source. We also cache the rewritten
  98. # module code in a special pyc. We must be aware of the possibility of
  99. # concurrent pytest processes rewriting and loading pycs. To avoid
  100. # tricky race conditions, we maintain the following invariant: The
  101. # cached pyc is always a complete, valid pyc. Operations on it must be
  102. # atomic. POSIX's atomic rename comes in handy.
  103. write = not sys.dont_write_bytecode
  104. cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
  105. if write:
  106. try:
  107. os.mkdir(cache_dir)
  108. except OSError:
  109. e = sys.exc_info()[1].errno
  110. if e == errno.EEXIST:
  111. # Either the __pycache__ directory already exists (the
  112. # common case) or it's blocked by a non-dir node. In the
  113. # latter case, we'll ignore it in _write_pyc.
  114. pass
  115. elif e in [errno.ENOENT, errno.ENOTDIR]:
  116. # One of the path components was not a directory, likely
  117. # because we're in a zip file.
  118. write = False
  119. elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
  120. state.trace("read only directory: %r" % fn_pypath.dirname)
  121. write = False
  122. else:
  123. raise
  124. cache_name = fn_pypath.basename[:-3] + PYC_TAIL
  125. pyc = os.path.join(cache_dir, cache_name)
  126. # Notice that even if we're in a read-only directory, I'm going
  127. # to check for a cached pyc. This may not be optimal...
  128. co = _read_pyc(fn_pypath, pyc, state.trace)
  129. if co is None:
  130. state.trace("rewriting %r" % (fn,))
  131. source_stat, co = _rewrite_test(state, fn_pypath)
  132. if co is None:
  133. # Probably a SyntaxError in the test.
  134. return None
  135. if write:
  136. _make_rewritten_pyc(state, source_stat, pyc, co)
  137. else:
  138. state.trace("found cached rewritten pyc for %r" % (fn,))
  139. self.modules[name] = co, pyc
  140. return self
  141. def load_module(self, name):
  142. # If there is an existing module object named 'fullname' in
  143. # sys.modules, the loader must use that existing module. (Otherwise,
  144. # the reload() builtin will not work correctly.)
  145. if name in sys.modules:
  146. return sys.modules[name]
  147. co, pyc = self.modules.pop(name)
  148. # I wish I could just call imp.load_compiled here, but __file__ has to
  149. # be set properly. In Python 3.2+, this all would be handled correctly
  150. # by load_compiled.
  151. mod = sys.modules[name] = imp.new_module(name)
  152. try:
  153. mod.__file__ = co.co_filename
  154. # Normally, this attribute is 3.2+.
  155. mod.__cached__ = pyc
  156. mod.__loader__ = self
  157. py.builtin.exec_(co, mod.__dict__)
  158. except:
  159. del sys.modules[name]
  160. raise
  161. return sys.modules[name]
  162. def is_package(self, name):
  163. try:
  164. fd, fn, desc = imp.find_module(name)
  165. except ImportError:
  166. return False
  167. if fd is not None:
  168. fd.close()
  169. tp = desc[2]
  170. return tp == imp.PKG_DIRECTORY
  171. @classmethod
  172. def _register_with_pkg_resources(cls):
  173. """
  174. Ensure package resources can be loaded from this loader. May be called
  175. multiple times, as the operation is idempotent.
  176. """
  177. try:
  178. import pkg_resources
  179. # access an attribute in case a deferred importer is present
  180. pkg_resources.__name__
  181. except ImportError:
  182. return
  183. # Since pytest tests are always located in the file system, the
  184. # DefaultProvider is appropriate.
  185. pkg_resources.register_loader_type(cls, pkg_resources.DefaultProvider)
  186. def get_data(self, pathname):
  187. """Optional PEP302 get_data API.
  188. """
  189. with open(pathname, 'rb') as f:
  190. return f.read()
  191. def _write_pyc(state, co, source_stat, pyc):
  192. # Technically, we don't have to have the same pyc format as
  193. # (C)Python, since these "pycs" should never be seen by builtin
  194. # import. However, there's little reason deviate, and I hope
  195. # sometime to be able to use imp.load_compiled to load them. (See
  196. # the comment in load_module above.)
  197. try:
  198. fp = open(pyc, "wb")
  199. except IOError:
  200. err = sys.exc_info()[1].errno
  201. state.trace("error writing pyc file at %s: errno=%s" %(pyc, err))
  202. # we ignore any failure to write the cache file
  203. # there are many reasons, permission-denied, __pycache__ being a
  204. # file etc.
  205. return False
  206. try:
  207. fp.write(imp.get_magic())
  208. mtime = int(source_stat.mtime)
  209. size = source_stat.size & 0xFFFFFFFF
  210. fp.write(struct.pack("<ll", mtime, size))
  211. marshal.dump(co, fp)
  212. finally:
  213. fp.close()
  214. return True
  215. RN = "\r\n".encode("utf-8")
  216. N = "\n".encode("utf-8")
  217. cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
  218. BOM_UTF8 = '\xef\xbb\xbf'
  219. def _rewrite_test(state, fn):
  220. """Try to read and rewrite *fn* and return the code object."""
  221. try:
  222. stat = fn.stat()
  223. source = fn.read("rb")
  224. except EnvironmentError:
  225. return None, None
  226. if ASCII_IS_DEFAULT_ENCODING:
  227. # ASCII is the default encoding in Python 2. Without a coding
  228. # declaration, Python 2 will complain about any bytes in the file
  229. # outside the ASCII range. Sadly, this behavior does not extend to
  230. # compile() or ast.parse(), which prefer to interpret the bytes as
  231. # latin-1. (At least they properly handle explicit coding cookies.) To
  232. # preserve this error behavior, we could force ast.parse() to use ASCII
  233. # as the encoding by inserting a coding cookie. Unfortunately, that
  234. # messes up line numbers. Thus, we have to check ourselves if anything
  235. # is outside the ASCII range in the case no encoding is explicitly
  236. # declared. For more context, see issue #269. Yay for Python 3 which
  237. # gets this right.
  238. end1 = source.find("\n")
  239. end2 = source.find("\n", end1 + 1)
  240. if (not source.startswith(BOM_UTF8) and
  241. cookie_re.match(source[0:end1]) is None and
  242. cookie_re.match(source[end1 + 1:end2]) is None):
  243. if hasattr(state, "_indecode"):
  244. # encodings imported us again, so don't rewrite.
  245. return None, None
  246. state._indecode = True
  247. try:
  248. try:
  249. source.decode("ascii")
  250. except UnicodeDecodeError:
  251. # Let it fail in real import.
  252. return None, None
  253. finally:
  254. del state._indecode
  255. # On Python versions which are not 2.7 and less than or equal to 3.1, the
  256. # parser expects *nix newlines.
  257. if REWRITE_NEWLINES:
  258. source = source.replace(RN, N) + N
  259. try:
  260. tree = ast.parse(source)
  261. except SyntaxError:
  262. # Let this pop up again in the real import.
  263. state.trace("failed to parse: %r" % (fn,))
  264. return None, None
  265. rewrite_asserts(tree)
  266. try:
  267. co = compile(tree, fn.strpath, "exec")
  268. except SyntaxError:
  269. # It's possible that this error is from some bug in the
  270. # assertion rewriting, but I don't know of a fast way to tell.
  271. state.trace("failed to compile: %r" % (fn,))
  272. return None, None
  273. return stat, co
  274. def _make_rewritten_pyc(state, source_stat, pyc, co):
  275. """Try to dump rewritten code to *pyc*."""
  276. if sys.platform.startswith("win"):
  277. # Windows grants exclusive access to open files and doesn't have atomic
  278. # rename, so just write into the final file.
  279. _write_pyc(state, co, source_stat, pyc)
  280. else:
  281. # When not on windows, assume rename is atomic. Dump the code object
  282. # into a file specific to this process and atomically replace it.
  283. proc_pyc = pyc + "." + str(os.getpid())
  284. if _write_pyc(state, co, source_stat, proc_pyc):
  285. os.rename(proc_pyc, pyc)
  286. def _read_pyc(source, pyc, trace=lambda x: None):
  287. """Possibly read a pytest pyc containing rewritten code.
  288. Return rewritten code if successful or None if not.
  289. """
  290. try:
  291. fp = open(pyc, "rb")
  292. except IOError:
  293. return None
  294. with fp:
  295. try:
  296. mtime = int(source.mtime())
  297. size = source.size()
  298. data = fp.read(12)
  299. except EnvironmentError as e:
  300. trace('_read_pyc(%s): EnvironmentError %s' % (source, e))
  301. return None
  302. # Check for invalid or out of date pyc file.
  303. if (len(data) != 12 or data[:4] != imp.get_magic() or
  304. struct.unpack("<ll", data[4:]) != (mtime, size)):
  305. trace('_read_pyc(%s): invalid or out of date pyc' % source)
  306. return None
  307. try:
  308. co = marshal.load(fp)
  309. except Exception as e:
  310. trace('_read_pyc(%s): marshal.load error %s' % (source, e))
  311. return None
  312. if not isinstance(co, types.CodeType):
  313. trace('_read_pyc(%s): not a code object' % source)
  314. return None
  315. return co
  316. def rewrite_asserts(mod):
  317. """Rewrite the assert statements in mod."""
  318. AssertionRewriter().run(mod)
  319. def _saferepr(obj):
  320. """Get a safe repr of an object for assertion error messages.
  321. The assertion formatting (util.format_explanation()) requires
  322. newlines to be escaped since they are a special character for it.
  323. Normally assertion.util.format_explanation() does this but for a
  324. custom repr it is possible to contain one of the special escape
  325. sequences, especially '\n{' and '\n}' are likely to be present in
  326. JSON reprs.
  327. """
  328. repr = py.io.saferepr(obj)
  329. if py.builtin._istext(repr):
  330. t = py.builtin.text
  331. else:
  332. t = py.builtin.bytes
  333. return repr.replace(t("\n"), t("\\n"))
  334. from _pytest.assertion.util import format_explanation as _format_explanation # noqa
  335. def _format_assertmsg(obj):
  336. """Format the custom assertion message given.
  337. For strings this simply replaces newlines with '\n~' so that
  338. util.format_explanation() will preserve them instead of escaping
  339. newlines. For other objects py.io.saferepr() is used first.
  340. """
  341. # reprlib appears to have a bug which means that if a string
  342. # contains a newline it gets escaped, however if an object has a
  343. # .__repr__() which contains newlines it does not get escaped.
  344. # However in either case we want to preserve the newline.
  345. if py.builtin._istext(obj) or py.builtin._isbytes(obj):
  346. s = obj
  347. is_repr = False
  348. else:
  349. s = py.io.saferepr(obj)
  350. is_repr = True
  351. if py.builtin._istext(s):
  352. t = py.builtin.text
  353. else:
  354. t = py.builtin.bytes
  355. s = s.replace(t("\n"), t("\n~")).replace(t("%"), t("%%"))
  356. if is_repr:
  357. s = s.replace(t("\\n"), t("\n~"))
  358. return s
  359. def _should_repr_global_name(obj):
  360. return not hasattr(obj, "__name__") and not py.builtin.callable(obj)
  361. def _format_boolop(explanations, is_or):
  362. explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
  363. if py.builtin._istext(explanation):
  364. t = py.builtin.text
  365. else:
  366. t = py.builtin.bytes
  367. return explanation.replace(t('%'), t('%%'))
  368. def _call_reprcompare(ops, results, expls, each_obj):
  369. for i, res, expl in zip(range(len(ops)), results, expls):
  370. try:
  371. done = not res
  372. except Exception:
  373. done = True
  374. if done:
  375. break
  376. if util._reprcompare is not None:
  377. custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
  378. if custom is not None:
  379. return custom
  380. return expl
  381. unary_map = {
  382. ast.Not: "not %s",
  383. ast.Invert: "~%s",
  384. ast.USub: "-%s",
  385. ast.UAdd: "+%s"
  386. }
  387. binop_map = {
  388. ast.BitOr: "|",
  389. ast.BitXor: "^",
  390. ast.BitAnd: "&",
  391. ast.LShift: "<<",
  392. ast.RShift: ">>",
  393. ast.Add: "+",
  394. ast.Sub: "-",
  395. ast.Mult: "*",
  396. ast.Div: "/",
  397. ast.FloorDiv: "//",
  398. ast.Mod: "%%", # escaped for string formatting
  399. ast.Eq: "==",
  400. ast.NotEq: "!=",
  401. ast.Lt: "<",
  402. ast.LtE: "<=",
  403. ast.Gt: ">",
  404. ast.GtE: ">=",
  405. ast.Pow: "**",
  406. ast.Is: "is",
  407. ast.IsNot: "is not",
  408. ast.In: "in",
  409. ast.NotIn: "not in"
  410. }
  411. # Python 3.5+ compatibility
  412. try:
  413. binop_map[ast.MatMult] = "@"
  414. except AttributeError:
  415. pass
  416. # Python 3.4+ compatibility
  417. if hasattr(ast, "NameConstant"):
  418. _NameConstant = ast.NameConstant
  419. else:
  420. def _NameConstant(c):
  421. return ast.Name(str(c), ast.Load())
  422. def set_location(node, lineno, col_offset):
  423. """Set node location information recursively."""
  424. def _fix(node, lineno, col_offset):
  425. if "lineno" in node._attributes:
  426. node.lineno = lineno
  427. if "col_offset" in node._attributes:
  428. node.col_offset = col_offset
  429. for child in ast.iter_child_nodes(node):
  430. _fix(child, lineno, col_offset)
  431. _fix(node, lineno, col_offset)
  432. return node
  433. class AssertionRewriter(ast.NodeVisitor):
  434. """Assertion rewriting implementation.
  435. The main entrypoint is to call .run() with an ast.Module instance,
  436. this will then find all the assert statements and re-write them to
  437. provide intermediate values and a detailed assertion error. See
  438. http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
  439. for an overview of how this works.
  440. The entry point here is .run() which will iterate over all the
  441. statements in an ast.Module and for each ast.Assert statement it
  442. finds call .visit() with it. Then .visit_Assert() takes over and
  443. is responsible for creating new ast statements to replace the
  444. original assert statement: it re-writes the test of an assertion
  445. to provide intermediate values and replace it with an if statement
  446. which raises an assertion error with a detailed explanation in
  447. case the expression is false.
  448. For this .visit_Assert() uses the visitor pattern to visit all the
  449. AST nodes of the ast.Assert.test field, each visit call returning
  450. an AST node and the corresponding explanation string. During this
  451. state is kept in several instance attributes:
  452. :statements: All the AST statements which will replace the assert
  453. statement.
  454. :variables: This is populated by .variable() with each variable
  455. used by the statements so that they can all be set to None at
  456. the end of the statements.
  457. :variable_counter: Counter to create new unique variables needed
  458. by statements. Variables are created using .variable() and
  459. have the form of "@py_assert0".
  460. :on_failure: The AST statements which will be executed if the
  461. assertion test fails. This is the code which will construct
  462. the failure message and raises the AssertionError.
  463. :explanation_specifiers: A dict filled by .explanation_param()
  464. with %-formatting placeholders and their corresponding
  465. expressions to use in the building of an assertion message.
  466. This is used by .pop_format_context() to build a message.
  467. :stack: A stack of the explanation_specifiers dicts maintained by
  468. .push_format_context() and .pop_format_context() which allows
  469. to build another %-formatted string while already building one.
  470. This state is reset on every new assert statement visited and used
  471. by the other visitors.
  472. """
  473. def run(self, mod):
  474. """Find all assert statements in *mod* and rewrite them."""
  475. if not mod.body:
  476. # Nothing to do.
  477. return
  478. # Insert some special imports at the top of the module but after any
  479. # docstrings and __future__ imports.
  480. aliases = [ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
  481. ast.alias("_pytest.assertion.rewrite", "@pytest_ar")]
  482. expect_docstring = True
  483. pos = 0
  484. lineno = 0
  485. for item in mod.body:
  486. if (expect_docstring and isinstance(item, ast.Expr) and
  487. isinstance(item.value, ast.Str)):
  488. doc = item.value.s
  489. if "PYTEST_DONT_REWRITE" in doc:
  490. # The module has disabled assertion rewriting.
  491. return
  492. lineno += len(doc) - 1
  493. expect_docstring = False
  494. elif (not isinstance(item, ast.ImportFrom) or item.level > 0 or
  495. item.module != "__future__"):
  496. lineno = item.lineno
  497. break
  498. pos += 1
  499. imports = [ast.Import([alias], lineno=lineno, col_offset=0)
  500. for alias in aliases]
  501. mod.body[pos:pos] = imports
  502. # Collect asserts.
  503. nodes = [mod]
  504. while nodes:
  505. node = nodes.pop()
  506. for name, field in ast.iter_fields(node):
  507. if isinstance(field, list):
  508. new = []
  509. for i, child in enumerate(field):
  510. if isinstance(child, ast.Assert):
  511. # Transform assert.
  512. new.extend(self.visit(child))
  513. else:
  514. new.append(child)
  515. if isinstance(child, ast.AST):
  516. nodes.append(child)
  517. setattr(node, name, new)
  518. elif (isinstance(field, ast.AST) and
  519. # Don't recurse into expressions as they can't contain
  520. # asserts.
  521. not isinstance(field, ast.expr)):
  522. nodes.append(field)
  523. def variable(self):
  524. """Get a new variable."""
  525. # Use a character invalid in python identifiers to avoid clashing.
  526. name = "@py_assert" + str(next(self.variable_counter))
  527. self.variables.append(name)
  528. return name
  529. def assign(self, expr):
  530. """Give *expr* a name."""
  531. name = self.variable()
  532. self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
  533. return ast.Name(name, ast.Load())
  534. def display(self, expr):
  535. """Call py.io.saferepr on the expression."""
  536. return self.helper("saferepr", expr)
  537. def helper(self, name, *args):
  538. """Call a helper in this module."""
  539. py_name = ast.Name("@pytest_ar", ast.Load())
  540. attr = ast.Attribute(py_name, "_" + name, ast.Load())
  541. return ast_Call(attr, list(args), [])
  542. def builtin(self, name):
  543. """Return the builtin called *name*."""
  544. builtin_name = ast.Name("@py_builtins", ast.Load())
  545. return ast.Attribute(builtin_name, name, ast.Load())
  546. def explanation_param(self, expr):
  547. """Return a new named %-formatting placeholder for expr.
  548. This creates a %-formatting placeholder for expr in the
  549. current formatting context, e.g. ``%(py0)s``. The placeholder
  550. and expr are placed in the current format context so that it
  551. can be used on the next call to .pop_format_context().
  552. """
  553. specifier = "py" + str(next(self.variable_counter))
  554. self.explanation_specifiers[specifier] = expr
  555. return "%(" + specifier + ")s"
  556. def push_format_context(self):
  557. """Create a new formatting context.
  558. The format context is used for when an explanation wants to
  559. have a variable value formatted in the assertion message. In
  560. this case the value required can be added using
  561. .explanation_param(). Finally .pop_format_context() is used
  562. to format a string of %-formatted values as added by
  563. .explanation_param().
  564. """
  565. self.explanation_specifiers = {}
  566. self.stack.append(self.explanation_specifiers)
  567. def pop_format_context(self, expl_expr):
  568. """Format the %-formatted string with current format context.
  569. The expl_expr should be an ast.Str instance constructed from
  570. the %-placeholders created by .explanation_param(). This will
  571. add the required code to format said string to .on_failure and
  572. return the ast.Name instance of the formatted string.
  573. """
  574. current = self.stack.pop()
  575. if self.stack:
  576. self.explanation_specifiers = self.stack[-1]
  577. keys = [ast.Str(key) for key in current.keys()]
  578. format_dict = ast.Dict(keys, list(current.values()))
  579. form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
  580. name = "@py_format" + str(next(self.variable_counter))
  581. self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
  582. return ast.Name(name, ast.Load())
  583. def generic_visit(self, node):
  584. """Handle expressions we don't have custom code for."""
  585. assert isinstance(node, ast.expr)
  586. res = self.assign(node)
  587. return res, self.explanation_param(self.display(res))
  588. def visit_Assert(self, assert_):
  589. """Return the AST statements to replace the ast.Assert instance.
  590. This re-writes the test of an assertion to provide
  591. intermediate values and replace it with an if statement which
  592. raises an assertion error with a detailed explanation in case
  593. the expression is false.
  594. """
  595. self.statements = []
  596. self.variables = []
  597. self.variable_counter = itertools.count()
  598. self.stack = []
  599. self.on_failure = []
  600. self.push_format_context()
  601. # Rewrite assert into a bunch of statements.
  602. top_condition, explanation = self.visit(assert_.test)
  603. # Create failure message.
  604. body = self.on_failure
  605. negation = ast.UnaryOp(ast.Not(), top_condition)
  606. self.statements.append(ast.If(negation, body, []))
  607. if assert_.msg:
  608. assertmsg = self.helper('format_assertmsg', assert_.msg)
  609. explanation = "\n>assert " + explanation
  610. else:
  611. assertmsg = ast.Str("")
  612. explanation = "assert " + explanation
  613. template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
  614. msg = self.pop_format_context(template)
  615. fmt = self.helper("format_explanation", msg)
  616. err_name = ast.Name("AssertionError", ast.Load())
  617. exc = ast_Call(err_name, [fmt], [])
  618. if sys.version_info[0] >= 3:
  619. raise_ = ast.Raise(exc, None)
  620. else:
  621. raise_ = ast.Raise(exc, None, None)
  622. body.append(raise_)
  623. # Clear temporary variables by setting them to None.
  624. if self.variables:
  625. variables = [ast.Name(name, ast.Store())
  626. for name in self.variables]
  627. clear = ast.Assign(variables, _NameConstant(None))
  628. self.statements.append(clear)
  629. # Fix line numbers.
  630. for stmt in self.statements:
  631. set_location(stmt, assert_.lineno, assert_.col_offset)
  632. return self.statements
  633. def visit_Name(self, name):
  634. # Display the repr of the name if it's a local variable or
  635. # _should_repr_global_name() thinks it's acceptable.
  636. locs = ast_Call(self.builtin("locals"), [], [])
  637. inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
  638. dorepr = self.helper("should_repr_global_name", name)
  639. test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
  640. expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
  641. return name, self.explanation_param(expr)
  642. def visit_BoolOp(self, boolop):
  643. res_var = self.variable()
  644. expl_list = self.assign(ast.List([], ast.Load()))
  645. app = ast.Attribute(expl_list, "append", ast.Load())
  646. is_or = int(isinstance(boolop.op, ast.Or))
  647. body = save = self.statements
  648. fail_save = self.on_failure
  649. levels = len(boolop.values) - 1
  650. self.push_format_context()
  651. # Process each operand, short-circuting if needed.
  652. for i, v in enumerate(boolop.values):
  653. if i:
  654. fail_inner = []
  655. # cond is set in a prior loop iteration below
  656. self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
  657. self.on_failure = fail_inner
  658. self.push_format_context()
  659. res, expl = self.visit(v)
  660. body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
  661. expl_format = self.pop_format_context(ast.Str(expl))
  662. call = ast_Call(app, [expl_format], [])
  663. self.on_failure.append(ast.Expr(call))
  664. if i < levels:
  665. cond = res
  666. if is_or:
  667. cond = ast.UnaryOp(ast.Not(), cond)
  668. inner = []
  669. self.statements.append(ast.If(cond, inner, []))
  670. self.statements = body = inner
  671. self.statements = save
  672. self.on_failure = fail_save
  673. expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
  674. expl = self.pop_format_context(expl_template)
  675. return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
  676. def visit_UnaryOp(self, unary):
  677. pattern = unary_map[unary.op.__class__]
  678. operand_res, operand_expl = self.visit(unary.operand)
  679. res = self.assign(ast.UnaryOp(unary.op, operand_res))
  680. return res, pattern % (operand_expl,)
  681. def visit_BinOp(self, binop):
  682. symbol = binop_map[binop.op.__class__]
  683. left_expr, left_expl = self.visit(binop.left)
  684. right_expr, right_expl = self.visit(binop.right)
  685. explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
  686. res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
  687. return res, explanation
  688. def visit_Call_35(self, call):
  689. """
  690. visit `ast.Call` nodes on Python3.5 and after
  691. """
  692. new_func, func_expl = self.visit(call.func)
  693. arg_expls = []
  694. new_args = []
  695. new_kwargs = []
  696. for arg in call.args:
  697. res, expl = self.visit(arg)
  698. arg_expls.append(expl)
  699. new_args.append(res)
  700. for keyword in call.keywords:
  701. res, expl = self.visit(keyword.value)
  702. new_kwargs.append(ast.keyword(keyword.arg, res))
  703. if keyword.arg:
  704. arg_expls.append(keyword.arg + "=" + expl)
  705. else: ## **args have `arg` keywords with an .arg of None
  706. arg_expls.append("**" + expl)
  707. expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
  708. new_call = ast.Call(new_func, new_args, new_kwargs)
  709. res = self.assign(new_call)
  710. res_expl = self.explanation_param(self.display(res))
  711. outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
  712. return res, outer_expl
  713. def visit_Starred(self, starred):
  714. # From Python 3.5, a Starred node can appear in a function call
  715. res, expl = self.visit(starred.value)
  716. return starred, '*' + expl
  717. def visit_Call_legacy(self, call):
  718. """
  719. visit `ast.Call nodes on 3.4 and below`
  720. """
  721. new_func, func_expl = self.visit(call.func)
  722. arg_expls = []
  723. new_args = []
  724. new_kwargs = []
  725. new_star = new_kwarg = None
  726. for arg in call.args:
  727. res, expl = self.visit(arg)
  728. new_args.append(res)
  729. arg_expls.append(expl)
  730. for keyword in call.keywords:
  731. res, expl = self.visit(keyword.value)
  732. new_kwargs.append(ast.keyword(keyword.arg, res))
  733. arg_expls.append(keyword.arg + "=" + expl)
  734. if call.starargs:
  735. new_star, expl = self.visit(call.starargs)
  736. arg_expls.append("*" + expl)
  737. if call.kwargs:
  738. new_kwarg, expl = self.visit(call.kwargs)
  739. arg_expls.append("**" + expl)
  740. expl = "%s(%s)" % (func_expl, ', '.join(arg_expls))
  741. new_call = ast.Call(new_func, new_args, new_kwargs,
  742. new_star, new_kwarg)
  743. res = self.assign(new_call)
  744. res_expl = self.explanation_param(self.display(res))
  745. outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
  746. return res, outer_expl
  747. # ast.Call signature changed on 3.5,
  748. # conditionally change which methods is named
  749. # visit_Call depending on Python version
  750. if sys.version_info >= (3, 5):
  751. visit_Call = visit_Call_35
  752. else:
  753. visit_Call = visit_Call_legacy
  754. def visit_Attribute(self, attr):
  755. if not isinstance(attr.ctx, ast.Load):
  756. return self.generic_visit(attr)
  757. value, value_expl = self.visit(attr.value)
  758. res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
  759. res_expl = self.explanation_param(self.display(res))
  760. pat = "%s\n{%s = %s.%s\n}"
  761. expl = pat % (res_expl, res_expl, value_expl, attr.attr)
  762. return res, expl
  763. def visit_Compare(self, comp):
  764. self.push_format_context()
  765. left_res, left_expl = self.visit(comp.left)
  766. res_variables = [self.variable() for i in range(len(comp.ops))]
  767. load_names = [ast.Name(v, ast.Load()) for v in res_variables]
  768. store_names = [ast.Name(v, ast.Store()) for v in res_variables]
  769. it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
  770. expls = []
  771. syms = []
  772. results = [left_res]
  773. for i, op, next_operand in it:
  774. next_res, next_expl = self.visit(next_operand)
  775. results.append(next_res)
  776. sym = binop_map[op.__class__]
  777. syms.append(ast.Str(sym))
  778. expl = "%s %s %s" % (left_expl, sym, next_expl)
  779. expls.append(ast.Str(expl))
  780. res_expr = ast.Compare(left_res, [op], [next_res])
  781. self.statements.append(ast.Assign([store_names[i]], res_expr))
  782. left_res, left_expl = next_res, next_expl
  783. # Use pytest.assertion.util._reprcompare if that's available.
  784. expl_call = self.helper("call_reprcompare",
  785. ast.Tuple(syms, ast.Load()),
  786. ast.Tuple(load_names, ast.Load()),
  787. ast.Tuple(expls, ast.Load()),
  788. ast.Tuple(results, ast.Load()))
  789. if len(comp.ops) > 1:
  790. res = ast.BoolOp(ast.And(), load_names)
  791. else:
  792. res = load_names[0]
  793. return res, self.explanation_param(self.pop_format_context(expl_call))