test_concurrency.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # Licensed under the Apache License: http://www.apache.org/licenses/LICENSE-2.0
  2. # For details: https://bitbucket.org/ned/coveragepy/src/default/NOTICE.txt
  3. """Tests for concurrency libraries."""
  4. import multiprocessing
  5. import threading
  6. import coverage
  7. from coverage import env
  8. from coverage.files import abs_file
  9. from tests.coveragetest import CoverageTest
  10. # These libraries aren't always available, we'll skip tests if they aren't.
  11. try:
  12. import eventlet
  13. except ImportError:
  14. eventlet = None
  15. try:
  16. import gevent
  17. except ImportError:
  18. gevent = None
  19. import greenlet
  20. def measurable_line(l):
  21. """Is this a line of code coverage will measure?
  22. Not blank, not a comment, and not "else"
  23. """
  24. l = l.strip()
  25. if not l:
  26. return False
  27. if l.startswith('#'):
  28. return False
  29. if l.startswith('else:'):
  30. return False
  31. return True
  32. def line_count(s):
  33. """How many measurable lines are in `s`?"""
  34. return len(list(filter(measurable_line, s.splitlines())))
  35. def print_simple_annotation(code, linenos):
  36. """Print the lines in `code` with X for each line number in `linenos`."""
  37. for lineno, line in enumerate(code.splitlines(), start=1):
  38. print(" {0} {1}".format("X" if lineno in linenos else " ", line))
  39. class LineCountTest(CoverageTest):
  40. """Test the helpers here."""
  41. run_in_temp_dir = False
  42. def test_line_count(self):
  43. CODE = """
  44. # Hey there!
  45. x = 1
  46. if x:
  47. print("hello")
  48. else:
  49. print("bye")
  50. print("done")
  51. """
  52. self.assertEqual(line_count(CODE), 5)
  53. # The code common to all the concurrency models.
  54. SUM_RANGE_Q = """
  55. # Above this will be imports defining queue and threading.
  56. class Producer(threading.Thread):
  57. def __init__(self, limit, q):
  58. threading.Thread.__init__(self)
  59. self.limit = limit
  60. self.q = q
  61. def run(self):
  62. for i in range(self.limit):
  63. self.q.put(i)
  64. self.q.put(None)
  65. class Consumer(threading.Thread):
  66. def __init__(self, q, qresult):
  67. threading.Thread.__init__(self)
  68. self.q = q
  69. self.qresult = qresult
  70. def run(self):
  71. sum = 0
  72. while True:
  73. i = self.q.get()
  74. if i is None:
  75. break
  76. sum += i
  77. self.qresult.put(sum)
  78. def sum_range(limit):
  79. q = queue.Queue()
  80. qresult = queue.Queue()
  81. c = Consumer(q, qresult)
  82. p = Producer(limit, q)
  83. c.start()
  84. p.start()
  85. p.join()
  86. c.join()
  87. return qresult.get()
  88. # Below this will be something using sum_range.
  89. """
  90. PRINT_SUM_RANGE = """
  91. print(sum_range({QLIMIT}))
  92. """
  93. # Import the things to use threads.
  94. if env.PY2:
  95. THREAD = """
  96. import threading
  97. import Queue as queue
  98. """
  99. else:
  100. THREAD = """
  101. import threading
  102. import queue
  103. """
  104. # Import the things to use eventlet.
  105. EVENTLET = """
  106. import eventlet.green.threading as threading
  107. import eventlet.queue as queue
  108. """
  109. # Import the things to use gevent.
  110. GEVENT = """
  111. from gevent import monkey
  112. monkey.patch_thread()
  113. import threading
  114. import gevent.queue as queue
  115. """
  116. # Uncomplicated code that doesn't use any of the concurrency stuff, to test
  117. # the simple case under each of the regimes.
  118. SIMPLE = """
  119. total = 0
  120. for i in range({QLIMIT}):
  121. total += i
  122. print(total)
  123. """
  124. def cant_trace_msg(concurrency, the_module):
  125. """What might coverage.py say about a concurrency setting and imported module?"""
  126. # In the concurrency choices, "multiprocessing" doesn't count, so remove it.
  127. if "multiprocessing" in concurrency:
  128. parts = concurrency.split(",")
  129. parts.remove("multiprocessing")
  130. concurrency = ",".join(parts)
  131. if the_module is None:
  132. # We don't even have the underlying module installed, we expect
  133. # coverage to alert us to this fact.
  134. expected_out = (
  135. "Couldn't trace with concurrency=%s, "
  136. "the module isn't installed.\n" % concurrency
  137. )
  138. elif env.C_TRACER or concurrency == "thread" or concurrency == "":
  139. expected_out = None
  140. else:
  141. expected_out = (
  142. "Can't support concurrency=%s with PyTracer, "
  143. "only threads are supported\n" % concurrency
  144. )
  145. return expected_out
  146. class ConcurrencyTest(CoverageTest):
  147. """Tests of the concurrency support in coverage.py."""
  148. QLIMIT = 1000
  149. def try_some_code(self, code, concurrency, the_module, expected_out=None):
  150. """Run some concurrency testing code and see that it was all covered.
  151. `code` is the Python code to execute. `concurrency` is the name of
  152. the concurrency regime to test it under. `the_module` is the imported
  153. module that must be available for this to work at all. `expected_out`
  154. is the text we expect the code to produce.
  155. """
  156. self.make_file("try_it.py", code)
  157. cmd = "coverage run --concurrency=%s try_it.py" % concurrency
  158. out = self.run_command(cmd)
  159. expected_cant_trace = cant_trace_msg(concurrency, the_module)
  160. if expected_cant_trace is not None:
  161. self.assertEqual(out, expected_cant_trace)
  162. else:
  163. # We can fully measure the code if we are using the C tracer, which
  164. # can support all the concurrency, or if we are using threads.
  165. if expected_out is None:
  166. expected_out = "%d\n" % (sum(range(self.QLIMIT)))
  167. print(code)
  168. self.assertEqual(out, expected_out)
  169. # Read the coverage file and see that try_it.py has all its lines
  170. # executed.
  171. data = coverage.CoverageData()
  172. data.read_file(".coverage")
  173. # If the test fails, it's helpful to see this info:
  174. fname = abs_file("try_it.py")
  175. linenos = data.lines(fname)
  176. print("{0}: {1}".format(len(linenos), linenos))
  177. print_simple_annotation(code, linenos)
  178. lines = line_count(code)
  179. self.assertEqual(data.line_counts()['try_it.py'], lines)
  180. def test_threads(self):
  181. code = (THREAD + SUM_RANGE_Q + PRINT_SUM_RANGE).format(QLIMIT=self.QLIMIT)
  182. self.try_some_code(code, "thread", threading)
  183. def test_threads_simple_code(self):
  184. code = SIMPLE.format(QLIMIT=self.QLIMIT)
  185. self.try_some_code(code, "thread", threading)
  186. def test_eventlet(self):
  187. code = (EVENTLET + SUM_RANGE_Q + PRINT_SUM_RANGE).format(QLIMIT=self.QLIMIT)
  188. self.try_some_code(code, "eventlet", eventlet)
  189. def test_eventlet_simple_code(self):
  190. code = SIMPLE.format(QLIMIT=self.QLIMIT)
  191. self.try_some_code(code, "eventlet", eventlet)
  192. def test_gevent(self):
  193. code = (GEVENT + SUM_RANGE_Q + PRINT_SUM_RANGE).format(QLIMIT=self.QLIMIT)
  194. self.try_some_code(code, "gevent", gevent)
  195. def test_gevent_simple_code(self):
  196. code = SIMPLE.format(QLIMIT=self.QLIMIT)
  197. self.try_some_code(code, "gevent", gevent)
  198. def test_greenlet(self):
  199. GREENLET = """\
  200. from greenlet import greenlet
  201. def test1(x, y):
  202. z = gr2.switch(x+y)
  203. print(z)
  204. def test2(u):
  205. print(u)
  206. gr1.switch(42)
  207. gr1 = greenlet(test1)
  208. gr2 = greenlet(test2)
  209. gr1.switch("hello", " world")
  210. """
  211. self.try_some_code(GREENLET, "greenlet", greenlet, "hello world\n42\n")
  212. def test_greenlet_simple_code(self):
  213. code = SIMPLE.format(QLIMIT=self.QLIMIT)
  214. self.try_some_code(code, "greenlet", greenlet)
  215. def test_bug_330(self):
  216. BUG_330 = """\
  217. from weakref import WeakKeyDictionary
  218. import eventlet
  219. def do():
  220. eventlet.sleep(.01)
  221. gts = WeakKeyDictionary()
  222. for _ in range(100):
  223. gts[eventlet.spawn(do)] = True
  224. eventlet.sleep(.005)
  225. eventlet.sleep(.1)
  226. print(len(gts))
  227. """
  228. self.try_some_code(BUG_330, "eventlet", eventlet, "0\n")
  229. SQUARE_OR_CUBE_WORK = """
  230. def work(x):
  231. # Use different lines in different subprocesses.
  232. if x % 2:
  233. y = x*x
  234. else:
  235. y = x*x*x
  236. return y
  237. """
  238. SUM_RANGE_WORK = """
  239. def work(x):
  240. return sum_range((x+1)*100)
  241. """
  242. MULTI_CODE = """
  243. # Above this will be a defintion of work().
  244. import multiprocessing
  245. import os
  246. import time
  247. import sys
  248. def process_worker_main(args):
  249. # Need to pause, or the tasks go too quick, and some processes
  250. # in the pool don't get any work, and then don't record data.
  251. time.sleep(0.02)
  252. ret = work(*args)
  253. return os.getpid(), ret
  254. if __name__ == "__main__": # pragma: no branch
  255. # This if is on a single line so we can get 100% coverage
  256. # even if we have no arguments.
  257. if len(sys.argv) > 1: multiprocessing.set_start_method(sys.argv[1])
  258. pool = multiprocessing.Pool({NPROCS})
  259. inputs = [(x,) for x in range({UPTO})]
  260. outputs = pool.imap_unordered(process_worker_main, inputs)
  261. pids = set()
  262. total = 0
  263. for pid, sq in outputs:
  264. pids.add(pid)
  265. total += sq
  266. print("%d pids, total = %d" % (len(pids), total))
  267. pool.close()
  268. pool.join()
  269. """
  270. class MultiprocessingTest(CoverageTest):
  271. """Test support of the multiprocessing module."""
  272. def try_multiprocessing_code(
  273. self, code, expected_out, the_module, concurrency="multiprocessing"
  274. ):
  275. """Run code using multiprocessing, it should produce `expected_out`."""
  276. self.make_file("multi.py", code)
  277. self.make_file(".coveragerc", """\
  278. [run]
  279. concurrency = %s
  280. """ % concurrency)
  281. if env.PYVERSION >= (3, 4):
  282. start_methods = ['fork', 'spawn']
  283. else:
  284. start_methods = ['']
  285. for start_method in start_methods:
  286. if start_method and start_method not in multiprocessing.get_all_start_methods():
  287. continue
  288. out = self.run_command("coverage run multi.py %s" % (start_method,))
  289. expected_cant_trace = cant_trace_msg(concurrency, the_module)
  290. if expected_cant_trace is not None:
  291. self.assertEqual(out, expected_cant_trace)
  292. else:
  293. self.assertEqual(out.rstrip(), expected_out)
  294. out = self.run_command("coverage combine")
  295. self.assertEqual(out, "")
  296. out = self.run_command("coverage report -m")
  297. last_line = self.squeezed_lines(out)[-1]
  298. self.assertRegex(last_line, r"multi.py \d+ 0 100%")
  299. def test_multiprocessing(self):
  300. nprocs = 3
  301. upto = 30
  302. code = (SQUARE_OR_CUBE_WORK + MULTI_CODE).format(NPROCS=nprocs, UPTO=upto)
  303. total = sum(x*x if x%2 else x*x*x for x in range(upto))
  304. expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total)
  305. self.try_multiprocessing_code(code, expected_out, threading)
  306. def test_multiprocessing_and_gevent(self):
  307. nprocs = 3
  308. upto = 30
  309. code = (
  310. SUM_RANGE_WORK + EVENTLET + SUM_RANGE_Q + MULTI_CODE
  311. ).format(NPROCS=nprocs, UPTO=upto)
  312. total = sum(sum(range((x + 1) * 100)) for x in range(upto))
  313. expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total)
  314. self.try_multiprocessing_code(
  315. code, expected_out, eventlet, concurrency="multiprocessing,eventlet"
  316. )
  317. def try_multiprocessing_code_with_branching(self, code, expected_out):
  318. """Run code using multiprocessing, it should produce `expected_out`."""
  319. self.make_file("multi.py", code)
  320. self.make_file("multi.rc", """\
  321. [run]
  322. concurrency = multiprocessing
  323. branch = True
  324. """)
  325. if env.PYVERSION >= (3, 4):
  326. start_methods = ['fork', 'spawn']
  327. else:
  328. start_methods = ['']
  329. for start_method in start_methods:
  330. if start_method and start_method not in multiprocessing.get_all_start_methods():
  331. continue
  332. out = self.run_command("coverage run --rcfile=multi.rc multi.py %s" % (start_method,))
  333. self.assertEqual(out.rstrip(), expected_out)
  334. out = self.run_command("coverage combine")
  335. self.assertEqual(out, "")
  336. out = self.run_command("coverage report -m")
  337. last_line = self.squeezed_lines(out)[-1]
  338. self.assertRegex(last_line, r"multi.py \d+ 0 \d+ 0 100%")
  339. def test_multiprocessing_with_branching(self):
  340. nprocs = 3
  341. upto = 30
  342. code = (SQUARE_OR_CUBE_WORK + MULTI_CODE).format(NPROCS=nprocs, UPTO=upto)
  343. total = sum(x*x if x%2 else x*x*x for x in range(upto))
  344. expected_out = "{nprocs} pids, total = {total}".format(nprocs=nprocs, total=total)
  345. self.try_multiprocessing_code_with_branching(code, expected_out)