text_utils.lua 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. local fmt = string.format
  2. --- @class nvim.text_utils.MDNode
  3. --- @field [integer] nvim.text_utils.MDNode
  4. --- @field type string
  5. --- @field text? string
  6. local INDENTATION = 4
  7. local NBSP = string.char(160)
  8. local M = {}
  9. local function contains(t, xs)
  10. return vim.tbl_contains(xs, t)
  11. end
  12. --- @param txt string
  13. --- @param srow integer
  14. --- @param scol integer
  15. --- @param erow? integer
  16. --- @param ecol? integer
  17. --- @return string
  18. local function slice_text(txt, srow, scol, erow, ecol)
  19. local lines = vim.split(txt, '\n')
  20. if srow == erow then
  21. return lines[srow + 1]:sub(scol + 1, ecol)
  22. end
  23. if erow then
  24. -- Trim the end
  25. for _ = erow + 2, #lines do
  26. table.remove(lines, #lines)
  27. end
  28. end
  29. -- Trim the start
  30. for _ = 1, srow do
  31. table.remove(lines, 1)
  32. end
  33. lines[1] = lines[1]:sub(scol + 1)
  34. lines[#lines] = lines[#lines]:sub(1, ecol)
  35. return table.concat(lines, '\n')
  36. end
  37. --- @param text string
  38. --- @return nvim.text_utils.MDNode
  39. local function parse_md_inline(text)
  40. local parser = vim.treesitter.languagetree.new(text, 'markdown_inline')
  41. local root = parser:parse(true)[1]:root()
  42. --- @param node TSNode
  43. --- @return nvim.text_utils.MDNode?
  44. local function extract(node)
  45. local ntype = node:type()
  46. if ntype:match('^%p$') then
  47. return
  48. end
  49. --- @type table<any,any>
  50. local ret = { type = ntype }
  51. ret.text = vim.treesitter.get_node_text(node, text)
  52. local row, col = 0, 0
  53. for child, child_field in node:iter_children() do
  54. local e = extract(child)
  55. if e and ntype == 'inline' then
  56. local srow, scol = child:start()
  57. if (srow == row and scol > col) or srow > row then
  58. local t = slice_text(ret.text, row, col, srow, scol)
  59. if t and t ~= '' then
  60. table.insert(ret, { type = 'text', j = true, text = t })
  61. end
  62. end
  63. row, col = child:end_()
  64. end
  65. if child_field then
  66. ret[child_field] = e
  67. else
  68. table.insert(ret, e)
  69. end
  70. end
  71. if ntype == 'inline' and (row > 0 or col > 0) then
  72. local t = slice_text(ret.text, row, col)
  73. if t and t ~= '' then
  74. table.insert(ret, { type = 'text', text = t })
  75. end
  76. end
  77. return ret
  78. end
  79. return extract(root) or {}
  80. end
  81. --- @param text string
  82. --- @return nvim.text_utils.MDNode
  83. local function parse_md(text)
  84. local parser = vim.treesitter.languagetree.new(text, 'markdown', {
  85. injections = { markdown = '' },
  86. })
  87. local root = parser:parse(true)[1]:root()
  88. local EXCLUDE_TEXT_TYPE = {
  89. list = true,
  90. list_item = true,
  91. section = true,
  92. document = true,
  93. fenced_code_block = true,
  94. fenced_code_block_delimiter = true,
  95. }
  96. --- @param node TSNode
  97. --- @return nvim.text_utils.MDNode?
  98. local function extract(node)
  99. local ntype = node:type()
  100. if ntype:match('^%p$') or contains(ntype, { 'block_continuation' }) then
  101. return
  102. end
  103. --- @type table<any,any>
  104. local ret = { type = ntype }
  105. if not EXCLUDE_TEXT_TYPE[ntype] then
  106. ret.text = vim.treesitter.get_node_text(node, text)
  107. end
  108. if ntype == 'inline' then
  109. ret = parse_md_inline(ret.text)
  110. end
  111. for child, child_field in node:iter_children() do
  112. local e = extract(child)
  113. if child_field then
  114. ret[child_field] = e
  115. else
  116. table.insert(ret, e)
  117. end
  118. end
  119. return ret
  120. end
  121. return extract(root) or {}
  122. end
  123. --- @param x string
  124. --- @param start_indent integer
  125. --- @param indent integer
  126. --- @param text_width integer
  127. --- @return string
  128. function M.wrap(x, start_indent, indent, text_width)
  129. local words = vim.split(vim.trim(x), '%s+')
  130. local parts = { string.rep(' ', start_indent) } --- @type string[]
  131. local count = indent
  132. for i, w in ipairs(words) do
  133. if count > indent and count + #w > text_width - 1 then
  134. parts[#parts + 1] = '\n'
  135. parts[#parts + 1] = string.rep(' ', indent)
  136. count = indent
  137. elseif i ~= 1 then
  138. parts[#parts + 1] = ' '
  139. count = count + 1
  140. end
  141. count = count + #w
  142. parts[#parts + 1] = w
  143. end
  144. return (table.concat(parts):gsub('%s+\n', '\n'):gsub('\n+$', ''))
  145. end
  146. --- @param node nvim.text_utils.MDNode
  147. --- @param start_indent integer
  148. --- @param indent integer
  149. --- @param text_width integer
  150. --- @param level integer
  151. --- @return string[]
  152. local function render_md(node, start_indent, indent, text_width, level, is_list)
  153. local parts = {} --- @type string[]
  154. -- For debugging
  155. local add_tag = false
  156. -- local add_tag = true
  157. local ntype = node.type
  158. if add_tag then
  159. parts[#parts + 1] = '<' .. ntype .. '>'
  160. end
  161. if ntype == 'text' then
  162. parts[#parts + 1] = node.text
  163. elseif ntype == 'html_tag' then
  164. error('html_tag: ' .. node.text)
  165. elseif ntype == 'inline_link' then
  166. vim.list_extend(parts, { '*', node[1].text, '*' })
  167. elseif ntype == 'shortcut_link' then
  168. if node[1].text:find('^<.*>$') then
  169. parts[#parts + 1] = node[1].text
  170. else
  171. vim.list_extend(parts, { '|', node[1].text, '|' })
  172. end
  173. elseif ntype == 'backslash_escape' then
  174. parts[#parts + 1] = node.text
  175. elseif ntype == 'emphasis' then
  176. parts[#parts + 1] = node.text:sub(2, -2)
  177. elseif ntype == 'code_span' then
  178. vim.list_extend(parts, { '`', node.text:sub(2, -2):gsub(' ', NBSP), '`' })
  179. elseif ntype == 'inline' then
  180. if #node == 0 then
  181. local text = assert(node.text)
  182. parts[#parts + 1] = M.wrap(text, start_indent, indent, text_width)
  183. else
  184. for _, child in ipairs(node) do
  185. vim.list_extend(parts, render_md(child, start_indent, indent, text_width, level + 1))
  186. end
  187. end
  188. elseif ntype == 'paragraph' then
  189. local pparts = {}
  190. for _, child in ipairs(node) do
  191. vim.list_extend(pparts, render_md(child, start_indent, indent, text_width, level + 1))
  192. end
  193. parts[#parts + 1] = M.wrap(table.concat(pparts), start_indent, indent, text_width)
  194. parts[#parts + 1] = '\n'
  195. elseif ntype == 'code_fence_content' then
  196. local lines = vim.split(node.text:gsub('\n%s*$', ''), '\n')
  197. local cindent = indent + INDENTATION
  198. if level > 3 then
  199. -- The tree-sitter markdown parser doesn't parse the code blocks indents
  200. -- correctly in lists. Fudge it!
  201. lines[1] = ' ' .. lines[1] -- ¯\_(ツ)_/¯
  202. cindent = indent - level
  203. local _, initial_indent = lines[1]:find('^%s*')
  204. initial_indent = initial_indent + cindent
  205. if initial_indent < indent then
  206. cindent = indent - INDENTATION
  207. end
  208. end
  209. for _, l in ipairs(lines) do
  210. if #l > 0 then
  211. parts[#parts + 1] = string.rep(' ', cindent)
  212. parts[#parts + 1] = l
  213. end
  214. parts[#parts + 1] = '\n'
  215. end
  216. elseif ntype == 'fenced_code_block' then
  217. parts[#parts + 1] = '>'
  218. for _, child in ipairs(node) do
  219. if child.type == 'info_string' then
  220. parts[#parts + 1] = child.text
  221. break
  222. end
  223. end
  224. parts[#parts + 1] = '\n'
  225. for _, child in ipairs(node) do
  226. if child.type ~= 'info_string' then
  227. vim.list_extend(parts, render_md(child, start_indent, indent, text_width, level + 1))
  228. end
  229. end
  230. parts[#parts + 1] = '<\n'
  231. elseif ntype == 'html_block' then
  232. local text = node.text:gsub('^<pre>help', '')
  233. text = text:gsub('</pre>%s*$', '')
  234. parts[#parts + 1] = text
  235. elseif ntype == 'list_marker_dot' then
  236. parts[#parts + 1] = node.text
  237. elseif contains(ntype, { 'list_marker_minus', 'list_marker_star' }) then
  238. parts[#parts + 1] = '• '
  239. elseif ntype == 'list_item' then
  240. parts[#parts + 1] = string.rep(' ', indent)
  241. local offset = node[1].type == 'list_marker_dot' and 3 or 2
  242. for i, child in ipairs(node) do
  243. local sindent = i <= 2 and 0 or (indent + offset)
  244. vim.list_extend(
  245. parts,
  246. render_md(child, sindent, indent + offset, text_width, level + 1, true)
  247. )
  248. end
  249. else
  250. if node.text then
  251. error(fmt('cannot render:\n%s', vim.inspect(node)))
  252. end
  253. for i, child in ipairs(node) do
  254. local start_indent0 = i == 1 and start_indent or indent
  255. vim.list_extend(
  256. parts,
  257. render_md(child, start_indent0, indent, text_width, level + 1, is_list)
  258. )
  259. if ntype ~= 'list' and i ~= #node then
  260. if (node[i + 1] or {}).type ~= 'list' then
  261. parts[#parts + 1] = '\n'
  262. end
  263. end
  264. end
  265. end
  266. if add_tag then
  267. parts[#parts + 1] = '</' .. ntype .. '>'
  268. end
  269. return parts
  270. end
  271. --- @param text_width integer
  272. local function align_tags(text_width)
  273. --- @param line string
  274. --- @return string
  275. return function(line)
  276. local tag_pat = '%s*(%*.+%*)%s*$'
  277. local tags = {}
  278. for m in line:gmatch(tag_pat) do
  279. table.insert(tags, m)
  280. end
  281. if #tags > 0 then
  282. line = line:gsub(tag_pat, '')
  283. local tags_str = ' ' .. table.concat(tags, ' ')
  284. --- @type integer
  285. local conceal_offset = select(2, tags_str:gsub('%*', '')) - 2
  286. local pad = string.rep(' ', text_width - #line - #tags_str + conceal_offset)
  287. return line .. pad .. tags_str
  288. end
  289. return line
  290. end
  291. end
  292. --- @param text string
  293. --- @param start_indent integer
  294. --- @param indent integer
  295. --- @param is_list? boolean
  296. --- @return string
  297. function M.md_to_vimdoc(text, start_indent, indent, text_width, is_list)
  298. -- Add an extra newline so the parser can properly capture ending ```
  299. local parsed = parse_md(text .. '\n')
  300. local ret = render_md(parsed, start_indent, indent, text_width, 0, is_list)
  301. local lines = vim.split(table.concat(ret):gsub(NBSP, ' '), '\n')
  302. lines = vim.tbl_map(align_tags(text_width), lines)
  303. local s = table.concat(lines, '\n')
  304. -- Reduce whitespace in code-blocks
  305. s = s:gsub('\n+%s*>([a-z]+)\n', ' >%1\n')
  306. s = s:gsub('\n+%s*>\n?\n', ' >\n')
  307. return s
  308. end
  309. return M