_query_linter.lua 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. local api = vim.api
  2. local namespace = api.nvim_create_namespace('vim.treesitter.query_linter')
  3. local M = {}
  4. --- @class QueryLinterNormalizedOpts
  5. --- @field langs string[]
  6. --- @field clear boolean
  7. --- @alias vim.treesitter.ParseError {msg: string, range: Range4}
  8. --- Contains language dependent context for the query linter
  9. --- @class QueryLinterLanguageContext
  10. --- @field lang string? Current `lang` of the targeted parser
  11. --- @field parser_info table? Parser info returned by vim.treesitter.language.inspect
  12. --- @field is_first_lang boolean Whether this is the first language of a linter run checking queries for multiple `langs`
  13. --- Adds a diagnostic for node in the query buffer
  14. --- @param diagnostics vim.Diagnostic[]
  15. --- @param range Range4
  16. --- @param lint string
  17. --- @param lang string?
  18. local function add_lint_for_node(diagnostics, range, lint, lang)
  19. local message = lint:gsub('\n', ' ')
  20. diagnostics[#diagnostics + 1] = {
  21. lnum = range[1],
  22. end_lnum = range[3],
  23. col = range[2],
  24. end_col = range[4],
  25. severity = vim.diagnostic.ERROR,
  26. message = message,
  27. source = lang,
  28. }
  29. end
  30. --- Determines the target language of a query file by its path: <lang>/<query_type>.scm
  31. --- @param buf integer
  32. --- @return string?
  33. local function guess_query_lang(buf)
  34. local filename = api.nvim_buf_get_name(buf)
  35. if filename ~= '' then
  36. local resolved_filename = vim.F.npcall(vim.fn.fnamemodify, filename, ':p:h:t')
  37. return resolved_filename and vim.treesitter.language.get_lang(resolved_filename)
  38. end
  39. end
  40. --- @param buf integer
  41. --- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil
  42. --- @return QueryLinterNormalizedOpts
  43. local function normalize_opts(buf, opts)
  44. opts = opts or {}
  45. if not opts.langs then
  46. opts.langs = guess_query_lang(buf)
  47. end
  48. if type(opts.langs) ~= 'table' then
  49. --- @diagnostic disable-next-line:assign-type-mismatch
  50. opts.langs = { opts.langs }
  51. end
  52. --- @cast opts QueryLinterNormalizedOpts
  53. opts.langs = opts.langs or {}
  54. return opts
  55. end
  56. local lint_query = [[;; query
  57. (program [(named_node) (anonymous_node) (list) (grouping)] @toplevel)
  58. (named_node
  59. name: _ @node.named)
  60. (anonymous_node
  61. name: _ @node.anonymous)
  62. (field_definition
  63. name: (identifier) @field)
  64. (predicate
  65. name: (identifier) @predicate.name
  66. type: (predicate_type) @predicate.type)
  67. (ERROR) @error
  68. ]]
  69. --- @param err string
  70. --- @param node TSNode
  71. --- @return vim.treesitter.ParseError
  72. local function get_error_entry(err, node)
  73. local start_line, start_col = node:range()
  74. local line_offset, col_offset, msg = err:gmatch('.-:%d+: Query error at (%d+):(%d+)%. ([^:]+)')() ---@type string, string, string
  75. start_line, start_col =
  76. start_line + tonumber(line_offset) - 1, start_col + tonumber(col_offset) - 1
  77. local end_line, end_col = start_line, start_col
  78. if msg:match('^Invalid syntax') or msg:match('^Impossible') then
  79. -- Use the length of the underlined node
  80. local underlined = vim.split(err, '\n')[2]
  81. end_col = end_col + #underlined
  82. elseif msg:match('^Invalid') then
  83. -- Use the length of the problematic type/capture/field
  84. end_col = end_col + #(msg:match('"([^"]+)"') or '')
  85. end
  86. return {
  87. msg = msg,
  88. range = { start_line, start_col, end_line, end_col },
  89. }
  90. end
  91. --- @param node TSNode
  92. --- @param buf integer
  93. --- @param lang string
  94. local function hash_parse(node, buf, lang)
  95. return tostring(node:id()) .. tostring(buf) .. tostring(vim.b[buf].changedtick) .. lang
  96. end
  97. --- @param node TSNode
  98. --- @param buf integer
  99. --- @param lang string
  100. --- @return vim.treesitter.ParseError?
  101. local parse = vim.func._memoize(hash_parse, function(node, buf, lang)
  102. local query_text = vim.treesitter.get_node_text(node, buf)
  103. local ok, err = pcall(vim.treesitter.query.parse, lang, query_text) ---@type boolean|vim.treesitter.ParseError, string|vim.treesitter.Query
  104. if not ok and type(err) == 'string' then
  105. return get_error_entry(err, node)
  106. end
  107. end)
  108. --- @param buf integer
  109. --- @param match table<integer,TSNode[]>
  110. --- @param query vim.treesitter.Query
  111. --- @param lang_context QueryLinterLanguageContext
  112. --- @param diagnostics vim.Diagnostic[]
  113. local function lint_match(buf, match, query, lang_context, diagnostics)
  114. local lang = lang_context.lang
  115. local parser_info = lang_context.parser_info
  116. for id, nodes in pairs(match) do
  117. for _, node in ipairs(nodes) do
  118. local cap_id = query.captures[id]
  119. -- perform language-independent checks only for first lang
  120. if lang_context.is_first_lang and cap_id == 'error' then
  121. local node_text = vim.treesitter.get_node_text(node, buf):gsub('\n', ' ')
  122. add_lint_for_node(diagnostics, { node:range() }, 'Syntax error: ' .. node_text)
  123. end
  124. -- other checks rely on Neovim parser introspection
  125. if lang and parser_info and cap_id == 'toplevel' then
  126. local err = parse(node, buf, lang)
  127. if err then
  128. add_lint_for_node(diagnostics, err.range, err.msg, lang)
  129. end
  130. end
  131. end
  132. end
  133. end
  134. --- @private
  135. --- @param buf integer Buffer to lint
  136. --- @param opts vim.treesitter.query.lint.Opts|QueryLinterNormalizedOpts|nil Options for linting
  137. function M.lint(buf, opts)
  138. if buf == 0 then
  139. buf = api.nvim_get_current_buf()
  140. end
  141. local diagnostics = {}
  142. local query = vim.treesitter.query.parse('query', lint_query)
  143. opts = normalize_opts(buf, opts)
  144. -- perform at least one iteration even with no langs to perform language independent checks
  145. for i = 1, math.max(1, #opts.langs) do
  146. local lang = opts.langs[i]
  147. --- @type (table|nil)
  148. local parser_info = vim.F.npcall(vim.treesitter.language.inspect, lang)
  149. local lang_context = {
  150. lang = lang,
  151. parser_info = parser_info,
  152. is_first_lang = i == 1,
  153. }
  154. local parser = assert(vim.treesitter.get_parser(buf, nil, { error = false }))
  155. parser:parse()
  156. parser:for_each_tree(function(tree, ltree)
  157. if ltree:lang() == 'query' then
  158. for _, match, _ in query:iter_matches(tree:root(), buf, 0, -1) do
  159. lint_match(buf, match, query, lang_context, diagnostics)
  160. end
  161. end
  162. end)
  163. end
  164. vim.diagnostic.set(namespace, buf, diagnostics)
  165. end
  166. --- @private
  167. --- @param buf integer
  168. function M.clear(buf)
  169. vim.diagnostic.reset(namespace, buf)
  170. end
  171. --- @private
  172. --- @param findstart 0|1
  173. --- @param base string
  174. function M.omnifunc(findstart, base)
  175. if findstart == 1 then
  176. local result =
  177. api.nvim_get_current_line():sub(1, api.nvim_win_get_cursor(0)[2]):find('["#%-%w]*$')
  178. return result - 1
  179. end
  180. local buf = api.nvim_get_current_buf()
  181. local query_lang = guess_query_lang(buf)
  182. local ok, parser_info = pcall(vim.treesitter.language.inspect, query_lang)
  183. if not ok then
  184. return -2
  185. end
  186. local items = {}
  187. for _, f in pairs(parser_info.fields) do
  188. if f:find(base, 1, true) then
  189. table.insert(items, f .. ':')
  190. end
  191. end
  192. for _, p in pairs(vim.treesitter.query.list_predicates()) do
  193. local text = '#' .. p
  194. local found = text:find(base, 1, true)
  195. if found and found <= 2 then -- with or without '#'
  196. table.insert(items, text)
  197. end
  198. text = '#not-' .. p
  199. found = text:find(base, 1, true)
  200. if found and found <= 2 then -- with or without '#'
  201. table.insert(items, text)
  202. end
  203. end
  204. for _, p in pairs(vim.treesitter.query.list_directives()) do
  205. local text = '#' .. p
  206. local found = text:find(base, 1, true)
  207. if found and found <= 2 then -- with or without '#'
  208. table.insert(items, text)
  209. end
  210. end
  211. for text, named in
  212. pairs(parser_info.symbols --[[@as table<string, boolean>]])
  213. do
  214. if not named then
  215. text = string.format('%q', text:sub(2, -2)):gsub('\n', 'n') ---@type string
  216. end
  217. if text:find(base, 1, true) then
  218. table.insert(items, text)
  219. end
  220. end
  221. return { words = items, refresh = 'always' }
  222. end
  223. return M