luaunit.lua 127 KB


  1. --[[
  2. luaunit.lua
  3. Description: A unit testing framework
  4. Homepage: https://github.com/bluebird75/luaunit
  5. Development by Philippe Fremy <phil@freehackers.org>
  6. Based on initial work of Ryu, Gwang (http://www.gpgstudy.com/gpgiki/LuaUnit)
  7. License: BSD License, see LICENSE.txt
  8. ]]--
  9. require("math")
  10. local M={}
  11. -- private exported functions (for testing)
  12. M.private = {}
  13. M.VERSION='3.4'
  14. M._VERSION=M.VERSION -- For LuaUnit v2 compatibility
  15. -- a version which distinguish between regular Lua and LuaJit
  16. M._LUAVERSION = (jit and jit.version) or _VERSION
  17. --[[ Some people like assertEquals( actual, expected ) and some people prefer
  18. assertEquals( expected, actual ).
  19. ]]--
  20. M.ORDER_ACTUAL_EXPECTED = true
  21. M.PRINT_TABLE_REF_IN_ERROR_MSG = false
  22. M.LINE_LENGTH = 80
  23. M.TABLE_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items
  24. M.LIST_DIFF_ANALYSIS_THRESHOLD = 10 -- display deep analysis for more than 10 items
  25. -- this setting allow to remove entries from the stack-trace, for
  26. -- example to hide a call to a framework which would be calling luaunit
  27. M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE = 0
  28. --[[ EPS is meant to help with Lua's floating point math in simple corner
  29. cases like almostEquals(1.1-0.1, 1), which may not work as-is (e.g. on numbers
  30. with rational binary representation) if the user doesn't provide some explicit
  31. error margin.
  32. The default margin used by almostEquals() in such cases is EPS; and since
  33. Lua may be compiled with different numeric precisions (single vs. double), we
  34. try to select a useful default for it dynamically. Note: If the initial value
  35. is not acceptable, it can be changed by the user to better suit specific needs.
  36. See also: https://en.wikipedia.org/wiki/Machine_epsilon
  37. ]]
  38. M.EPS = 2^-52 -- = machine epsilon for "double", ~2.22E-16
  39. if math.abs(1.1 - 1 - 0.1) > M.EPS then
  40. -- rounding error is above EPS, assume single precision
  41. M.EPS = 2^-23 -- = machine epsilon for "float", ~1.19E-07
  42. end
  43. -- set this to false to debug luaunit
  44. local STRIP_LUAUNIT_FROM_STACKTRACE = true
  45. M.VERBOSITY_DEFAULT = 10
  46. M.VERBOSITY_LOW = 1
  47. M.VERBOSITY_QUIET = 0
  48. M.VERBOSITY_VERBOSE = 20
  49. M.DEFAULT_DEEP_ANALYSIS = nil
  50. M.FORCE_DEEP_ANALYSIS = true
  51. M.DISABLE_DEEP_ANALYSIS = false
  52. -- set EXPORT_ASSERT_TO_GLOBALS to have all asserts visible as global values
  53. -- EXPORT_ASSERT_TO_GLOBALS = true
  54. -- we need to keep a copy of the script args before it is overriden
  55. local cmdline_argv = rawget(_G, "arg")
  56. M.FAILURE_PREFIX = 'LuaUnit test FAILURE: ' -- prefix string for failed tests
  57. M.SUCCESS_PREFIX = 'LuaUnit test SUCCESS: ' -- prefix string for successful tests finished early
  58. M.SKIP_PREFIX = 'LuaUnit test SKIP: ' -- prefix string for skipped tests
  59. M.USAGE=[[Usage: lua <your_test_suite.lua> [options] [testname1 [testname2] ... ]
  60. Options:
  61. -h, --help: Print this help
  62. --version: Print version information
  63. -v, --verbose: Increase verbosity
  64. -q, --quiet: Set verbosity to minimum
  65. -e, --error: Stop on first error
  66. -f, --failure: Stop on first failure or error
  67. -s, --shuffle: Shuffle tests before running them
  68. -o, --output OUTPUT: Set output type to OUTPUT
  69. Possible values: text, tap, junit, nil
  70. -n, --name NAME: For junit only, mandatory name of xml file
  71. -r, --repeat NUM: Execute all tests NUM times, e.g. to trig the JIT
  72. -p, --pattern PATTERN: Execute all test names matching the Lua PATTERN
  73. May be repeated to include several patterns
  74. Make sure you escape magic chars like +? with %
  75. -x, --exclude PATTERN: Exclude all test names matching the Lua PATTERN
  76. May be repeated to exclude several patterns
  77. Make sure you escape magic chars like +? with %
  78. testname1, testname2, ... : tests to run in the form of testFunction,
  79. TestClass or TestClass.testMethod
  80. You may also control LuaUnit options with the following environment variables:
  81. * LUAUNIT_OUTPUT: same as --output
  82. * LUAUNIT_JUNIT_FNAME: same as --name ]]
  83. ----------------------------------------------------------------
  84. --
  85. -- general utility functions
  86. --
  87. ----------------------------------------------------------------
  88. --[[ Note on catching exit
  89. I have seen the case where running a big suite of test cases and one of them would
  90. perform a os.exit(0), making the outside world think that the full test suite was executed
  91. successfully.
  92. This is an attempt to mitigate this problem: we override os.exit() to now let a test
  93. exit the framework while we are running. When we are not running, it behaves normally.
  94. ]]
  95. M.oldOsExit = os.exit
  96. os.exit = function(...)
  97. if M.LuaUnit and #M.LuaUnit.instances ~= 0 then
  98. local msg = [[You are trying to exit but there is still a running instance of LuaUnit.
  99. LuaUnit expects to run until the end before exiting with a complete status of successful/failed tests.
  100. To force exit LuaUnit while running, please call before os.exit (assuming lu is the luaunit module loaded):
  101. lu.unregisterCurrentSuite()
  102. ]]
  103. M.private.error_fmt(2, msg)
  104. end
  105. M.oldOsExit(...)
  106. end
  107. local function pcall_or_abort(func, ...)
  108. -- unpack is a global function for Lua 5.1, otherwise use table.unpack
  109. local unpack = rawget(_G, "unpack") or table.unpack
  110. local result = {pcall(func, ...)}
  111. if not result[1] then
  112. -- an error occurred
  113. print(result[2]) -- error message
  114. print()
  115. print(M.USAGE)
  116. os.exit(-1)
  117. end
  118. return unpack(result, 2)
  119. end
  120. local crossTypeOrdering = {
  121. number = 1, boolean = 2, string = 3, table = 4, other = 5
  122. }
  123. local crossTypeComparison = {
  124. number = function(a, b) return a < b end,
  125. string = function(a, b) return a < b end,
  126. other = function(a, b) return tostring(a) < tostring(b) end,
  127. }
  128. local function crossTypeSort(a, b)
  129. local type_a, type_b = type(a), type(b)
  130. if type_a == type_b then
  131. local func = crossTypeComparison[type_a] or crossTypeComparison.other
  132. return func(a, b)
  133. end
  134. type_a = crossTypeOrdering[type_a] or crossTypeOrdering.other
  135. type_b = crossTypeOrdering[type_b] or crossTypeOrdering.other
  136. return type_a < type_b
  137. end
  138. local function __genSortedIndex( t )
  139. -- Returns a sequence consisting of t's keys, sorted.
  140. local sortedIndex = {}
  141. for key,_ in pairs(t) do
  142. table.insert(sortedIndex, key)
  143. end
  144. table.sort(sortedIndex, crossTypeSort)
  145. return sortedIndex
  146. end
  147. M.private.__genSortedIndex = __genSortedIndex
  148. local function sortedNext(state, control)
  149. -- Equivalent of the next() function of table iteration, but returns the
  150. -- keys in sorted order (see __genSortedIndex and crossTypeSort).
  151. -- The state is a temporary variable during iteration and contains the
  152. -- sorted key table (state.sortedIdx). It also stores the last index (into
  153. -- the keys) used by the iteration, to find the next one quickly.
  154. local key
  155. --print("sortedNext: control = "..tostring(control) )
  156. if control == nil then
  157. -- start of iteration
  158. state.count = #state.sortedIdx
  159. state.lastIdx = 1
  160. key = state.sortedIdx[1]
  161. return key, state.t[key]
  162. end
  163. -- normally, we expect the control variable to match the last key used
  164. if control ~= state.sortedIdx[state.lastIdx] then
  165. -- strange, we have to find the next value by ourselves
  166. -- the key table is sorted in crossTypeSort() order! -> use bisection
  167. local lower, upper = 1, state.count
  168. repeat
  169. state.lastIdx = math.modf((lower + upper) / 2)
  170. key = state.sortedIdx[state.lastIdx]
  171. if key == control then
  172. break -- key found (and thus prev index)
  173. end
  174. if crossTypeSort(key, control) then
  175. -- key < control, continue search "right" (towards upper bound)
  176. lower = state.lastIdx + 1
  177. else
  178. -- key > control, continue search "left" (towards lower bound)
  179. upper = state.lastIdx - 1
  180. end
  181. until lower > upper
  182. if lower > upper then -- only true if the key wasn't found, ...
  183. state.lastIdx = state.count -- ... so ensure no match in code below
  184. end
  185. end
  186. -- proceed by retrieving the next value (or nil) from the sorted keys
  187. state.lastIdx = state.lastIdx + 1
  188. key = state.sortedIdx[state.lastIdx]
  189. if key then
  190. return key, state.t[key]
  191. end
  192. -- getting here means returning `nil`, which will end the iteration
  193. end
  194. local function sortedPairs(tbl)
  195. -- Equivalent of the pairs() function on tables. Allows to iterate in
  196. -- sorted order. As required by "generic for" loops, this will return the
  197. -- iterator (function), an "invariant state", and the initial control value.
  198. -- (see http://www.lua.org/pil/7.2.html)
  199. return sortedNext, {t = tbl, sortedIdx = __genSortedIndex(tbl)}, nil
  200. end
  201. M.private.sortedPairs = sortedPairs
  202. -- seed the random with a strongly varying seed
  203. math.randomseed(math.floor(os.clock()*1E11))
  204. local function randomizeTable( t )
  205. -- randomize the item orders of the table t
  206. for i = #t, 2, -1 do
  207. local j = math.random(i)
  208. if i ~= j then
  209. t[i], t[j] = t[j], t[i]
  210. end
  211. end
  212. end
  213. M.private.randomizeTable = randomizeTable
  214. local function strsplit(delimiter, text)
  215. -- Split text into a list consisting of the strings in text, separated
  216. -- by strings matching delimiter (which may _NOT_ be a pattern).
  217. -- Example: strsplit(", ", "Anna, Bob, Charlie, Dolores")
  218. if delimiter == "" or delimiter == nil then -- this would result in endless loops
  219. error("delimiter is nil or empty string!")
  220. end
  221. if text == nil then
  222. return nil
  223. end
  224. local list, pos, first, last = {}, 1
  225. while true do
  226. first, last = text:find(delimiter, pos, true)
  227. if first then -- found?
  228. table.insert(list, text:sub(pos, first - 1))
  229. pos = last + 1
  230. else
  231. table.insert(list, text:sub(pos))
  232. break
  233. end
  234. end
  235. return list
  236. end
  237. M.private.strsplit = strsplit
  238. local function hasNewLine( s )
  239. -- return true if s has a newline
  240. return (string.find(s, '\n', 1, true) ~= nil)
  241. end
  242. M.private.hasNewLine = hasNewLine
  243. local function prefixString( prefix, s )
  244. -- Prefix all the lines of s with prefix
  245. return prefix .. string.gsub(s, '\n', '\n' .. prefix)
  246. end
  247. M.private.prefixString = prefixString
  248. local function strMatch(s, pattern, start, final )
  249. -- return true if s matches completely the pattern from index start to index end
  250. -- return false in every other cases
  251. -- if start is nil, matches from the beginning of the string
  252. -- if final is nil, matches to the end of the string
  253. start = start or 1
  254. final = final or string.len(s)
  255. local foundStart, foundEnd = string.find(s, pattern, start, false)
  256. return foundStart == start and foundEnd == final
  257. end
  258. M.private.strMatch = strMatch
  259. local function patternFilter(patterns, expr)
  260. -- Run `expr` through the inclusion and exclusion rules defined in patterns
  261. -- and return true if expr shall be included, false for excluded.
  262. -- Inclusion pattern are defined as normal patterns, exclusions
  263. -- patterns start with `!` and are followed by a normal pattern
  264. -- result: nil = UNKNOWN (not matched yet), true = ACCEPT, false = REJECT
  265. -- default: true if no explicit "include" is found, set to false otherwise
  266. local default, result = true, nil
  267. if patterns ~= nil then
  268. for _, pattern in ipairs(patterns) do
  269. local exclude = pattern:sub(1,1) == '!'
  270. if exclude then
  271. pattern = pattern:sub(2)
  272. else
  273. -- at least one include pattern specified, a match is required
  274. default = false
  275. end
  276. -- print('pattern: ',pattern)
  277. -- print('exclude: ',exclude)
  278. -- print('default: ',default)
  279. if string.find(expr, pattern) then
  280. -- set result to false when excluding, true otherwise
  281. result = not exclude
  282. end
  283. end
  284. end
  285. if result ~= nil then
  286. return result
  287. end
  288. return default
  289. end
  290. M.private.patternFilter = patternFilter
  291. local function xmlEscape( s )
  292. -- Return s escaped for XML attributes
  293. -- escapes table:
  294. -- " &quot;
  295. -- ' &apos;
  296. -- < &lt;
  297. -- > &gt;
  298. -- & &amp;
  299. return string.gsub( s, '.', {
  300. ['&'] = "&amp;",
  301. ['"'] = "&quot;",
  302. ["'"] = "&apos;",
  303. ['<'] = "&lt;",
  304. ['>'] = "&gt;",
  305. } )
  306. end
  307. M.private.xmlEscape = xmlEscape
  308. local function xmlCDataEscape( s )
  309. -- Return s escaped for CData section, escapes: "]]>"
  310. return string.gsub( s, ']]>', ']]&gt;' )
  311. end
  312. M.private.xmlCDataEscape = xmlCDataEscape
  313. local function lstrip( s )
  314. --[[Return s with all leading white spaces and tabs removed]]
  315. local idx = 0
  316. while idx < s:len() do
  317. idx = idx + 1
  318. local c = s:sub(idx,idx)
  319. if c ~= ' ' and c ~= '\t' then
  320. break
  321. end
  322. end
  323. return s:sub(idx)
  324. end
  325. M.private.lstrip = lstrip
  326. local function extractFileLineInfo( s )
  327. --[[ From a string in the form "(leading spaces) dir1/dir2\dir3\file.lua:linenb: msg"
  328. Return the "file.lua:linenb" information
  329. ]]
  330. local s2 = lstrip(s)
  331. local firstColon = s2:find(':', 1, true)
  332. if firstColon == nil then
  333. -- string is not in the format file:line:
  334. return s
  335. end
  336. local secondColon = s2:find(':', firstColon+1, true)
  337. if secondColon == nil then
  338. -- string is not in the format file:line:
  339. return s
  340. end
  341. return s2:sub(1, secondColon-1)
  342. end
  343. M.private.extractFileLineInfo = extractFileLineInfo
  344. local function stripLuaunitTrace2( stackTrace, errMsg )
  345. --[[
  346. -- Example of a traceback:
  347. <<stack traceback:
  348. example_with_luaunit.lua:130: in function 'test2_withFailure'
  349. ./luaunit.lua:1449: in function <./luaunit.lua:1449>
  350. [C]: in function 'xpcall'
  351. ./luaunit.lua:1449: in function 'protectedCall'
  352. ./luaunit.lua:1508: in function 'execOneFunction'
  353. ./luaunit.lua:1596: in function 'runSuiteByInstances'
  354. ./luaunit.lua:1660: in function 'runSuiteByNames'
  355. ./luaunit.lua:1736: in function 'runSuite'
  356. example_with_luaunit.lua:140: in main chunk
  357. [C]: in ?>>
  358. error message: <<example_with_luaunit.lua:130: expected 2, got 1>>
  359. Other example:
  360. <<stack traceback:
  361. ./luaunit.lua:545: in function 'assertEquals'
  362. example_with_luaunit.lua:58: in function 'TestToto.test7'
  363. ./luaunit.lua:1517: in function <./luaunit.lua:1517>
  364. [C]: in function 'xpcall'
  365. ./luaunit.lua:1517: in function 'protectedCall'
  366. ./luaunit.lua:1578: in function 'execOneFunction'
  367. ./luaunit.lua:1677: in function 'runSuiteByInstances'
  368. ./luaunit.lua:1730: in function 'runSuiteByNames'
  369. ./luaunit.lua:1806: in function 'runSuite'
  370. example_with_luaunit.lua:140: in main chunk
  371. [C]: in ?>>
  372. error message: <<example_with_luaunit.lua:58: expected 2, got 1>>
  373. <<stack traceback:
  374. luaunit2/example_with_luaunit.lua:124: in function 'test1_withFailure'
  375. luaunit2/luaunit.lua:1532: in function <luaunit2/luaunit.lua:1532>
  376. [C]: in function 'xpcall'
  377. luaunit2/luaunit.lua:1532: in function 'protectedCall'
  378. luaunit2/luaunit.lua:1591: in function 'execOneFunction'
  379. luaunit2/luaunit.lua:1679: in function 'runSuiteByInstances'
  380. luaunit2/luaunit.lua:1743: in function 'runSuiteByNames'
  381. luaunit2/luaunit.lua:1819: in function 'runSuite'
  382. luaunit2/example_with_luaunit.lua:140: in main chunk
  383. [C]: in ?>>
  384. error message: <<luaunit2/example_with_luaunit.lua:124: expected 2, got 1>>
  385. -- first line is "stack traceback": KEEP
  386. -- next line may be luaunit line: REMOVE
  387. -- next lines are call in the program under testOk: REMOVE
  388. -- next lines are calls from luaunit to call the program under test: KEEP
  389. -- Strategy:
  390. -- keep first line
  391. -- remove lines that are part of luaunit
  392. -- kepp lines until we hit a luaunit line
  393. The strategy for stripping is:
  394. * keep first line "stack traceback:"
  395. * part1:
  396. * analyse all lines of the stack from bottom to top of the stack (first line to last line)
  397. * extract the "file:line:" part of the line
  398. * compare it with the "file:line" part of the error message
  399. * if it does not match strip the line
  400. * if it matches, keep the line and move to part 2
  401. * part2:
  402. * anything NOT starting with luaunit.lua is the interesting part of the stack trace
  403. * anything starting again with luaunit.lua is part of the test launcher and should be stripped out
  404. ]]
  405. local function isLuaunitInternalLine( s )
  406. -- return true if line of stack trace comes from inside luaunit
  407. return s:find('[/\\]luaunit%.lua:%d+: ') ~= nil
  408. end
  409. -- print( '<<'..stackTrace..'>>' )
  410. local t = strsplit( '\n', stackTrace )
  411. -- print( prettystr(t) )
  412. local idx = 2
  413. local errMsgFileLine = extractFileLineInfo(errMsg)
  414. -- print('emfi="'..errMsgFileLine..'"')
  415. -- remove lines that are still part of luaunit
  416. while t[idx] and extractFileLineInfo(t[idx]) ~= errMsgFileLine do
  417. -- print('Removing : '..t[idx] )
  418. table.remove(t, idx)
  419. end
  420. -- keep lines until we hit luaunit again
  421. while t[idx] and (not isLuaunitInternalLine(t[idx])) do
  422. -- print('Keeping : '..t[idx] )
  423. idx = idx + 1
  424. end
  425. -- remove remaining luaunit lines
  426. while t[idx] do
  427. -- print('Removing2 : '..t[idx] )
  428. table.remove(t, idx)
  429. end
  430. -- print( prettystr(t) )
  431. return table.concat( t, '\n')
  432. end
  433. M.private.stripLuaunitTrace2 = stripLuaunitTrace2
  434. local function prettystr_sub(v, indentLevel, printTableRefs, cycleDetectTable )
  435. local type_v = type(v)
  436. if "string" == type_v then
  437. -- use clever delimiters according to content:
  438. -- enclose with single quotes if string contains ", but no '
  439. if v:find('"', 1, true) and not v:find("'", 1, true) then
  440. return "'" .. v .. "'"
  441. end
  442. -- use double quotes otherwise, escape embedded "
  443. return '"' .. v:gsub('"', '\\"') .. '"'
  444. elseif "table" == type_v then
  445. --if v.__class__ then
  446. -- return string.gsub( tostring(v), 'table', v.__class__ )
  447. --end
  448. return M.private._table_tostring(v, indentLevel, printTableRefs, cycleDetectTable)
  449. elseif "number" == type_v then
  450. -- eliminate differences in formatting between various Lua versions
  451. if v ~= v then
  452. return "#NaN" -- "not a number"
  453. end
  454. if v == math.huge then
  455. return "#Inf" -- "infinite"
  456. end
  457. if v == -math.huge then
  458. return "-#Inf"
  459. end
  460. if _VERSION == "Lua 5.3" then
  461. local i = math.tointeger(v)
  462. if i then
  463. return tostring(i)
  464. end
  465. end
  466. end
  467. return tostring(v)
  468. end
  469. local function prettystr( v )
  470. --[[ Pretty string conversion, to display the full content of a variable of any type.
  471. * string are enclosed with " by default, or with ' if string contains a "
  472. * tables are expanded to show their full content, with indentation in case of nested tables
  473. ]]--
  474. local cycleDetectTable = {}
  475. local s = prettystr_sub(v, 1, M.PRINT_TABLE_REF_IN_ERROR_MSG, cycleDetectTable)
  476. if cycleDetectTable.detected and not M.PRINT_TABLE_REF_IN_ERROR_MSG then
  477. -- some table contain recursive references,
  478. -- so we must recompute the value by including all table references
  479. -- else the result looks like crap
  480. cycleDetectTable = {}
  481. s = prettystr_sub(v, 1, true, cycleDetectTable)
  482. end
  483. return s
  484. end
  485. M.prettystr = prettystr
  486. function M.adjust_err_msg_with_iter( err_msg, iter_msg )
  487. --[[ Adjust the error message err_msg: trim the FAILURE_PREFIX or SUCCESS_PREFIX information if needed,
  488. add the iteration message if any and return the result.
  489. err_msg: string, error message captured with pcall
  490. iter_msg: a string describing the current iteration ("iteration N") or nil
  491. if there is no iteration in this test.
  492. Returns: (new_err_msg, test_status)
  493. new_err_msg: string, adjusted error message, or nil in case of success
  494. test_status: M.NodeStatus.FAIL, SUCCESS or ERROR according to the information
  495. contained in the error message.
  496. ]]
  497. if iter_msg then
  498. iter_msg = iter_msg..', '
  499. else
  500. iter_msg = ''
  501. end
  502. local RE_FILE_LINE = '.*:%d+: '
  503. -- error message is not necessarily a string,
  504. -- so convert the value to string with prettystr()
  505. if type( err_msg ) ~= 'string' then
  506. err_msg = prettystr( err_msg )
  507. end
  508. if (err_msg:find( M.SUCCESS_PREFIX ) == 1) or err_msg:match( '('..RE_FILE_LINE..')' .. M.SUCCESS_PREFIX .. ".*" ) then
  509. -- test finished early with success()
  510. return nil, M.NodeStatus.SUCCESS
  511. end
  512. if (err_msg:find( M.SKIP_PREFIX ) == 1) or (err_msg:match( '('..RE_FILE_LINE..')' .. M.SKIP_PREFIX .. ".*" ) ~= nil) then
  513. -- substitute prefix by iteration message
  514. err_msg = err_msg:gsub('.*'..M.SKIP_PREFIX, iter_msg, 1)
  515. -- print("failure detected")
  516. return err_msg, M.NodeStatus.SKIP
  517. end
  518. if (err_msg:find( M.FAILURE_PREFIX ) == 1) or (err_msg:match( '('..RE_FILE_LINE..')' .. M.FAILURE_PREFIX .. ".*" ) ~= nil) then
  519. -- substitute prefix by iteration message
  520. err_msg = err_msg:gsub(M.FAILURE_PREFIX, iter_msg, 1)
  521. -- print("failure detected")
  522. return err_msg, M.NodeStatus.FAIL
  523. end
  524. -- print("error detected")
  525. -- regular error, not a failure
  526. if iter_msg then
  527. local match
  528. -- "./test\\test_luaunit.lua:2241: some error msg
  529. match = err_msg:match( '(.*:%d+: ).*' )
  530. if match then
  531. err_msg = err_msg:gsub( match, match .. iter_msg )
  532. else
  533. -- no file:line: infromation, just add the iteration info at the beginning of the line
  534. err_msg = iter_msg .. err_msg
  535. end
  536. end
  537. return err_msg, M.NodeStatus.ERROR
  538. end
  539. local function tryMismatchFormatting( table_a, table_b, doDeepAnalysis, margin )
  540. --[[
  541. Prepares a nice error message when comparing tables, performing a deeper
  542. analysis.
  543. Arguments:
  544. * table_a, table_b: tables to be compared
  545. * doDeepAnalysis:
  546. M.DEFAULT_DEEP_ANALYSIS: (the default if not specified) perform deep analysis only for big lists and big dictionnaries
  547. M.FORCE_DEEP_ANALYSIS : always perform deep analysis
  548. M.DISABLE_DEEP_ANALYSIS: never perform deep analysis
  549. * margin: supplied only for almost equality
  550. Returns: {success, result}
  551. * success: false if deep analysis could not be performed
  552. in this case, just use standard assertion message
  553. * result: if success is true, a multi-line string with deep analysis of the two lists
  554. ]]
  555. -- check if table_a & table_b are suitable for deep analysis
  556. if type(table_a) ~= 'table' or type(table_b) ~= 'table' then
  557. return false
  558. end
  559. if doDeepAnalysis == M.DISABLE_DEEP_ANALYSIS then
  560. return false
  561. end
  562. local len_a, len_b, isPureList = #table_a, #table_b, true
  563. for k1, v1 in pairs(table_a) do
  564. if type(k1) ~= 'number' or k1 > len_a then
  565. -- this table a mapping
  566. isPureList = false
  567. break
  568. end
  569. end
  570. if isPureList then
  571. for k2, v2 in pairs(table_b) do
  572. if type(k2) ~= 'number' or k2 > len_b then
  573. -- this table a mapping
  574. isPureList = false
  575. break
  576. end
  577. end
  578. end
  579. if isPureList and math.min(len_a, len_b) < M.LIST_DIFF_ANALYSIS_THRESHOLD then
  580. if not (doDeepAnalysis == M.FORCE_DEEP_ANALYSIS) then
  581. return false
  582. end
  583. end
  584. if isPureList then
  585. return M.private.mismatchFormattingPureList( table_a, table_b, margin )
  586. else
  587. -- only work on mapping for the moment
  588. -- return M.private.mismatchFormattingMapping( table_a, table_b, doDeepAnalysis )
  589. return false
  590. end
  591. end
  592. M.private.tryMismatchFormatting = tryMismatchFormatting
  593. local function getTaTbDescr()
  594. if not M.ORDER_ACTUAL_EXPECTED then
  595. return 'expected', 'actual'
  596. end
  597. return 'actual', 'expected'
  598. end
  599. local function extendWithStrFmt( res, ... )
  600. table.insert( res, string.format( ... ) )
  601. end
  602. local function mismatchFormattingMapping( table_a, table_b, doDeepAnalysis )
  603. --[[
  604. Prepares a nice error message when comparing tables which are not pure lists, performing a deeper
  605. analysis.
  606. Returns: {success, result}
  607. * success: false if deep analysis could not be performed
  608. in this case, just use standard assertion message
  609. * result: if success is true, a multi-line string with deep analysis of the two lists
  610. ]]
  611. -- disable for the moment
  612. --[[
  613. local result = {}
  614. local descrTa, descrTb = getTaTbDescr()
  615. local keysCommon = {}
  616. local keysOnlyTa = {}
  617. local keysOnlyTb = {}
  618. local keysDiffTaTb = {}
  619. local k, v
  620. for k,v in pairs( table_a ) do
  621. if is_equal( v, table_b[k] ) then
  622. table.insert( keysCommon, k )
  623. else
  624. if table_b[k] == nil then
  625. table.insert( keysOnlyTa, k )
  626. else
  627. table.insert( keysDiffTaTb, k )
  628. end
  629. end
  630. end
  631. for k,v in pairs( table_b ) do
  632. if not is_equal( v, table_a[k] ) and table_a[k] == nil then
  633. table.insert( keysOnlyTb, k )
  634. end
  635. end
  636. local len_a = #keysCommon + #keysDiffTaTb + #keysOnlyTa
  637. local len_b = #keysCommon + #keysDiffTaTb + #keysOnlyTb
  638. local limited_display = (len_a < 5 or len_b < 5)
  639. if math.min(len_a, len_b) < M.TABLE_DIFF_ANALYSIS_THRESHOLD then
  640. return false
  641. end
  642. if not limited_display then
  643. if len_a == len_b then
  644. extendWithStrFmt( result, 'Table A (%s) and B (%s) both have %d items', descrTa, descrTb, len_a )
  645. else
  646. extendWithStrFmt( result, 'Table A (%s) has %d items and table B (%s) has %d items', descrTa, len_a, descrTb, len_b )
  647. end
  648. if #keysCommon == 0 and #keysDiffTaTb == 0 then
  649. table.insert( result, 'Table A and B have no keys in common, they are totally different')
  650. else
  651. local s_other = 'other '
  652. if #keysCommon then
  653. extendWithStrFmt( result, 'Table A and B have %d identical items', #keysCommon )
  654. else
  655. table.insert( result, 'Table A and B have no identical items' )
  656. s_other = ''
  657. end
  658. if #keysDiffTaTb ~= 0 then
  659. result[#result] = string.format( '%s and %d items differing present in both tables', result[#result], #keysDiffTaTb)
  660. else
  661. result[#result] = string.format( '%s and no %sitems differing present in both tables', result[#result], s_other, #keysDiffTaTb)
  662. end
  663. end
  664. extendWithStrFmt( result, 'Table A has %d keys not present in table B and table B has %d keys not present in table A', #keysOnlyTa, #keysOnlyTb )
  665. end
  666. local function keytostring(k)
  667. if "string" == type(k) and k:match("^[_%a][_%w]*$") then
  668. return k
  669. end
  670. return prettystr(k)
  671. end
  672. if #keysDiffTaTb ~= 0 then
  673. table.insert( result, 'Items differing in A and B:')
  674. for k,v in sortedPairs( keysDiffTaTb ) do
  675. extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) )
  676. extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) )
  677. end
  678. end
  679. if #keysOnlyTa ~= 0 then
  680. table.insert( result, 'Items only in table A:' )
  681. for k,v in sortedPairs( keysOnlyTa ) do
  682. extendWithStrFmt( result, ' - A[%s]: %s', keytostring(v), prettystr(table_a[v]) )
  683. end
  684. end
  685. if #keysOnlyTb ~= 0 then
  686. table.insert( result, 'Items only in table B:' )
  687. for k,v in sortedPairs( keysOnlyTb ) do
  688. extendWithStrFmt( result, ' + B[%s]: %s', keytostring(v), prettystr(table_b[v]) )
  689. end
  690. end
  691. if #keysCommon ~= 0 then
  692. table.insert( result, 'Items common to A and B:')
  693. for k,v in sortedPairs( keysCommon ) do
  694. extendWithStrFmt( result, ' = A and B [%s]: %s', keytostring(v), prettystr(table_a[v]) )
  695. end
  696. end
  697. return true, table.concat( result, '\n')
  698. ]]
  699. end
  700. M.private.mismatchFormattingMapping = mismatchFormattingMapping
  701. local function mismatchFormattingPureList( table_a, table_b, margin )
  702. --[[
  703. Prepares a nice error message when comparing tables which are lists, performing a deeper
  704. analysis.
  705. margin is supplied only for almost equality
  706. Returns: {success, result}
  707. * success: false if deep analysis could not be performed
  708. in this case, just use standard assertion message
  709. * result: if success is true, a multi-line string with deep analysis of the two lists
  710. ]]
  711. local result, descrTa, descrTb = {}, getTaTbDescr()
  712. local len_a, len_b, refa, refb = #table_a, #table_b, '', ''
  713. if M.PRINT_TABLE_REF_IN_ERROR_MSG then
  714. refa, refb = string.format( '<%s> ', M.private.table_ref(table_a)), string.format('<%s> ', M.private.table_ref(table_b) )
  715. end
  716. local longest, shortest = math.max(len_a, len_b), math.min(len_a, len_b)
  717. local deltalv = longest - shortest
  718. local commonUntil = shortest
  719. for i = 1, shortest do
  720. if not M.private.is_table_equals(table_a[i], table_b[i], margin) then
  721. commonUntil = i - 1
  722. break
  723. end
  724. end
  725. local commonBackTo = shortest - 1
  726. for i = 0, shortest - 1 do
  727. if not M.private.is_table_equals(table_a[len_a-i], table_b[len_b-i], margin) then
  728. commonBackTo = i - 1
  729. break
  730. end
  731. end
  732. table.insert( result, 'List difference analysis:' )
  733. if len_a == len_b then
  734. -- TODO: handle expected/actual naming
  735. extendWithStrFmt( result, '* lists %sA (%s) and %sB (%s) have the same size', refa, descrTa, refb, descrTb )
  736. else
  737. extendWithStrFmt( result, '* list sizes differ: list %sA (%s) has %d items, list %sB (%s) has %d items', refa, descrTa, len_a, refb, descrTb, len_b )
  738. end
  739. extendWithStrFmt( result, '* lists A and B start differing at index %d', commonUntil+1 )
  740. if commonBackTo >= 0 then
  741. if deltalv > 0 then
  742. extendWithStrFmt( result, '* lists A and B are equal again from index %d for A, %d for B', len_a-commonBackTo, len_b-commonBackTo )
  743. else
  744. extendWithStrFmt( result, '* lists A and B are equal again from index %d', len_a-commonBackTo )
  745. end
  746. end
  747. local function insertABValue(ai, bi)
  748. bi = bi or ai
  749. if M.private.is_table_equals( table_a[ai], table_b[bi], margin) then
  750. return extendWithStrFmt( result, ' = A[%d], B[%d]: %s', ai, bi, prettystr(table_a[ai]) )
  751. else
  752. extendWithStrFmt( result, ' - A[%d]: %s', ai, prettystr(table_a[ai]))
  753. extendWithStrFmt( result, ' + B[%d]: %s', bi, prettystr(table_b[bi]))
  754. end
  755. end
  756. -- common parts to list A & B, at the beginning
  757. if commonUntil > 0 then
  758. table.insert( result, '* Common parts:' )
  759. for i = 1, commonUntil do
  760. insertABValue( i )
  761. end
  762. end
  763. -- diffing parts to list A & B
  764. if commonUntil < shortest - commonBackTo - 1 then
  765. table.insert( result, '* Differing parts:' )
  766. for i = commonUntil + 1, shortest - commonBackTo - 1 do
  767. insertABValue( i )
  768. end
  769. end
  770. -- display indexes of one list, with no match on other list
  771. if shortest - commonBackTo <= longest - commonBackTo - 1 then
  772. table.insert( result, '* Present only in one list:' )
  773. for i = shortest - commonBackTo, longest - commonBackTo - 1 do
  774. if len_a > len_b then
  775. extendWithStrFmt( result, ' - A[%d]: %s', i, prettystr(table_a[i]) )
  776. -- table.insert( result, '+ (no matching B index)')
  777. else
  778. -- table.insert( result, '- no matching A index')
  779. extendWithStrFmt( result, ' + B[%d]: %s', i, prettystr(table_b[i]) )
  780. end
  781. end
  782. end
  783. -- common parts to list A & B, at the end
  784. if commonBackTo >= 0 then
  785. table.insert( result, '* Common parts at the end of the lists' )
  786. for i = longest - commonBackTo, longest do
  787. if len_a > len_b then
  788. insertABValue( i, i-deltalv )
  789. else
  790. insertABValue( i-deltalv, i )
  791. end
  792. end
  793. end
  794. return true, table.concat( result, '\n')
  795. end
  796. M.private.mismatchFormattingPureList = mismatchFormattingPureList
  797. local function prettystrPairs(value1, value2, suffix_a, suffix_b)
  798. --[[
  799. This function helps with the recurring task of constructing the "expected
  800. vs. actual" error messages. It takes two arbitrary values and formats
  801. corresponding strings with prettystr().
  802. To keep the (possibly complex) output more readable in case the resulting
  803. strings contain line breaks, they get automatically prefixed with additional
  804. newlines. Both suffixes are optional (default to empty strings), and get
  805. appended to the "value1" string. "suffix_a" is used if line breaks were
  806. encountered, "suffix_b" otherwise.
  807. Returns the two formatted strings (including padding/newlines).
  808. ]]
  809. local str1, str2 = prettystr(value1), prettystr(value2)
  810. if hasNewLine(str1) or hasNewLine(str2) then
  811. -- line break(s) detected, add padding
  812. return "\n" .. str1 .. (suffix_a or ""), "\n" .. str2
  813. end
  814. return str1 .. (suffix_b or ""), str2
  815. end
  816. M.private.prettystrPairs = prettystrPairs
  817. local UNKNOWN_REF = 'table 00-unknown ref'
  818. local ref_generator = { value=1, [UNKNOWN_REF]=0 }
  819. local function table_ref( t )
  820. -- return the default tostring() for tables, with the table ID, even if the table has a metatable
  821. -- with the __tostring converter
  822. local ref = ''
  823. local mt = getmetatable( t )
  824. if mt == nil then
  825. ref = tostring(t)
  826. else
  827. local success, result
  828. success, result = pcall(setmetatable, t, nil)
  829. if not success then
  830. -- protected table, if __tostring is defined, we can
  831. -- not get the reference. And we can not know in advance.
  832. ref = tostring(t)
  833. if not ref:match( 'table: 0?x?[%x]+' ) then
  834. return UNKNOWN_REF
  835. end
  836. else
  837. ref = tostring(t)
  838. setmetatable( t, mt )
  839. end
  840. end
  841. -- strip the "table: " part
  842. ref = ref:sub(8)
  843. if ref ~= UNKNOWN_REF and ref_generator[ref] == nil then
  844. -- Create a new reference number
  845. ref_generator[ref] = ref_generator.value
  846. ref_generator.value = ref_generator.value+1
  847. end
  848. if M.PRINT_TABLE_REF_IN_ERROR_MSG then
  849. return string.format('table %02d-%s', ref_generator[ref], ref)
  850. else
  851. return string.format('table %02d', ref_generator[ref])
  852. end
  853. end
  854. M.private.table_ref = table_ref
  855. local TABLE_TOSTRING_SEP = ", "
  856. local TABLE_TOSTRING_SEP_LEN = string.len(TABLE_TOSTRING_SEP)
  857. local function _table_tostring( tbl, indentLevel, printTableRefs, cycleDetectTable )
  858. printTableRefs = printTableRefs or M.PRINT_TABLE_REF_IN_ERROR_MSG
  859. cycleDetectTable = cycleDetectTable or {}
  860. cycleDetectTable[tbl] = true
  861. local result, dispOnMultLines = {}, false
  862. -- like prettystr but do not enclose with "" if the string is just alphanumerical
  863. -- this is better for displaying table keys who are often simple strings
  864. local function keytostring(k)
  865. if "string" == type(k) and k:match("^[_%a][_%w]*$") then
  866. return k
  867. end
  868. return prettystr_sub(k, indentLevel+1, printTableRefs, cycleDetectTable)
  869. end
  870. local mt = getmetatable( tbl )
  871. if mt and mt.__tostring then
  872. -- if table has a __tostring() function in its metatable, use it to display the table
  873. -- else, compute a regular table
  874. result = tostring(tbl)
  875. if type(result) ~= 'string' then
  876. return string.format( '<invalid tostring() result: "%s" >', prettystr(result) )
  877. end
  878. result = strsplit( '\n', result )
  879. return M.private._table_tostring_format_multiline_string( result, indentLevel )
  880. else
  881. -- no metatable, compute the table representation
  882. local entry, count, seq_index = nil, 0, 1
  883. for k, v in sortedPairs( tbl ) do
  884. -- key part
  885. if k == seq_index then
  886. -- for the sequential part of tables, we'll skip the "<key>=" output
  887. entry = ''
  888. seq_index = seq_index + 1
  889. elseif cycleDetectTable[k] then
  890. -- recursion in the key detected
  891. cycleDetectTable.detected = true
  892. entry = "<"..table_ref(k)..">="
  893. else
  894. entry = keytostring(k) .. "="
  895. end
  896. -- value part
  897. if cycleDetectTable[v] then
  898. -- recursion in the value detected!
  899. cycleDetectTable.detected = true
  900. entry = entry .. "<"..table_ref(v)..">"
  901. else
  902. entry = entry ..
  903. prettystr_sub( v, indentLevel+1, printTableRefs, cycleDetectTable )
  904. end
  905. count = count + 1
  906. result[count] = entry
  907. end
  908. return M.private._table_tostring_format_result( tbl, result, indentLevel, printTableRefs )
  909. end
  910. end
  911. M.private._table_tostring = _table_tostring -- prettystr_sub() needs it
  912. local function _table_tostring_format_multiline_string( tbl_str, indentLevel )
  913. local indentString = '\n'..string.rep(" ", indentLevel - 1)
  914. return table.concat( tbl_str, indentString )
  915. end
  916. M.private._table_tostring_format_multiline_string = _table_tostring_format_multiline_string
  917. local function _table_tostring_format_result( tbl, result, indentLevel, printTableRefs )
  918. -- final function called in _table_to_string() to format the resulting list of
  919. -- string describing the table.
  920. local dispOnMultLines = false
  921. -- set dispOnMultLines to true if the maximum LINE_LENGTH would be exceeded with the values
  922. local totalLength = 0
  923. for k, v in ipairs( result ) do
  924. totalLength = totalLength + string.len( v )
  925. if totalLength >= M.LINE_LENGTH then
  926. dispOnMultLines = true
  927. break
  928. end
  929. end
  930. -- set dispOnMultLines to true if the max LINE_LENGTH would be exceeded
  931. -- with the values and the separators.
  932. if not dispOnMultLines then
  933. -- adjust with length of separator(s):
  934. -- two items need 1 sep, three items two seps, ... plus len of '{}'
  935. if #result > 0 then
  936. totalLength = totalLength + TABLE_TOSTRING_SEP_LEN * (#result - 1)
  937. end
  938. dispOnMultLines = (totalLength + 2 >= M.LINE_LENGTH)
  939. end
  940. -- now reformat the result table (currently holding element strings)
  941. if dispOnMultLines then
  942. local indentString = string.rep(" ", indentLevel - 1)
  943. result = {
  944. "{\n ",
  945. indentString,
  946. table.concat(result, ",\n " .. indentString),
  947. "\n",
  948. indentString,
  949. "}"
  950. }
  951. else
  952. result = {"{", table.concat(result, TABLE_TOSTRING_SEP), "}"}
  953. end
  954. if printTableRefs then
  955. table.insert(result, 1, "<"..table_ref(tbl).."> ") -- prepend table ref
  956. end
  957. return table.concat(result)
  958. end
  959. M.private._table_tostring_format_result = _table_tostring_format_result -- prettystr_sub() needs it
  960. local function table_findkeyof(t, element)
  961. -- Return the key k of the given element in table t, so that t[k] == element
  962. -- (or `nil` if element is not present within t). Note that we use our
  963. -- 'general' is_equal comparison for matching, so this function should
  964. -- handle table-type elements gracefully and consistently.
  965. if type(t) == "table" then
  966. for k, v in pairs(t) do
  967. if M.private.is_table_equals(v, element) then
  968. return k
  969. end
  970. end
  971. end
  972. return nil
  973. end
  974. local function _is_table_items_equals(actual, expected )
  975. local type_a, type_e = type(actual), type(expected)
  976. if type_a ~= type_e then
  977. return false
  978. elseif (type_a == 'table') --[[and (type_e == 'table')]] then
  979. for k, v in pairs(actual) do
  980. if table_findkeyof(expected, v) == nil then
  981. return false -- v not contained in expected
  982. end
  983. end
  984. for k, v in pairs(expected) do
  985. if table_findkeyof(actual, v) == nil then
  986. return false -- v not contained in actual
  987. end
  988. end
  989. return true
  990. elseif actual ~= expected then
  991. return false
  992. end
  993. return true
  994. end
  995. --[[
  996. This is a specialized metatable to help with the bookkeeping of recursions
  997. in _is_table_equals(). It provides an __index table that implements utility
  998. functions for easier management of the table. The "cached" method queries
  999. the state of a specific (actual,expected) pair; and the "store" method sets
  1000. this state to the given value. The state of pairs not "seen" / visited is
  1001. assumed to be `nil`.
  1002. ]]
  1003. local _recursion_cache_MT = {
  1004. __index = {
  1005. -- Return the cached value for an (actual,expected) pair (or `nil`)
  1006. cached = function(t, actual, expected)
  1007. local subtable = t[actual] or {}
  1008. return subtable[expected]
  1009. end,
  1010. -- Store cached value for a specific (actual,expected) pair.
  1011. -- Returns the value, so it's easy to use for a "tailcall" (return ...).
  1012. store = function(t, actual, expected, value, asymmetric)
  1013. local subtable = t[actual]
  1014. if not subtable then
  1015. subtable = {}
  1016. t[actual] = subtable
  1017. end
  1018. subtable[expected] = value
  1019. -- Unless explicitly marked "asymmetric": Consider the recursion
  1020. -- on (expected,actual) to be equivalent to (actual,expected) by
  1021. -- default, and thus cache the value for both.
  1022. if not asymmetric then
  1023. t:store(expected, actual, value, true)
  1024. end
  1025. return value
  1026. end
  1027. }
  1028. }
  1029. local function _is_table_equals(actual, expected, cycleDetectTable, marginForAlmostEqual)
  1030. --[[Returns true if both table are equal.
  1031. If argument marginForAlmostEqual is suppied, number comparison is done using alomstEqual instead
  1032. of strict equality.
  1033. cycleDetectTable is an internal argument used during recursion on tables.
  1034. ]]
  1035. --print('_is_table_equals( \n '..prettystr(actual)..'\n , '..prettystr(expected)..
  1036. -- '\n , '..prettystr(cycleDetectTable)..'\n , '..prettystr(marginForAlmostEqual)..' )')
  1037. local type_a, type_e = type(actual), type(expected)
  1038. if type_a ~= type_e then
  1039. return false -- different types won't match
  1040. end
  1041. if type_a == 'number' then
  1042. if marginForAlmostEqual ~= nil then
  1043. return M.almostEquals(actual, expected, marginForAlmostEqual)
  1044. else
  1045. return actual == expected
  1046. end
  1047. elseif type_a ~= 'table' then
  1048. -- other types compare directly
  1049. return actual == expected
  1050. end
  1051. cycleDetectTable = cycleDetectTable or { actual={}, expected={} }
  1052. if cycleDetectTable.actual[ actual ] then
  1053. -- oh, we hit a cycle in actual
  1054. if cycleDetectTable.expected[ expected ] then
  1055. -- uh, we hit a cycle at the same time in expected
  1056. -- so the two tables have similar structure
  1057. return true
  1058. end
  1059. -- cycle was hit only in actual, the structure differs from expected
  1060. return false
  1061. end
  1062. if cycleDetectTable.expected[ expected ] then
  1063. -- no cycle in actual, but cycle in expected
  1064. -- the structure differ
  1065. return false
  1066. end
  1067. -- at this point, no table cycle detected, we are
  1068. -- seeing this table for the first time
  1069. -- mark the cycle detection
  1070. cycleDetectTable.actual[ actual ] = true
  1071. cycleDetectTable.expected[ expected ] = true
  1072. local actualKeysMatched = {}
  1073. for k, v in pairs(actual) do
  1074. actualKeysMatched[k] = true -- Keep track of matched keys
  1075. if not _is_table_equals(v, expected[k], cycleDetectTable, marginForAlmostEqual) then
  1076. -- table differs on this key
  1077. -- clear the cycle detection before returning
  1078. cycleDetectTable.actual[ actual ] = nil
  1079. cycleDetectTable.expected[ expected ] = nil
  1080. return false
  1081. end
  1082. end
  1083. for k, v in pairs(expected) do
  1084. if not actualKeysMatched[k] then
  1085. -- Found a key that we did not see in "actual" -> mismatch
  1086. -- clear the cycle detection before returning
  1087. cycleDetectTable.actual[ actual ] = nil
  1088. cycleDetectTable.expected[ expected ] = nil
  1089. return false
  1090. end
  1091. -- Otherwise actual[k] was already matched against v = expected[k].
  1092. end
  1093. -- all key match, we have a match !
  1094. cycleDetectTable.actual[ actual ] = nil
  1095. cycleDetectTable.expected[ expected ] = nil
  1096. return true
  1097. end
  1098. M.private._is_table_equals = _is_table_equals
  1099. local function failure(main_msg, extra_msg_or_nil, level)
  1100. -- raise an error indicating a test failure
  1101. -- for error() compatibility we adjust "level" here (by +1), to report the
  1102. -- calling context
  1103. local msg
  1104. if type(extra_msg_or_nil) == 'string' and extra_msg_or_nil:len() > 0 then
  1105. msg = extra_msg_or_nil .. '\n' .. main_msg
  1106. else
  1107. msg = main_msg
  1108. end
  1109. error(M.FAILURE_PREFIX .. msg, (level or 1) + 1 + M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE)
  1110. end
  1111. local function is_table_equals(actual, expected, marginForAlmostEqual)
  1112. return _is_table_equals(actual, expected, nil, marginForAlmostEqual)
  1113. end
  1114. M.private.is_table_equals = is_table_equals
  1115. local function fail_fmt(level, extra_msg_or_nil, ...)
  1116. -- failure with printf-style formatted message and given error level
  1117. failure(string.format(...), extra_msg_or_nil, (level or 1) + 1)
  1118. end
  1119. M.private.fail_fmt = fail_fmt
  1120. local function error_fmt(level, ...)
  1121. -- printf-style error()
  1122. error(string.format(...), (level or 1) + 1 + M.STRIP_EXTRA_ENTRIES_IN_STACK_TRACE)
  1123. end
  1124. M.private.error_fmt = error_fmt
  1125. ----------------------------------------------------------------
  1126. --
  1127. -- assertions
  1128. --
  1129. ----------------------------------------------------------------
  1130. local function errorMsgEquality(actual, expected, doDeepAnalysis, margin)
  1131. -- margin is supplied only for almost equal verification
  1132. if not M.ORDER_ACTUAL_EXPECTED then
  1133. expected, actual = actual, expected
  1134. end
  1135. if type(expected) == 'string' or type(expected) == 'table' then
  1136. local strExpected, strActual = prettystrPairs(expected, actual)
  1137. local result = string.format("expected: %s\nactual: %s", strExpected, strActual)
  1138. if margin then
  1139. result = result .. '\nwere not equal by the margin of: '..prettystr(margin)
  1140. end
  1141. -- extend with mismatch analysis if possible:
  1142. local success, mismatchResult
  1143. success, mismatchResult = tryMismatchFormatting( actual, expected, doDeepAnalysis, margin )
  1144. if success then
  1145. result = table.concat( { result, mismatchResult }, '\n' )
  1146. end
  1147. return result
  1148. end
  1149. return string.format("expected: %s, actual: %s",
  1150. prettystr(expected), prettystr(actual))
  1151. end
  1152. function M.assertError(f, ...)
  1153. -- assert that calling f with the arguments will raise an error
  1154. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1155. if pcall( f, ... ) then
  1156. failure( "Expected an error when calling function but no error generated", nil, 2 )
  1157. end
  1158. end
  1159. function M.fail( msg )
  1160. -- stops a test due to a failure
  1161. failure( msg, nil, 2 )
  1162. end
  1163. function M.failIf( cond, msg )
  1164. -- Fails a test with "msg" if condition is true
  1165. if cond then
  1166. failure( msg, nil, 2 )
  1167. end
  1168. end
  1169. function M.skip(msg)
  1170. -- skip a running test
  1171. error_fmt(2, M.SKIP_PREFIX .. msg)
  1172. end
  1173. function M.skipIf( cond, msg )
  1174. -- skip a running test if condition is met
  1175. if cond then
  1176. error_fmt(2, M.SKIP_PREFIX .. msg)
  1177. end
  1178. end
  1179. function M.runOnlyIf( cond, msg )
  1180. -- continue a running test if condition is met, else skip it
  1181. if not cond then
  1182. error_fmt(2, M.SKIP_PREFIX .. prettystr(msg))
  1183. end
  1184. end
  1185. function M.success()
  1186. -- stops a test with a success
  1187. error_fmt(2, M.SUCCESS_PREFIX)
  1188. end
  1189. function M.successIf( cond )
  1190. -- stops a test with a success if condition is met
  1191. if cond then
  1192. error_fmt(2, M.SUCCESS_PREFIX)
  1193. end
  1194. end
  1195. ------------------------------------------------------------------
  1196. -- Equality assertions
  1197. ------------------------------------------------------------------
  1198. function M.assertEquals(actual, expected, extra_msg_or_nil, doDeepAnalysis)
  1199. if type(actual) == 'table' and type(expected) == 'table' then
  1200. if not is_table_equals(actual, expected) then
  1201. failure( errorMsgEquality(actual, expected, doDeepAnalysis), extra_msg_or_nil, 2 )
  1202. end
  1203. elseif type(actual) ~= type(expected) then
  1204. failure( errorMsgEquality(actual, expected), extra_msg_or_nil, 2 )
  1205. elseif actual ~= expected then
  1206. failure( errorMsgEquality(actual, expected), extra_msg_or_nil, 2 )
  1207. end
  1208. end
  1209. function M.almostEquals( actual, expected, margin )
  1210. if type(actual) ~= 'number' or type(expected) ~= 'number' or type(margin) ~= 'number' then
  1211. error_fmt(3, 'almostEquals: must supply only number arguments.\nArguments supplied: %s, %s, %s',
  1212. prettystr(actual), prettystr(expected), prettystr(margin))
  1213. end
  1214. if margin < 0 then
  1215. error_fmt(3, 'almostEquals: margin must not be negative, current value is ' .. margin)
  1216. end
  1217. return math.abs(expected - actual) <= margin
  1218. end
  1219. function M.assertAlmostEquals( actual, expected, margin, extra_msg_or_nil )
  1220. -- check that two floats are close by margin
  1221. margin = margin or M.EPS
  1222. if type(margin) ~= 'number' then
  1223. error_fmt(2, 'almostEquals: margin must be a number, not %s', prettystr(margin))
  1224. end
  1225. if type(actual) == 'table' and type(expected) == 'table' then
  1226. -- handle almost equals for table
  1227. if not is_table_equals(actual, expected, margin) then
  1228. failure( errorMsgEquality(actual, expected, nil, margin), extra_msg_or_nil, 2 )
  1229. end
  1230. elseif type(actual) == 'number' and type(expected) == 'number' and type(margin) == 'number' then
  1231. if not M.almostEquals(actual, expected, margin) then
  1232. if not M.ORDER_ACTUAL_EXPECTED then
  1233. expected, actual = actual, expected
  1234. end
  1235. local delta = math.abs(actual - expected)
  1236. fail_fmt(2, extra_msg_or_nil, 'Values are not almost equal\n' ..
  1237. 'Actual: %s, expected: %s, delta %s above margin of %s',
  1238. actual, expected, delta, margin)
  1239. end
  1240. else
  1241. error_fmt(3, 'almostEquals: must supply only number or table arguments.\nArguments supplied: %s, %s, %s',
  1242. prettystr(actual), prettystr(expected), prettystr(margin))
  1243. end
  1244. end
  1245. function M.assertNotEquals(actual, expected, extra_msg_or_nil)
  1246. if type(actual) ~= type(expected) then
  1247. return
  1248. end
  1249. if type(actual) == 'table' and type(expected) == 'table' then
  1250. if not is_table_equals(actual, expected) then
  1251. return
  1252. end
  1253. elseif actual ~= expected then
  1254. return
  1255. end
  1256. fail_fmt(2, extra_msg_or_nil, 'Received the not expected value: %s', prettystr(actual))
  1257. end
  1258. function M.assertNotAlmostEquals( actual, expected, margin, extra_msg_or_nil )
  1259. -- check that two floats are not close by margin
  1260. margin = margin or M.EPS
  1261. if M.almostEquals(actual, expected, margin) then
  1262. if not M.ORDER_ACTUAL_EXPECTED then
  1263. expected, actual = actual, expected
  1264. end
  1265. local delta = math.abs(actual - expected)
  1266. fail_fmt(2, extra_msg_or_nil, 'Values are almost equal\nActual: %s, expected: %s' ..
  1267. ', delta %s below margin of %s',
  1268. actual, expected, delta, margin)
  1269. end
  1270. end
  1271. function M.assertItemsEquals(actual, expected, extra_msg_or_nil)
  1272. -- checks that the items of table expected
  1273. -- are contained in table actual. Warning, this function
  1274. -- is at least O(n^2)
  1275. if not _is_table_items_equals(actual, expected ) then
  1276. expected, actual = prettystrPairs(expected, actual)
  1277. fail_fmt(2, extra_msg_or_nil, 'Content of the tables are not identical:\nExpected: %s\nActual: %s',
  1278. expected, actual)
  1279. end
  1280. end
  1281. ------------------------------------------------------------------
  1282. -- String assertion
  1283. ------------------------------------------------------------------
  1284. function M.assertStrContains( str, sub, isPattern, extra_msg_or_nil )
  1285. -- this relies on lua string.find function
  1286. -- a string always contains the empty string
  1287. -- assert( type(str) == 'string', 'Argument 1 of assertStrContains() should be a string.' ) )
  1288. -- assert( type(sub) == 'string', 'Argument 2 of assertStrContains() should be a string.' ) )
  1289. if not string.find(str, sub, 1, not isPattern) then
  1290. sub, str = prettystrPairs(sub, str, '\n')
  1291. fail_fmt(2, extra_msg_or_nil, 'Could not find %s %s in string %s',
  1292. isPattern and 'pattern' or 'substring', sub, str)
  1293. end
  1294. end
  1295. function M.assertStrIContains( str, sub, extra_msg_or_nil )
  1296. -- this relies on lua string.find function
  1297. -- a string always contains the empty string
  1298. if not string.find(str:lower(), sub:lower(), 1, true) then
  1299. sub, str = prettystrPairs(sub, str, '\n')
  1300. fail_fmt(2, extra_msg_or_nil, 'Could not find (case insensitively) substring %s in string %s',
  1301. sub, str)
  1302. end
  1303. end
  1304. function M.assertNotStrContains( str, sub, isPattern, extra_msg_or_nil )
  1305. -- this relies on lua string.find function
  1306. -- a string always contains the empty string
  1307. if string.find(str, sub, 1, not isPattern) then
  1308. sub, str = prettystrPairs(sub, str, '\n')
  1309. fail_fmt(2, extra_msg_or_nil, 'Found the not expected %s %s in string %s',
  1310. isPattern and 'pattern' or 'substring', sub, str)
  1311. end
  1312. end
  1313. function M.assertNotStrIContains( str, sub, extra_msg_or_nil )
  1314. -- this relies on lua string.find function
  1315. -- a string always contains the empty string
  1316. if string.find(str:lower(), sub:lower(), 1, true) then
  1317. sub, str = prettystrPairs(sub, str, '\n')
  1318. fail_fmt(2, extra_msg_or_nil, 'Found (case insensitively) the not expected substring %s in string %s',
  1319. sub, str)
  1320. end
  1321. end
  1322. function M.assertStrMatches( str, pattern, start, final, extra_msg_or_nil )
  1323. -- Verify a full match for the string
  1324. if not strMatch( str, pattern, start, final ) then
  1325. pattern, str = prettystrPairs(pattern, str, '\n')
  1326. fail_fmt(2, extra_msg_or_nil, 'Could not match pattern %s with string %s',
  1327. pattern, str)
  1328. end
  1329. end
  1330. local function _assertErrorMsgEquals( stripFileAndLine, expectedMsg, func, ... )
  1331. local no_error, error_msg = pcall( func, ... )
  1332. if no_error then
  1333. failure( 'No error generated when calling function but expected error: '..M.prettystr(expectedMsg), nil, 3 )
  1334. end
  1335. if type(expectedMsg) == "string" and type(error_msg) ~= "string" then
  1336. -- table are converted to string automatically
  1337. error_msg = tostring(error_msg)
  1338. end
  1339. local differ = false
  1340. if stripFileAndLine then
  1341. if error_msg:gsub("^.+:%d+: ", "") ~= expectedMsg then
  1342. differ = true
  1343. end
  1344. else
  1345. if error_msg ~= expectedMsg then
  1346. local tr = type(error_msg)
  1347. local te = type(expectedMsg)
  1348. if te == 'table' then
  1349. if tr ~= 'table' then
  1350. differ = true
  1351. else
  1352. local ok = pcall(M.assertItemsEquals, error_msg, expectedMsg)
  1353. if not ok then
  1354. differ = true
  1355. end
  1356. end
  1357. else
  1358. differ = true
  1359. end
  1360. end
  1361. end
  1362. if differ then
  1363. error_msg, expectedMsg = prettystrPairs(error_msg, expectedMsg)
  1364. fail_fmt(3, nil, 'Error message expected: %s\nError message received: %s\n',
  1365. expectedMsg, error_msg)
  1366. end
  1367. end
  1368. function M.assertErrorMsgEquals( expectedMsg, func, ... )
  1369. -- assert that calling f with the arguments will raise an error
  1370. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1371. _assertErrorMsgEquals(false, expectedMsg, func, ...)
  1372. end
  1373. function M.assertErrorMsgContentEquals(expectedMsg, func, ...)
  1374. _assertErrorMsgEquals(true, expectedMsg, func, ...)
  1375. end
  1376. function M.assertErrorMsgContains( partialMsg, func, ... )
  1377. -- assert that calling f with the arguments will raise an error
  1378. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1379. local no_error, error_msg = pcall( func, ... )
  1380. if no_error then
  1381. failure( 'No error generated when calling function but expected error containing: '..prettystr(partialMsg), nil, 2 )
  1382. end
  1383. if type(error_msg) ~= "string" then
  1384. error_msg = tostring(error_msg)
  1385. end
  1386. if not string.find( error_msg, partialMsg, nil, true ) then
  1387. error_msg, partialMsg = prettystrPairs(error_msg, partialMsg)
  1388. fail_fmt(2, nil, 'Error message does not contain: %s\nError message received: %s\n',
  1389. partialMsg, error_msg)
  1390. end
  1391. end
  1392. function M.assertErrorMsgMatches( expectedMsg, func, ... )
  1393. -- assert that calling f with the arguments will raise an error
  1394. -- example: assertError( f, 1, 2 ) => f(1,2) should generate an error
  1395. local no_error, error_msg = pcall( func, ... )
  1396. if no_error then
  1397. failure( 'No error generated when calling function but expected error matching: "'..expectedMsg..'"', nil, 2 )
  1398. end
  1399. if type(error_msg) ~= "string" then
  1400. error_msg = tostring(error_msg)
  1401. end
  1402. if not strMatch( error_msg, expectedMsg ) then
  1403. expectedMsg, error_msg = prettystrPairs(expectedMsg, error_msg)
  1404. fail_fmt(2, nil, 'Error message does not match pattern: %s\nError message received: %s\n',
  1405. expectedMsg, error_msg)
  1406. end
  1407. end
  1408. ------------------------------------------------------------------
  1409. -- Type assertions
  1410. ------------------------------------------------------------------
  1411. function M.assertEvalToTrue(value, extra_msg_or_nil)
  1412. if not value then
  1413. failure("expected: a value evaluating to true, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1414. end
  1415. end
  1416. function M.assertEvalToFalse(value, extra_msg_or_nil)
  1417. if value then
  1418. failure("expected: false or nil, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1419. end
  1420. end
  1421. function M.assertIsTrue(value, extra_msg_or_nil)
  1422. if value ~= true then
  1423. failure("expected: true, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1424. end
  1425. end
  1426. function M.assertNotIsTrue(value, extra_msg_or_nil)
  1427. if value == true then
  1428. failure("expected: not true, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1429. end
  1430. end
  1431. function M.assertIsFalse(value, extra_msg_or_nil)
  1432. if value ~= false then
  1433. failure("expected: false, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1434. end
  1435. end
  1436. function M.assertNotIsFalse(value, extra_msg_or_nil)
  1437. if value == false then
  1438. failure("expected: not false, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1439. end
  1440. end
  1441. function M.assertIsNil(value, extra_msg_or_nil)
  1442. if value ~= nil then
  1443. failure("expected: nil, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1444. end
  1445. end
  1446. function M.assertNotIsNil(value, extra_msg_or_nil)
  1447. if value == nil then
  1448. failure("expected: not nil, actual: nil", extra_msg_or_nil, 2)
  1449. end
  1450. end
  1451. --[[
  1452. Add type assertion functions to the module table M. Each of these functions
  1453. takes a single parameter "value", and checks that its Lua type matches the
  1454. expected string (derived from the function name):
  1455. M.assertIsXxx(value) -> ensure that type(value) conforms to "xxx"
  1456. ]]
  1457. for _, funcName in ipairs(
  1458. {'assertIsNumber', 'assertIsString', 'assertIsTable', 'assertIsBoolean',
  1459. 'assertIsFunction', 'assertIsUserdata', 'assertIsThread'}
  1460. ) do
  1461. local typeExpected = funcName:match("^assertIs([A-Z]%a*)$")
  1462. -- Lua type() always returns lowercase, also make sure the match() succeeded
  1463. typeExpected = typeExpected and typeExpected:lower()
  1464. or error("bad function name '"..funcName.."' for type assertion")
  1465. M[funcName] = function(value, extra_msg_or_nil)
  1466. if type(value) ~= typeExpected then
  1467. if type(value) == 'nil' then
  1468. fail_fmt(2, extra_msg_or_nil, 'expected: a %s value, actual: nil',
  1469. typeExpected, type(value), prettystrPairs(value))
  1470. else
  1471. fail_fmt(2, extra_msg_or_nil, 'expected: a %s value, actual: type %s, value %s',
  1472. typeExpected, type(value), prettystrPairs(value))
  1473. end
  1474. end
  1475. end
  1476. end
  1477. --[[
  1478. Add shortcuts for verifying type of a variable, without failure (luaunit v2 compatibility)
  1479. M.isXxx(value) -> returns true if type(value) conforms to "xxx"
  1480. ]]
  1481. for _, typeExpected in ipairs(
  1482. {'Number', 'String', 'Table', 'Boolean',
  1483. 'Function', 'Userdata', 'Thread', 'Nil' }
  1484. ) do
  1485. local typeExpectedLower = typeExpected:lower()
  1486. local isType = function(value)
  1487. return (type(value) == typeExpectedLower)
  1488. end
  1489. M['is'..typeExpected] = isType
  1490. M['is_'..typeExpectedLower] = isType
  1491. end
  1492. --[[
  1493. Add non-type assertion functions to the module table M. Each of these functions
  1494. takes a single parameter "value", and checks that its Lua type differs from the
  1495. expected string (derived from the function name):
  1496. M.assertNotIsXxx(value) -> ensure that type(value) is not "xxx"
  1497. ]]
  1498. for _, funcName in ipairs(
  1499. {'assertNotIsNumber', 'assertNotIsString', 'assertNotIsTable', 'assertNotIsBoolean',
  1500. 'assertNotIsFunction', 'assertNotIsUserdata', 'assertNotIsThread'}
  1501. ) do
  1502. local typeUnexpected = funcName:match("^assertNotIs([A-Z]%a*)$")
  1503. -- Lua type() always returns lowercase, also make sure the match() succeeded
  1504. typeUnexpected = typeUnexpected and typeUnexpected:lower()
  1505. or error("bad function name '"..funcName.."' for type assertion")
  1506. M[funcName] = function(value, extra_msg_or_nil)
  1507. if type(value) == typeUnexpected then
  1508. fail_fmt(2, extra_msg_or_nil, 'expected: not a %s type, actual: value %s',
  1509. typeUnexpected, prettystrPairs(value))
  1510. end
  1511. end
  1512. end
  1513. function M.assertIs(actual, expected, extra_msg_or_nil)
  1514. if actual ~= expected then
  1515. if not M.ORDER_ACTUAL_EXPECTED then
  1516. actual, expected = expected, actual
  1517. end
  1518. local old_print_table_ref_in_error_msg = M.PRINT_TABLE_REF_IN_ERROR_MSG
  1519. M.PRINT_TABLE_REF_IN_ERROR_MSG = true
  1520. expected, actual = prettystrPairs(expected, actual, '\n', '')
  1521. M.PRINT_TABLE_REF_IN_ERROR_MSG = old_print_table_ref_in_error_msg
  1522. fail_fmt(2, extra_msg_or_nil, 'expected and actual object should not be different\nExpected: %s\nReceived: %s',
  1523. expected, actual)
  1524. end
  1525. end
  1526. function M.assertNotIs(actual, expected, extra_msg_or_nil)
  1527. if actual == expected then
  1528. local old_print_table_ref_in_error_msg = M.PRINT_TABLE_REF_IN_ERROR_MSG
  1529. M.PRINT_TABLE_REF_IN_ERROR_MSG = true
  1530. local s_expected
  1531. if not M.ORDER_ACTUAL_EXPECTED then
  1532. s_expected = prettystrPairs(actual)
  1533. else
  1534. s_expected = prettystrPairs(expected)
  1535. end
  1536. M.PRINT_TABLE_REF_IN_ERROR_MSG = old_print_table_ref_in_error_msg
  1537. fail_fmt(2, extra_msg_or_nil, 'expected and actual object should be different: %s', s_expected )
  1538. end
  1539. end
  1540. ------------------------------------------------------------------
  1541. -- Scientific assertions
  1542. ------------------------------------------------------------------
  1543. function M.assertIsNaN(value, extra_msg_or_nil)
  1544. if type(value) ~= "number" or value == value then
  1545. failure("expected: NaN, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1546. end
  1547. end
  1548. function M.assertNotIsNaN(value, extra_msg_or_nil)
  1549. if type(value) == "number" and value ~= value then
  1550. failure("expected: not NaN, actual: NaN", extra_msg_or_nil, 2)
  1551. end
  1552. end
  1553. function M.assertIsInf(value, extra_msg_or_nil)
  1554. if type(value) ~= "number" or math.abs(value) ~= math.huge then
  1555. failure("expected: #Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1556. end
  1557. end
  1558. function M.assertIsPlusInf(value, extra_msg_or_nil)
  1559. if type(value) ~= "number" or value ~= math.huge then
  1560. failure("expected: #Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1561. end
  1562. end
  1563. function M.assertIsMinusInf(value, extra_msg_or_nil)
  1564. if type(value) ~= "number" or value ~= -math.huge then
  1565. failure("expected: -#Inf, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1566. end
  1567. end
  1568. function M.assertNotIsPlusInf(value, extra_msg_or_nil)
  1569. if type(value) == "number" and value == math.huge then
  1570. failure("expected: not #Inf, actual: #Inf", extra_msg_or_nil, 2)
  1571. end
  1572. end
  1573. function M.assertNotIsMinusInf(value, extra_msg_or_nil)
  1574. if type(value) == "number" and value == -math.huge then
  1575. failure("expected: not -#Inf, actual: -#Inf", extra_msg_or_nil, 2)
  1576. end
  1577. end
  1578. function M.assertNotIsInf(value, extra_msg_or_nil)
  1579. if type(value) == "number" and math.abs(value) == math.huge then
  1580. failure("expected: not infinity, actual: " .. prettystr(value), extra_msg_or_nil, 2)
  1581. end
  1582. end
  1583. function M.assertIsPlusZero(value, extra_msg_or_nil)
  1584. if type(value) ~= 'number' or value ~= 0 then
  1585. failure("expected: +0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1586. else if (1/value == -math.huge) then
  1587. -- more precise error diagnosis
  1588. failure("expected: +0.0, actual: -0.0", extra_msg_or_nil, 2)
  1589. else if (1/value ~= math.huge) then
  1590. -- strange, case should have already been covered
  1591. failure("expected: +0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1592. end
  1593. end
  1594. end
  1595. end
  1596. function M.assertIsMinusZero(value, extra_msg_or_nil)
  1597. if type(value) ~= 'number' or value ~= 0 then
  1598. failure("expected: -0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1599. else if (1/value == math.huge) then
  1600. -- more precise error diagnosis
  1601. failure("expected: -0.0, actual: +0.0", extra_msg_or_nil, 2)
  1602. else if (1/value ~= -math.huge) then
  1603. -- strange, case should have already been covered
  1604. failure("expected: -0.0, actual: " ..prettystr(value), extra_msg_or_nil, 2)
  1605. end
  1606. end
  1607. end
  1608. end
  1609. function M.assertNotIsPlusZero(value, extra_msg_or_nil)
  1610. if type(value) == 'number' and (1/value == math.huge) then
  1611. failure("expected: not +0.0, actual: +0.0", extra_msg_or_nil, 2)
  1612. end
  1613. end
  1614. function M.assertNotIsMinusZero(value, extra_msg_or_nil)
  1615. if type(value) == 'number' and (1/value == -math.huge) then
  1616. failure("expected: not -0.0, actual: -0.0", extra_msg_or_nil, 2)
  1617. end
  1618. end
  1619. function M.assertTableContains(t, expected, extra_msg_or_nil)
  1620. -- checks that table t contains the expected element
  1621. if table_findkeyof(t, expected) == nil then
  1622. t, expected = prettystrPairs(t, expected)
  1623. fail_fmt(2, extra_msg_or_nil, 'Table %s does NOT contain the expected element %s',
  1624. t, expected)
  1625. end
  1626. end
  1627. function M.assertNotTableContains(t, expected, extra_msg_or_nil)
  1628. -- checks that table t doesn't contain the expected element
  1629. local k = table_findkeyof(t, expected)
  1630. if k ~= nil then
  1631. t, expected = prettystrPairs(t, expected)
  1632. fail_fmt(2, extra_msg_or_nil, 'Table %s DOES contain the unwanted element %s (at key %s)',
  1633. t, expected, prettystr(k))
  1634. end
  1635. end
  1636. ----------------------------------------------------------------
  1637. -- Compatibility layer
  1638. ----------------------------------------------------------------
  1639. -- for compatibility with LuaUnit v2.x
  1640. function M.wrapFunctions()
  1641. -- In LuaUnit version <= 2.1 , this function was necessary to include
  1642. -- a test function inside the global test suite. Nowadays, the functions
  1643. -- are simply run directly as part of the test discovery process.
  1644. -- so just do nothing !
  1645. io.stderr:write[[Use of WrapFunctions() is no longer needed.
  1646. Just prefix your test function names with "test" or "Test" and they
  1647. will be picked up and run by LuaUnit.
  1648. ]]
  1649. end
  1650. local list_of_funcs = {
  1651. -- { official function name , alias }
  1652. -- general assertions
  1653. { 'assertEquals' , 'assert_equals' },
  1654. { 'assertItemsEquals' , 'assert_items_equals' },
  1655. { 'assertNotEquals' , 'assert_not_equals' },
  1656. { 'assertAlmostEquals' , 'assert_almost_equals' },
  1657. { 'assertNotAlmostEquals' , 'assert_not_almost_equals' },
  1658. { 'assertEvalToTrue' , 'assert_eval_to_true' },
  1659. { 'assertEvalToFalse' , 'assert_eval_to_false' },
  1660. { 'assertStrContains' , 'assert_str_contains' },
  1661. { 'assertStrIContains' , 'assert_str_icontains' },
  1662. { 'assertNotStrContains' , 'assert_not_str_contains' },
  1663. { 'assertNotStrIContains' , 'assert_not_str_icontains' },
  1664. { 'assertStrMatches' , 'assert_str_matches' },
  1665. { 'assertError' , 'assert_error' },
  1666. { 'assertErrorMsgEquals' , 'assert_error_msg_equals' },
  1667. { 'assertErrorMsgContains' , 'assert_error_msg_contains' },
  1668. { 'assertErrorMsgMatches' , 'assert_error_msg_matches' },
  1669. { 'assertErrorMsgContentEquals', 'assert_error_msg_content_equals' },
  1670. { 'assertIs' , 'assert_is' },
  1671. { 'assertNotIs' , 'assert_not_is' },
  1672. { 'assertTableContains' , 'assert_table_contains' },
  1673. { 'assertNotTableContains' , 'assert_not_table_contains' },
  1674. { 'wrapFunctions' , 'WrapFunctions' },
  1675. { 'wrapFunctions' , 'wrap_functions' },
  1676. -- type assertions: assertIsXXX -> assert_is_xxx
  1677. { 'assertIsNumber' , 'assert_is_number' },
  1678. { 'assertIsString' , 'assert_is_string' },
  1679. { 'assertIsTable' , 'assert_is_table' },
  1680. { 'assertIsBoolean' , 'assert_is_boolean' },
  1681. { 'assertIsNil' , 'assert_is_nil' },
  1682. { 'assertIsTrue' , 'assert_is_true' },
  1683. { 'assertIsFalse' , 'assert_is_false' },
  1684. { 'assertIsNaN' , 'assert_is_nan' },
  1685. { 'assertIsInf' , 'assert_is_inf' },
  1686. { 'assertIsPlusInf' , 'assert_is_plus_inf' },
  1687. { 'assertIsMinusInf' , 'assert_is_minus_inf' },
  1688. { 'assertIsPlusZero' , 'assert_is_plus_zero' },
  1689. { 'assertIsMinusZero' , 'assert_is_minus_zero' },
  1690. { 'assertIsFunction' , 'assert_is_function' },
  1691. { 'assertIsThread' , 'assert_is_thread' },
  1692. { 'assertIsUserdata' , 'assert_is_userdata' },
  1693. -- type assertions: assertIsXXX -> assertXxx
  1694. { 'assertIsNumber' , 'assertNumber' },
  1695. { 'assertIsString' , 'assertString' },
  1696. { 'assertIsTable' , 'assertTable' },
  1697. { 'assertIsBoolean' , 'assertBoolean' },
  1698. { 'assertIsNil' , 'assertNil' },
  1699. { 'assertIsTrue' , 'assertTrue' },
  1700. { 'assertIsFalse' , 'assertFalse' },
  1701. { 'assertIsNaN' , 'assertNaN' },
  1702. { 'assertIsInf' , 'assertInf' },
  1703. { 'assertIsPlusInf' , 'assertPlusInf' },
  1704. { 'assertIsMinusInf' , 'assertMinusInf' },
  1705. { 'assertIsPlusZero' , 'assertPlusZero' },
  1706. { 'assertIsMinusZero' , 'assertMinusZero'},
  1707. { 'assertIsFunction' , 'assertFunction' },
  1708. { 'assertIsThread' , 'assertThread' },
  1709. { 'assertIsUserdata' , 'assertUserdata' },
  1710. -- type assertions: assertIsXXX -> assert_xxx (luaunit v2 compat)
  1711. { 'assertIsNumber' , 'assert_number' },
  1712. { 'assertIsString' , 'assert_string' },
  1713. { 'assertIsTable' , 'assert_table' },
  1714. { 'assertIsBoolean' , 'assert_boolean' },
  1715. { 'assertIsNil' , 'assert_nil' },
  1716. { 'assertIsTrue' , 'assert_true' },
  1717. { 'assertIsFalse' , 'assert_false' },
  1718. { 'assertIsNaN' , 'assert_nan' },
  1719. { 'assertIsInf' , 'assert_inf' },
  1720. { 'assertIsPlusInf' , 'assert_plus_inf' },
  1721. { 'assertIsMinusInf' , 'assert_minus_inf' },
  1722. { 'assertIsPlusZero' , 'assert_plus_zero' },
  1723. { 'assertIsMinusZero' , 'assert_minus_zero' },
  1724. { 'assertIsFunction' , 'assert_function' },
  1725. { 'assertIsThread' , 'assert_thread' },
  1726. { 'assertIsUserdata' , 'assert_userdata' },
  1727. -- type assertions: assertNotIsXXX -> assert_not_is_xxx
  1728. { 'assertNotIsNumber' , 'assert_not_is_number' },
  1729. { 'assertNotIsString' , 'assert_not_is_string' },
  1730. { 'assertNotIsTable' , 'assert_not_is_table' },
  1731. { 'assertNotIsBoolean' , 'assert_not_is_boolean' },
  1732. { 'assertNotIsNil' , 'assert_not_is_nil' },
  1733. { 'assertNotIsTrue' , 'assert_not_is_true' },
  1734. { 'assertNotIsFalse' , 'assert_not_is_false' },
  1735. { 'assertNotIsNaN' , 'assert_not_is_nan' },
  1736. { 'assertNotIsInf' , 'assert_not_is_inf' },
  1737. { 'assertNotIsPlusInf' , 'assert_not_plus_inf' },
  1738. { 'assertNotIsMinusInf' , 'assert_not_minus_inf' },
  1739. { 'assertNotIsPlusZero' , 'assert_not_plus_zero' },
  1740. { 'assertNotIsMinusZero' , 'assert_not_minus_zero' },
  1741. { 'assertNotIsFunction' , 'assert_not_is_function' },
  1742. { 'assertNotIsThread' , 'assert_not_is_thread' },
  1743. { 'assertNotIsUserdata' , 'assert_not_is_userdata' },
  1744. -- type assertions: assertNotIsXXX -> assertNotXxx (luaunit v2 compat)
  1745. { 'assertNotIsNumber' , 'assertNotNumber' },
  1746. { 'assertNotIsString' , 'assertNotString' },
  1747. { 'assertNotIsTable' , 'assertNotTable' },
  1748. { 'assertNotIsBoolean' , 'assertNotBoolean' },
  1749. { 'assertNotIsNil' , 'assertNotNil' },
  1750. { 'assertNotIsTrue' , 'assertNotTrue' },
  1751. { 'assertNotIsFalse' , 'assertNotFalse' },
  1752. { 'assertNotIsNaN' , 'assertNotNaN' },
  1753. { 'assertNotIsInf' , 'assertNotInf' },
  1754. { 'assertNotIsPlusInf' , 'assertNotPlusInf' },
  1755. { 'assertNotIsMinusInf' , 'assertNotMinusInf' },
  1756. { 'assertNotIsPlusZero' , 'assertNotPlusZero' },
  1757. { 'assertNotIsMinusZero' , 'assertNotMinusZero' },
  1758. { 'assertNotIsFunction' , 'assertNotFunction' },
  1759. { 'assertNotIsThread' , 'assertNotThread' },
  1760. { 'assertNotIsUserdata' , 'assertNotUserdata' },
  1761. -- type assertions: assertNotIsXXX -> assert_not_xxx
  1762. { 'assertNotIsNumber' , 'assert_not_number' },
  1763. { 'assertNotIsString' , 'assert_not_string' },
  1764. { 'assertNotIsTable' , 'assert_not_table' },
  1765. { 'assertNotIsBoolean' , 'assert_not_boolean' },
  1766. { 'assertNotIsNil' , 'assert_not_nil' },
  1767. { 'assertNotIsTrue' , 'assert_not_true' },
  1768. { 'assertNotIsFalse' , 'assert_not_false' },
  1769. { 'assertNotIsNaN' , 'assert_not_nan' },
  1770. { 'assertNotIsInf' , 'assert_not_inf' },
  1771. { 'assertNotIsPlusInf' , 'assert_not_plus_inf' },
  1772. { 'assertNotIsMinusInf' , 'assert_not_minus_inf' },
  1773. { 'assertNotIsPlusZero' , 'assert_not_plus_zero' },
  1774. { 'assertNotIsMinusZero' , 'assert_not_minus_zero' },
  1775. { 'assertNotIsFunction' , 'assert_not_function' },
  1776. { 'assertNotIsThread' , 'assert_not_thread' },
  1777. { 'assertNotIsUserdata' , 'assert_not_userdata' },
  1778. -- all assertions with Coroutine duplicate Thread assertions
  1779. { 'assertIsThread' , 'assertIsCoroutine' },
  1780. { 'assertIsThread' , 'assertCoroutine' },
  1781. { 'assertIsThread' , 'assert_is_coroutine' },
  1782. { 'assertIsThread' , 'assert_coroutine' },
  1783. { 'assertNotIsThread' , 'assertNotIsCoroutine' },
  1784. { 'assertNotIsThread' , 'assertNotCoroutine' },
  1785. { 'assertNotIsThread' , 'assert_not_is_coroutine' },
  1786. { 'assertNotIsThread' , 'assert_not_coroutine' },
  1787. }
  1788. -- Create all aliases in M
  1789. for _,v in ipairs( list_of_funcs ) do
  1790. local funcname, alias = v[1], v[2]
  1791. M[alias] = M[funcname]
  1792. if EXPORT_ASSERT_TO_GLOBALS then
  1793. _G[funcname] = M[funcname]
  1794. _G[alias] = M[funcname]
  1795. end
  1796. end
  1797. ----------------------------------------------------------------
  1798. --
  1799. -- Outputters
  1800. --
  1801. ----------------------------------------------------------------
  1802. -- A common "base" class for outputters
  1803. -- For concepts involved (class inheritance) see http://www.lua.org/pil/16.2.html
  1804. local genericOutput = { __class__ = 'genericOutput' } -- class
  1805. local genericOutput_MT = { __index = genericOutput } -- metatable
  1806. M.genericOutput = genericOutput -- publish, so that custom classes may derive from it
  1807. function genericOutput.new(runner, default_verbosity)
  1808. -- runner is the "parent" object controlling the output, usually a LuaUnit instance
  1809. local t = { runner = runner }
  1810. if runner then
  1811. t.result = runner.result
  1812. t.verbosity = runner.verbosity or default_verbosity
  1813. t.fname = runner.fname
  1814. else
  1815. t.verbosity = default_verbosity
  1816. end
  1817. return setmetatable( t, genericOutput_MT)
  1818. end
  1819. -- abstract ("empty") methods
  1820. function genericOutput:startSuite()
  1821. -- Called once, when the suite is started
  1822. end
  1823. function genericOutput:startClass(className)
  1824. -- Called each time a new test class is started
  1825. end
  1826. function genericOutput:startTest(testName)
  1827. -- called each time a new test is started, right before the setUp()
  1828. -- the current test status node is already created and available in: self.result.currentNode
  1829. end
  1830. function genericOutput:updateStatus(node)
  1831. -- called with status failed or error as soon as the error/failure is encountered
  1832. -- this method is NOT called for a successful test because a test is marked as successful by default
  1833. -- and does not need to be updated
  1834. end
  1835. function genericOutput:endTest(node)
  1836. -- called when the test is finished, after the tearDown() method
  1837. end
  1838. function genericOutput:endClass()
  1839. -- called when executing the class is finished, before moving on to the next class of at the end of the test execution
  1840. end
  1841. function genericOutput:endSuite()
  1842. -- called at the end of the test suite execution
  1843. end
  1844. ----------------------------------------------------------------
  1845. -- class TapOutput
  1846. ----------------------------------------------------------------
  1847. local TapOutput = genericOutput.new() -- derived class
  1848. local TapOutput_MT = { __index = TapOutput } -- metatable
  1849. TapOutput.__class__ = 'TapOutput'
  1850. -- For a good reference for TAP format, check: http://testanything.org/tap-specification.html
  1851. function TapOutput.new(runner)
  1852. local t = genericOutput.new(runner, M.VERBOSITY_LOW)
  1853. return setmetatable( t, TapOutput_MT)
  1854. end
  1855. function TapOutput:startSuite()
  1856. print("1.."..self.result.selectedCount)
  1857. print('# Started on '..self.result.startDate)
  1858. end
  1859. function TapOutput:startClass(className)
  1860. if className ~= '[TestFunctions]' then
  1861. print('# Starting class: '..className)
  1862. end
  1863. end
  1864. function TapOutput:updateStatus( node )
  1865. if node:isSkipped() then
  1866. io.stdout:write("ok ", self.result.currentTestNumber, "\t# SKIP ", node.msg, "\n" )
  1867. return
  1868. end
  1869. io.stdout:write("not ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
  1870. if self.verbosity > M.VERBOSITY_LOW then
  1871. print( prefixString( '# ', node.msg ) )
  1872. end
  1873. if (node:isFailure() or node:isError()) and self.verbosity > M.VERBOSITY_DEFAULT then
  1874. print( prefixString( '# ', node.stackTrace ) )
  1875. end
  1876. end
  1877. function TapOutput:endTest( node )
  1878. if node:isSuccess() then
  1879. io.stdout:write("ok ", self.result.currentTestNumber, "\t", node.testName, "\n")
  1880. end
  1881. end
  1882. function TapOutput:endSuite()
  1883. print( '# '..M.LuaUnit.statusLine( self.result ) )
  1884. return self.result.notSuccessCount
  1885. end
  1886. -- class TapOutput end
  1887. ----------------------------------------------------------------
  1888. -- class JUnitOutput
  1889. ----------------------------------------------------------------
  1890. -- See directory junitxml for more information about the junit format
  1891. local JUnitOutput = genericOutput.new() -- derived class
  1892. local JUnitOutput_MT = { __index = JUnitOutput } -- metatable
  1893. JUnitOutput.__class__ = 'JUnitOutput'
  1894. function JUnitOutput.new(runner)
  1895. local t = genericOutput.new(runner, M.VERBOSITY_LOW)
  1896. t.testList = {}
  1897. return setmetatable( t, JUnitOutput_MT )
  1898. end
  1899. function JUnitOutput:startSuite()
  1900. -- open xml file early to deal with errors
  1901. if self.fname == nil then
  1902. error('With Junit, an output filename must be supplied with --name!')
  1903. end
  1904. if string.sub(self.fname,-4) ~= '.xml' then
  1905. self.fname = self.fname..'.xml'
  1906. end
  1907. self.fd = io.open(self.fname, "w")
  1908. if self.fd == nil then
  1909. error("Could not open file for writing: "..self.fname)
  1910. end
  1911. print('# XML output to '..self.fname)
  1912. print('# Started on '..self.result.startDate)
  1913. end
  1914. function JUnitOutput:startClass(className)
  1915. if className ~= '[TestFunctions]' then
  1916. print('# Starting class: '..className)
  1917. end
  1918. end
  1919. function JUnitOutput:startTest(testName)
  1920. print('# Starting test: '..testName)
  1921. end
  1922. function JUnitOutput:updateStatus( node )
  1923. if node:isFailure() then
  1924. print( '# Failure: ' .. prefixString( '# ', node.msg ):sub(4, nil) )
  1925. -- print('# ' .. node.stackTrace)
  1926. elseif node:isError() then
  1927. print( '# Error: ' .. prefixString( '# ' , node.msg ):sub(4, nil) )
  1928. -- print('# ' .. node.stackTrace)
  1929. end
  1930. end
  1931. function JUnitOutput:endSuite()
  1932. print( '# '..M.LuaUnit.statusLine(self.result))
  1933. -- XML file writing
  1934. self.fd:write('<?xml version="1.0" encoding="UTF-8" ?>\n')
  1935. self.fd:write('<testsuites>\n')
  1936. self.fd:write(string.format(
  1937. ' <testsuite name="LuaUnit" id="00001" package="" hostname="localhost" tests="%d" timestamp="%s" time="%0.3f" errors="%d" failures="%d" skipped="%d">\n',
  1938. self.result.runCount, self.result.startIsodate, self.result.duration, self.result.errorCount, self.result.failureCount, self.result.skippedCount ))
  1939. self.fd:write(" <properties>\n")
  1940. self.fd:write(string.format(' <property name="Lua Version" value="%s"/>\n', _VERSION ) )
  1941. self.fd:write(string.format(' <property name="LuaUnit Version" value="%s"/>\n', M.VERSION) )
  1942. -- XXX please include system name and version if possible
  1943. self.fd:write(" </properties>\n")
  1944. for i,node in ipairs(self.result.allTests) do
  1945. self.fd:write(string.format(' <testcase classname="%s" name="%s" time="%0.3f">\n',
  1946. node.className, node.testName, node.duration ) )
  1947. if node:isNotSuccess() then
  1948. self.fd:write(node:statusXML())
  1949. end
  1950. self.fd:write(' </testcase>\n')
  1951. end
  1952. -- Next two lines are needed to validate junit ANT xsd, but really not useful in general:
  1953. self.fd:write(' <system-out/>\n')
  1954. self.fd:write(' <system-err/>\n')
  1955. self.fd:write(' </testsuite>\n')
  1956. self.fd:write('</testsuites>\n')
  1957. self.fd:close()
  1958. return self.result.notSuccessCount
  1959. end
  1960. -- class TapOutput end
  1961. ----------------------------------------------------------------
  1962. -- class TextOutput
  1963. ----------------------------------------------------------------
  1964. --[[ Example of other unit-tests suite text output
  1965. -- Python Non verbose:
  1966. For each test: . or F or E
  1967. If some failed tests:
  1968. ==============
  1969. ERROR / FAILURE: TestName (testfile.testclass)
  1970. ---------
  1971. Stack trace
  1972. then --------------
  1973. then "Ran x tests in 0.000s"
  1974. then OK or FAILED (failures=1, error=1)
  1975. -- Python Verbose:
  1976. testname (filename.classname) ... ok
  1977. testname (filename.classname) ... FAIL
  1978. testname (filename.classname) ... ERROR
  1979. then --------------
  1980. then "Ran x tests in 0.000s"
  1981. then OK or FAILED (failures=1, error=1)
  1982. -- Ruby:
  1983. Started
  1984. .
  1985. Finished in 0.002695 seconds.
  1986. 1 tests, 2 assertions, 0 failures, 0 errors
  1987. -- Ruby:
  1988. >> ruby tc_simple_number2.rb
  1989. Loaded suite tc_simple_number2
  1990. Started
  1991. F..
  1992. Finished in 0.038617 seconds.
  1993. 1) Failure:
  1994. test_failure(TestSimpleNumber) [tc_simple_number2.rb:16]:
  1995. Adding doesn't work.
  1996. <3> expected but was
  1997. <4>.
  1998. 3 tests, 4 assertions, 1 failures, 0 errors
  1999. -- Java Junit
  2000. .......F.
  2001. Time: 0,003
  2002. There was 1 failure:
  2003. 1) testCapacity(junit.samples.VectorTest)junit.framework.AssertionFailedError
  2004. at junit.samples.VectorTest.testCapacity(VectorTest.java:87)
  2005. at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
  2006. at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
  2007. at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
  2008. FAILURES!!!
  2009. Tests run: 8, Failures: 1, Errors: 0
  2010. -- Maven
  2011. # mvn test
  2012. -------------------------------------------------------
  2013. T E S T S
  2014. -------------------------------------------------------
  2015. Running math.AdditionTest
  2016. Tests run: 2, Failures: 1, Errors: 0, Skipped: 0, Time elapsed:
  2017. 0.03 sec <<< FAILURE!
  2018. Results :
  2019. Failed tests:
  2020. testLireSymbole(math.AdditionTest)
  2021. Tests run: 2, Failures: 1, Errors: 0, Skipped: 0
  2022. -- LuaUnit
  2023. ---- non verbose
  2024. * display . or F or E when running tests
  2025. ---- verbose
  2026. * display test name + ok/fail
  2027. ----
  2028. * blank line
  2029. * number) ERROR or FAILURE: TestName
  2030. Stack trace
  2031. * blank line
  2032. * number) ERROR or FAILURE: TestName
  2033. Stack trace
  2034. then --------------
  2035. then "Ran x tests in 0.000s (%d not selected, %d skipped)"
  2036. then OK or FAILED (failures=1, error=1)
  2037. ]]
  2038. local TextOutput = genericOutput.new() -- derived class
  2039. local TextOutput_MT = { __index = TextOutput } -- metatable
  2040. TextOutput.__class__ = 'TextOutput'
  2041. function TextOutput.new(runner)
  2042. local t = genericOutput.new(runner, M.VERBOSITY_DEFAULT)
  2043. t.errorList = {}
  2044. return setmetatable( t, TextOutput_MT )
  2045. end
  2046. function TextOutput:startSuite()
  2047. if self.verbosity > M.VERBOSITY_DEFAULT then
  2048. print( 'Started on '.. self.result.startDate )
  2049. end
  2050. end
  2051. function TextOutput:startTest(testName)
  2052. if self.verbosity > M.VERBOSITY_DEFAULT then
  2053. io.stdout:write( " ", self.result.currentNode.testName, " ... " )
  2054. end
  2055. end
  2056. function TextOutput:endTest( node )
  2057. if node:isSuccess() then
  2058. if self.verbosity > M.VERBOSITY_DEFAULT then
  2059. io.stdout:write("Ok\n")
  2060. else
  2061. io.stdout:write(".")
  2062. io.stdout:flush()
  2063. end
  2064. else
  2065. if self.verbosity > M.VERBOSITY_DEFAULT then
  2066. print( node.status )
  2067. print( node.msg )
  2068. --[[
  2069. -- find out when to do this:
  2070. if self.verbosity > M.VERBOSITY_DEFAULT then
  2071. print( node.stackTrace )
  2072. end
  2073. ]]
  2074. else
  2075. -- write only the first character of status E, F or S
  2076. io.stdout:write(string.sub(node.status, 1, 1))
  2077. io.stdout:flush()
  2078. end
  2079. end
  2080. end
  2081. function TextOutput:displayOneFailedTest( index, fail )
  2082. print(index..") "..fail.testName )
  2083. print( fail.msg )
  2084. print( fail.stackTrace )
  2085. print()
  2086. end
  2087. function TextOutput:displayErroredTests()
  2088. if #self.result.errorTests ~= 0 then
  2089. print("Tests with errors:")
  2090. print("------------------")
  2091. for i, v in ipairs(self.result.errorTests) do
  2092. self:displayOneFailedTest(i, v)
  2093. end
  2094. end
  2095. end
  2096. function TextOutput:displayFailedTests()
  2097. if #self.result.failedTests ~= 0 then
  2098. print("Failed tests:")
  2099. print("-------------")
  2100. for i, v in ipairs(self.result.failedTests) do
  2101. self:displayOneFailedTest(i, v)
  2102. end
  2103. end
  2104. end
  2105. function TextOutput:endSuite()
  2106. if self.verbosity > M.VERBOSITY_DEFAULT then
  2107. print("=========================================================")
  2108. else
  2109. print()
  2110. end
  2111. self:displayErroredTests()
  2112. self:displayFailedTests()
  2113. print( M.LuaUnit.statusLine( self.result ) )
  2114. if self.result.notSuccessCount == 0 then
  2115. print('OK')
  2116. end
  2117. end
  2118. -- class TextOutput end
  2119. ----------------------------------------------------------------
  2120. -- class NilOutput
  2121. ----------------------------------------------------------------
  2122. local function nopCallable()
  2123. --print(42)
  2124. return nopCallable
  2125. end
  2126. local NilOutput = { __class__ = 'NilOuptut' } -- class
  2127. local NilOutput_MT = { __index = nopCallable } -- metatable
  2128. function NilOutput.new(runner)
  2129. return setmetatable( { __class__ = 'NilOutput' }, NilOutput_MT )
  2130. end
  2131. ----------------------------------------------------------------
  2132. --
  2133. -- class LuaUnit
  2134. --
  2135. ----------------------------------------------------------------
  2136. M.LuaUnit = {
  2137. outputType = TextOutput,
  2138. verbosity = M.VERBOSITY_DEFAULT,
  2139. __class__ = 'LuaUnit',
  2140. instances = {}
  2141. }
  2142. local LuaUnit_MT = { __index = M.LuaUnit }
  2143. if EXPORT_ASSERT_TO_GLOBALS then
  2144. LuaUnit = M.LuaUnit
  2145. end
  2146. function M.LuaUnit.new()
  2147. local newInstance = setmetatable( {}, LuaUnit_MT )
  2148. return newInstance
  2149. end
  2150. -----------------[[ Utility methods ]]---------------------
  2151. function M.LuaUnit.asFunction(aObject)
  2152. -- return "aObject" if it is a function, and nil otherwise
  2153. if 'function' == type(aObject) then
  2154. return aObject
  2155. end
  2156. end
  2157. function M.LuaUnit.splitClassMethod(someName)
  2158. --[[
  2159. Return a pair of className, methodName strings for a name in the form
  2160. "class.method". If no class part (or separator) is found, will return
  2161. nil, someName instead (the latter being unchanged).
  2162. This convention thus also replaces the older isClassMethod() test:
  2163. You just have to check for a non-nil className (return) value.
  2164. ]]
  2165. local separator = string.find(someName, '.', 1, true)
  2166. if separator then
  2167. return someName:sub(1, separator - 1), someName:sub(separator + 1)
  2168. end
  2169. return nil, someName
  2170. end
  2171. function M.LuaUnit.isMethodTestName( s )
  2172. -- return true is the name matches the name of a test method
  2173. -- default rule is that is starts with 'Test' or with 'test'
  2174. return string.sub(s, 1, 4):lower() == 'test'
  2175. end
  2176. function M.LuaUnit.isTestName( s )
  2177. -- return true is the name matches the name of a test
  2178. -- default rule is that is starts with 'Test' or with 'test'
  2179. return string.sub(s, 1, 4):lower() == 'test'
  2180. end
  2181. function M.LuaUnit.collectTests()
  2182. -- return a list of all test names in the global namespace
  2183. -- that match LuaUnit.isTestName
  2184. local testNames = {}
  2185. for k, _ in pairs(_G) do
  2186. if type(k) == "string" and M.LuaUnit.isTestName( k ) then
  2187. table.insert( testNames , k )
  2188. end
  2189. end
  2190. table.sort( testNames )
  2191. return testNames
  2192. end
  2193. function M.LuaUnit.parseCmdLine( cmdLine )
  2194. -- parse the command line
  2195. -- Supported command line parameters:
  2196. -- --verbose, -v: increase verbosity
  2197. -- --quiet, -q: silence output
  2198. -- --error, -e: treat errors as fatal (quit program)
  2199. -- --output, -o, + name: select output type
  2200. -- --pattern, -p, + pattern: run test matching pattern, may be repeated
  2201. -- --exclude, -x, + pattern: run test not matching pattern, may be repeated
  2202. -- --shuffle, -s, : shuffle tests before reunning them
  2203. -- --name, -n, + fname: name of output file for junit, default to stdout
  2204. -- --repeat, -r, + num: number of times to execute each test
  2205. -- [testnames, ...]: run selected test names
  2206. --
  2207. -- Returns a table with the following fields:
  2208. -- verbosity: nil, M.VERBOSITY_DEFAULT, M.VERBOSITY_QUIET, M.VERBOSITY_VERBOSE
  2209. -- output: nil, 'tap', 'junit', 'text', 'nil'
  2210. -- testNames: nil or a list of test names to run
  2211. -- exeRepeat: num or 1
  2212. -- pattern: nil or a list of patterns
  2213. -- exclude: nil or a list of patterns
  2214. local result, state = {}, nil
  2215. local SET_OUTPUT = 1
  2216. local SET_PATTERN = 2
  2217. local SET_EXCLUDE = 3
  2218. local SET_FNAME = 4
  2219. local SET_REPEAT = 5
  2220. if cmdLine == nil then
  2221. return result
  2222. end
  2223. local function parseOption( option )
  2224. if option == '--help' or option == '-h' then
  2225. result['help'] = true
  2226. return
  2227. elseif option == '--version' then
  2228. result['version'] = true
  2229. return
  2230. elseif option == '--verbose' or option == '-v' then
  2231. result['verbosity'] = M.VERBOSITY_VERBOSE
  2232. return
  2233. elseif option == '--quiet' or option == '-q' then
  2234. result['verbosity'] = M.VERBOSITY_QUIET
  2235. return
  2236. elseif option == '--error' or option == '-e' then
  2237. result['quitOnError'] = true
  2238. return
  2239. elseif option == '--failure' or option == '-f' then
  2240. result['quitOnFailure'] = true
  2241. return
  2242. elseif option == '--shuffle' or option == '-s' then
  2243. result['shuffle'] = true
  2244. return
  2245. elseif option == '--output' or option == '-o' then
  2246. state = SET_OUTPUT
  2247. return state
  2248. elseif option == '--name' or option == '-n' then
  2249. state = SET_FNAME
  2250. return state
  2251. elseif option == '--repeat' or option == '-r' then
  2252. state = SET_REPEAT
  2253. return state
  2254. elseif option == '--pattern' or option == '-p' then
  2255. state = SET_PATTERN
  2256. return state
  2257. elseif option == '--exclude' or option == '-x' then
  2258. state = SET_EXCLUDE
  2259. return state
  2260. end
  2261. error('Unknown option: '..option,3)
  2262. end
  2263. local function setArg( cmdArg, state )
  2264. if state == SET_OUTPUT then
  2265. result['output'] = cmdArg
  2266. return
  2267. elseif state == SET_FNAME then
  2268. result['fname'] = cmdArg
  2269. return
  2270. elseif state == SET_REPEAT then
  2271. result['exeRepeat'] = tonumber(cmdArg)
  2272. or error('Malformed -r argument: '..cmdArg)
  2273. return
  2274. elseif state == SET_PATTERN then
  2275. if result['pattern'] then
  2276. table.insert( result['pattern'], cmdArg )
  2277. else
  2278. result['pattern'] = { cmdArg }
  2279. end
  2280. return
  2281. elseif state == SET_EXCLUDE then
  2282. local notArg = '!'..cmdArg
  2283. if result['pattern'] then
  2284. table.insert( result['pattern'], notArg )
  2285. else
  2286. result['pattern'] = { notArg }
  2287. end
  2288. return
  2289. end
  2290. error('Unknown parse state: '.. state)
  2291. end
  2292. for i, cmdArg in ipairs(cmdLine) do
  2293. if state ~= nil then
  2294. setArg( cmdArg, state, result )
  2295. state = nil
  2296. else
  2297. if cmdArg:sub(1,1) == '-' then
  2298. state = parseOption( cmdArg )
  2299. else
  2300. if result['testNames'] then
  2301. table.insert( result['testNames'], cmdArg )
  2302. else
  2303. result['testNames'] = { cmdArg }
  2304. end
  2305. end
  2306. end
  2307. end
  2308. if result['help'] then
  2309. M.LuaUnit.help()
  2310. end
  2311. if result['version'] then
  2312. M.LuaUnit.version()
  2313. end
  2314. if state ~= nil then
  2315. error('Missing argument after '..cmdLine[ #cmdLine ],2 )
  2316. end
  2317. return result
  2318. end
  2319. function M.LuaUnit.help()
  2320. print(M.USAGE)
  2321. os.exit(0)
  2322. end
  2323. function M.LuaUnit.version()
  2324. print('LuaUnit v'..M.VERSION..' by Philippe Fremy <phil@freehackers.org>')
  2325. os.exit(0)
  2326. end
  2327. ----------------------------------------------------------------
  2328. -- class NodeStatus
  2329. ----------------------------------------------------------------
  2330. local NodeStatus = { __class__ = 'NodeStatus' } -- class
  2331. local NodeStatus_MT = { __index = NodeStatus } -- metatable
  2332. M.NodeStatus = NodeStatus
  2333. -- values of status
  2334. NodeStatus.SUCCESS = 'SUCCESS'
  2335. NodeStatus.SKIP = 'SKIP'
  2336. NodeStatus.FAIL = 'FAIL'
  2337. NodeStatus.ERROR = 'ERROR'
  2338. function NodeStatus.new( number, testName, className )
  2339. -- default constructor, test are PASS by default
  2340. local t = { number = number, testName = testName, className = className }
  2341. setmetatable( t, NodeStatus_MT )
  2342. t:success()
  2343. return t
  2344. end
  2345. function NodeStatus:success()
  2346. self.status = self.SUCCESS
  2347. -- useless because lua does this for us, but it helps me remembering the relevant field names
  2348. self.msg = nil
  2349. self.stackTrace = nil
  2350. end
  2351. function NodeStatus:skip(msg)
  2352. self.status = self.SKIP
  2353. self.msg = msg
  2354. self.stackTrace = nil
  2355. end
  2356. function NodeStatus:fail(msg, stackTrace)
  2357. self.status = self.FAIL
  2358. self.msg = msg
  2359. self.stackTrace = stackTrace
  2360. end
  2361. function NodeStatus:error(msg, stackTrace)
  2362. self.status = self.ERROR
  2363. self.msg = msg
  2364. self.stackTrace = stackTrace
  2365. end
  2366. function NodeStatus:isSuccess()
  2367. return self.status == NodeStatus.SUCCESS
  2368. end
  2369. function NodeStatus:isNotSuccess()
  2370. -- Return true if node is either failure or error or skip
  2371. return (self.status == NodeStatus.FAIL or self.status == NodeStatus.ERROR or self.status == NodeStatus.SKIP)
  2372. end
  2373. function NodeStatus:isSkipped()
  2374. return self.status == NodeStatus.SKIP
  2375. end
  2376. function NodeStatus:isFailure()
  2377. return self.status == NodeStatus.FAIL
  2378. end
  2379. function NodeStatus:isError()
  2380. return self.status == NodeStatus.ERROR
  2381. end
  2382. function NodeStatus:statusXML()
  2383. if self:isError() then
  2384. return table.concat(
  2385. {' <error type="', xmlEscape(self.msg), '">\n',
  2386. ' <![CDATA[', xmlCDataEscape(self.stackTrace),
  2387. ']]></error>\n'})
  2388. elseif self:isFailure() then
  2389. return table.concat(
  2390. {' <failure type="', xmlEscape(self.msg), '">\n',
  2391. ' <![CDATA[', xmlCDataEscape(self.stackTrace),
  2392. ']]></failure>\n'})
  2393. elseif self:isSkipped() then
  2394. return table.concat({' <skipped>', xmlEscape(self.msg),'</skipped>\n' } )
  2395. end
  2396. return ' <passed/>\n' -- (not XSD-compliant! normally shouldn't get here)
  2397. end
  2398. --------------[[ Output methods ]]-------------------------
  2399. local function conditional_plural(number, singular)
  2400. -- returns a grammatically well-formed string "%d <singular/plural>"
  2401. local suffix = ''
  2402. if number ~= 1 then -- use plural
  2403. suffix = (singular:sub(-2) == 'ss') and 'es' or 's'
  2404. end
  2405. return string.format('%d %s%s', number, singular, suffix)
  2406. end
  2407. function M.LuaUnit.statusLine(result)
  2408. -- return status line string according to results
  2409. local s = {
  2410. string.format('Ran %d tests in %0.3f seconds',
  2411. result.runCount, result.duration),
  2412. conditional_plural(result.successCount, 'success'),
  2413. }
  2414. if result.notSuccessCount > 0 then
  2415. if result.failureCount > 0 then
  2416. table.insert(s, conditional_plural(result.failureCount, 'failure'))
  2417. end
  2418. if result.errorCount > 0 then
  2419. table.insert(s, conditional_plural(result.errorCount, 'error'))
  2420. end
  2421. else
  2422. table.insert(s, '0 failures')
  2423. end
  2424. if result.skippedCount > 0 then
  2425. table.insert(s, string.format("%d skipped", result.skippedCount))
  2426. end
  2427. if result.nonSelectedCount > 0 then
  2428. table.insert(s, string.format("%d non-selected", result.nonSelectedCount))
  2429. end
  2430. return table.concat(s, ', ')
  2431. end
  2432. function M.LuaUnit:startSuite(selectedCount, nonSelectedCount)
  2433. self.result = {
  2434. selectedCount = selectedCount,
  2435. nonSelectedCount = nonSelectedCount,
  2436. successCount = 0,
  2437. runCount = 0,
  2438. currentTestNumber = 0,
  2439. currentClassName = "",
  2440. currentNode = nil,
  2441. suiteStarted = true,
  2442. startTime = os.clock(),
  2443. startDate = os.date(os.getenv('LUAUNIT_DATEFMT')),
  2444. startIsodate = os.date('%Y-%m-%dT%H:%M:%S'),
  2445. patternIncludeFilter = self.patternIncludeFilter,
  2446. -- list of test node status
  2447. allTests = {},
  2448. failedTests = {},
  2449. errorTests = {},
  2450. skippedTests = {},
  2451. failureCount = 0,
  2452. errorCount = 0,
  2453. notSuccessCount = 0,
  2454. skippedCount = 0,
  2455. }
  2456. self.outputType = self.outputType or TextOutput
  2457. self.output = self.outputType.new(self)
  2458. self.output:startSuite()
  2459. end
  2460. function M.LuaUnit:startClass( className, classInstance )
  2461. self.result.currentClassName = className
  2462. self.output:startClass( className )
  2463. self:setupClass( className, classInstance )
  2464. end
  2465. function M.LuaUnit:startTest( testName )
  2466. self.result.currentTestNumber = self.result.currentTestNumber + 1
  2467. self.result.runCount = self.result.runCount + 1
  2468. self.result.currentNode = NodeStatus.new(
  2469. self.result.currentTestNumber,
  2470. testName,
  2471. self.result.currentClassName
  2472. )
  2473. self.result.currentNode.startTime = os.clock()
  2474. table.insert( self.result.allTests, self.result.currentNode )
  2475. self.output:startTest( testName )
  2476. end
  2477. function M.LuaUnit:updateStatus( err )
  2478. -- "err" is expected to be a table / result from protectedCall()
  2479. if err.status == NodeStatus.SUCCESS then
  2480. return
  2481. end
  2482. local node = self.result.currentNode
  2483. --[[ As a first approach, we will report only one error or one failure for one test.
  2484. However, we can have the case where the test is in failure, and the teardown is in error.
  2485. In such case, it's a good idea to report both a failure and an error in the test suite. This is
  2486. what Python unittest does for example. However, it mixes up counts so need to be handled carefully: for
  2487. example, there could be more (failures + errors) count that tests. What happens to the current node ?
  2488. We will do this more intelligent version later.
  2489. ]]
  2490. -- if the node is already in failure/error, just don't report the new error (see above)
  2491. if node.status ~= NodeStatus.SUCCESS then
  2492. return
  2493. end
  2494. if err.status == NodeStatus.FAIL then
  2495. node:fail( err.msg, err.trace )
  2496. table.insert( self.result.failedTests, node )
  2497. elseif err.status == NodeStatus.ERROR then
  2498. node:error( err.msg, err.trace )
  2499. table.insert( self.result.errorTests, node )
  2500. elseif err.status == NodeStatus.SKIP then
  2501. node:skip( err.msg )
  2502. table.insert( self.result.skippedTests, node )
  2503. else
  2504. error('No such status: ' .. prettystr(err.status))
  2505. end
  2506. self.output:updateStatus( node )
  2507. end
  2508. function M.LuaUnit:endTest()
  2509. local node = self.result.currentNode
  2510. -- print( 'endTest() '..prettystr(node))
  2511. -- print( 'endTest() '..prettystr(node:isNotSuccess()))
  2512. node.duration = os.clock() - node.startTime
  2513. node.startTime = nil
  2514. self.output:endTest( node )
  2515. if node:isSuccess() then
  2516. self.result.successCount = self.result.successCount + 1
  2517. elseif node:isError() then
  2518. if self.quitOnError or self.quitOnFailure then
  2519. -- Runtime error - abort test execution as requested by
  2520. -- "--error" option. This is done by setting a special
  2521. -- flag that gets handled in internalRunSuiteByInstances().
  2522. print("\nERROR during LuaUnit test execution:\n" .. node.msg)
  2523. self.result.aborted = true
  2524. end
  2525. elseif node:isFailure() then
  2526. if self.quitOnFailure then
  2527. -- Failure - abort test execution as requested by
  2528. -- "--failure" option. This is done by setting a special
  2529. -- flag that gets handled in internalRunSuiteByInstances().
  2530. print("\nFailure during LuaUnit test execution:\n" .. node.msg)
  2531. self.result.aborted = true
  2532. end
  2533. elseif node:isSkipped() then
  2534. self.result.runCount = self.result.runCount - 1
  2535. else
  2536. error('No such node status: ' .. prettystr(node.status))
  2537. end
  2538. self.result.currentNode = nil
  2539. end
  2540. function M.LuaUnit:endClass()
  2541. self:teardownClass( self.lastClassName, self.lastClassInstance )
  2542. self.output:endClass()
  2543. end
  2544. function M.LuaUnit:endSuite()
  2545. if self.result.suiteStarted == false then
  2546. error('LuaUnit:endSuite() -- suite was already ended' )
  2547. end
  2548. self.result.duration = os.clock()-self.result.startTime
  2549. self.result.suiteStarted = false
  2550. -- Expose test counts for outputter's endSuite(). This could be managed
  2551. -- internally instead by using the length of the lists of failed tests
  2552. -- but unit tests rely on these fields being present.
  2553. self.result.failureCount = #self.result.failedTests
  2554. self.result.errorCount = #self.result.errorTests
  2555. self.result.notSuccessCount = self.result.failureCount + self.result.errorCount
  2556. self.result.skippedCount = #self.result.skippedTests
  2557. self.output:endSuite()
  2558. end
  2559. function M.LuaUnit:setOutputType(outputType, fname)
  2560. -- Configures LuaUnit runner output
  2561. -- outputType is one of: NIL, TAP, JUNIT, TEXT
  2562. -- when outputType is junit, the additional argument fname is used to set the name of junit output file
  2563. -- for other formats, fname is ignored
  2564. if outputType:upper() == "NIL" then
  2565. self.outputType = NilOutput
  2566. return
  2567. end
  2568. if outputType:upper() == "TAP" then
  2569. self.outputType = TapOutput
  2570. return
  2571. end
  2572. if outputType:upper() == "JUNIT" then
  2573. self.outputType = JUnitOutput
  2574. if fname then
  2575. self.fname = fname
  2576. end
  2577. return
  2578. end
  2579. if outputType:upper() == "TEXT" then
  2580. self.outputType = TextOutput
  2581. return
  2582. end
  2583. error( 'No such format: '..outputType,2)
  2584. end
  2585. --------------[[ Runner ]]-----------------
  2586. function M.LuaUnit:protectedCall(classInstance, methodInstance, prettyFuncName)
  2587. -- if classInstance is nil, this is just a function call
  2588. -- else, it's method of a class being called.
  2589. local function err_handler(e)
  2590. -- transform error into a table, adding the traceback information
  2591. return {
  2592. status = NodeStatus.ERROR,
  2593. msg = e,
  2594. trace = string.sub(debug.traceback("", 1), 2)
  2595. }
  2596. end
  2597. local ok, err
  2598. if classInstance then
  2599. -- stupid Lua < 5.2 does not allow xpcall with arguments so let's use a workaround
  2600. ok, err = xpcall( function () methodInstance(classInstance) end, err_handler )
  2601. else
  2602. ok, err = xpcall( function () methodInstance() end, err_handler )
  2603. end
  2604. if ok then
  2605. return {status = NodeStatus.SUCCESS}
  2606. end
  2607. -- print('ok="'..prettystr(ok)..'" err="'..prettystr(err)..'"')
  2608. local iter_msg
  2609. iter_msg = self.exeRepeat and 'iteration '..self.currentCount
  2610. err.msg, err.status = M.adjust_err_msg_with_iter( err.msg, iter_msg )
  2611. if err.status == NodeStatus.SUCCESS or err.status == NodeStatus.SKIP then
  2612. err.trace = nil
  2613. return err
  2614. end
  2615. -- reformat / improve the stack trace
  2616. if prettyFuncName then -- we do have the real method name
  2617. err.trace = err.trace:gsub("in (%a+) 'methodInstance'", "in %1 '"..prettyFuncName.."'")
  2618. end
  2619. if STRIP_LUAUNIT_FROM_STACKTRACE then
  2620. err.trace = stripLuaunitTrace2(err.trace, err.msg)
  2621. end
  2622. return err -- return the error "object" (table)
  2623. end
  2624. function M.LuaUnit:execOneFunction(className, methodName, classInstance, methodInstance)
  2625. -- When executing a test function, className and classInstance must be nil
  2626. -- When executing a class method, all parameters must be set
  2627. if type(methodInstance) ~= 'function' then
  2628. self:unregisterSuite()
  2629. error( tostring(methodName)..' must be a function, not '..type(methodInstance))
  2630. end
  2631. local prettyFuncName
  2632. if className == nil then
  2633. className = '[TestFunctions]'
  2634. prettyFuncName = methodName
  2635. else
  2636. prettyFuncName = className..'.'..methodName
  2637. end
  2638. if self.lastClassName ~= className then
  2639. if self.lastClassName ~= nil then
  2640. self:endClass()
  2641. end
  2642. self:startClass( className, classInstance )
  2643. self.lastClassName = className
  2644. self.lastClassInstance = classInstance
  2645. end
  2646. self:startTest(prettyFuncName)
  2647. local node = self.result.currentNode
  2648. for iter_n = 1, self.exeRepeat or 1 do
  2649. if node:isNotSuccess() then
  2650. break
  2651. end
  2652. self.currentCount = iter_n
  2653. -- run setUp first (if any)
  2654. if classInstance then
  2655. local func = self.asFunction( classInstance.setUp ) or
  2656. self.asFunction( classInstance.Setup ) or
  2657. self.asFunction( classInstance.setup ) or
  2658. self.asFunction( classInstance.SetUp )
  2659. if func then
  2660. self:updateStatus(self:protectedCall(classInstance, func, className..'.setUp'))
  2661. end
  2662. end
  2663. -- run testMethod()
  2664. if node:isSuccess() then
  2665. self:updateStatus(self:protectedCall(classInstance, methodInstance, prettyFuncName))
  2666. end
  2667. -- lastly, run tearDown (if any)
  2668. if classInstance then
  2669. local func = self.asFunction( classInstance.tearDown ) or
  2670. self.asFunction( classInstance.TearDown ) or
  2671. self.asFunction( classInstance.teardown ) or
  2672. self.asFunction( classInstance.Teardown )
  2673. if func then
  2674. self:updateStatus(self:protectedCall(classInstance, func, className..'.tearDown'))
  2675. end
  2676. end
  2677. end
  2678. self:endTest()
  2679. end
  2680. function M.LuaUnit.expandOneClass( result, className, classInstance )
  2681. --[[
  2682. Input: a list of { name, instance }, a class name, a class instance
  2683. Ouptut: modify result to add all test method instance in the form:
  2684. { className.methodName, classInstance }
  2685. ]]
  2686. for methodName, methodInstance in sortedPairs(classInstance) do
  2687. if M.LuaUnit.asFunction(methodInstance) and M.LuaUnit.isMethodTestName( methodName ) then
  2688. table.insert( result, { className..'.'..methodName, classInstance } )
  2689. end
  2690. end
  2691. end
  2692. function M.LuaUnit.expandClasses( listOfNameAndInst )
  2693. --[[
  2694. -- expand all classes (provided as {className, classInstance}) to a list of {className.methodName, classInstance}
  2695. -- functions and methods remain untouched
  2696. Input: a list of { name, instance }
  2697. Output:
  2698. * { function name, function instance } : do nothing
  2699. * { class.method name, class instance }: do nothing
  2700. * { class name, class instance } : add all method names in the form of (className.methodName, classInstance)
  2701. ]]
  2702. local result = {}
  2703. for i,v in ipairs( listOfNameAndInst ) do
  2704. local name, instance = v[1], v[2]
  2705. if M.LuaUnit.asFunction(instance) then
  2706. table.insert( result, { name, instance } )
  2707. else
  2708. if type(instance) ~= 'table' then
  2709. error( 'Instance must be a table or a function, not a '..type(instance)..' with value '..prettystr(instance))
  2710. end
  2711. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2712. if className then
  2713. local methodInstance = instance[methodName]
  2714. if methodInstance == nil then
  2715. error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
  2716. end
  2717. table.insert( result, { name, instance } )
  2718. else
  2719. M.LuaUnit.expandOneClass( result, name, instance )
  2720. end
  2721. end
  2722. end
  2723. return result
  2724. end
  2725. function M.LuaUnit.applyPatternFilter( patternIncFilter, listOfNameAndInst )
  2726. local included, excluded = {}, {}
  2727. for i, v in ipairs( listOfNameAndInst ) do
  2728. -- local name, instance = v[1], v[2]
  2729. if patternFilter( patternIncFilter, v[1] ) then
  2730. table.insert( included, v )
  2731. else
  2732. table.insert( excluded, v )
  2733. end
  2734. end
  2735. return included, excluded
  2736. end
  2737. local function getKeyInListWithGlobalFallback( key, listOfNameAndInst )
  2738. local result = nil
  2739. for i,v in ipairs( listOfNameAndInst ) do
  2740. if(listOfNameAndInst[i][1] == key) then
  2741. result = listOfNameAndInst[i][2]
  2742. break
  2743. end
  2744. end
  2745. if(not M.LuaUnit.asFunction( result ) ) then
  2746. result = _G[key]
  2747. end
  2748. return result
  2749. end
  2750. function M.LuaUnit:setupSuite( listOfNameAndInst )
  2751. local setupSuite = getKeyInListWithGlobalFallback("setupSuite", listOfNameAndInst)
  2752. if self.asFunction( setupSuite ) then
  2753. self:updateStatus( self:protectedCall( nil, setupSuite, 'setupSuite' ) )
  2754. end
  2755. end
  2756. function M.LuaUnit:teardownSuite(listOfNameAndInst)
  2757. local teardownSuite = getKeyInListWithGlobalFallback("teardownSuite", listOfNameAndInst)
  2758. if self.asFunction( teardownSuite ) then
  2759. self:updateStatus( self:protectedCall( nil, teardownSuite, 'teardownSuite') )
  2760. end
  2761. end
  2762. function M.LuaUnit:setupClass( className, instance )
  2763. if type( instance ) == 'table' and self.asFunction( instance.setupClass ) then
  2764. self:updateStatus( self:protectedCall( instance, instance.setupClass, className..'.setupClass' ) )
  2765. end
  2766. end
  2767. function M.LuaUnit:teardownClass( className, instance )
  2768. if type( instance ) == 'table' and self.asFunction( instance.teardownClass ) then
  2769. self:updateStatus( self:protectedCall( instance, instance.teardownClass, className..'.teardownClass' ) )
  2770. end
  2771. end
  2772. function M.LuaUnit:internalRunSuiteByInstances( listOfNameAndInst )
  2773. --[[ Run an explicit list of tests. Each item of the list must be one of:
  2774. * { function name, function instance }
  2775. * { class name, class instance }
  2776. * { class.method name, class instance }
  2777. This function is internal to LuaUnit. The official API to perform this action is runSuiteByInstances()
  2778. ]]
  2779. local expandedList = self.expandClasses( listOfNameAndInst )
  2780. if self.shuffle then
  2781. randomizeTable( expandedList )
  2782. end
  2783. local filteredList, filteredOutList = self.applyPatternFilter(
  2784. self.patternIncludeFilter, expandedList )
  2785. self:startSuite( #filteredList, #filteredOutList )
  2786. self:setupSuite( listOfNameAndInst )
  2787. for i,v in ipairs( filteredList ) do
  2788. local name, instance = v[1], v[2]
  2789. if M.LuaUnit.asFunction(instance) then
  2790. self:execOneFunction( nil, name, nil, instance )
  2791. else
  2792. -- expandClasses() should have already taken care of sanitizing the input
  2793. assert( type(instance) == 'table' )
  2794. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2795. assert( className ~= nil )
  2796. local methodInstance = instance[methodName]
  2797. assert(methodInstance ~= nil)
  2798. self:execOneFunction( className, methodName, instance, methodInstance )
  2799. end
  2800. if self.result.aborted then
  2801. break -- "--error" or "--failure" option triggered
  2802. end
  2803. end
  2804. if self.lastClassName ~= nil then
  2805. self:endClass()
  2806. end
  2807. self:teardownSuite( listOfNameAndInst )
  2808. self:endSuite()
  2809. if self.result.aborted then
  2810. print("LuaUnit ABORTED (as requested by --error or --failure option)")
  2811. self:unregisterSuite()
  2812. os.exit(-2)
  2813. end
  2814. end
  2815. function M.LuaUnit:internalRunSuiteByNames( listOfName )
  2816. --[[ Run LuaUnit with a list of generic names, coming either from command-line or from global
  2817. namespace analysis. Convert the list into a list of (name, valid instances (table or function))
  2818. and calls internalRunSuiteByInstances.
  2819. ]]
  2820. local instanceName, instance
  2821. local listOfNameAndInst = {}
  2822. for i,name in ipairs( listOfName ) do
  2823. local className, methodName = M.LuaUnit.splitClassMethod( name )
  2824. if className then
  2825. instanceName = className
  2826. instance = _G[instanceName]
  2827. if instance == nil then
  2828. self:unregisterSuite()
  2829. error( "No such name in global space: "..instanceName )
  2830. end
  2831. if type(instance) ~= 'table' then
  2832. self:unregisterSuite()
  2833. error( 'Instance of '..instanceName..' must be a table, not '..type(instance))
  2834. end
  2835. local methodInstance = instance[methodName]
  2836. if methodInstance == nil then
  2837. self:unregisterSuite()
  2838. error( "Could not find method in class "..tostring(className).." for method "..tostring(methodName) )
  2839. end
  2840. else
  2841. -- for functions and classes
  2842. instanceName = name
  2843. instance = _G[instanceName]
  2844. end
  2845. if instance == nil then
  2846. self:unregisterSuite()
  2847. error( "No such name in global space: "..instanceName )
  2848. end
  2849. if (type(instance) ~= 'table' and type(instance) ~= 'function') then
  2850. self:unregisterSuite()
  2851. error( 'Name must match a function or a table: '..instanceName )
  2852. end
  2853. table.insert( listOfNameAndInst, { name, instance } )
  2854. end
  2855. self:internalRunSuiteByInstances( listOfNameAndInst )
  2856. end
  2857. function M.LuaUnit.run(...)
  2858. -- Run some specific test classes.
  2859. -- If no arguments are passed, run the class names specified on the
  2860. -- command line. If no class name is specified on the command line
  2861. -- run all classes whose name starts with 'Test'
  2862. --
  2863. -- If arguments are passed, they must be strings of the class names
  2864. -- that you want to run or generic command line arguments (-o, -p, -v, ...)
  2865. local runner = M.LuaUnit.new()
  2866. return runner:runSuite(...)
  2867. end
  2868. function M.LuaUnit:registerSuite()
  2869. -- register the current instance into our global array of instances
  2870. -- print('-> Register suite')
  2871. M.LuaUnit.instances[ #M.LuaUnit.instances+1 ] = self
  2872. end
  2873. function M.unregisterCurrentSuite()
  2874. -- force unregister the last registered suite
  2875. table.remove(M.LuaUnit.instances, #M.LuaUnit.instances)
  2876. end
  2877. function M.LuaUnit:unregisterSuite()
  2878. -- print('<- Unregister suite')
  2879. -- remove our current instqances from the global array of instances
  2880. local instanceIdx = nil
  2881. for i, instance in ipairs(M.LuaUnit.instances) do
  2882. if instance == self then
  2883. instanceIdx = i
  2884. break
  2885. end
  2886. end
  2887. if instanceIdx ~= nil then
  2888. table.remove(M.LuaUnit.instances, instanceIdx)
  2889. -- print('Unregister done')
  2890. end
  2891. end
  2892. function M.LuaUnit:initFromArguments( ... )
  2893. --[[Parses all arguments from either command-line or direct call and set internal
  2894. flags of LuaUnit runner according to it.
  2895. Return the list of names which were possibly passed on the command-line or as arguments
  2896. ]]
  2897. local args = {...}
  2898. if type(args[1]) == 'table' and args[1].__class__ == 'LuaUnit' then
  2899. -- run was called with the syntax M.LuaUnit:runSuite()
  2900. -- we support both M.LuaUnit.run() and M.LuaUnit:run()
  2901. -- strip out the first argument self to make it a command-line argument list
  2902. table.remove(args,1)
  2903. end
  2904. if #args == 0 then
  2905. args = cmdline_argv
  2906. end
  2907. local options = pcall_or_abort( M.LuaUnit.parseCmdLine, args )
  2908. -- We expect these option fields to be either `nil` or contain
  2909. -- valid values, so it's safe to always copy them directly.
  2910. self.verbosity = options.verbosity
  2911. self.quitOnError = options.quitOnError
  2912. self.quitOnFailure = options.quitOnFailure
  2913. self.exeRepeat = options.exeRepeat
  2914. self.patternIncludeFilter = options.pattern
  2915. self.shuffle = options.shuffle
  2916. options.output = options.output or os.getenv('LUAUNIT_OUTPUT')
  2917. options.fname = options.fname or os.getenv('LUAUNIT_JUNIT_FNAME')
  2918. if options.output then
  2919. if options.output:lower() == 'junit' and options.fname == nil then
  2920. print('With junit output, a filename must be supplied with -n or --name')
  2921. os.exit(-1)
  2922. end
  2923. pcall_or_abort(self.setOutputType, self, options.output, options.fname)
  2924. end
  2925. return options.testNames
  2926. end
  2927. function M.LuaUnit:runSuite( ... )
  2928. testNames = self:initFromArguments(...)
  2929. self:registerSuite()
  2930. self:internalRunSuiteByNames( testNames or M.LuaUnit.collectTests() )
  2931. self:unregisterSuite()
  2932. return self.result.notSuccessCount
  2933. end
  2934. function M.LuaUnit:runSuiteByInstances( listOfNameAndInst )
  2935. --[[
  2936. Run all test functions or tables provided as input.
  2937. Input: a list of { name, instance }
  2938. instance can either be a function or a table containing test functions starting with the prefix "test"
  2939. return the number of failures and errors, 0 meaning success
  2940. ]]
  2941. -- parse the command-line arguments
  2942. testNames = self:initFromArguments()
  2943. self:registerSuite()
  2944. self:internalRunSuiteByInstances( listOfNameAndInst )
  2945. self:unregisterSuite()
  2946. return self.result.notSuccessCount
  2947. end
  2948. -- class LuaUnit
  2949. -- For compatbility with LuaUnit v2
  2950. M.run = M.LuaUnit.run
  2951. M.Run = M.LuaUnit.run
  2952. function M:setVerbosity( verbosity )
  2953. -- set the verbosity value (as integer)
  2954. M.LuaUnit.verbosity = verbosity
  2955. end
  2956. M.set_verbosity = M.setVerbosity
  2957. M.SetVerbosity = M.setVerbosity
  2958. return M