query.lua 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182
  1. --- @brief This Lua |treesitter-query| interface allows you to create queries and use them to parse
  2. --- text. See |vim.treesitter.query.parse()| for a working example.
  3. local api = vim.api
  4. local language = require('vim.treesitter.language')
  5. local memoize = vim.func._memoize
  6. local MODELINE_FORMAT = '^;+%s*inherits%s*:?%s*([a-z_,()]+)%s*$'
  7. local EXTENDS_FORMAT = '^;+%s*extends%s*$'
  8. local M = {}
  9. local function is_directive(name)
  10. return string.sub(name, -1) == '!'
  11. end
  12. ---@nodoc
  13. ---@class vim.treesitter.query.ProcessedPredicate
  14. ---@field [1] string predicate name
  15. ---@field [2] boolean should match
  16. ---@field [3] (integer|string)[] the original predicate
  17. ---@alias vim.treesitter.query.ProcessedDirective (integer|string)[]
  18. ---@nodoc
  19. ---@class vim.treesitter.query.ProcessedPattern {
  20. ---@field predicates vim.treesitter.query.ProcessedPredicate[]
  21. ---@field directives vim.treesitter.query.ProcessedDirective[]
  22. --- Splits the query patterns into predicates and directives.
  23. ---@param patterns table<integer, (integer|string)[][]>
  24. ---@return table<integer, vim.treesitter.query.ProcessedPattern>
  25. ---@return boolean
  26. local function process_patterns(patterns)
  27. ---@type table<integer, vim.treesitter.query.ProcessedPattern>
  28. local processed_patterns = {}
  29. local has_combined = false
  30. for k, pattern_list in pairs(patterns) do
  31. ---@type vim.treesitter.query.ProcessedPredicate[]
  32. local predicates = {}
  33. ---@type vim.treesitter.query.ProcessedDirective[]
  34. local directives = {}
  35. for _, pattern in ipairs(pattern_list) do
  36. -- Note: tree-sitter strips the leading # from predicates for us.
  37. local pred_name = pattern[1]
  38. ---@cast pred_name string
  39. if is_directive(pred_name) then
  40. table.insert(directives, pattern)
  41. if vim.deep_equal(pattern, { 'set!', 'injection.combined' }) then
  42. has_combined = true
  43. end
  44. else
  45. local should_match = true
  46. if pred_name:match('^not%-') then
  47. pred_name = pred_name:sub(5)
  48. should_match = false
  49. end
  50. table.insert(predicates, { pred_name, should_match, pattern })
  51. end
  52. end
  53. processed_patterns[k] = { predicates = predicates, directives = directives }
  54. end
  55. return processed_patterns, has_combined
  56. end
  57. ---@nodoc
  58. ---Parsed query, see |vim.treesitter.query.parse()|
  59. ---
  60. ---@class vim.treesitter.Query
  61. ---@field lang string parser language name
  62. ---@field captures string[] list of (unique) capture names defined in query
  63. ---@field info vim.treesitter.QueryInfo query context (e.g. captures, predicates, directives)
  64. ---@field query TSQuery userdata query object
  65. ---@field has_combined_injections boolean whether the query contains combined injections
  66. ---@field private _processed_patterns table<integer, vim.treesitter.query.ProcessedPattern>
  67. local Query = {}
  68. Query.__index = Query
  69. ---@package
  70. ---@see vim.treesitter.query.parse
  71. ---@param lang string
  72. ---@param ts_query TSQuery
  73. ---@return vim.treesitter.Query
  74. function Query.new(lang, ts_query)
  75. local self = setmetatable({}, Query)
  76. local query_info = ts_query:inspect() ---@type TSQueryInfo
  77. self.query = ts_query
  78. self.lang = lang
  79. self.info = {
  80. captures = query_info.captures,
  81. patterns = query_info.patterns,
  82. }
  83. self.captures = self.info.captures
  84. self._processed_patterns, self.has_combined_injections = process_patterns(self.info.patterns)
  85. return self
  86. end
  87. ---@nodoc
  88. ---Information for Query, see |vim.treesitter.query.parse()|
  89. ---@class vim.treesitter.QueryInfo
  90. ---
  91. ---List of (unique) capture names defined in query.
  92. ---@field captures string[]
  93. ---
  94. ---Contains information about predicates and directives.
  95. ---Key is pattern id, and value is list of predicates or directives defined in the pattern.
  96. ---A predicate or directive is a list of (integer|string); integer represents `capture_id`, and
  97. ---string represents (literal) arguments to predicate/directive. See |treesitter-predicates|
  98. ---and |treesitter-directives| for more details.
  99. ---@field patterns table<integer, (integer|string)[][]>
  100. ---@param files string[]
  101. ---@return string[]
  102. local function dedupe_files(files)
  103. local result = {}
  104. ---@type table<string,boolean>
  105. local seen = {}
  106. for _, path in ipairs(files) do
  107. if not seen[path] then
  108. table.insert(result, path)
  109. seen[path] = true
  110. end
  111. end
  112. return result
  113. end
  114. local function safe_read(filename, read_quantifier)
  115. local file, err = io.open(filename, 'r')
  116. if not file then
  117. error(err)
  118. end
  119. local content = file:read(read_quantifier)
  120. io.close(file)
  121. return content
  122. end
  123. --- Adds {ilang} to {base_langs}, only if {ilang} is different than {lang}
  124. ---
  125. ---@return boolean true If lang == ilang
  126. local function add_included_lang(base_langs, lang, ilang)
  127. if lang == ilang then
  128. return true
  129. end
  130. table.insert(base_langs, ilang)
  131. return false
  132. end
  133. --- Gets the list of files used to make up a query
  134. ---
  135. ---@param lang string Language to get query for
  136. ---@param query_name string Name of the query to load (e.g., "highlights")
  137. ---@param is_included? boolean Internal parameter, most of the time left as `nil`
  138. ---@return string[] query_files List of files to load for given query and language
  139. function M.get_files(lang, query_name, is_included)
  140. local query_path = string.format('queries/%s/%s.scm', lang, query_name)
  141. local lang_files = dedupe_files(api.nvim_get_runtime_file(query_path, true))
  142. if #lang_files == 0 then
  143. return {}
  144. end
  145. local base_query = nil ---@type string?
  146. local extensions = {}
  147. local base_langs = {} ---@type string[]
  148. -- Now get the base languages by looking at the first line of every file
  149. -- The syntax is the following :
  150. -- ;+ inherits: ({language},)*{language}
  151. --
  152. -- {language} ::= {lang} | ({lang})
  153. for _, filename in ipairs(lang_files) do
  154. local file, err = io.open(filename, 'r')
  155. if not file then
  156. error(err)
  157. end
  158. local extension = false
  159. for modeline in
  160. ---@return string
  161. function()
  162. return file:read('*l')
  163. end
  164. do
  165. if not vim.startswith(modeline, ';') then
  166. break
  167. end
  168. local langlist = modeline:match(MODELINE_FORMAT)
  169. if langlist then
  170. ---@diagnostic disable-next-line:param-type-mismatch
  171. for _, incllang in ipairs(vim.split(langlist, ',', true)) do
  172. local is_optional = incllang:match('%(.*%)')
  173. if is_optional then
  174. if not is_included then
  175. if add_included_lang(base_langs, lang, incllang:sub(2, #incllang - 1)) then
  176. extension = true
  177. end
  178. end
  179. else
  180. if add_included_lang(base_langs, lang, incllang) then
  181. extension = true
  182. end
  183. end
  184. end
  185. elseif modeline:match(EXTENDS_FORMAT) then
  186. extension = true
  187. end
  188. end
  189. if extension then
  190. table.insert(extensions, filename)
  191. elseif base_query == nil then
  192. base_query = filename
  193. end
  194. io.close(file)
  195. end
  196. local query_files = {}
  197. for _, base_lang in ipairs(base_langs) do
  198. local base_files = M.get_files(base_lang, query_name, true)
  199. vim.list_extend(query_files, base_files)
  200. end
  201. vim.list_extend(query_files, { base_query })
  202. vim.list_extend(query_files, extensions)
  203. return query_files
  204. end
  205. ---@param filenames string[]
  206. ---@return string
  207. local function read_query_files(filenames)
  208. local contents = {}
  209. for _, filename in ipairs(filenames) do
  210. table.insert(contents, safe_read(filename, '*a'))
  211. end
  212. return table.concat(contents, '')
  213. end
  214. -- The explicitly set query strings from |vim.treesitter.query.set()|
  215. ---@type table<string,table<string,string>>
  216. local explicit_queries = setmetatable({}, {
  217. __index = function(t, k)
  218. local lang_queries = {}
  219. rawset(t, k, lang_queries)
  220. return lang_queries
  221. end,
  222. })
  223. --- Sets the runtime query named {query_name} for {lang}
  224. ---
  225. --- This allows users to override or extend any runtime files and/or configuration
  226. --- set by plugins.
  227. ---
  228. --- For example, you could enable spellchecking of `C` identifiers with the
  229. --- following code:
  230. --- ```lua
  231. --- vim.treesitter.query.set(
  232. --- 'c',
  233. --- 'highlights',
  234. --- [[;inherits c
  235. --- (identifier) @spell]])
  236. --- ]])
  237. --- ```
  238. ---
  239. ---@param lang string Language to use for the query
  240. ---@param query_name string Name of the query (e.g., "highlights")
  241. ---@param text string Query text (unparsed).
  242. function M.set(lang, query_name, text)
  243. --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics
  244. M.get:clear(lang, query_name)
  245. explicit_queries[lang][query_name] = text
  246. end
  247. --- Returns the runtime query {query_name} for {lang}.
  248. ---
  249. ---@param lang string Language to use for the query
  250. ---@param query_name string Name of the query (e.g. "highlights")
  251. ---
  252. ---@return vim.treesitter.Query? : Parsed query. `nil` if no query files are found.
  253. M.get = memoize('concat-2', function(lang, query_name)
  254. local query_string ---@type string
  255. if explicit_queries[lang][query_name] then
  256. local query_files = {}
  257. local base_langs = {} ---@type string[]
  258. for line in explicit_queries[lang][query_name]:gmatch('([^\n]*)\n?') do
  259. if not vim.startswith(line, ';') then
  260. break
  261. end
  262. local lang_list = line:match(MODELINE_FORMAT)
  263. if lang_list then
  264. for _, incl_lang in ipairs(vim.split(lang_list, ',')) do
  265. local is_optional = incl_lang:match('%(.*%)')
  266. if is_optional then
  267. add_included_lang(base_langs, lang, incl_lang:sub(2, #incl_lang - 1))
  268. else
  269. add_included_lang(base_langs, lang, incl_lang)
  270. end
  271. end
  272. elseif line:match(EXTENDS_FORMAT) then
  273. table.insert(base_langs, lang)
  274. end
  275. end
  276. for _, base_lang in ipairs(base_langs) do
  277. local base_files = M.get_files(base_lang, query_name, true)
  278. vim.list_extend(query_files, base_files)
  279. end
  280. query_string = read_query_files(query_files) .. explicit_queries[lang][query_name]
  281. else
  282. local query_files = M.get_files(lang, query_name)
  283. query_string = read_query_files(query_files)
  284. end
  285. if #query_string == 0 then
  286. return nil
  287. end
  288. return M.parse(lang, query_string)
  289. end, false)
  290. api.nvim_create_autocmd('OptionSet', {
  291. pattern = { 'runtimepath' },
  292. group = api.nvim_create_augroup('nvim.treesitter.query_cache_reset', { clear = true }),
  293. callback = function()
  294. --- @diagnostic disable-next-line: undefined-field LuaLS bad at generics
  295. M.get:clear()
  296. end,
  297. })
  298. --- Parses a {query} string and returns a `Query` object (|lua-treesitter-query|), which can be used
  299. --- to search the tree for the query patterns (via |Query:iter_captures()|, |Query:iter_matches()|),
  300. --- or inspect the query via these fields:
  301. --- - `captures`: a list of unique capture names defined in the query (alias: `info.captures`).
  302. --- - `info.patterns`: information about predicates.
  303. ---
  304. --- Example:
  305. --- ```lua
  306. --- local query = vim.treesitter.query.parse('vimdoc', [[
  307. --- ; query
  308. --- ((h1) @str
  309. --- (#trim! @str 1 1 1 1))
  310. --- ]])
  311. --- local tree = vim.treesitter.get_parser():parse()[1]
  312. --- for id, node, metadata in query:iter_captures(tree:root(), 0) do
  313. --- -- Print the node name and source text.
  314. --- vim.print({node:type(), vim.treesitter.get_node_text(node, vim.api.nvim_get_current_buf())})
  315. --- end
  316. --- ```
  317. ---
  318. ---@param lang string Language to use for the query
  319. ---@param query string Query text, in s-expr syntax
  320. ---
  321. ---@return vim.treesitter.Query : Parsed query
  322. ---
  323. ---@see [vim.treesitter.query.get()]
  324. M.parse = memoize('concat-2', function(lang, query)
  325. assert(language.add(lang))
  326. local ts_query = vim._ts_parse_query(lang, query)
  327. return Query.new(lang, ts_query)
  328. end, false)
  329. --- Implementations of predicates that can optionally be prefixed with "any-".
  330. ---
  331. --- These functions contain the implementations for each predicate, correctly
  332. --- handling the "any" vs "all" semantics. They are called from the
  333. --- predicate_handlers table with the appropriate arguments for each predicate.
  334. local impl = {
  335. --- @param match table<integer,TSNode[]>
  336. --- @param source integer|string
  337. --- @param predicate any[]
  338. --- @param any boolean
  339. ['eq'] = function(match, source, predicate, any)
  340. local nodes = match[predicate[2]]
  341. if not nodes or #nodes == 0 then
  342. return true
  343. end
  344. for _, node in ipairs(nodes) do
  345. local node_text = vim.treesitter.get_node_text(node, source)
  346. local str ---@type string
  347. if type(predicate[3]) == 'string' then
  348. -- (#eq? @aa "foo")
  349. str = predicate[3]
  350. else
  351. -- (#eq? @aa @bb)
  352. local other = assert(match[predicate[3]])
  353. assert(#other == 1, '#eq? does not support comparison with captures on multiple nodes')
  354. str = vim.treesitter.get_node_text(other[1], source)
  355. end
  356. local res = str ~= nil and node_text == str
  357. if any and res then
  358. return true
  359. elseif not any and not res then
  360. return false
  361. end
  362. end
  363. return not any
  364. end,
  365. --- @param match table<integer,TSNode[]>
  366. --- @param source integer|string
  367. --- @param predicate any[]
  368. --- @param any boolean
  369. ['lua-match'] = function(match, source, predicate, any)
  370. local nodes = match[predicate[2]]
  371. if not nodes or #nodes == 0 then
  372. return true
  373. end
  374. for _, node in ipairs(nodes) do
  375. local regex = predicate[3]
  376. local res = string.find(vim.treesitter.get_node_text(node, source), regex) ~= nil
  377. if any and res then
  378. return true
  379. elseif not any and not res then
  380. return false
  381. end
  382. end
  383. return not any
  384. end,
  385. ['match'] = (function()
  386. local magic_prefixes = { ['\\v'] = true, ['\\m'] = true, ['\\M'] = true, ['\\V'] = true }
  387. local function check_magic(str)
  388. if string.len(str) < 2 or magic_prefixes[string.sub(str, 1, 2)] then
  389. return str
  390. end
  391. return '\\v' .. str
  392. end
  393. local compiled_vim_regexes = setmetatable({}, {
  394. __index = function(t, pattern)
  395. local res = vim.regex(check_magic(pattern))
  396. rawset(t, pattern, res)
  397. return res
  398. end,
  399. })
  400. --- @param match table<integer,TSNode[]>
  401. --- @param source integer|string
  402. --- @param predicate any[]
  403. --- @param any boolean
  404. return function(match, source, predicate, any)
  405. local nodes = match[predicate[2]]
  406. if not nodes or #nodes == 0 then
  407. return true
  408. end
  409. for _, node in ipairs(nodes) do
  410. local regex = compiled_vim_regexes[predicate[3]] ---@type vim.regex
  411. local res = regex:match_str(vim.treesitter.get_node_text(node, source))
  412. if any and res then
  413. return true
  414. elseif not any and not res then
  415. return false
  416. end
  417. end
  418. return not any
  419. end
  420. end)(),
  421. --- @param match table<integer,TSNode[]>
  422. --- @param source integer|string
  423. --- @param predicate any[]
  424. --- @param any boolean
  425. ['contains'] = function(match, source, predicate, any)
  426. local nodes = match[predicate[2]]
  427. if not nodes or #nodes == 0 then
  428. return true
  429. end
  430. for _, node in ipairs(nodes) do
  431. local node_text = vim.treesitter.get_node_text(node, source)
  432. for i = 3, #predicate do
  433. local res = string.find(node_text, predicate[i], 1, true)
  434. if any and res then
  435. return true
  436. elseif not any and not res then
  437. return false
  438. end
  439. end
  440. end
  441. return not any
  442. end,
  443. }
  444. ---@alias TSPredicate fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[]): boolean
  445. -- Predicate handler receive the following arguments
  446. -- (match, pattern, bufnr, predicate)
  447. ---@type table<string,TSPredicate>
  448. local predicate_handlers = {
  449. ['eq?'] = function(match, _, source, predicate)
  450. return impl['eq'](match, source, predicate, false)
  451. end,
  452. ['any-eq?'] = function(match, _, source, predicate)
  453. return impl['eq'](match, source, predicate, true)
  454. end,
  455. ['lua-match?'] = function(match, _, source, predicate)
  456. return impl['lua-match'](match, source, predicate, false)
  457. end,
  458. ['any-lua-match?'] = function(match, _, source, predicate)
  459. return impl['lua-match'](match, source, predicate, true)
  460. end,
  461. ['match?'] = function(match, _, source, predicate)
  462. return impl['match'](match, source, predicate, false)
  463. end,
  464. ['any-match?'] = function(match, _, source, predicate)
  465. return impl['match'](match, source, predicate, true)
  466. end,
  467. ['contains?'] = function(match, _, source, predicate)
  468. return impl['contains'](match, source, predicate, false)
  469. end,
  470. ['any-contains?'] = function(match, _, source, predicate)
  471. return impl['contains'](match, source, predicate, true)
  472. end,
  473. ['any-of?'] = function(match, _, source, predicate)
  474. local nodes = match[predicate[2]]
  475. if not nodes or #nodes == 0 then
  476. return true
  477. end
  478. for _, node in ipairs(nodes) do
  479. local node_text = vim.treesitter.get_node_text(node, source)
  480. -- Since 'predicate' will not be used by callers of this function, use it
  481. -- to store a string set built from the list of words to check against.
  482. local string_set = predicate['string_set'] --- @type table<string, boolean>
  483. if not string_set then
  484. string_set = {}
  485. for i = 3, #predicate do
  486. string_set[predicate[i]] = true
  487. end
  488. predicate['string_set'] = string_set
  489. end
  490. if string_set[node_text] then
  491. return true
  492. end
  493. end
  494. return false
  495. end,
  496. ['has-ancestor?'] = function(match, _, _, predicate)
  497. local nodes = match[predicate[2]]
  498. if not nodes or #nodes == 0 then
  499. return true
  500. end
  501. for _, node in ipairs(nodes) do
  502. if node:__has_ancestor(predicate) then
  503. return true
  504. end
  505. end
  506. return false
  507. end,
  508. ['has-parent?'] = function(match, _, _, predicate)
  509. local nodes = match[predicate[2]]
  510. if not nodes or #nodes == 0 then
  511. return true
  512. end
  513. for _, node in ipairs(nodes) do
  514. if vim.list_contains({ unpack(predicate, 3) }, node:parent():type()) then
  515. return true
  516. end
  517. end
  518. return false
  519. end,
  520. }
  521. -- As we provide lua-match? also expose vim-match?
  522. predicate_handlers['vim-match?'] = predicate_handlers['match?']
  523. predicate_handlers['any-vim-match?'] = predicate_handlers['any-match?']
  524. ---@nodoc
  525. ---@class vim.treesitter.query.TSMetadata
  526. ---@field range? Range
  527. ---@field conceal? string
  528. ---@field [integer]? vim.treesitter.query.TSMetadata
  529. ---@field [string]? integer|string
  530. ---@alias TSDirective fun(match: table<integer,TSNode[]>, _, _, predicate: (string|integer)[], metadata: vim.treesitter.query.TSMetadata)
  531. -- Predicate handler receive the following arguments
  532. -- (match, pattern, bufnr, predicate)
  533. -- Directives store metadata or perform side effects against a match.
  534. -- Directives should always end with a `!`.
  535. -- Directive handler receive the following arguments
  536. -- (match, pattern, bufnr, predicate, metadata)
  537. ---@type table<string,TSDirective>
  538. local directive_handlers = {
  539. ['set!'] = function(_, _, _, pred, metadata)
  540. if #pred >= 3 and type(pred[2]) == 'number' then
  541. -- (#set! @capture key value)
  542. local capture_id, key, value = pred[2], pred[3], pred[4]
  543. if not metadata[capture_id] then
  544. metadata[capture_id] = {}
  545. end
  546. metadata[capture_id][key] = value
  547. else
  548. -- (#set! key value)
  549. local key, value = pred[2], pred[3]
  550. metadata[key] = value or true
  551. end
  552. end,
  553. -- Shifts the range of a node.
  554. -- Example: (#offset! @_node 0 1 0 -1)
  555. ['offset!'] = function(match, _, _, pred, metadata)
  556. local capture_id = pred[2] --[[@as integer]]
  557. local nodes = match[capture_id]
  558. if not nodes or #nodes == 0 then
  559. return
  560. end
  561. assert(#nodes == 1, '#offset! does not support captures on multiple nodes')
  562. local node = nodes[1]
  563. if not metadata[capture_id] then
  564. metadata[capture_id] = {}
  565. end
  566. local range = metadata[capture_id].range or { node:range() }
  567. local start_row_offset = pred[3] or 0
  568. local start_col_offset = pred[4] or 0
  569. local end_row_offset = pred[5] or 0
  570. local end_col_offset = pred[6] or 0
  571. range[1] = range[1] + start_row_offset
  572. range[2] = range[2] + start_col_offset
  573. range[3] = range[3] + end_row_offset
  574. range[4] = range[4] + end_col_offset
  575. -- If this produces an invalid range, we just skip it.
  576. if range[1] < range[3] or (range[1] == range[3] and range[2] <= range[4]) then
  577. metadata[capture_id].range = range
  578. end
  579. end,
  580. -- Transform the content of the node
  581. -- Example: (#gsub! @_node ".*%.(.*)" "%1")
  582. ['gsub!'] = function(match, _, bufnr, pred, metadata)
  583. assert(#pred == 4)
  584. local id = pred[2]
  585. assert(type(id) == 'number')
  586. local nodes = match[id]
  587. if not nodes or #nodes == 0 then
  588. return
  589. end
  590. assert(#nodes == 1, '#gsub! does not support captures on multiple nodes')
  591. local node = nodes[1]
  592. local text = vim.treesitter.get_node_text(node, bufnr, { metadata = metadata[id] }) or ''
  593. if not metadata[id] then
  594. metadata[id] = {}
  595. end
  596. local pattern, replacement = pred[3], pred[4]
  597. assert(type(pattern) == 'string')
  598. assert(type(replacement) == 'string')
  599. metadata[id].text = text:gsub(pattern, replacement)
  600. end,
  601. -- Trim whitespace from both sides of the node
  602. -- Example: (#trim! @fold 1 1 1 1)
  603. ['trim!'] = function(match, _, bufnr, pred, metadata)
  604. local capture_id = pred[2]
  605. assert(type(capture_id) == 'number')
  606. local trim_start_lines = pred[3] == '1'
  607. local trim_start_cols = pred[4] == '1'
  608. local trim_end_lines = pred[5] == '1' or not pred[3] -- default true for backwards compatibility
  609. local trim_end_cols = pred[6] == '1'
  610. local nodes = match[capture_id]
  611. if not nodes or #nodes == 0 then
  612. return
  613. end
  614. assert(#nodes == 1, '#trim! does not support captures on multiple nodes')
  615. local node = nodes[1]
  616. local start_row, start_col, end_row, end_col = node:range()
  617. local node_text = vim.split(vim.treesitter.get_node_text(node, bufnr), '\n')
  618. if end_col == 0 then
  619. -- get_node_text() will ignore the last line if the node ends at column 0
  620. node_text[#node_text + 1] = ''
  621. end
  622. local end_idx = #node_text
  623. local start_idx = 1
  624. if trim_end_lines then
  625. while end_idx > 0 and node_text[end_idx]:find('^%s*$') do
  626. end_idx = end_idx - 1
  627. end_row = end_row - 1
  628. -- set the end position to the last column of the next line, or 0 if we just trimmed the
  629. -- last line
  630. end_col = end_idx > 0 and #node_text[end_idx] or 0
  631. end
  632. end
  633. if trim_end_cols then
  634. if end_idx == 0 then
  635. end_row = start_row
  636. end_col = start_col
  637. else
  638. local whitespace_start = node_text[end_idx]:find('(%s*)$')
  639. end_col = (whitespace_start - 1) + (end_idx == 1 and start_col or 0)
  640. end
  641. end
  642. if trim_start_lines then
  643. while start_idx <= end_idx and node_text[start_idx]:find('^%s*$') do
  644. start_idx = start_idx + 1
  645. start_row = start_row + 1
  646. start_col = 0
  647. end
  648. end
  649. if trim_start_cols and node_text[start_idx] then
  650. local _, whitespace_end = node_text[start_idx]:find('^(%s*)')
  651. whitespace_end = whitespace_end or 0
  652. start_col = (start_idx == 1 and start_col or 0) + whitespace_end
  653. end
  654. -- If this produces an invalid range, we just skip it.
  655. if start_row < end_row or (start_row == end_row and start_col <= end_col) then
  656. metadata[capture_id] = metadata[capture_id] or {}
  657. metadata[capture_id].range = { start_row, start_col, end_row, end_col }
  658. end
  659. end,
  660. }
  661. --- @class vim.treesitter.query.add_predicate.Opts
  662. --- @inlinedoc
  663. ---
  664. --- Override an existing predicate of the same name
  665. --- @field force? boolean
  666. ---
  667. --- Use the correct implementation of the match table where capture IDs map to
  668. --- a list of nodes instead of a single node. Defaults to true. This option will
  669. --- be removed in a future release.
  670. --- @field all? boolean
  671. --- Adds a new predicate to be used in queries
  672. ---
  673. ---@param name string Name of the predicate, without leading #
  674. ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata): boolean?
  675. --- - see |vim.treesitter.query.add_directive()| for argument meanings
  676. ---@param opts? vim.treesitter.query.add_predicate.Opts
  677. function M.add_predicate(name, handler, opts)
  678. -- Backward compatibility: old signature had "force" as boolean argument
  679. if type(opts) == 'boolean' then
  680. opts = { force = opts }
  681. end
  682. opts = opts or {}
  683. if predicate_handlers[name] and not opts.force then
  684. error(string.format('Overriding existing predicate %s', name))
  685. end
  686. if opts.all ~= false then
  687. predicate_handlers[name] = handler
  688. else
  689. --- @param match table<integer, TSNode[]>
  690. local function wrapper(match, ...)
  691. local m = {} ---@type table<integer, TSNode>
  692. for k, v in pairs(match) do
  693. if type(k) == 'number' then
  694. m[k] = v[#v]
  695. end
  696. end
  697. return handler(m, ...)
  698. end
  699. predicate_handlers[name] = wrapper
  700. end
  701. end
  702. --- Adds a new directive to be used in queries
  703. ---
  704. --- Handlers can set match level data by setting directly on the
  705. --- metadata object `metadata.key = value`. Additionally, handlers
  706. --- can set node level data by using the capture id on the
  707. --- metadata table `metadata[capture_id].key = value`
  708. ---
  709. ---@param name string Name of the directive, without leading #
  710. ---@param handler fun(match: table<integer,TSNode[]>, pattern: integer, source: integer|string, predicate: any[], metadata: vim.treesitter.query.TSMetadata)
  711. --- - match: A table mapping capture IDs to a list of captured nodes
  712. --- - pattern: the index of the matching pattern in the query file
  713. --- - predicate: list of strings containing the full directive being called, e.g.
  714. --- `(node (#set! conceal "-"))` would get the predicate `{ "#set!", "conceal", "-" }`
  715. ---@param opts vim.treesitter.query.add_predicate.Opts
  716. function M.add_directive(name, handler, opts)
  717. -- Backward compatibility: old signature had "force" as boolean argument
  718. if type(opts) == 'boolean' then
  719. opts = { force = opts }
  720. end
  721. opts = opts or {}
  722. if directive_handlers[name] and not opts.force then
  723. error(string.format('Overriding existing directive %s', name))
  724. end
  725. if opts.all then
  726. directive_handlers[name] = handler
  727. else
  728. --- @param match table<integer, TSNode[]>
  729. local function wrapper(match, ...)
  730. local m = {} ---@type table<integer, TSNode>
  731. for k, v in pairs(match) do
  732. m[k] = v[#v]
  733. end
  734. handler(m, ...)
  735. end
  736. directive_handlers[name] = wrapper
  737. end
  738. end
  739. --- Lists the currently available directives to use in queries.
  740. ---@return string[] : Supported directives.
  741. function M.list_directives()
  742. return vim.tbl_keys(directive_handlers)
  743. end
  744. --- Lists the currently available predicates to use in queries.
  745. ---@return string[] : Supported predicates.
  746. function M.list_predicates()
  747. return vim.tbl_keys(predicate_handlers)
  748. end
  749. ---@private
  750. ---@param pattern_i integer
  751. ---@param predicates vim.treesitter.query.ProcessedPredicate[]
  752. ---@param captures table<integer, TSNode[]>
  753. ---@param source integer|string
  754. ---@return boolean whether the predicates match
  755. function Query:_match_predicates(predicates, pattern_i, captures, source)
  756. for _, predicate in ipairs(predicates) do
  757. local processed_name = predicate[1]
  758. local should_match = predicate[2]
  759. local orig_predicate = predicate[3]
  760. local handler = predicate_handlers[processed_name]
  761. if not handler then
  762. error(string.format('No handler for %s', orig_predicate[1]))
  763. return false
  764. end
  765. local does_match = handler(captures, pattern_i, source, orig_predicate)
  766. if does_match ~= should_match then
  767. return false
  768. end
  769. end
  770. return true
  771. end
  772. ---@private
  773. ---@param pattern_i integer
  774. ---@param directives vim.treesitter.query.ProcessedDirective[]
  775. ---@param source integer|string
  776. ---@param captures table<integer, TSNode[]>
  777. ---@return vim.treesitter.query.TSMetadata metadata
  778. function Query:_apply_directives(directives, pattern_i, captures, source)
  779. ---@type vim.treesitter.query.TSMetadata
  780. local metadata = {}
  781. for _, directive in pairs(directives) do
  782. local handler = directive_handlers[directive[1]]
  783. if not handler then
  784. error(string.format('No handler for %s', directive[1]))
  785. end
  786. handler(captures, pattern_i, source, directive, metadata)
  787. end
  788. return metadata
  789. end
  790. --- Returns the start and stop value if set else the node's range.
  791. -- When the node's range is used, the stop is incremented by 1
  792. -- to make the search inclusive.
  793. ---@param start integer?
  794. ---@param stop integer?
  795. ---@param node TSNode
  796. ---@return integer, integer
  797. local function value_or_node_range(start, stop, node)
  798. if start == nil then
  799. start = node:start()
  800. end
  801. if stop == nil then
  802. stop = node:end_() + 1 -- Make stop inclusive
  803. end
  804. return start, stop
  805. end
  806. --- Iterates over all captures from all matches in {node}.
  807. ---
  808. --- {source} is required if the query contains predicates; then the caller
  809. --- must ensure to use a freshly parsed tree consistent with the current
  810. --- text of the buffer (if relevant). {start} and {stop} can be used to limit
  811. --- matches inside a row range (this is typically used with root node
  812. --- as the {node}, i.e., to get syntax highlight matches in the current
  813. --- viewport). When omitted, the {start} and {stop} row values are used from the given node.
  814. ---
  815. --- The iterator returns four values:
  816. --- 1. the numeric id identifying the capture
  817. --- 2. the captured node
  818. --- 3. metadata from any directives processing the match
  819. --- 4. the match itself
  820. ---
  821. --- Example: how to get captures by name:
  822. --- ```lua
  823. --- for id, node, metadata, match in query:iter_captures(tree:root(), bufnr, first, last) do
  824. --- local name = query.captures[id] -- name of the capture in the query
  825. --- -- typically useful info about the node:
  826. --- local type = node:type() -- type of the captured node
  827. --- local row1, col1, row2, col2 = node:range() -- range of the capture
  828. --- -- ... use the info here ...
  829. --- end
  830. --- ```
  831. ---
  832. ---@param node TSNode under which the search will occur
  833. ---@param source (integer|string) Source buffer or string to extract text from
  834. ---@param start? integer Starting line for the search. Defaults to `node:start()`.
  835. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
  836. ---
  837. ---@return (fun(end_line: integer|nil): integer, TSNode, vim.treesitter.query.TSMetadata, TSQueryMatch, TSTree):
  838. --- capture id, capture node, metadata, match, tree
  839. ---
  840. ---@note Captures are only returned if the query pattern of a specific capture contained predicates.
  841. function Query:iter_captures(node, source, start, stop)
  842. if type(source) == 'number' and source == 0 then
  843. source = api.nvim_get_current_buf()
  844. end
  845. start, stop = value_or_node_range(start, stop, node)
  846. -- Copy the tree to ensure it is valid during the entire lifetime of the iterator
  847. local tree = node:tree():copy()
  848. local cursor = vim._create_ts_querycursor(node, self.query, start, stop, { match_limit = 256 })
  849. -- For faster checks that a match is not in the cache.
  850. local highest_cached_match_id = -1
  851. ---@type table<integer, vim.treesitter.query.TSMetadata>
  852. local match_cache = {}
  853. local function iter(end_line)
  854. local capture, captured_node, match = cursor:next_capture()
  855. if not capture then
  856. return
  857. end
  858. local match_id, pattern_i = match:info()
  859. --- @type vim.treesitter.query.TSMetadata
  860. local metadata
  861. if match_id <= highest_cached_match_id then
  862. metadata = match_cache[match_id]
  863. end
  864. if not metadata then
  865. metadata = {}
  866. local processed_pattern = self._processed_patterns[pattern_i]
  867. if processed_pattern then
  868. local captures = match:captures()
  869. local predicates = processed_pattern.predicates
  870. if not self:_match_predicates(predicates, pattern_i, captures, source) then
  871. cursor:remove_match(match_id)
  872. if end_line and captured_node:range() > end_line then
  873. return nil, captured_node, nil, nil
  874. end
  875. return iter(end_line) -- tail call: try next match
  876. end
  877. local directives = processed_pattern.directives
  878. metadata = self:_apply_directives(directives, pattern_i, captures, source)
  879. end
  880. highest_cached_match_id = math.max(highest_cached_match_id, match_id)
  881. match_cache[match_id] = metadata
  882. end
  883. return capture, captured_node, metadata, match, tree
  884. end
  885. return iter
  886. end
  887. --- Iterates the matches of self on a given range.
  888. ---
  889. --- Iterate over all matches within a {node}. The arguments are the same as for
  890. --- |Query:iter_captures()| but the iterated values are different: an (1-based)
  891. --- index of the pattern in the query, a table mapping capture indices to a list
  892. --- of nodes, and metadata from any directives processing the match.
  893. ---
  894. --- Example:
  895. ---
  896. --- ```lua
  897. --- for pattern, match, metadata in cquery:iter_matches(tree:root(), bufnr, 0, -1) do
  898. --- for id, nodes in pairs(match) do
  899. --- local name = query.captures[id]
  900. --- for _, node in ipairs(nodes) do
  901. --- -- `node` was captured by the `name` capture in the match
  902. ---
  903. --- local node_data = metadata[id] -- Node level metadata
  904. --- -- ... use the info here ...
  905. --- end
  906. --- end
  907. --- end
  908. --- ```
  909. ---
  910. ---
  911. ---@param node TSNode under which the search will occur
  912. ---@param source (integer|string) Source buffer or string to search
  913. ---@param start? integer Starting line for the search. Defaults to `node:start()`.
  914. ---@param stop? integer Stopping line for the search (end-exclusive). Defaults to `node:end_()`.
  915. ---@param opts? table Optional keyword arguments:
  916. --- - max_start_depth (integer) if non-zero, sets the maximum start depth
  917. --- for each match. This is used to prevent traversing too deep into a tree.
  918. --- - match_limit (integer) Set the maximum number of in-progress matches (Default: 256).
  919. --- - all (boolean) When `false` (default `true`), the returned table maps capture IDs to a single
  920. --- (last) node instead of the full list of matching nodes. This option is only for backward
  921. --- compatibility and will be removed in a future release.
  922. ---
  923. ---@return (fun(): integer, table<integer, TSNode[]>, vim.treesitter.query.TSMetadata, TSTree): pattern id, match, metadata, tree
  924. function Query:iter_matches(node, source, start, stop, opts)
  925. opts = opts or {}
  926. opts.match_limit = opts.match_limit or 256
  927. if type(source) == 'number' and source == 0 then
  928. source = api.nvim_get_current_buf()
  929. end
  930. start, stop = value_or_node_range(start, stop, node)
  931. -- Copy the tree to ensure it is valid during the entire lifetime of the iterator
  932. local tree = node:tree():copy()
  933. local cursor = vim._create_ts_querycursor(node, self.query, start, stop, opts)
  934. local function iter()
  935. local match = cursor:next_match()
  936. if not match then
  937. return
  938. end
  939. local match_id, pattern_i = match:info()
  940. local processed_pattern = self._processed_patterns[pattern_i]
  941. local captures = match:captures()
  942. --- @type vim.treesitter.query.TSMetadata
  943. local metadata = {}
  944. if processed_pattern then
  945. local predicates = processed_pattern.predicates
  946. if not self:_match_predicates(predicates, pattern_i, captures, source) then
  947. cursor:remove_match(match_id)
  948. return iter() -- tail call: try next match
  949. end
  950. local directives = processed_pattern.directives
  951. metadata = self:_apply_directives(directives, pattern_i, captures, source)
  952. end
  953. if opts.all == false then
  954. -- Convert the match table into the old buggy version for backward
  955. -- compatibility. This is slow, but we only do it when the caller explicitly opted into it by
  956. -- setting `all` to `false`.
  957. local old_match = {} ---@type table<integer, TSNode>
  958. for k, v in pairs(captures or {}) do
  959. old_match[k] = v[#v]
  960. end
  961. return pattern_i, old_match, metadata
  962. end
  963. -- TODO(lewis6991): create a new function that returns {match, metadata}
  964. return pattern_i, captures, metadata, tree
  965. end
  966. return iter
  967. end
  968. --- Optional keyword arguments:
  969. --- @class vim.treesitter.query.lint.Opts
  970. --- @inlinedoc
  971. ---
  972. --- Language(s) to use for checking the query.
  973. --- If multiple languages are specified, queries are validated for all of them
  974. --- @field langs? string|string[]
  975. ---
  976. --- Just clear current lint errors
  977. --- @field clear boolean
  978. --- Lint treesitter queries using installed parser, or clear lint errors.
  979. ---
  980. --- Use |treesitter-parsers| in runtimepath to check the query file in {buf} for errors:
  981. ---
  982. --- - verify that used nodes are valid identifiers in the grammar.
  983. --- - verify that predicates and directives are valid.
  984. --- - verify that top-level s-expressions are valid.
  985. ---
  986. --- The found diagnostics are reported using |diagnostic-api|.
  987. --- By default, the parser used for verification is determined by the containing folder
  988. --- of the query file, e.g., if the path ends in `/lua/highlights.scm`, the parser for the
  989. --- `lua` language will be used.
  990. ---@param buf (integer) Buffer handle
  991. ---@param opts? vim.treesitter.query.lint.Opts
  992. function M.lint(buf, opts)
  993. if opts and opts.clear then
  994. vim.treesitter._query_linter.clear(buf)
  995. else
  996. vim.treesitter._query_linter.lint(buf, opts)
  997. end
  998. end
  999. --- Omnifunc for completing node names and predicates in treesitter queries.
  1000. ---
  1001. --- Use via
  1002. ---
  1003. --- ```lua
  1004. --- vim.bo.omnifunc = 'v:lua.vim.treesitter.query.omnifunc'
  1005. --- ```
  1006. ---
  1007. --- @param findstart 0|1
  1008. --- @param base string
  1009. function M.omnifunc(findstart, base)
  1010. return vim.treesitter._query_linter.omnifunc(findstart, base)
  1011. end
  1012. --- Opens a live editor to query the buffer you started from.
  1013. ---
  1014. --- Can also be shown with [:EditQuery]().
  1015. ---
  1016. --- If you move the cursor to a capture name ("@foo"), text matching the capture is highlighted in
  1017. --- the source buffer. The query editor is a scratch buffer, use `:write` to save it. You can find
  1018. --- example queries at `$VIMRUNTIME/queries/`.
  1019. ---
  1020. --- @param lang? string language to open the query editor for. If omitted, inferred from the current buffer's filetype.
  1021. function M.edit(lang)
  1022. assert(vim.treesitter.dev.edit_query(lang))
  1023. end
  1024. return M