_fold.lua 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. local ts = vim.treesitter
  2. local Range = require('vim.treesitter._range')
  3. local api = vim.api
  4. ---Treesitter folding is done in two steps:
  5. ---(1) compute the fold levels with the syntax tree and cache the result (`compute_folds_levels`)
  6. ---(2) evaluate foldexpr for each window, which reads from the cache (`foldupdate`)
  7. ---@class TS.FoldInfo
  8. ---
  9. ---@field levels string[] the cached foldexpr result for each line
  10. ---@field levels0 integer[] the cached raw fold levels
  11. ---
  12. ---The range edited since the last invocation of the callback scheduled in on_bytes.
  13. ---Should compute fold levels in this range.
  14. ---@field on_bytes_range? Range2
  15. ---
  16. ---The range on which to evaluate foldexpr.
  17. ---When in insert mode, the evaluation is deferred to InsertLeave.
  18. ---@field foldupdate_range? Range2
  19. ---
  20. ---The treesitter parser associated with this buffer.
  21. ---@field parser? vim.treesitter.LanguageTree
  22. local FoldInfo = {}
  23. FoldInfo.__index = FoldInfo
  24. ---@private
  25. ---@param bufnr integer
  26. function FoldInfo.new(bufnr)
  27. return setmetatable({
  28. levels0 = {},
  29. levels = {},
  30. parser = ts.get_parser(bufnr, nil, { error = false }),
  31. }, FoldInfo)
  32. end
  33. ---@package
  34. ---@param srow integer
  35. ---@param erow integer 0-indexed, exclusive
  36. function FoldInfo:remove_range(srow, erow)
  37. vim._list_remove(self.levels, srow + 1, erow)
  38. vim._list_remove(self.levels0, srow + 1, erow)
  39. end
  40. ---@package
  41. ---@param srow integer
  42. ---@param erow integer 0-indexed, exclusive
  43. function FoldInfo:add_range(srow, erow)
  44. vim._list_insert(self.levels, srow + 1, erow, -1)
  45. vim._list_insert(self.levels0, srow + 1, erow, -1)
  46. end
  47. ---@param range Range2
  48. ---@param srow integer
  49. ---@param erow_old integer
  50. ---@param erow_new integer 0-indexed, exclusive
  51. local function edit_range(range, srow, erow_old, erow_new)
  52. range[1] = math.min(srow, range[1])
  53. if erow_old <= range[2] then
  54. range[2] = range[2] + (erow_new - erow_old)
  55. end
  56. range[2] = math.max(range[2], erow_new)
  57. end
  58. -- TODO(lewis6991): Setup a decor provider so injections folds can be parsed
  59. -- as the window is redrawn
  60. ---@param bufnr integer
  61. ---@param info TS.FoldInfo
  62. ---@param srow integer?
  63. ---@param erow integer? 0-indexed, exclusive
  64. ---@param callback function?
  65. local function compute_folds_levels(bufnr, info, srow, erow, callback)
  66. srow = srow or 0
  67. erow = erow or api.nvim_buf_line_count(bufnr)
  68. local parser = info.parser
  69. if not parser then
  70. return
  71. end
  72. parser:parse(nil, function(_, trees)
  73. if not trees then
  74. return
  75. end
  76. local enter_counts = {} ---@type table<integer, integer>
  77. local leave_counts = {} ---@type table<integer, integer>
  78. local prev_start = -1
  79. local prev_stop = -1
  80. parser:for_each_tree(function(tree, ltree)
  81. local query = ts.query.get(ltree:lang(), 'folds')
  82. if not query then
  83. return
  84. end
  85. -- Collect folds starting from srow - 1, because we should first subtract the folds that end at
  86. -- srow - 1 from the level of srow - 1 to get accurate level of srow.
  87. for _, match, metadata in query:iter_matches(tree:root(), bufnr, math.max(srow - 1, 0), erow) do
  88. for id, nodes in pairs(match) do
  89. if query.captures[id] == 'fold' then
  90. local range = ts.get_range(nodes[1], bufnr, metadata[id])
  91. local start, _, stop, stop_col = Range.unpack4(range)
  92. if #nodes > 1 then
  93. -- assumes nodes are ordered by range
  94. local end_range = ts.get_range(nodes[#nodes], bufnr, metadata[id])
  95. local _, _, end_stop, end_stop_col = Range.unpack4(end_range)
  96. stop = end_stop
  97. stop_col = end_stop_col
  98. end
  99. if stop_col == 0 then
  100. stop = stop - 1
  101. end
  102. local fold_length = stop - start + 1
  103. -- Fold only multiline nodes that are not exactly the same as previously met folds
  104. -- Checking against just the previously found fold is sufficient if nodes
  105. -- are returned in preorder or postorder when traversing tree
  106. if
  107. fold_length > vim.wo.foldminlines and not (start == prev_start and stop == prev_stop)
  108. then
  109. enter_counts[start + 1] = (enter_counts[start + 1] or 0) + 1
  110. leave_counts[stop + 1] = (leave_counts[stop + 1] or 0) + 1
  111. prev_start = start
  112. prev_stop = stop
  113. end
  114. end
  115. end
  116. end
  117. end)
  118. local nestmax = vim.wo.foldnestmax
  119. local level0_prev = info.levels0[srow] or 0
  120. local leave_prev = leave_counts[srow] or 0
  121. -- We now have the list of fold opening and closing, fill the gaps and mark where fold start
  122. for lnum = srow + 1, erow do
  123. local enter_line = enter_counts[lnum] or 0
  124. local leave_line = leave_counts[lnum] or 0
  125. local level0 = level0_prev - leave_prev + enter_line
  126. -- Determine if it's the start/end of a fold
  127. -- NB: vim's fold-expr interface does not have a mechanism to indicate that
  128. -- two (or more) folds start at this line, so it cannot distinguish between
  129. -- ( \n ( \n )) \n (( \n ) \n )
  130. -- versus
  131. -- ( \n ( \n ) \n ( \n ) \n )
  132. -- Both are represented by ['>1', '>2', '2', '>2', '2', '1'], and
  133. -- vim interprets as the second case.
  134. -- If it did have such a mechanism, (clamped - clamped_prev)
  135. -- would be the correct number of starts to pass on.
  136. local adjusted = level0 ---@type integer
  137. local prefix = ''
  138. if enter_line > 0 then
  139. prefix = '>'
  140. if leave_line > 0 then
  141. -- If this line ends a fold f1 and starts a fold f2, then move f1's end to the previous line
  142. -- so that f2 gets the correct level on this line. This may reduce the size of f1 below
  143. -- foldminlines, but we don't handle it for simplicity.
  144. adjusted = level0 - leave_line
  145. leave_line = 0
  146. end
  147. end
  148. -- Clamp at foldnestmax.
  149. local clamped = adjusted
  150. if adjusted > nestmax then
  151. prefix = ''
  152. clamped = nestmax
  153. end
  154. -- Record the "real" level, so that it can be used as "base" of later compute_folds_levels().
  155. info.levels0[lnum] = adjusted
  156. info.levels[lnum] = prefix .. tostring(clamped)
  157. leave_prev = leave_line
  158. level0_prev = adjusted
  159. end
  160. if callback then
  161. callback()
  162. end
  163. end)
  164. end
  165. local M = {}
  166. ---@type table<integer,TS.FoldInfo>
  167. local foldinfos = {}
  168. local group = api.nvim_create_augroup('nvim.treesitter.fold', {})
  169. --- Update the folds in the windows that contain the buffer and use expr foldmethod (assuming that
  170. --- the user doesn't use different foldexpr for the same buffer).
  171. ---
  172. --- Nvim usually automatically updates folds when text changes, but it doesn't work here because
  173. --- FoldInfo update is scheduled. So we do it manually.
  174. ---@package
  175. ---@param srow integer
  176. ---@param erow integer 0-indexed, exclusive
  177. function FoldInfo:foldupdate(bufnr, srow, erow)
  178. if self.foldupdate_range then
  179. edit_range(self.foldupdate_range, srow, erow, erow)
  180. else
  181. self.foldupdate_range = { srow, erow }
  182. end
  183. if api.nvim_get_mode().mode:match('^i') then
  184. -- foldUpdate() is guarded in insert mode. So update folds on InsertLeave
  185. if #(api.nvim_get_autocmds({
  186. group = group,
  187. buffer = bufnr,
  188. })) > 0 then
  189. return
  190. end
  191. api.nvim_create_autocmd('InsertLeave', {
  192. group = group,
  193. buffer = bufnr,
  194. once = true,
  195. callback = function()
  196. self:do_foldupdate(bufnr)
  197. end,
  198. })
  199. return
  200. end
  201. self:do_foldupdate(bufnr)
  202. end
  203. ---@package
  204. function FoldInfo:do_foldupdate(bufnr)
  205. -- InsertLeave is not executed when <C-C> is used for exiting the insert mode, leaving
  206. -- do_foldupdate untouched. If another execution of foldupdate consumes foldupdate_range, the
  207. -- InsertLeave do_foldupdate gets nil foldupdate_range. In that case, skip the update. This is
  208. -- correct because the update that consumed the range must have incorporated the range that
  209. -- InsertLeave meant to update.
  210. if not self.foldupdate_range then
  211. return
  212. end
  213. local srow, erow = self.foldupdate_range[1], self.foldupdate_range[2]
  214. self.foldupdate_range = nil
  215. for _, win in ipairs(vim.fn.win_findbuf(bufnr)) do
  216. if vim.wo[win].foldmethod == 'expr' then
  217. vim._foldupdate(win, srow, erow)
  218. end
  219. end
  220. end
  221. --- Schedule a function only if bufnr is loaded.
  222. --- We schedule fold level computation for the following reasons:
  223. --- * queries seem to use the old buffer state in on_bytes for some unknown reason;
  224. --- * to avoid textlock;
  225. --- * to avoid infinite recursion:
  226. --- compute_folds_levels → parse → _do_callback → on_changedtree → compute_folds_levels.
  227. ---@param bufnr integer
  228. ---@param fn function
  229. local function schedule_if_loaded(bufnr, fn)
  230. vim.schedule(function()
  231. if not api.nvim_buf_is_loaded(bufnr) then
  232. return
  233. end
  234. fn()
  235. end)
  236. end
  237. ---@param bufnr integer
  238. ---@param foldinfo TS.FoldInfo
  239. ---@param tree_changes Range4[]
  240. local function on_changedtree(bufnr, foldinfo, tree_changes)
  241. schedule_if_loaded(bufnr, function()
  242. local srow_upd, erow_upd ---@type integer?, integer?
  243. local max_erow = api.nvim_buf_line_count(bufnr)
  244. -- TODO(ribru17): Replace this with a proper .all() awaiter once #19624 is resolved
  245. local iterations = 0
  246. for _, change in ipairs(tree_changes) do
  247. local srow, _, erow, ecol = Range.unpack4(change)
  248. -- If a parser doesn't have any ranges explicitly set, treesitter will
  249. -- return a range with end_row and end_bytes with a value of UINT32_MAX,
  250. -- so clip end_row to the max buffer line.
  251. -- TODO(lewis6991): Handle this generally
  252. if erow > max_erow then
  253. erow = max_erow
  254. elseif ecol > 0 then
  255. erow = erow + 1
  256. end
  257. -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
  258. srow = math.max(srow - vim.wo.foldminlines, 0)
  259. srow_upd = srow_upd and math.min(srow_upd, srow) or srow
  260. erow_upd = erow_upd and math.max(erow_upd, erow) or erow
  261. compute_folds_levels(bufnr, foldinfo, srow, erow, function()
  262. iterations = iterations + 1
  263. if iterations == #tree_changes then
  264. foldinfo:foldupdate(bufnr, srow_upd, erow_upd)
  265. end
  266. end)
  267. end
  268. end)
  269. end
  270. ---@param bufnr integer
  271. ---@param foldinfo TS.FoldInfo
  272. ---@param start_row integer
  273. ---@param old_row integer
  274. ---@param old_col integer
  275. ---@param new_row integer
  276. ---@param new_col integer
  277. local function on_bytes(bufnr, foldinfo, start_row, start_col, old_row, old_col, new_row, new_col)
  278. -- extend the end to fully include the range
  279. local end_row_old = start_row + old_row + 1
  280. local end_row_new = start_row + new_row + 1
  281. if new_row ~= old_row then
  282. -- foldexpr can be evaluated before the scheduled callback is invoked. So it may observe the
  283. -- outdated levels, which may spuriously open the folds that didn't change. So we should shift
  284. -- folds as accurately as possible. For this to be perfectly accurate, we should track the
  285. -- actual TSNodes that account for each fold, and compare the node's range with the edited
  286. -- range. But for simplicity, we just check whether the start row is completely removed (e.g.,
  287. -- `dd`) or shifted (e.g., `o`).
  288. if new_row < old_row then
  289. if start_col == 0 and new_row == 0 and new_col == 0 then
  290. foldinfo:remove_range(start_row, start_row + (end_row_old - end_row_new))
  291. else
  292. foldinfo:remove_range(end_row_new, end_row_old)
  293. end
  294. else
  295. if start_col == 0 and old_row == 0 and old_col == 0 then
  296. foldinfo:add_range(start_row, start_row + (end_row_new - end_row_old))
  297. else
  298. foldinfo:add_range(end_row_old, end_row_new)
  299. end
  300. end
  301. if foldinfo.on_bytes_range then
  302. edit_range(foldinfo.on_bytes_range, start_row, end_row_old, end_row_new)
  303. else
  304. foldinfo.on_bytes_range = { start_row, end_row_new }
  305. end
  306. if foldinfo.foldupdate_range then
  307. edit_range(foldinfo.foldupdate_range, start_row, end_row_old, end_row_new)
  308. end
  309. -- This callback must not use on_bytes arguments, because they can be outdated when the callback
  310. -- is invoked. For example, `J` with non-zero count triggers multiple on_bytes before executing
  311. -- the scheduled callback. So we accumulate the edited ranges in `on_bytes_range`.
  312. schedule_if_loaded(bufnr, function()
  313. if not foldinfo.on_bytes_range then
  314. return
  315. end
  316. local srow, erow = foldinfo.on_bytes_range[1], foldinfo.on_bytes_range[2]
  317. foldinfo.on_bytes_range = nil
  318. -- Start from `srow - foldminlines`, because this edit may have shrunken the fold below limit.
  319. srow = math.max(srow - vim.wo.foldminlines, 0)
  320. compute_folds_levels(bufnr, foldinfo, srow, erow, function()
  321. foldinfo:foldupdate(bufnr, srow, erow)
  322. end)
  323. end)
  324. end
  325. end
  326. ---@param lnum integer|nil
  327. ---@return string
  328. function M.foldexpr(lnum)
  329. lnum = lnum or vim.v.lnum
  330. local bufnr = api.nvim_get_current_buf()
  331. if not foldinfos[bufnr] then
  332. foldinfos[bufnr] = FoldInfo.new(bufnr)
  333. api.nvim_create_autocmd('BufUnload', {
  334. buffer = bufnr,
  335. once = true,
  336. callback = function()
  337. foldinfos[bufnr] = nil
  338. end,
  339. })
  340. local parser = foldinfos[bufnr].parser
  341. if not parser then
  342. return '0'
  343. end
  344. compute_folds_levels(bufnr, foldinfos[bufnr])
  345. parser:register_cbs({
  346. on_changedtree = function(tree_changes)
  347. on_changedtree(bufnr, foldinfos[bufnr], tree_changes)
  348. end,
  349. on_bytes = function(_, _, start_row, start_col, _, old_row, old_col, _, new_row, new_col, _)
  350. on_bytes(bufnr, foldinfos[bufnr], start_row, start_col, old_row, old_col, new_row, new_col)
  351. end,
  352. on_detach = function()
  353. foldinfos[bufnr] = nil
  354. end,
  355. })
  356. end
  357. return foldinfos[bufnr].levels[lnum] or '0'
  358. end
  359. api.nvim_create_autocmd('OptionSet', {
  360. pattern = { 'foldminlines', 'foldnestmax' },
  361. desc = 'Refresh treesitter folds',
  362. callback = function()
  363. local buf = api.nvim_get_current_buf()
  364. local bufs = vim.v.option_type == 'global' and vim.tbl_keys(foldinfos)
  365. or foldinfos[buf] and { buf }
  366. or {}
  367. for _, bufnr in ipairs(bufs) do
  368. foldinfos[bufnr] = FoldInfo.new(bufnr)
  369. api.nvim_buf_call(bufnr, function()
  370. compute_folds_levels(bufnr, foldinfos[bufnr], nil, nil, function()
  371. foldinfos[bufnr]:foldupdate(bufnr, 0, api.nvim_buf_line_count(bufnr))
  372. end)
  373. end)
  374. end
  375. end,
  376. })
  377. return M