query.lua 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063
  1. local api = vim.api
  2. local language = require('vim.treesitter.language')
  3. local memoize = vim.func._memoize
  4. local M = {}
  5. ---@nodoc
  6. ---Parsed query, see |vim.treesitter.query.parse()|
  7. ---
  8. ---@class vim.treesitter.Query
  9. ---@field lang string name of the language for this parser
  10. ---@field captures string[] list of (unique) capture names defined in query
  11. ---@field info vim.treesitter.QueryInfo contains information used in the query (e.g. captures, predicates, directives)
  12. ---@field query TSQuery userdata query object
  13. local Query = {}
  14. Query.__index = Query
  15. ---@package
  16. ---@see vim.treesitter.query.parse
  17. ---@param lang string
  18. ---@param ts_query TSQuery
  19. ---@return vim.treesitter.Query
  20. function Query.new(lang, ts_query)
  21. local self = setmetatable({}, Query)
  22. local query_info = ts_query:inspect() ---@type TSQueryInfo
  23. self.query = ts_query
  24. self.lang = lang
  25. self.info = {
  26. captures = query_info.captures,
  27. patterns = query_info.patterns,
  28. }
  29. self.captures = self.info.captures
  30. return self
  31. end
  32. ---@nodoc
  33. ---Information for Query, see |vim.treesitter.query.parse()|
  34. ---@class vim.treesitter.QueryInfo
  35. ---
  36. ---List of (unique) capture names defined in query.
  37. ---@field captures string[]
  38. ---
  39. ---Contains information about predicates and directives.
  40. ---Key is pattern id, and value is list of predicates or directives defined in the pattern.
  41. ---A predicate or directive is a list of (integer|string); integer represents `capture_id`, and
  42. ---string represents (literal) arguments to predicate/directive. See |treesitter-predicates|
  43. ---and |treesitter-directives| for more details.
  44. ---@field patterns table<integer, (integer|string)[][]>
  45. ---@param files string[]
  46. ---@return string[]
  47. local function dedupe_files(files)
  48. local result = {}
  49. ---@type table<string,boolean>
  50. local seen = {}
  51. for _, path in ipairs(files) do
  52. if not seen[path] then
  53. table.insert(result, path)
  54. seen[path] = true
  55. end
  56. end
  57. return result
  58. end
  59. local function safe_read(filename, read_quantifier)
  60. local file, err = io.open(filename, 'r')
  61. if not file then
  62. error(err)
  63. end
  64. local content = file:read(read_quantifier)
  65. io.close(file)
  66. return content
  67. end
  68. --- Adds {ilang} to {base_langs}, only if {ilang} is different than {lang}
  69. ---
  70. ---@return boolean true If lang == ilang
  71. local function add_included_lang(base_langs, lang, ilang)
  72. if lang == ilang then
  73. return true
  74. end
  75. table.insert(base_langs, ilang)
  76. return false
  77. end
  78. --- Gets the list of files used to make up a query
  79. ---
  80. ---@param lang string Language to get query for
  81. ---@param query_name string Name of the query to load (e.g., "highlights")
  82. ---@param is_included? boolean Internal parameter, most of the time left as `nil`
  83. ---@return string[] query_files List of files to load for given query and language
  84. function M.get_files(lang, query_name, is_included)
  85. local query_path = string.format('queries/%s/%s.scm', lang, query_name)
  86. local lang_files = dedupe_files(api.nvim_get_runtime_file(query_path, true))
  87. if #lang_files == 0 then
  88. return {}
  89. end
  90. local base_query = nil ---@type string?
  91. local extensions = {}
  92. local base_langs = {} ---@type string[]
  93. -- Now get the base languages by looking at the first line of every file
  94. -- The syntax is the following :
  95. -- ;+ inherits: ({language},)*{language}
  96. --
  97. -- {language} ::= {lang} | ({lang})
  98. local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
  99. local EXTENDS_FORMAT = '^;+%s*extends%s*$'
  100. for _, filename in ipairs(lang_files) do
  101. local file, err = io.open(filename, 'r')
  102. if not file then
  103. error(err)
  104. end
  105. local extension = false
  106. for modeline in
  107. ---@return string
  108. function()
  109. return file:read('*l')
  110. end
  111. do
  112. if not vim.startswith(modeline, ';') then
  113. break
  114. end
  115. local langlist = modeline:match(MODELINE_FORMAT)
  116. if langlist then
  117. ---@diagnostic disable-next-line:param-type-mismatch
  118. for _, incllang in ipairs(vim.split(langlist, ',', true)) do
  119. local is_optional = incllang:match('%(.*%)')
  120. if is_optional then
  121. if not is_included then
  122. if add_included_lang(base_langs, lang, incllang:sub(2, #incllang - 1)) then
  123. extension = true
  124. end
  125. end
  126. else
  127. if add_included_lang(base_langs, lang, incllang) then
  128. extension = true
  129. end
  130. end
  131. end
  132. elseif modeline:match(EXTENDS_FORMAT) then
  133. extension = true
  134. end
  135. end
  136. if extension then
  137. table.insert(extensions, filename)
  138. elseif base_query == nil then
  139. base_query = filename
  140. end
  141. io.close(file)
  142. end
  143. local query_files = {}
  144. for _, base_lang in ipairs(base_langs) do
  145. local base_files = M.get_files(base_lang, query_name, true)
  146. vim.list_extend(query_files, base_files)
  147. end
  148. vim.list_extend(query_files, { base_query })
  149. vim.list_extend(query_files, extensions)
  150. return query_files
  151. end
  152. ---@param filenames string[]
  153. ---@return string
  154. local function read_query_files(filenames)
  155. local contents = {}
  156. for _, filename in ipairs(filenames) do
  157. table.insert(contents, safe_read(filename, '*a'))
  158. end
  159. return table.concat(contents, '')
  160. end
  161. -- The explicitly set queries from |vim.treesitter.query.set()|
  162. ---@type table<string,table<string,vim.treesitter.Query>>
  163. local explicit_queries = setmetatable({}, {
  164. __index = function(t, k)
  165. local lang_queries = {}
  166. rawset(t, k, lang_queries)
  167. return lang_queries
  168. end,
  169. })
  170. --- Sets the runtime query named {query_name} for {lang}
  171. ---
  172. --- This allows users to override any runtime files and/or configuration
  173. --- set by plugins.
  174. ---
  175. ---@param lang string Language to use for the query
  176. ---@param query_name string Name of the query (e.g., "highlights")
  177. ---@param text string Query text (unparsed).
  178. function M.set(lang, query_name, text)
  179. explicit_queries[lang][query_name] = M.parse(lang, text)
  180. end
  181. --- Returns the runtime query {query_name} for {lang}.
  182. ---
  183. ---@param lang string Language to use for the query
  184. ---@param query_name string Name of the query (e.g. "highlights")
  185. ---
  186. ---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found.
  187. M.get = memoize('concat-2', function(lang, query_name)
  188. if explicit_queries[lang][query_name] then
  189. return explicit_queries[lang][query_name]
  190. end
  191. local query_files = M.get_files(lang, query_name)
  192. local query_string = read_query_files(query_files)
  193. if #query_string == 0 then
  194. return nil
  195. end
  196. return M.parse(lang, query_string)
  197. end)
  198. --- Parse {query} as a string. (If the query is in a file, the caller
  199. --- should read the contents into a string before calling).
  200. ---
  201. --- Returns a `Query` (see |lua-treesitter-query|) object which can be used to
  202. --- search nodes in the syntax tree for the patterns defined in {query}
  203. --- using the `iter_captures` and `iter_matches` methods.
  204. ---
  205. --- Exposes `info` and `captures` with additional context about {query}.
  206. --- - `captures` contains the list of unique capture names defined in {query}.
  207. --- - `info.captures` also points to `captures`.
  208. --- - `info.patterns` contains information about predicates.
  209. ---
  210. ---@param lang string Language to use for the query
  211. ---@param query string Query in s-expr syntax
  212. ---
  213. ---@return vim.treesitter.Query : Parsed query
  214. ---
  215. ---@see [vim.treesitter.query.get()]
  216. M.parse = memoize('concat-2', function(lang, query)
  217. assert(language.add(lang))
  218. local ts_query = vim._ts_parse_query(lang, query)
  219. return Query.new(lang, ts_query)
  220. end)
  221. --- Implementations of predicates that can optionally be prefixed with "any-".
  222. ---
  223. --- These functions contain the implementations for each predicate, correctly
  224. --- handling the "any" vs "all" semantics. They are called from the
  225. --- predicate_handlers table with the appropriate arguments for each predicate.
  226. local impl = {
  227. --- @param match table<integer,TSNode[]>
  228. --- @param source integer|string
  229. --- @param predicate any[]
  230. --- @param any boolean
  231. ['eq'] = function(match, source, predicate, any)
  232. local nodes = match[predicate[2]]
  233. if not nodes or #nodes == 0 then
  234. return true
  235. end
  236. for _, node in ipairs(nodes) do
  237. local node_text = vim.treesitter.get_node_text(node, source)
  238. local str ---@type string
  239. if type(predicate[3]) == 'string' then
  240. -- (#eq? @aa "foo")
  241. str = predicate[3]
  242. else
  243. -- (#eq? @aa @bb)
  244. local other = assert(match[predicate[3]])
  245. assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes')
  246. str = vim.treesitter.get_node_text(other[1], source)
  247. end
  248. local res = str ~= nil and node_text == str
  249. if any and res then
  250. return true
  251. elseif not any and not res then
  252. return false
  253. end
  254. end
  255. return not any
  256. end,
  257. --- @param match table<integer,TSNode[]>
  258. --- @param source integer|string
  259. --- @param predicate any[]
  260. --- @param any boolean
  261. ['lua-match'] = function(match, source, predicate, any)
  262. local nodes = match[predicate[2]]
  263. if not nodes or #nodes == 0 then
  264. return true
  265. end
  266. for _, node in ipairs(nodes) do
  267. local regex = predicate[3]
  268. local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
  269. if any and res then
  270. return true
  271. elseif not any and not res then
  272. return false
  273. end
  274. end
  275. return not any
  276. end,
  277. ['match'] = (function()
  278. local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
  279. local function check_magic(str)
  280. if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
  281. return str
  282. end
  283. return '\\v' .. str
  284. end
  285. local compiled_vim_regexes = setmetatable({}, {
  286. __index = function(t, pattern)
  287. local res = vim.regex(check_magic(pattern))
  288. rawset(t, pattern, res)
  289. return res
  290. end,
  291. })
  292. --- @param match table<integer,TSNode[]>
  293. --- @param source integer|string
  294. --- @param predicate any[]
  295. --- @param any boolean
  296. return function(match, source, predicate, any)
  297. local nodes = match[predicate[2]]
  298. if not nodes or #nodes == 0 then
  299. return true
  300. end
  301. for _, node in ipairs(nodes) do
  302. local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex
  303. local res = regex:match_str(vim.treesitter.get_node_text(node, source))
  304. if any and res then
  305. return true
  306. elseif not any and not res then
  307. return false
  308. end
  309. end
  310. return not any
  311. end
  312. end)(),
  313. --- @param match table<integer,TSNode[]>
  314. --- @param source integer|string
  315. --- @param predicate any[]
  316. --- @param any boolean
  317. ['contains'] = function(match, source, predicate, any)
  318. local nodes = match[predicate[2]]
  319. if not nodes or #nodes == 0 then
  320. return true
  321. end
  322. for _, node in ipairs(nodes) do
  323. local node_text = vim.treesitter.get_node_text(node, source)
  324. for i = 3, #predicate do
  325. local res = string.find(node_text, predicate[i], 1, true)
  326. if any and res then
  327. return true
  328. elseif not any and not res then
  329. return false
  330. end
  331. end
  332. end
  333. return not any
  334. end,
  335. }
  336. ---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean
  337. -- Predicate handler receive the following arguments
  338. -- (match, pattern, bufnr, predicate)
  339. ---@type table<string,TSPredicate>
  340. local predicate_handlers = {
  341. ['eq?'] = function(match, _, source, predicate)
  342. return impl['eq'](match, source, predicate, false)
  343. end,
  344. ['any-eq?'] = function(match, _, source, predicate)
  345. return impl['eq'](match, source, predicate, true)
  346. end,
  347. ['lua-match?'] = function(match, _, source, predicate)
  348. return impl['lua-match'](match, source, predicate, false)
  349. end,
  350. ['any-lua-match?'] = function(match, _, source, predicate)
  351. return impl['lua-match'](match, source, predicate, true)
  352. end,
  353. ['match?'] = function(match, _, source, predicate)
  354. return impl['match'](match, source, predicate, false)
  355. end,
  356. ['any-match?'] = function(match, _, source, predicate)
  357. return impl['match'](match, source, predicate, true)
  358. end,
  359. ['contains?'] = function(match, _, source, predicate)
  360. return impl['contains'](match, source, predicate, false)
  361. end,
  362. ['any-contains?'] = function(match, _, source, predicate)
  363. return impl['contains'](match, source, predicate, true)
  364. end,
  365. ['any-of?'] = function(match, _, source, predicate)
  366. local nodes = match[predicate[2]]
  367. if not nodes or #nodes == 0 then
  368. return true
  369. end
  370. for _, node in ipairs(nodes) do
  371. local node_text = vim.treesitter.get_node_text(node, source)
  372. -- Since 'predicate' will not be used by callers of this function, use it
  373. -- to store a string set built from the list of words to check against.
  374. local string_set = predicate['string_set'] --- @type table<string, boolean>
  375. if not string_set then
  376. string_set = {}
  377. for i = 3, #predicate do
  378. string_set[predicate[i]] = true
  379. end
  380. predicate['string_set'] = string_set
  381. end
  382. if string_set[node_text] then
  383. return true
  384. end
  385. end
  386. return false
  387. end,
  388. ['has-ancestor?'] = function(match, _, _, predicate)
  389. local nodes = match[predicate[2]]
  390. if not nodes or #nodes == 0 then
  391. return true
  392. end
  393. for _, node in ipairs(nodes) do
  394. if node:__has_ancestor(predicate) then
  395. return true
  396. end
  397. end
  398. return false
  399. end,
  400. ['has-parent?'] = function(match, _, _, predicate)
  401. local nodes = match[predicate[2]]
  402. if not nodes or #nodes == 0 then
  403. return true
  404. end
  405. for _, node in ipairs(nodes) do
  406. if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
  407. return true
  408. end
  409. end
  410. return false
  411. end,
  412. }
  413. -- As we provide lua-match? also expose vim-match?
  414. predicate_handlers['vim-match?'] = predicate_handlers['match?']
  415. predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
  416. ---@nodoc
  417. ---@class vim.treesitter.query.TSMetadata
  418. ---@field range? Range
  419. ---@field conceal? string
  420. ---@field [integer]? vim.treesitter.query.TSMetadata
  421. ---@field [string]? integer|string
  422. ---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
  423. -- Predicate handler receive the following arguments
  424. -- (match, pattern, bufnr, predicate)
  425. -- Directives store metadata or perform side effects against a match.
  426. -- Directives should always end with a `!`.
  427. -- Directive handler receive the following arguments
  428. -- (match, pattern, bufnr, predicate, metadata)
  429. ---@type table<string,TSDirective>
  430. local directive_handlers = {
  431. ['set!'] = function(_, _, _, pred, metadata)
  432. if #pred >= 3 and type(pred[2]) == 'number' then
  433. -- (#set! @capture key value)
  434. local capture_id, key, value = pred[2], pred[3], pred[4]
  435. if not metadata[capture_id] then
  436. metadata[capture_id] = {}
  437. end
  438. metadata[capture_id][key] = value
  439. else
  440. -- (#set! key value)
  441. local key, value = pred[2], pred[3]
  442. metadata[key] = value or true
  443. end
  444. end,
  445. -- Shifts the range of a node.
  446. -- Example: (#offset! @_node 0 1 0 -1)
  447. ['offset!'] = function(match, _, _, pred, metadata)
  448. local capture_id = pred[2] --[[@as integer]]
  449. local nodes = match[capture_id]
  450. if not nodes or #nodes == 0 then
  451. return
  452. end
  453. assert(#nodes == 1, '#offset! does not support captures on multiple nodes')
  454. local node = nodes[1]
  455. if not metadata[capture_id] then
  456. metadata[capture_id] = {}
  457. end
  458. local range = metadata[capture_id].range or { node:range() }
  459. local start_row_offset = pred[3] or 0
  460. local start_col_offset = pred[4] or 0
  461. local end_row_offset = pred[5] or 0
  462. local end_col_offset = pred[6] or 0
  463. range[1] = range[1] + start_row_offset
  464. range[2] = range[2] + start_col_offset
  465. range[3] = range[3] + end_row_offset
  466. range[4] = range[4] + end_col_offset
  467. -- If this produces an invalid range, we just skip it.
  468. if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
  469. metadata[capture_id].range = range
  470. end
  471. end,
  472. -- Transform the content of the node
  473. -- Example: (#gsub! @_node ".*%.(.*)" "%1")
  474. ['gsub!'] = function(match, _, bufnr, pred, metadata)
  475. assert(#pred == 4)
  476. local id = pred[2]
  477. assert(type(id) == 'number')
  478. local nodes = match[id]
  479. if not nodes or #nodes == 0 then
  480. return
  481. end
  482. assert(#nodes == 1, '#gsub! does not support captures on multiple nodes')
  483. local node = nodes[1]
  484. local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
  485. if not metadata[id] then
  486. metadata[id] = {}
  487. end
  488. local pattern, replacement = pred[3], pred[4]
  489. assert(type(pattern) == 'string')
  490. assert(type(replacement) == 'string')
  491. metadata[id].text = text:gsub(pattern, replacement)
  492. end,
  493. -- Trim whitespace from both sides of the node
  494. -- Example: (#trim! @fold 1 1 1 1)
  495. ['trim!'] = function(match, _, bufnr, pred, metadata)
  496. local capture_id = pred[2]
  497. assert(type(capture_id) == 'number')
  498. local trim_start_lines = pred[3] == '1'
  499. local trim_start_cols = pred[4] == '1'
  500. local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility
  501. local trim_end_cols = pred[6] == '1'
  502. local nodes = match[capture_id]
  503. if not nodes or #nodes == 0 then
  504. return
  505. end
  506. assert(#nodes == 1, '#trim! does not support captures on multiple nodes')
  507. local node = nodes[1]
  508. local start_row, start_col, end_row, end_col = node:range()
  509. local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n')
  510. if end_col == 0 then
  511. -- get_node_text() will ignore the last line if the node ends at column 0
  512. node_text[#node_text + 1] = ''
  513. end
  514. local end_idx = #node_text
  515. local start_idx = 1
  516. if trim_end_lines then
  517. while end_idx > 0 and node_text[end_idx]:find('^%s*$') do
  518. end_idx = end_idx - 1
  519. end_row = end_row - 1
  520. -- set the end position to the last column of the next line, or 0 if we just trimmed the
  521. -- last line
  522. end_col = end_idx > 0 and #node_text[end_idx] or 0
  523. end
  524. end
  525. if trim_end_cols then
  526. if end_idx == 0 then
  527. end_row = start_row
  528. end_col = start_col
  529. else
  530. local whitespace_start = node_text[end_idx]:find('(%s*)$')
  531. end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0)
  532. end
  533. end
  534. if trim_start_lines then
  535. while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do
  536. start_idx = start_idx + 1
  537. start_row = start_row + 1
  538. start_col = 0
  539. end
  540. end
  541. if trim_start_cols and node_text[start_idx] then
  542. local _, whitespace_end = node_text[start_idx]:find('^(%s*)')
  543. whitespace_end = whitespace_end or 0
  544. start_col = (start_idx == 1 and start_col or 0) + whitespace_end
  545. end
  546. -- If this produces an invalid range, we just skip it.
  547. if start_row < end_row or (start_row == end_row and start_col <= end_col) then
  548. metadata[capture_id] = metadata[capture_id] or {}
  549. metadata[capture_id].range = { start_row, start_col, end_row, end_col }
  550. end
  551. end,
  552. }
  553. --- @class vim.treesitter.query.add_predicate.Opts
  554. --- @inlinedoc
  555. ---
  556. --- Override an existing predicate of the same name
  557. --- @field force? boolean
  558. ---
  559. --- Use the correct implementation of the match table where capture IDs map to
  560. --- a list of nodes instead of a single node. Defaults to true. This option will
  561. --- be removed in a future release.
  562. --- @field all? boolean
  563. --- Adds a new predicate to be used in queries
  564. ---
  565. ---@param name string Name of the predicate, without leading #
  566. ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata): boolean?
  567. --- - see |vim.treesitter.query.add_directive()| for argument meanings
  568. ---@param opts? vim.treesitter.query.add_predicate.Opts
  569. function M.add_predicate(name, handler, opts)
  570. -- Backward compatibility: old signature had "force" as boolean argument
  571. if type(opts) == 'boolean' then
  572. opts = { force = opts }
  573. end
  574. opts = opts or {}
  575. if predicate_handlers[name] and not opts.force then
  576. error(string.format('Overriding existing predicate %s', name))
  577. end
  578. if opts.all ~= false then
  579. predicate_handlers[name] = handler
  580. else
  581. --- @param match table<integer, TSNode[]>
  582. local function wrapper(match, ...)
  583. local m = {} ---@type table<integer, TSNode>
  584. for k, v in pairs(match) do
  585. if type(k) == 'number' then
  586. m[k] = v[#v]
  587. end
  588. end
  589. return handler(m, ...)
  590. end
  591. predicate_handlers[name] = wrapper
  592. end
  593. end
  594. --- Adds a new directive to be used in queries
  595. ---
  596. --- Handlers can set match level data by setting directly on the
  597. --- metadata object `metadata.key = value`. Additionally, handlers
  598. --- can set node level data by using the capture id on the
  599. --- metadata table `metadata[capture_id].key = value`
  600. ---
  601. ---@param name string Name of the directive, without leading #
  602. ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata)
  603. --- - match: A table mapping capture IDs to a list of captured nodes
  604. --- - pattern: the index of the matching pattern in the query file
  605. --- - predicate: list of strings containing the full directive being called, e.g.
  606. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }`
  607. ---@param opts vim.treesitter.query.add_predicate.Opts
  608. function M.add_directive(name, handler, opts)
  609. -- Backward compatibility: old signature had "force" as boolean argument
  610. if type(opts) == 'boolean' then
  611. opts = { force = opts }
  612. end
  613. opts = opts or {}
  614. if directive_handlers[name] and not opts.force then
  615. error(string.format('Overriding existing directive %s', name))
  616. end
  617. if opts.all then
  618. directive_handlers[name] = handler
  619. else
  620. --- @param match table<integer, TSNode[]>
  621. local function wrapper(match, ...)
  622. local m = {} ---@type table<integer, TSNode>
  623. for k, v in pairs(match) do
  624. m[k] = v[#v]
  625. end
  626. handler(m, ...)
  627. end
  628. directive_handlers[name] = wrapper
  629. end
  630. end
  631. --- Lists the currently available directives to use in queries.
  632. ---@return string[] : Supported directives.
  633. function M.list_directives()
  634. return vim.tbl_keys(directive_handlers)
  635. end
  636. --- Lists the currently available predicates to use in queries.
  637. ---@return string[] : Supported predicates.
  638. function M.list_predicates()
  639. return vim.tbl_keys(predicate_handlers)
  640. end
  641. local function xor(x, y)
  642. return (x or y) and not (x and y)
  643. end
  644. local function is_directive(name)
  645. return string.sub(name, -1) == '!'
  646. end
  647. ---@private
  648. ---@param match TSQueryMatch
  649. ---@param source integer|string
  650. function Query:match_preds(match, source)
  651. local _, pattern = match:info()
  652. local preds = self.info.patterns[pattern]
  653. if not preds then
  654. return true
  655. end
  656. local captures = match:captures()
  657. for _, pred in pairs(preds) do
  658. -- Here we only want to return if a predicate DOES NOT match, and
  659. -- continue on the other case. This way unknown predicates will not be considered,
  660. -- which allows some testing and easier user extensibility (#12173).
  661. -- Also, tree-sitter strips the leading # from predicates for us.
  662. local is_not = false
  663. -- Skip over directives... they will get processed after all the predicates.
  664. if not is_directive(pred[1]) then
  665. local pred_name = pred[1]
  666. if pred_name:match('^not%-') then
  667. pred_name = pred_name:sub(5)
  668. is_not = true
  669. end
  670. local handler = predicate_handlers[pred_name]
  671. if not handler then
  672. error(string.format('No handler for %s', pred[1]))
  673. return false
  674. end
  675. local pred_matches = handler(captures, pattern, source, pred)
  676. if not xor(is_not, pred_matches) then
  677. return false
  678. end
  679. end
  680. end
  681. return true
  682. end
  683. ---@private
  684. ---@param match TSQueryMatch
  685. ---@return vim.treesitter.query.TSMetadata metadata
  686. function Query:apply_directives(match, source)
  687. ---@type vim.treesitter.query.TSMetadata
  688. local metadata = {}
  689. local _, pattern = match:info()
  690. local preds = self.info.patterns[pattern]
  691. if not preds then
  692. return metadata
  693. end
  694. local captures = match:captures()
  695. for _, pred in pairs(preds) do
  696. if is_directive(pred[1]) then
  697. local handler = directive_handlers[pred[1]]
  698. if not handler then
  699. error(string.format('No handler for %s', pred[1]))
  700. end
  701. handler(captures, pattern, source, pred, metadata)
  702. end
  703. end
  704. return metadata
  705. end
  706. --- Returns the start and stop value if set else the node's range.
  707. -- When the node's range is used, the stop is incremented by 1
  708. -- to make the search inclusive.
  709. ---@param start integer?
  710. ---@param stop integer?
  711. ---@param node TSNode
  712. ---@return integer, integer
  713. local function value_or_node_range(start, stop, node)
  714. if start == nil then
  715. start = node:start()
  716. end
  717. if stop == nil then
  718. stop = node:end_() + 1 -- Make stop inclusive
  719. end
  720. return start, stop
  721. end
  722. --- @param match TSQueryMatch
  723. --- @return integer
  724. local function match_id_hash(_, match)
  725. return (match:info())
  726. end
  727. --- Iterate over all captures from all matches inside {node}
  728. ---
  729. --- {source} is needed if the query contains predicates; then the caller
  730. --- must ensure to use a freshly parsed tree consistent with the current
  731. --- text of the buffer (if relevant). {start} and {stop} can be used to limit
  732. --- matches inside a row range (this is typically used with root node
  733. --- as the {node}, i.e., to get syntax highlight matches in the current
  734. --- viewport). When omitted, the {start} and {stop} row values are used from the given node.
  735. ---
  736. --- The iterator returns four values: a numeric id identifying the capture,
  737. --- the captured node, metadata from any directives processing the match,
  738. --- and the match itself.
  739. --- The following example shows how to get captures by name:
  740. ---
  741. --- ```lua
  742. --- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do
  743. --- local name = query.captures[id] -- name of the capture in the query
  744. --- -- typically useful info about the node:
  745. --- local type = node:type() -- type of the captured node
  746. --- local row1, col1, row2, col2 = node:range() -- range of the capture
  747. --- -- ... use the info here ...
  748. --- end
  749. --- ```
  750. ---
  751. ---@param node TSNode under which the search will occur
  752. ---@param source (integer|string) Source buffer or string to extract text from
  753. ---@param start? integer Starting line for the search. Defaults to `node:start()`.
  754. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
  755. ---
  756. ---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch):
  757. --- capture id, capture node, metadata, match
  758. ---
  759. ---@note Captures are only returned if the query pattern of a specific capture contained predicates.
  760. function Query:iter_captures(node, source, start, stop)
  761. if type(source) == 'number' and source == 0 then
  762. source = api.nvim_get_current_buf()
  763. end
  764. start, stop = value_or_node_range(start, stop, node)
  765. local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
  766. local apply_directives = memoize(match_id_hash, self.apply_directives, true)
  767. local match_preds = memoize(match_id_hash, self.match_preds, true)
  768. local function iter(end_line)
  769. local capture, captured_node, match = cursor:next_capture()
  770. if not capture then
  771. return
  772. end
  773. if not match_preds(self, match, source) then
  774. local match_id = match:info()
  775. cursor:remove_match(match_id)
  776. if end_line and captured_node:range() > end_line then
  777. return nil, captured_node, nil, nil
  778. end
  779. return iter(end_line) -- tail call: try next match
  780. end
  781. local metadata = apply_directives(self, match, source)
  782. return capture, captured_node, metadata, match
  783. end
  784. return iter
  785. end
  786. --- Iterates the matches of self on a given range.
  787. ---
  788. --- Iterate over all matches within a {node}. The arguments are the same as for
  789. --- |Query:iter_captures()| but the iterated values are different: an (1-based)
  790. --- index of the pattern in the query, a table mapping capture indices to a list
  791. --- of nodes, and metadata from any directives processing the match.
  792. ---
  793. --- Example:
  794. ---
  795. --- ```lua
  796. --- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1) do
  797. --- for id, nodes in pairs(match) do
  798. --- local name = query.captures[id]
  799. --- for _, node in ipairs(nodes) do
  800. --- -- `node` was captured by the `name` capture in the match
  801. ---
  802. --- local node_data = metadata[id] -- Node level metadata
  803. --- ... use the info here ...
  804. --- end
  805. --- end
  806. --- end
  807. --- ```
  808. ---
  809. ---
  810. ---@param node TSNode under which the search will occur
  811. ---@param source (integer|string) Source buffer or string to search
  812. ---@param start? integer Starting line for the search. Defaults to `node:start()`.
  813. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
  814. ---@param opts? table Optional keyword arguments:
  815. --- - max_start_depth (integer) if non-zero, sets the maximum start depth
  816. --- for each match. This is used to prevent traversing too deep into a tree.
  817. --- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256).
  818. --- - all (boolean) When `false` (default `true`), the returned table maps capture IDs to a single
  819. --- (last) node instead of the full list of matching nodes. This option is only for backward
  820. --- compatibility and will be removed in a future release.
  821. ---
  822. ---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata): pattern id, match, metadata
  823. function Query:iter_matches(node, source, start, stop, opts)
  824. opts = opts or {}
  825. opts.match_limit = opts.match_limit or 256
  826. if type(source) == 'number' and source == 0 then
  827. source = api.nvim_get_current_buf()
  828. end
  829. start, stop = value_or_node_range(start, stop, node)
  830. local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
  831. local function iter()
  832. local match = cursor:next_match()
  833. if not match then
  834. return
  835. end
  836. local match_id, pattern = match:info()
  837. if not self:match_preds(match, source) then
  838. cursor:remove_match(match_id)
  839. return iter() -- tail call: try next match
  840. end
  841. local metadata = self:apply_directives(match, source)
  842. local captures = match:captures()
  843. if opts.all == false then
  844. -- Convert the match table into the old buggy version for backward
  845. -- compatibility. This is slow, but we only do it when the caller explicitly opted into it by
  846. -- setting `all` to `false`.
  847. local old_match = {} ---@type table<integer, TSNode>
  848. for k, v in pairs(captures or {}) do
  849. old_match[k] = v[#v]
  850. end
  851. return pattern, old_match, metadata
  852. end
  853. -- TODO(lewis6991): create a new function that returns {match, metadata}
  854. return pattern, captures, metadata
  855. end
  856. return iter
  857. end
  858. --- Optional keyword arguments:
  859. --- @class vim.treesitter.query.lint.Opts
  860. --- @inlinedoc
  861. ---
  862. --- Language(s) to use for checking the query.
  863. --- If multiple languages are specified, queries are validated for all of them
  864. --- @field langs? string|string[]
  865. ---
  866. --- Just clear current lint errors
  867. --- @field clear boolean
  868. --- Lint treesitter queries using installed parser, or clear lint errors.
  869. ---
  870. --- Use |treesitter-parsers| in runtimepath to check the query file in {buf} for errors:
  871. ---
  872. --- - verify that used nodes are valid identifiers in the grammar.
  873. --- - verify that predicates and directives are valid.
  874. --- - verify that top-level s-expressions are valid.
  875. ---
  876. --- The found diagnostics are reported using |diagnostic-api|.
  877. --- By default, the parser used for verification is determined by the containing folder
  878. --- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the
  879. --- `lua` language will be used.
  880. ---@param buf (integer) Buffer handle
  881. ---@param opts? vim.treesitter.query.lint.Opts
  882. function M.lint(buf, opts)
  883. if opts and opts.clear then
  884. vim.treesitter._query_linter.clear(buf)
  885. else
  886. vim.treesitter._query_linter.lint(buf, opts)
  887. end
  888. end
  889. --- Omnifunc for completing node names and predicates in treesitter queries.
  890. ---
  891. --- Use via
  892. ---
  893. --- ```lua
  894. --- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc'
  895. --- ```
  896. ---
  897. --- @param findstart 0|1
  898. --- @param base string
  899. function M.omnifunc(findstart, base)
  900. return vim.treesitter._query_linter.omnifunc(findstart, base)
  901. end
  902. --- Opens a live editor to query the buffer you started from.
  903. ---
  904. --- Can also be shown with [:EditQuery]().
  905. ---
  906. --- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted in
  907. --- the source buffer. The query editor is a scratch buffer, use `:write` to save it. You can find
  908. --- example queries at `$VIMRUNTIME/queries/`.
  909. ---
  910. --- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype.
  911. function M.edit(lang)
  912. assert(vim.treesitter.dev.edit_query(lang))
  913. end
  914. return M