query.lua 33 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076
  1. --- @brief This Lua |treesitter-query| interface allows you to create queries and use them to parse
  2. --- text. See |vim.treesitter.query.parse()| for a working example.
  3. local api = vim.api
  4. local language = require('vim.treesitter.language')
  5. local memoize = vim.func._memoize
  6. local M = {}
  7. ---@nodoc
  8. ---Parsed query, see |vim.treesitter.query.parse()|
  9. ---
  10. ---@class vim.treesitter.Query
  11. ---@field lang string parser language name
  12. ---@field captures string[] list of (unique) capture names defined in query
  13. ---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives)
  14. ---@field query TSQuery userdata query object
  15. local Query = {}
  16. Query.__index = Query
  17. ---@package
  18. ---@see vim.treesitter.query.parse
  19. ---@param lang string
  20. ---@param ts_query TSQuery
  21. ---@return vim.treesitter.Query
  22. function Query.new(lang, ts_query)
  23. local self = setmetatable({}, Query)
  24. local query_info = ts_query:inspect() ---@type TSQueryInfo
  25. self.query = ts_query
  26. self.lang = lang
  27. self.info = {
  28. captures = query_info.captures,
  29. patterns = query_info.patterns,
  30. }
  31. self.captures = self.info.captures
  32. return self
  33. end
  34. ---@nodoc
  35. ---Information for Query, see |vim.treesitter.query.parse()|
  36. ---@class vim.treesitter.QueryInfo
  37. ---
  38. ---List of (unique) capture names defined in query.
  39. ---@field captures string[]
  40. ---
  41. ---Contains information about predicates and directives.
  42. ---Key is pattern id, and value is list of predicates or directives defined in the pattern.
  43. ---A predicate or directive is a list of (integer|string); integer represents `capture_id`, and
  44. ---string represents (literal) arguments to predicate/directive. See |treesitter-predicates|
  45. ---and |treesitter-directives| for more details.
  46. ---@field patterns table<integer, (integer|string)[][]>
  47. ---@param files string[]
  48. ---@return string[]
  49. local function dedupe_files(files)
  50. local result = {}
  51. ---@type table<string,boolean>
  52. local seen = {}
  53. for _, path in ipairs(files) do
  54. if not seen[path] then
  55. table.insert(result, path)
  56. seen[path] = true
  57. end
  58. end
  59. return result
  60. end
  61. local function safe_read(filename, read_quantifier)
  62. local file, err = io.open(filename, 'r')
  63. if not file then
  64. error(err)
  65. end
  66. local content = file:read(read_quantifier)
  67. io.close(file)
  68. return content
  69. end
  70. --- Adds {ilang} to {base_langs}, only if {ilang} is different than {lang}
  71. ---
  72. ---@return boolean true If lang == ilang
  73. local function add_included_lang(base_langs, lang, ilang)
  74. if lang == ilang then
  75. return true
  76. end
  77. table.insert(base_langs, ilang)
  78. return false
  79. end
  80. --- Gets the list of files used to make up a query
  81. ---
  82. ---@param lang string Language to get query for
  83. ---@param query_name string Name of the query to load (e.g., "highlights")
  84. ---@param is_included? boolean Internal parameter, most of the time left as `nil`
  85. ---@return string[] query_files List of files to load for given query and language
  86. function M.get_files(lang, query_name, is_included)
  87. local query_path = string.format('queries/%s/%s.scm', lang, query_name)
  88. local lang_files = dedupe_files(api.nvim_get_runtime_file(query_path, true))
  89. if #lang_files == 0 then
  90. return {}
  91. end
  92. local base_query = nil ---@type string?
  93. local extensions = {}
  94. local base_langs = {} ---@type string[]
  95. -- Now get the base languages by looking at the first line of every file
  96. -- The syntax is the following :
  97. -- ;+ inherits: ({language},)*{language}
  98. --
  99. -- {language} ::= {lang} | ({lang})
  100. local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
  101. local EXTENDS_FORMAT = '^;+%s*extends%s*$'
  102. for _, filename in ipairs(lang_files) do
  103. local file, err = io.open(filename, 'r')
  104. if not file then
  105. error(err)
  106. end
  107. local extension = false
  108. for modeline in
  109. ---@return string
  110. function()
  111. return file:read('*l')
  112. end
  113. do
  114. if not vim.startswith(modeline, ';') then
  115. break
  116. end
  117. local langlist = modeline:match(MODELINE_FORMAT)
  118. if langlist then
  119. ---@diagnostic disable-next-line:param-type-mismatch
  120. for _, incllang in ipairs(vim.split(langlist, ',', true)) do
  121. local is_optional = incllang:match('%(.*%)')
  122. if is_optional then
  123. if not is_included then
  124. if add_included_lang(base_langs, lang, incllang:sub(2, #incllang - 1)) then
  125. extension = true
  126. end
  127. end
  128. else
  129. if add_included_lang(base_langs, lang, incllang) then
  130. extension = true
  131. end
  132. end
  133. end
  134. elseif modeline:match(EXTENDS_FORMAT) then
  135. extension = true
  136. end
  137. end
  138. if extension then
  139. table.insert(extensions, filename)
  140. elseif base_query == nil then
  141. base_query = filename
  142. end
  143. io.close(file)
  144. end
  145. local query_files = {}
  146. for _, base_lang in ipairs(base_langs) do
  147. local base_files = M.get_files(base_lang, query_name, true)
  148. vim.list_extend(query_files, base_files)
  149. end
  150. vim.list_extend(query_files, { base_query })
  151. vim.list_extend(query_files, extensions)
  152. return query_files
  153. end
  154. ---@param filenames string[]
  155. ---@return string
  156. local function read_query_files(filenames)
  157. local contents = {}
  158. for _, filename in ipairs(filenames) do
  159. table.insert(contents, safe_read(filename, '*a'))
  160. end
  161. return table.concat(contents, '')
  162. end
  163. -- The explicitly set queries from |vim.treesitter.query.set()|
  164. ---@type table<string,table<string,vim.treesitter.Query>>
  165. local explicit_queries = setmetatable({}, {
  166. __index = function(t, k)
  167. local lang_queries = {}
  168. rawset(t, k, lang_queries)
  169. return lang_queries
  170. end,
  171. })
  172. --- Sets the runtime query named {query_name} for {lang}
  173. ---
  174. --- This allows users to override any runtime files and/or configuration
  175. --- set by plugins.
  176. ---
  177. ---@param lang string Language to use for the query
  178. ---@param query_name string Name of the query (e.g., "highlights")
  179. ---@param text string Query text (unparsed).
  180. function M.set(lang, query_name, text)
  181. explicit_queries[lang][query_name] = M.parse(lang, text)
  182. end
  183. --- Returns the runtime query {query_name} for {lang}.
  184. ---
  185. ---@param lang string Language to use for the query
  186. ---@param query_name string Name of the query (e.g. "highlights")
  187. ---
  188. ---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found.
  189. M.get = memoize('concat-2', function(lang, query_name)
  190. if explicit_queries[lang][query_name] then
  191. return explicit_queries[lang][query_name]
  192. end
  193. local query_files = M.get_files(lang, query_name)
  194. local query_string = read_query_files(query_files)
  195. if #query_string == 0 then
  196. return nil
  197. end
  198. return M.parse(lang, query_string)
  199. end)
  200. --- Parses a {query} string and returns a `Query` object (|lua-treesitter-query|), which can be used
  201. --- to search the tree for the query patterns (via |Query:iter_captures()|, |Query:iter_matches()|),
  202. --- or inspect the query via these fields:
  203. --- - `captures`: a list of unique capture names defined in the query (alias: `info.captures`).
  204. --- - `info.patterns`: information about predicates.
  205. ---
  206. --- Example (select the code then run `:'<,'>lua` to try it):
  207. --- ```lua
  208. --- local query = vim.treesitter.query.parse('vimdoc', [[
  209. --- ; query
  210. --- ((h1) @str
  211. --- (#trim! @str 1 1 1 1))
  212. --- ]])
  213. --- local tree = vim.treesitter.get_parser():parse()[1]
  214. --- for id, node, metadata in query:iter_captures(tree:root(), 0) do
  215. --- -- Print the node name and source text.
  216. --- vim.print({node:type(), vim.treesitter.get_node_text(node, vim.api.nvim_get_current_buf())})
  217. --- end
  218. --- ```
  219. ---
  220. ---@param lang string Language to use for the query
  221. ---@param query string Query text, in s-expr syntax
  222. ---
  223. ---@return vim.treesitter.Query : Parsed query
  224. ---
  225. ---@see [vim.treesitter.query.get()]
  226. M.parse = memoize('concat-2', function(lang, query)
  227. assert(language.add(lang))
  228. local ts_query = vim._ts_parse_query(lang, query)
  229. return Query.new(lang, ts_query)
  230. end)
  231. --- Implementations of predicates that can optionally be prefixed with "any-".
  232. ---
  233. --- These functions contain the implementations for each predicate, correctly
  234. --- handling the "any" vs "all" semantics. They are called from the
  235. --- predicate_handlers table with the appropriate arguments for each predicate.
  236. local impl = {
  237. --- @param match table<integer,TSNode[]>
  238. --- @param source integer|string
  239. --- @param predicate any[]
  240. --- @param any boolean
  241. ['eq'] = function(match, source, predicate, any)
  242. local nodes = match[predicate[2]]
  243. if not nodes or #nodes == 0 then
  244. return true
  245. end
  246. for _, node in ipairs(nodes) do
  247. local node_text = vim.treesitter.get_node_text(node, source)
  248. local str ---@type string
  249. if type(predicate[3]) == 'string' then
  250. -- (#eq? @aa "foo")
  251. str = predicate[3]
  252. else
  253. -- (#eq? @aa @bb)
  254. local other = assert(match[predicate[3]])
  255. assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes')
  256. str = vim.treesitter.get_node_text(other[1], source)
  257. end
  258. local res = str ~= nil and node_text == str
  259. if any and res then
  260. return true
  261. elseif not any and not res then
  262. return false
  263. end
  264. end
  265. return not any
  266. end,
  267. --- @param match table<integer,TSNode[]>
  268. --- @param source integer|string
  269. --- @param predicate any[]
  270. --- @param any boolean
  271. ['lua-match'] = function(match, source, predicate, any)
  272. local nodes = match[predicate[2]]
  273. if not nodes or #nodes == 0 then
  274. return true
  275. end
  276. for _, node in ipairs(nodes) do
  277. local regex = predicate[3]
  278. local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
  279. if any and res then
  280. return true
  281. elseif not any and not res then
  282. return false
  283. end
  284. end
  285. return not any
  286. end,
  287. ['match'] = (function()
  288. local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
  289. local function check_magic(str)
  290. if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
  291. return str
  292. end
  293. return '\\v' .. str
  294. end
  295. local compiled_vim_regexes = setmetatable({}, {
  296. __index = function(t, pattern)
  297. local res = vim.regex(check_magic(pattern))
  298. rawset(t, pattern, res)
  299. return res
  300. end,
  301. })
  302. --- @param match table<integer,TSNode[]>
  303. --- @param source integer|string
  304. --- @param predicate any[]
  305. --- @param any boolean
  306. return function(match, source, predicate, any)
  307. local nodes = match[predicate[2]]
  308. if not nodes or #nodes == 0 then
  309. return true
  310. end
  311. for _, node in ipairs(nodes) do
  312. local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex
  313. local res = regex:match_str(vim.treesitter.get_node_text(node, source))
  314. if any and res then
  315. return true
  316. elseif not any and not res then
  317. return false
  318. end
  319. end
  320. return not any
  321. end
  322. end)(),
  323. --- @param match table<integer,TSNode[]>
  324. --- @param source integer|string
  325. --- @param predicate any[]
  326. --- @param any boolean
  327. ['contains'] = function(match, source, predicate, any)
  328. local nodes = match[predicate[2]]
  329. if not nodes or #nodes == 0 then
  330. return true
  331. end
  332. for _, node in ipairs(nodes) do
  333. local node_text = vim.treesitter.get_node_text(node, source)
  334. for i = 3, #predicate do
  335. local res = string.find(node_text, predicate[i], 1, true)
  336. if any and res then
  337. return true
  338. elseif not any and not res then
  339. return false
  340. end
  341. end
  342. end
  343. return not any
  344. end,
  345. }
  346. ---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean
  347. -- Predicate handler receive the following arguments
  348. -- (match, pattern, bufnr, predicate)
  349. ---@type table<string,TSPredicate>
  350. local predicate_handlers = {
  351. ['eq?'] = function(match, _, source, predicate)
  352. return impl['eq'](match, source, predicate, false)
  353. end,
  354. ['any-eq?'] = function(match, _, source, predicate)
  355. return impl['eq'](match, source, predicate, true)
  356. end,
  357. ['lua-match?'] = function(match, _, source, predicate)
  358. return impl['lua-match'](match, source, predicate, false)
  359. end,
  360. ['any-lua-match?'] = function(match, _, source, predicate)
  361. return impl['lua-match'](match, source, predicate, true)
  362. end,
  363. ['match?'] = function(match, _, source, predicate)
  364. return impl['match'](match, source, predicate, false)
  365. end,
  366. ['any-match?'] = function(match, _, source, predicate)
  367. return impl['match'](match, source, predicate, true)
  368. end,
  369. ['contains?'] = function(match, _, source, predicate)
  370. return impl['contains'](match, source, predicate, false)
  371. end,
  372. ['any-contains?'] = function(match, _, source, predicate)
  373. return impl['contains'](match, source, predicate, true)
  374. end,
  375. ['any-of?'] = function(match, _, source, predicate)
  376. local nodes = match[predicate[2]]
  377. if not nodes or #nodes == 0 then
  378. return true
  379. end
  380. for _, node in ipairs(nodes) do
  381. local node_text = vim.treesitter.get_node_text(node, source)
  382. -- Since 'predicate' will not be used by callers of this function, use it
  383. -- to store a string set built from the list of words to check against.
  384. local string_set = predicate['string_set'] --- @type table<string, boolean>
  385. if not string_set then
  386. string_set = {}
  387. for i = 3, #predicate do
  388. string_set[predicate[i]] = true
  389. end
  390. predicate['string_set'] = string_set
  391. end
  392. if string_set[node_text] then
  393. return true
  394. end
  395. end
  396. return false
  397. end,
  398. ['has-ancestor?'] = function(match, _, _, predicate)
  399. local nodes = match[predicate[2]]
  400. if not nodes or #nodes == 0 then
  401. return true
  402. end
  403. for _, node in ipairs(nodes) do
  404. if node:__has_ancestor(predicate) then
  405. return true
  406. end
  407. end
  408. return false
  409. end,
  410. ['has-parent?'] = function(match, _, _, predicate)
  411. local nodes = match[predicate[2]]
  412. if not nodes or #nodes == 0 then
  413. return true
  414. end
  415. for _, node in ipairs(nodes) do
  416. if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
  417. return true
  418. end
  419. end
  420. return false
  421. end,
  422. }
  423. -- As we provide lua-match? also expose vim-match?
  424. predicate_handlers['vim-match?'] = predicate_handlers['match?']
  425. predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
  426. ---@nodoc
  427. ---@class vim.treesitter.query.TSMetadata
  428. ---@field range? Range
  429. ---@field conceal? string
  430. ---@field [integer]? vim.treesitter.query.TSMetadata
  431. ---@field [string]? integer|string
  432. ---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
  433. -- Predicate handler receive the following arguments
  434. -- (match, pattern, bufnr, predicate)
  435. -- Directives store metadata or perform side effects against a match.
  436. -- Directives should always end with a `!`.
  437. -- Directive handler receive the following arguments
  438. -- (match, pattern, bufnr, predicate, metadata)
  439. ---@type table<string,TSDirective>
  440. local directive_handlers = {
  441. ['set!'] = function(_, _, _, pred, metadata)
  442. if #pred >= 3 and type(pred[2]) == 'number' then
  443. -- (#set! @capture key value)
  444. local capture_id, key, value = pred[2], pred[3], pred[4]
  445. if not metadata[capture_id] then
  446. metadata[capture_id] = {}
  447. end
  448. metadata[capture_id][key] = value
  449. else
  450. -- (#set! key value)
  451. local key, value = pred[2], pred[3]
  452. metadata[key] = value or true
  453. end
  454. end,
  455. -- Shifts the range of a node.
  456. -- Example: (#offset! @_node 0 1 0 -1)
  457. ['offset!'] = function(match, _, _, pred, metadata)
  458. local capture_id = pred[2] --[[@as integer]]
  459. local nodes = match[capture_id]
  460. if not nodes or #nodes == 0 then
  461. return
  462. end
  463. assert(#nodes == 1, '#offset! does not support captures on multiple nodes')
  464. local node = nodes[1]
  465. if not metadata[capture_id] then
  466. metadata[capture_id] = {}
  467. end
  468. local range = metadata[capture_id].range or { node:range() }
  469. local start_row_offset = pred[3] or 0
  470. local start_col_offset = pred[4] or 0
  471. local end_row_offset = pred[5] or 0
  472. local end_col_offset = pred[6] or 0
  473. range[1] = range[1] + start_row_offset
  474. range[2] = range[2] + start_col_offset
  475. range[3] = range[3] + end_row_offset
  476. range[4] = range[4] + end_col_offset
  477. -- If this produces an invalid range, we just skip it.
  478. if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
  479. metadata[capture_id].range = range
  480. end
  481. end,
  482. -- Transform the content of the node
  483. -- Example: (#gsub! @_node ".*%.(.*)" "%1")
  484. ['gsub!'] = function(match, _, bufnr, pred, metadata)
  485. assert(#pred == 4)
  486. local id = pred[2]
  487. assert(type(id) == 'number')
  488. local nodes = match[id]
  489. if not nodes or #nodes == 0 then
  490. return
  491. end
  492. assert(#nodes == 1, '#gsub! does not support captures on multiple nodes')
  493. local node = nodes[1]
  494. local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
  495. if not metadata[id] then
  496. metadata[id] = {}
  497. end
  498. local pattern, replacement = pred[3], pred[4]
  499. assert(type(pattern) == 'string')
  500. assert(type(replacement) == 'string')
  501. metadata[id].text = text:gsub(pattern, replacement)
  502. end,
  503. -- Trim whitespace from both sides of the node
  504. -- Example: (#trim! @fold 1 1 1 1)
  505. ['trim!'] = function(match, _, bufnr, pred, metadata)
  506. local capture_id = pred[2]
  507. assert(type(capture_id) == 'number')
  508. local trim_start_lines = pred[3] == '1'
  509. local trim_start_cols = pred[4] == '1'
  510. local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility
  511. local trim_end_cols = pred[6] == '1'
  512. local nodes = match[capture_id]
  513. if not nodes or #nodes == 0 then
  514. return
  515. end
  516. assert(#nodes == 1, '#trim! does not support captures on multiple nodes')
  517. local node = nodes[1]
  518. local start_row, start_col, end_row, end_col = node:range()
  519. local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n')
  520. if end_col == 0 then
  521. -- get_node_text() will ignore the last line if the node ends at column 0
  522. node_text[#node_text + 1] = ''
  523. end
  524. local end_idx = #node_text
  525. local start_idx = 1
  526. if trim_end_lines then
  527. while end_idx > 0 and node_text[end_idx]:find('^%s*$') do
  528. end_idx = end_idx - 1
  529. end_row = end_row - 1
  530. -- set the end position to the last column of the next line, or 0 if we just trimmed the
  531. -- last line
  532. end_col = end_idx > 0 and #node_text[end_idx] or 0
  533. end
  534. end
  535. if trim_end_cols then
  536. if end_idx == 0 then
  537. end_row = start_row
  538. end_col = start_col
  539. else
  540. local whitespace_start = node_text[end_idx]:find('(%s*)$')
  541. end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0)
  542. end
  543. end
  544. if trim_start_lines then
  545. while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do
  546. start_idx = start_idx + 1
  547. start_row = start_row + 1
  548. start_col = 0
  549. end
  550. end
  551. if trim_start_cols and node_text[start_idx] then
  552. local _, whitespace_end = node_text[start_idx]:find('^(%s*)')
  553. whitespace_end = whitespace_end or 0
  554. start_col = (start_idx == 1 and start_col or 0) + whitespace_end
  555. end
  556. -- If this produces an invalid range, we just skip it.
  557. if start_row < end_row or (start_row == end_row and start_col <= end_col) then
  558. metadata[capture_id] = metadata[capture_id] or {}
  559. metadata[capture_id].range = { start_row, start_col, end_row, end_col }
  560. end
  561. end,
  562. }
  563. --- @class vim.treesitter.query.add_predicate.Opts
  564. --- @inlinedoc
  565. ---
  566. --- Override an existing predicate of the same name
  567. --- @field force? boolean
  568. ---
  569. --- Use the correct implementation of the match table where capture IDs map to
  570. --- a list of nodes instead of a single node. Defaults to true. This option will
  571. --- be removed in a future release.
  572. --- @field all? boolean
  573. --- Adds a new predicate to be used in queries
  574. ---
  575. ---@param name string Name of the predicate, without leading #
  576. ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata): boolean?
  577. --- - see |vim.treesitter.query.add_directive()| for argument meanings
  578. ---@param opts? vim.treesitter.query.add_predicate.Opts
  579. function M.add_predicate(name, handler, opts)
  580. -- Backward compatibility: old signature had "force" as boolean argument
  581. if type(opts) == 'boolean' then
  582. opts = { force = opts }
  583. end
  584. opts = opts or {}
  585. if predicate_handlers[name] and not opts.force then
  586. error(string.format('Overriding existing predicate %s', name))
  587. end
  588. if opts.all ~= false then
  589. predicate_handlers[name] = handler
  590. else
  591. --- @param match table<integer, TSNode[]>
  592. local function wrapper(match, ...)
  593. local m = {} ---@type table<integer, TSNode>
  594. for k, v in pairs(match) do
  595. if type(k) == 'number' then
  596. m[k] = v[#v]
  597. end
  598. end
  599. return handler(m, ...)
  600. end
  601. predicate_handlers[name] = wrapper
  602. end
  603. end
  604. --- Adds a new directive to be used in queries
  605. ---
  606. --- Handlers can set match level data by setting directly on the
  607. --- metadata object `metadata.key = value`. Additionally, handlers
  608. --- can set node level data by using the capture id on the
  609. --- metadata table `metadata[capture_id].key = value`
  610. ---
  611. ---@param name string Name of the directive, without leading #
  612. ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata)
  613. --- - match: A table mapping capture IDs to a list of captured nodes
  614. --- - pattern: the index of the matching pattern in the query file
  615. --- - predicate: list of strings containing the full directive being called, e.g.
  616. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }`
  617. ---@param opts vim.treesitter.query.add_predicate.Opts
  618. function M.add_directive(name, handler, opts)
  619. -- Backward compatibility: old signature had "force" as boolean argument
  620. if type(opts) == 'boolean' then
  621. opts = { force = opts }
  622. end
  623. opts = opts or {}
  624. if directive_handlers[name] and not opts.force then
  625. error(string.format('Overriding existing directive %s', name))
  626. end
  627. if opts.all then
  628. directive_handlers[name] = handler
  629. else
  630. --- @param match table<integer, TSNode[]>
  631. local function wrapper(match, ...)
  632. local m = {} ---@type table<integer, TSNode>
  633. for k, v in pairs(match) do
  634. m[k] = v[#v]
  635. end
  636. handler(m, ...)
  637. end
  638. directive_handlers[name] = wrapper
  639. end
  640. end
  641. --- Lists the currently available directives to use in queries.
  642. ---@return string[] : Supported directives.
  643. function M.list_directives()
  644. return vim.tbl_keys(directive_handlers)
  645. end
  646. --- Lists the currently available predicates to use in queries.
  647. ---@return string[] : Supported predicates.
  648. function M.list_predicates()
  649. return vim.tbl_keys(predicate_handlers)
  650. end
  651. local function xor(x, y)
  652. return (x or y) and not (x and y)
  653. end
  654. local function is_directive(name)
  655. return string.sub(name, -1) == '!'
  656. end
  657. ---@private
  658. ---@param match TSQueryMatch
  659. ---@param source integer|string
  660. function Query:match_preds(match, source)
  661. local _, pattern = match:info()
  662. local preds = self.info.patterns[pattern]
  663. if not preds then
  664. return true
  665. end
  666. local captures = match:captures()
  667. for _, pred in pairs(preds) do
  668. -- Here we only want to return if a predicate DOES NOT match, and
  669. -- continue on the other case. This way unknown predicates will not be considered,
  670. -- which allows some testing and easier user extensibility (#12173).
  671. -- Also, tree-sitter strips the leading # from predicates for us.
  672. local is_not = false
  673. -- Skip over directives... they will get processed after all the predicates.
  674. if not is_directive(pred[1]) then
  675. local pred_name = pred[1]
  676. if pred_name:match('^not%-') then
  677. pred_name = pred_name:sub(5)
  678. is_not = true
  679. end
  680. local handler = predicate_handlers[pred_name]
  681. if not handler then
  682. error(string.format('No handler for %s', pred[1]))
  683. return false
  684. end
  685. local pred_matches = handler(captures, pattern, source, pred)
  686. if not xor(is_not, pred_matches) then
  687. return false
  688. end
  689. end
  690. end
  691. return true
  692. end
  693. ---@private
  694. ---@param match TSQueryMatch
  695. ---@return vim.treesitter.query.TSMetadata metadata
  696. function Query:apply_directives(match, source)
  697. ---@type vim.treesitter.query.TSMetadata
  698. local metadata = {}
  699. local _, pattern = match:info()
  700. local preds = self.info.patterns[pattern]
  701. if not preds then
  702. return metadata
  703. end
  704. local captures = match:captures()
  705. for _, pred in pairs(preds) do
  706. if is_directive(pred[1]) then
  707. local handler = directive_handlers[pred[1]]
  708. if not handler then
  709. error(string.format('No handler for %s', pred[1]))
  710. end
  711. handler(captures, pattern, source, pred, metadata)
  712. end
  713. end
  714. return metadata
  715. end
  716. --- Returns the start and stop value if set else the node's range.
  717. -- When the node's range is used, the stop is incremented by 1
  718. -- to make the search inclusive.
  719. ---@param start integer?
  720. ---@param stop integer?
  721. ---@param node TSNode
  722. ---@return integer, integer
  723. local function value_or_node_range(start, stop, node)
  724. if start == nil then
  725. start = node:start()
  726. end
  727. if stop == nil then
  728. stop = node:end_() + 1 -- Make stop inclusive
  729. end
  730. return start, stop
  731. end
  732. --- @param match TSQueryMatch
  733. --- @return integer
  734. local function match_id_hash(_, match)
  735. return (match:info())
  736. end
  737. --- Iterates over all captures from all matches in {node}.
  738. ---
  739. --- {source} is required if the query contains predicates; then the caller
  740. --- must ensure to use a freshly parsed tree consistent with the current
  741. --- text of the buffer (if relevant). {start} and {stop} can be used to limit
  742. --- matches inside a row range (this is typically used with root node
  743. --- as the {node}, i.e., to get syntax highlight matches in the current
  744. --- viewport). When omitted, the {start} and {stop} row values are used from the given node.
  745. ---
  746. --- The iterator returns four values:
  747. --- 1. the numeric id identifying the capture
  748. --- 2. the captured node
  749. --- 3. metadata from any directives processing the match
  750. --- 4. the match itself
  751. ---
  752. --- Example: how to get captures by name:
  753. --- ```lua
  754. --- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do
  755. --- local name = query.captures[id] -- name of the capture in the query
  756. --- -- typically useful info about the node:
  757. --- local type = node:type() -- type of the captured node
  758. --- local row1, col1, row2, col2 = node:range() -- range of the capture
  759. --- -- ... use the info here ...
  760. --- end
  761. --- ```
  762. ---
  763. ---@param node TSNode under which the search will occur
  764. ---@param source (integer|string) Source buffer or string to extract text from
  765. ---@param start? integer Starting line for the search. Defaults to `node:start()`.
  766. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
  767. ---
  768. ---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch):
  769. --- capture id, capture node, metadata, match
  770. ---
  771. ---@note Captures are only returned if the query pattern of a specific capture contained predicates.
  772. function Query:iter_captures(node, source, start, stop)
  773. if type(source) == 'number' and source == 0 then
  774. source = api.nvim_get_current_buf()
  775. end
  776. start, stop = value_or_node_range(start, stop, node)
  777. local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
  778. local apply_directives = memoize(match_id_hash, self.apply_directives, false)
  779. local match_preds = memoize(match_id_hash, self.match_preds, false)
  780. local function iter(end_line)
  781. local capture, captured_node, match = cursor:next_capture()
  782. if not capture then
  783. return
  784. end
  785. if not match_preds(self, match, source) then
  786. local match_id = match:info()
  787. cursor:remove_match(match_id)
  788. if end_line and captured_node:range() > end_line then
  789. return nil, captured_node, nil, nil
  790. end
  791. return iter(end_line) -- tail call: try next match
  792. end
  793. local metadata = apply_directives(self, match, source)
  794. return capture, captured_node, metadata, match
  795. end
  796. return iter
  797. end
  798. --- Iterates the matches of self on a given range.
  799. ---
  800. --- Iterate over all matches within a {node}. The arguments are the same as for
  801. --- |Query:iter_captures()| but the iterated values are different: an (1-based)
  802. --- index of the pattern in the query, a table mapping capture indices to a list
  803. --- of nodes, and metadata from any directives processing the match.
  804. ---
  805. --- Example:
  806. ---
  807. --- ```lua
  808. --- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1) do
  809. --- for id, nodes in pairs(match) do
  810. --- local name = query.captures[id]
  811. --- for _, node in ipairs(nodes) do
  812. --- -- `node` was captured by the `name` capture in the match
  813. ---
  814. --- local node_data = metadata[id] -- Node level metadata
  815. --- ... use the info here ...
  816. --- end
  817. --- end
  818. --- end
  819. --- ```
  820. ---
  821. ---
  822. ---@param node TSNode under which the search will occur
  823. ---@param source (integer|string) Source buffer or string to search
  824. ---@param start? integer Starting line for the search. Defaults to `node:start()`.
  825. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
  826. ---@param opts? table Optional keyword arguments:
  827. --- - max_start_depth (integer) if non-zero, sets the maximum start depth
  828. --- for each match. This is used to prevent traversing too deep into a tree.
  829. --- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256).
  830. --- - all (boolean) When `false` (default `true`), the returned table maps capture IDs to a single
  831. --- (last) node instead of the full list of matching nodes. This option is only for backward
  832. --- compatibility and will be removed in a future release.
  833. ---
  834. ---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata): pattern id, match, metadata
  835. function Query:iter_matches(node, source, start, stop, opts)
  836. opts = opts or {}
  837. opts.match_limit = opts.match_limit or 256
  838. if type(source) == 'number' and source == 0 then
  839. source = api.nvim_get_current_buf()
  840. end
  841. start, stop = value_or_node_range(start, stop, node)
  842. local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
  843. local function iter()
  844. local match = cursor:next_match()
  845. if not match then
  846. return
  847. end
  848. local match_id, pattern = match:info()
  849. if not self:match_preds(match, source) then
  850. cursor:remove_match(match_id)
  851. return iter() -- tail call: try next match
  852. end
  853. local metadata = self:apply_directives(match, source)
  854. local captures = match:captures()
  855. if opts.all == false then
  856. -- Convert the match table into the old buggy version for backward
  857. -- compatibility. This is slow, but we only do it when the caller explicitly opted into it by
  858. -- setting `all` to `false`.
  859. local old_match = {} ---@type table<integer, TSNode>
  860. for k, v in pairs(captures or {}) do
  861. old_match[k] = v[#v]
  862. end
  863. return pattern, old_match, metadata
  864. end
  865. -- TODO(lewis6991): create a new function that returns {match, metadata}
  866. return pattern, captures, metadata
  867. end
  868. return iter
  869. end
  870. --- Optional keyword arguments:
  871. --- @class vim.treesitter.query.lint.Opts
  872. --- @inlinedoc
  873. ---
  874. --- Language(s) to use for checking the query.
  875. --- If multiple languages are specified, queries are validated for all of them
  876. --- @field langs? string|string[]
  877. ---
  878. --- Just clear current lint errors
  879. --- @field clear boolean
  880. --- Lint treesitter queries using installed parser, or clear lint errors.
  881. ---
  882. --- Use |treesitter-parsers| in runtimepath to check the query file in {buf} for errors:
  883. ---
  884. --- - verify that used nodes are valid identifiers in the grammar.
  885. --- - verify that predicates and directives are valid.
  886. --- - verify that top-level s-expressions are valid.
  887. ---
  888. --- The found diagnostics are reported using |diagnostic-api|.
  889. --- By default, the parser used for verification is determined by the containing folder
  890. --- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the
  891. --- `lua` language will be used.
  892. ---@param buf (integer) Buffer handle
  893. ---@param opts? vim.treesitter.query.lint.Opts
  894. function M.lint(buf, opts)
  895. if opts and opts.clear then
  896. vim.treesitter._query_linter.clear(buf)
  897. else
  898. vim.treesitter._query_linter.lint(buf, opts)
  899. end
  900. end
  901. --- Omnifunc for completing node names and predicates in treesitter queries.
  902. ---
  903. --- Use via
  904. ---
  905. --- ```lua
  906. --- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc'
  907. --- ```
  908. ---
  909. --- @param findstart 0|1
  910. --- @param base string
  911. function M.omnifunc(findstart, base)
  912. return vim.treesitter._query_linter.omnifunc(findstart, base)
  913. end
  914. --- Opens a live editor to query the buffer you started from.
  915. ---
  916. --- Can also be shown with [:EditQuery]().
  917. ---
  918. --- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted in
  919. --- the source buffer. The query editor is a scratch buffer, use `:write` to save it. You can find
  920. --- example queries at `$VIMRUNTIME/queries/`.
  921. ---
  922. --- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype.
  923. function M.edit(lang)
  924. assert(vim.treesitter.dev.edit_query(lang))
  925. end
  926. return M