123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198 |
- local api = vim.api
- local M = {}
- ---@class Range2
- ---@inlinedoc
- ---@field [1] integer start row
- ---@field [2] integer end row
- ---@class Range4
- ---@inlinedoc
- ---@field [1] integer start row
- ---@field [2] integer start column
- ---@field [3] integer end row
- ---@field [4] integer end column
- ---@class Range6
- ---@inlinedoc
- ---@field [1] integer start row
- ---@field [2] integer start column
- ---@field [3] integer start bytes
- ---@field [4] integer end row
- ---@field [5] integer end column
- ---@field [6] integer end bytes
- ---@alias Range Range2|Range4|Range6
- ---@private
- ---@param a_row integer
- ---@param a_col integer
- ---@param b_row integer
- ---@param b_col integer
- ---@return integer
- --- 1: a > b
- --- 0: a == b
- --- -1: a < b
- local function cmp_pos(a_row, a_col, b_row, b_col)
- if a_row == b_row then
- if a_col > b_col then
- return 1
- elseif a_col < b_col then
- return -1
- else
- return 0
- end
- elseif a_row > b_row then
- return 1
- end
- return -1
- end
- M.cmp_pos = {
- lt = function(...)
- return cmp_pos(...) == -1
- end,
- le = function(...)
- return cmp_pos(...) ~= 1
- end,
- gt = function(...)
- return cmp_pos(...) == 1
- end,
- ge = function(...)
- return cmp_pos(...) ~= -1
- end,
- eq = function(...)
- return cmp_pos(...) == 0
- end,
- ne = function(...)
- return cmp_pos(...) ~= 0
- end,
- }
- setmetatable(M.cmp_pos, { __call = cmp_pos })
- ---@private
- ---Check if a variable is a valid range object
- ---@param r any
- ---@return boolean
- function M.validate(r)
- if type(r) ~= 'table' or #r ~= 6 and #r ~= 4 then
- return false
- end
- for _, e in
- ipairs(r --[[@as any[] ]])
- do
- if type(e) ~= 'number' then
- return false
- end
- end
- return true
- end
- ---@private
- ---@param r1 Range
- ---@param r2 Range
- ---@return boolean
- function M.intercepts(r1, r2)
- local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
- local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
- -- r1 is above r2
- if M.cmp_pos.le(erow_1, ecol_1, srow_2, scol_2) then
- return false
- end
- -- r1 is below r2
- if M.cmp_pos.ge(srow_1, scol_1, erow_2, ecol_2) then
- return false
- end
- return true
- end
- ---@private
- ---@param r Range
- ---@return integer, integer, integer, integer
- function M.unpack4(r)
- if #r == 2 then
- return r[1], 0, r[2], 0
- end
- local off_1 = #r == 6 and 1 or 0
- return r[1], r[2], r[3 + off_1], r[4 + off_1]
- end
- ---@private
- ---@param r Range6
- ---@return integer, integer, integer, integer, integer, integer
- function M.unpack6(r)
- return r[1], r[2], r[3], r[4], r[5], r[6]
- end
- ---@private
- ---@param r1 Range
- ---@param r2 Range
- ---@return boolean whether r1 contains r2
- function M.contains(r1, r2)
- local srow_1, scol_1, erow_1, ecol_1 = M.unpack4(r1)
- local srow_2, scol_2, erow_2, ecol_2 = M.unpack4(r2)
- -- start doesn't fit
- if M.cmp_pos.gt(srow_1, scol_1, srow_2, scol_2) then
- return false
- end
- -- end doesn't fit
- if M.cmp_pos.lt(erow_1, ecol_1, erow_2, ecol_2) then
- return false
- end
- return true
- end
- --- @private
- --- @param source integer|string
- --- @param index integer
- --- @return integer
- local function get_offset(source, index)
- if index == 0 then
- return 0
- end
- if type(source) == 'number' then
- return api.nvim_buf_get_offset(source, index)
- end
- local byte = 0
- local next_offset = source:gmatch('()\n')
- local line = 1
- while line <= index do
- byte = next_offset() --[[@as integer]]
- line = line + 1
- end
- return byte
- end
- ---@private
- ---@param source integer|string
- ---@param range Range
- ---@return Range6
- function M.add_bytes(source, range)
- if type(range) == 'table' and #range == 6 then
- return range --[[@as Range6]]
- end
- local start_row, start_col, end_row, end_col = M.unpack4(range)
- -- TODO(vigoux): proper byte computation here, and account for EOL ?
- local start_byte = get_offset(source, start_row) + start_col
- local end_byte = get_offset(source, end_row) + end_col
- return { start_row, start_col, start_byte, end_row, end_col, end_byte }
- end
- return M
|