_range.lua 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. local api = vim.api
  2. local M = {}
  3. ---@class Range2
  4. ---@inlinedoc
  5. ---@field [1] integer start row
  6. ---@field [2] integer end row
  7. ---@class Range4
  8. ---@inlinedoc
  9. ---@field [1] integer start row
  10. ---@field [2] integer start column
  11. ---@field [3] integer end row
  12. ---@field [4] integer end column
  13. ---@class Range6
  14. ---@inlinedoc
  15. ---@field [1] integer start row
  16. ---@field [2] integer start column
  17. ---@field [3] integer start bytes
  18. ---@field [4] integer end row
  19. ---@field [5] integer end column
  20. ---@field [6] integer end bytes
  21. ---@alias Range Range2|Range4|Range6
  22. ---@private
  23. ---@param a_row integer
  24. ---@param a_col integer
  25. ---@param b_row integer
  26. ---@param b_col integer
  27. ---@return integer
  28. --- 1: a > b
  29. --- 0: a == b
  30. --- -1: a < b
  31. local function cmp_pos(a_row, a_col, b_row, b_col)
  32. if a_row == b_row then
  33. if a_col > b_col then
  34. return 1
  35. elseif a_col < b_col then
  36. return -1
  37. else
  38. return 0
  39. end
  40. elseif a_row > b_row then
  41. return 1
  42. end
  43. return -1
  44. end
  45. M.cmp_pos = {
  46. lt = function(...)
  47. return cmp_pos(...) == -1
  48. end,
  49. le = function(...)
  50. return cmp_pos(...) ~= 1
  51. end,
  52. gt = function(...)
  53. return cmp_pos(...) == 1
  54. end,
  55. ge = function(...)
  56. return cmp_pos(...) ~= -1
  57. end,
  58. eq = function(...)
  59. return cmp_pos(...) == 0
  60. end,
  61. ne = function(...)
  62. return cmp_pos(...) ~= 0
  63. end,
  64. }
  65. setmetatable(M.cmp_pos, { __call = cmp_pos })
  66. ---@private
  67. ---Check if a variable is a valid range object
  68. ---@param r any
  69. ---@return boolean
  70. function M.validate(r)
  71. if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
  72. return false
  73. end
  74. for _, e in
  75. ipairs(r --[[@as any[] ]])
  76. do
  77. if type(e) ~= 'number' then
  78. return false
  79. end
  80. end
  81. return true
  82. end
  83. ---@private
  84. ---@param r1 Range
  85. ---@param r2 Range
  86. ---@return boolean
  87. function M.intercepts(r1, r2)
  88. local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
  89. local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
  90. -- r1 is above r2
  91. if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
  92. return false
  93. end
  94. -- r1 is below r2
  95. if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
  96. return false
  97. end
  98. return true
  99. end
  100. ---@private
  101. ---@param r Range
  102. ---@return integer, integer, integer, integer
  103. function M.unpack4(r)
  104. if #r == 2 then
  105. return r[1], 0, r[2], 0
  106. end
  107. local off_1 = #r == 6 and 1 or 0
  108. return r[1], r[2], r[3 + off_1], r[4 + off_1]
  109. end
  110. ---@private
  111. ---@param r Range6
  112. ---@return integer, integer, integer, integer, integer, integer
  113. function M.unpack6(r)
  114. return r[1], r[2], r[3], r[4], r[5], r[6]
  115. end
  116. ---@private
  117. ---@param r1 Range
  118. ---@param r2 Range
  119. ---@return boolean whether r1 contains r2
  120. function M.contains(r1, r2)
  121. local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
  122. local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
  123. -- start doesn't fit
  124. if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
  125. return false
  126. end
  127. -- end doesn't fit
  128. if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
  129. return false
  130. end
  131. return true
  132. end
  133. --- @private
  134. --- @param source integer|string
  135. --- @param index integer
  136. --- @return integer
  137. local function get_offset(source, index)
  138. if index == 0 then
  139. return 0
  140. end
  141. if type(source) == 'number' then
  142. return api.nvim_buf_get_offset(source, index)
  143. end
  144. local byte = 0
  145. local next_offset = source:gmatch('()\n')
  146. local line = 1
  147. while line <= index do
  148. byte = next_offset() --[[@as integer]]
  149. line = line + 1
  150. end
  151. return byte
  152. end
  153. ---@private
  154. ---@param source integer|string
  155. ---@param range Range
  156. ---@return Range6
  157. function M.add_bytes(source, range)
  158. if type(range) == 'table' and #range == 6 then
  159. return range --[[@as Range6]]
  160. end
  161. local start_row, start_col, end_row, end_col = M.unpack4(range)
  162. -- TODO(vigoux): proper byte computation here, and account for EOL ?
  163. local start_byte = get_offset(source, start_row) + start_col
  164. local end_byte = get_offset(source, end_row) + end_col
  165. return { start_row, start_col, start_byte, end_row, end_col, end_byte }
  166. end
  167. return M