pat.lua 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. --[========================================================================[--
  2. Lua pattern matching library (minus string.gsub) ported to Lua.
  3. Copyright © 1994–2018 Lua.org, PUC-Rio.
  4. Copyright © 2019 Pedro Gimeno Fortea.
  5. Permission is hereby granted, free of charge, to any person obtaining a copy
  6. of this software and associated documentation files (the "Software"), to deal
  7. in the Software without restriction, including without limitation the rights
  8. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. copies of the Software, and to permit persons to whom the Software is
  10. furnished to do so, subject to the following conditions:
  11. The above copyright notice and this permission notice shall be included in all
  12. copies or substantial portions of the Software.
  13. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  14. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  15. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  16. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  17. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  18. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  19. SOFTWARE.
  20. --]========================================================================]--
  21. local LUA_MAXCAPTURES = 32
  22. local CAP_UNFINISHED = -1
  23. local CAP_POSITION = -2
  24. local L_ESC = 37 -- 37='%'
  25. -- only used for string.find(x, y, z, true)
  26. -- local SPECIALS = '^$*+?.([%-'
  27. local function LUA_QL(x)
  28. return "'" .. x .. "'"
  29. end
  30. local function check_capture(ms, l)
  31. l = l - 48 -- 48='0'
  32. if l < 1 or l > ms.level or ms.capture[l].len == CAP_UNFINISHED then
  33. return error("invalid capture index")
  34. end
  35. return l
  36. end
  37. local function capture_to_close(ms)
  38. local level = ms.level
  39. while level > 0 do
  40. if ms.capture[level].len == CAP_UNFINISHED then
  41. return level
  42. end
  43. level = level - 1
  44. end
  45. return error("invalid pattern capture")
  46. end
  47. local function classend(pat, p)
  48. local cc = pat:byte(p)
  49. p = p + 1
  50. if cc == L_ESC then
  51. if (pat:byte(p) or 0) == 0 then
  52. return error("malformed pattern (ends with " .. LUA_QL('%') .. ")")
  53. end
  54. return p + 1
  55. end
  56. if cc == 91 then -- 91='['
  57. if pat:byte(p) == 94 then -- 94='^'
  58. p = p + 1
  59. end
  60. repeat -- look for a `]'
  61. cc = pat:byte(p) or 0
  62. p = p + 1
  63. if cc == 0 then
  64. return error("malformed pattern (missing " .. LUA_QL(']') .. ")")
  65. end
  66. if cc == L_ESC then
  67. if (pat:byte(p) or 0) ~= 0 then
  68. p = p + 1
  69. end
  70. end
  71. until pat:byte(p) == 93 -- 93=']'
  72. return p + 1
  73. end
  74. return p
  75. end
  76. local function match_class(c, cl)
  77. local u = cl - cl % 64 + cl % 32 -- upper case
  78. local negate = cl == u -- true if uppercase
  79. local res
  80. if u == 65 then -- 65='A'
  81. res = c >= 65 and c <= 90 or c >= 97 and c <= 122 -- 65='A', 90='Z', 97='a', 122='z'
  82. elseif u == 67 then -- 67='C'
  83. res = c < 32 or c == 127
  84. elseif u == 68 then -- 68='D'
  85. res = c >= 48 and c <= 57 -- 48='0', 57='9'
  86. elseif u == 76 then -- 76='L'
  87. res = c >= 97 and c <= 122 -- 97='a', 122='z'
  88. elseif u == 80 then -- 80='P'
  89. res = c >= 33 and c <= 47 or c >= 58 and c <= 64
  90. or c >= 91 and c <= 96 or c >= 123 and c <= 126
  91. elseif u == 83 then -- 83='S'
  92. res = c == 9 or c >= 10 and c <= 13 or c == 32 -- 9=HT, 10=LF, 13=CR, 32=' '
  93. elseif u == 85 then -- 85='U'
  94. res = c >= 65 and c <= 90
  95. elseif u == 87 then -- 87='W'
  96. res = c >= 65 and c <= 90 or c >= 97 and c <= 122 -- 65='A', 90='Z', 97='a', 122='z'
  97. or c >= 48 and c <= 57 -- 48='0', 57='9'
  98. elseif u == 88 then -- 88='X'
  99. res = c >= 65 and c <= 70 or c >= 97 and c <= 102 -- 65='A', 70='F', 97='a', 102='f'
  100. or c >= 48 and c <= 57 -- 48='0', 57='9'
  101. elseif u == 90 then
  102. res = c == 0
  103. else
  104. return c == cl
  105. end
  106. return negate ~= res
  107. end
  108. local function matchbracketclass(c, pat, p, ec)
  109. local sig = true
  110. p = p + 1
  111. if pat:byte(p) == 94 then -- 94='^'
  112. sig = false
  113. p = p + 1
  114. end
  115. while p < ec do
  116. local cc = pat:byte(p)
  117. if cc == L_ESC then
  118. p = p + 1
  119. if match_class(c, pat:byte(p)) then
  120. return sig
  121. end
  122. elseif pat:byte(p + 1) == 45 and p + 2 < ec then -- 45='-'
  123. p = p + 2
  124. if cc <= c and c <= pat:byte(p) then
  125. return sig
  126. end
  127. elseif cc == c then
  128. return sig
  129. end
  130. p = p + 1
  131. end
  132. return not sig
  133. end
  134. local function singlematch(c, pat, p, ep)
  135. local cc = pat:byte(p)
  136. if cc == 46 then -- 46='.'
  137. return true
  138. end
  139. if cc == L_ESC then
  140. return match_class(c, pat:byte(p + 1))
  141. end
  142. if cc == 91 then -- 91='['
  143. return matchbracketclass(c, pat, p, ep - 1)
  144. end
  145. return cc == c
  146. end
  147. local match
  148. local function matchbalance(ms, str, s, pat, p)
  149. local b = pat:byte(p)
  150. local e = pat:byte(p + 1)
  151. if (b or 0) == 0 or (e or 0) == 0 then
  152. return error("unbalanced pattern")
  153. end
  154. if str:byte(s) ~= b then
  155. return false
  156. end
  157. local cont = 1
  158. s = s + 1
  159. while s < ms.src_end do
  160. local cc = str:byte(s)
  161. if cc == e then
  162. cont = cont - 1
  163. if cont == 0 then
  164. return s + 1
  165. end
  166. elseif cc == b then
  167. cont = cont + 1
  168. end
  169. s = s + 1
  170. end
  171. return false
  172. end
  173. local function max_expand(ms, str, s, pat, p, ep)
  174. local i = 0 -- counts maximum expand for item
  175. while s + i < ms.src_end and singlematch(str:byte(s + i), pat, p, ep) do
  176. i = i + 1
  177. end
  178. while i >= 0 do
  179. local res = match(ms, str, s + i, pat, ep + 1)
  180. if res then return res end
  181. i = i - 1 -- else it didn't match; reduce 1 repetition to try again
  182. end
  183. return false
  184. end
  185. local function min_expand(ms, str, s, pat, p, ep)
  186. repeat
  187. local res = match(ms, str, s, pat, ep + 1)
  188. if res then
  189. return res
  190. end
  191. if s >= ms.src_end or not singlematch(str:byte(s), pat, p, ep) then
  192. return false
  193. end
  194. s = s + 1
  195. until false
  196. end
  197. local function start_capture(ms, str, s, pat, p, what)
  198. local level = ms.level + 1
  199. if level > LUA_MAXCAPTURES then
  200. return error("too many captures")
  201. end
  202. if not ms.capture[level] then
  203. ms.capture[level] = {}
  204. end
  205. ms.capture[level].init = s
  206. ms.capture[level].len = what
  207. ms.level = level
  208. local res = match(ms, str, s, pat, p)
  209. if not res then
  210. ms.level = ms.level - 1
  211. end
  212. return res
  213. end
  214. local function end_capture(ms, str, s, pat, p)
  215. local l = capture_to_close(ms)
  216. ms.capture[l].len = s - ms.capture[l].init
  217. local res = match(ms, str, s, pat, p)
  218. if not res then
  219. ms.capture[l].len = CAP_UNFINISHED
  220. end
  221. return res
  222. end
  223. local function substrcomp(str, start1, start2, len)
  224. for i = start1, start1 + len - 1 do
  225. if str:byte(i) ~= str:byte(i + (start2 - start1)) then
  226. return false
  227. end
  228. end
  229. return true
  230. end
  231. local function match_capture(ms, str, s, l)
  232. l = check_capture(ms, l)
  233. local len = ms.capture[l].len
  234. if ms.src_end - s >= len and substrcomp(str, ms.capture[l].init, s, len) then
  235. return s + len
  236. end
  237. return false
  238. end
  239. -- this is local already, don't add 'local' or it will fail
  240. match = function (ms, str, s, pat, p)
  241. local cc = pat:byte(p) or 0
  242. if cc == 40 then -- 40='('
  243. -- start capture
  244. if pat:byte(p + 1) == 41 then -- 41=')'
  245. -- position capture
  246. return start_capture(ms, str, s, pat, p + 2, CAP_POSITION)
  247. end
  248. return start_capture(ms, str, s, pat, p + 1, CAP_UNFINISHED)
  249. end
  250. if cc == 41 then -- 41=')'
  251. -- end capture
  252. return end_capture(ms, str, s, pat, p + 1)
  253. end
  254. if cc == L_ESC then
  255. cc = pat:byte(p + 1)
  256. if cc == 98 then -- 98='b'
  257. -- balanced string
  258. s = matchbalance(ms, str, s, pat, p + 2)
  259. if not s then
  260. return false
  261. end
  262. return match(ms, str, s, pat, p + 4)
  263. end
  264. if cc == 102 then -- 102='f'
  265. -- frontier
  266. p = p + 2
  267. if pat:byte(p) ~= 91 then -- 91='['
  268. return error("missing " .. LUA_QL('[') .. " after "
  269. .. LUA_QL('%f') .. " in pattern")
  270. end
  271. local ep = classend(pat, p)
  272. local previous = str:byte(s - 1) or 0
  273. if matchbracketclass(previous, pat, p, ep - 1)
  274. or not matchbracketclass(str:byte(s) or 0, pat, p, ep - 1)
  275. then
  276. return false
  277. end
  278. return match(ms, str, s, pat, ep)
  279. end
  280. if cc >= 48 and cc <= 57 then -- 48='0', 57='9'
  281. -- capture results (%0-%9)?
  282. s = match_capture(ms, str, s, cc)
  283. if not s then
  284. return false
  285. end
  286. return match(ms, str, s, pat, p + 2)
  287. end
  288. cc = -1 -- don't match anything else until default
  289. end
  290. if cc == 0 then -- end of pattern
  291. return s
  292. end
  293. if cc == 36 then -- 36='$'
  294. if (pat:byte(p + 1) or 0) == 0 then -- is the `$' the last char in pattern?
  295. return s == ms.src_end and s -- check end of string
  296. end
  297. -- else fall through to default
  298. end
  299. -- default
  300. local ep = classend(pat, p)
  301. local m = s < ms.src_end and singlematch(str:byte(s), pat, p, ep)
  302. cc = pat:byte(ep) or 0
  303. if cc == 63 then
  304. if m then
  305. local res = match(ms, str, s + 1, pat, ep + 1)
  306. if res then
  307. return res
  308. end
  309. end
  310. return match(ms, str, s, pat, ep + 1)
  311. end
  312. if cc == 42 then -- 42='*'
  313. -- 0 or more repetitions
  314. return max_expand(ms, str, s, pat, p, ep)
  315. end
  316. if cc == 43 then -- 43='+'
  317. -- 1 or more repetitions
  318. return m and max_expand(ms, str, s + 1, pat, p, ep)
  319. end
  320. if cc == 45 then -- 45='-'
  321. -- 0 or more repetitions (minimum)
  322. return min_expand(ms, str, s, pat, p, ep)
  323. end
  324. return m and match(ms, str, s + 1, pat, ep)
  325. end
  326. -- lmemfind not implemented - only used for string.find(x, y, z, true)
  327. local function push_onecapture(ms, i, str, s, e)
  328. if i > ms.level then
  329. if i ~= 1 then
  330. return error("invalid capture index")
  331. end
  332. ms.captures[#ms.captures + 1] = str:sub(s, e - 1)
  333. else
  334. local l = ms.capture[i].len
  335. if l == CAP_UNFINISHED then
  336. return error("unfinished capture")
  337. end
  338. if l == CAP_POSITION then
  339. ms.captures[#ms.captures + 1] = ms.capture[i].init
  340. else
  341. ms.captures[#ms.captures + 1] = str:sub(ms.capture[i].init, ms.capture[i].init + l - 1)
  342. end
  343. end
  344. end
  345. local function push_captures(ms, str, s, e)
  346. local nlevels = ms.level
  347. if nlevels == 0 and s then
  348. nlevels = 1
  349. end
  350. for i = 1, nlevels do
  351. push_onecapture(ms, i, str, s, e)
  352. end
  353. end
  354. local function posrelat(pos, len)
  355. if pos < 0 then pos = pos + len + 1 end
  356. return pos >= 1 and pos or 1
  357. end
  358. local string_find = string.find
  359. local SPECIALS = { [("^"):byte(1)]=true, [("$"):byte(1)]=true,
  360. [("*"):byte(1)]=true, [("+"):byte(1)]=true,
  361. [("?"):byte(1)]=true, [("."):byte(1)]=true,
  362. [("("):byte(1)]=true, [("["):byte(1)]=true,
  363. [("%"):byte(1)]=true, [("-"):byte(1)]=true,
  364. [0] = true }
  365. local function specials_free(s)
  366. for i = 1, #s do
  367. if SPECIALS[s:byte(i)] then
  368. if s:byte(i) == 0 then
  369. return true -- stop at first NUL
  370. end
  371. return false
  372. end
  373. end
  374. return true
  375. end
  376. local function str_find_aux(find, str, pat, init, explicit)
  377. local l1 = #str
  378. init = posrelat(init or 1, l1)
  379. if init < 1 then
  380. init = 1
  381. elseif init > l1 + 1 then
  382. init = l1 + 1
  383. end
  384. if find and (explicit or specials_free(pat)) then -- explicit request?
  385. -- do a plain search
  386. return string_find(str, pat, init, true)
  387. end
  388. local p = 1
  389. local anchor = false
  390. if pat:byte(1) == 94 then -- 94='^'
  391. p = p + 1
  392. anchor = true
  393. end
  394. local s1 = init
  395. local ms = {captures = {}, capture = {}, src_end = l1 + 1}
  396. repeat
  397. ms.level = 0
  398. local res = match(ms, str, s1, pat, p)
  399. if res then
  400. if find then
  401. push_captures(ms, str, false, 0)
  402. return s1, res - 1, unpack(ms.captures)
  403. end
  404. push_captures(ms, str, s1, res)
  405. return unpack(ms.captures)
  406. end
  407. s1 = s1 + 1
  408. until anchor or s1 > ms.src_end
  409. return nil
  410. end
  411. local function str_find(str, pat, init, explicit)
  412. return str_find_aux(true, str, pat, init, explicit)
  413. end
  414. local function str_match(str, pat, init)
  415. return str_find_aux(false, str, pat, init)
  416. end
  417. local function gmatch(str, pat)
  418. local init = 1
  419. local ms = {captures = {}, capture = {}, src_init = str, src_end = #str + 1}
  420. local function gmatch_aux()
  421. for i = 1, #ms.captures do ms.captures[i] = nil end
  422. for src = init, ms.src_end do
  423. ms.level = 0
  424. local e = match(ms, str, src, pat, 1)
  425. if e then
  426. local newstart = e
  427. if e == src then newstart = newstart + 1 end -- empty match? go at least one position
  428. init = newstart
  429. push_captures(ms, str, src, e)
  430. return unpack(ms.captures)
  431. end
  432. end
  433. end
  434. return gmatch_aux
  435. end
  436. return { find = str_find, match = str_match, gmatch = gmatch, gfind = gmatch }