session.lua 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. ---
  2. --- Nvim msgpack-RPC protocol session. Manages requests/notifications/responses.
  3. ---
  4. local uv = vim.uv
  5. local RpcStream = require('test.client.rpc_stream')
  6. --- Nvim msgpack-RPC protocol session. Manages requests/notifications/responses.
  7. ---
  8. --- @class test.Session
  9. --- @field private _pending_messages string[] Requests/notifications received from the remote end.
  10. --- @field private _rpc_stream test.RpcStream
  11. --- @field private _prepare uv.uv_prepare_t
  12. --- @field private _timer uv.uv_timer_t
  13. --- @field private _is_running boolean true during `Session:run()` scope.
  14. --- @field exec_lua_setup boolean
  15. local Session = {}
  16. Session.__index = Session
  17. if package.loaded['jit'] then
  18. -- luajit pcall is already coroutine safe
  19. Session.safe_pcall = pcall
  20. else
  21. Session.safe_pcall = require 'coxpcall'.pcall
  22. end
  23. local function resume(co, ...)
  24. local status, result = coroutine.resume(co, ...)
  25. if coroutine.status(co) == 'dead' then
  26. if not status then
  27. error(result)
  28. end
  29. return
  30. end
  31. assert(coroutine.status(co) == 'suspended')
  32. result(co)
  33. end
  34. local function coroutine_exec(func, ...)
  35. local args = { ... }
  36. local on_complete --- @type function?
  37. if #args > 0 and type(args[#args]) == 'function' then
  38. -- completion callback
  39. on_complete = table.remove(args)
  40. end
  41. resume(coroutine.create(function()
  42. local status, result, flag = Session.safe_pcall(func, unpack(args))
  43. if on_complete then
  44. coroutine.yield(function()
  45. -- run the completion callback on the main thread
  46. on_complete(status, result, flag)
  47. end)
  48. end
  49. end))
  50. end
  51. --- Creates a new msgpack-RPC session.
  52. function Session.new(stream)
  53. return setmetatable({
  54. _rpc_stream = RpcStream.new(stream),
  55. _pending_messages = {},
  56. _prepare = uv.new_prepare(),
  57. _timer = uv.new_timer(),
  58. _is_running = false,
  59. }, Session)
  60. end
  61. --- @param timeout integer?
  62. --- @return string?
  63. function Session:next_message(timeout)
  64. local function on_request(method, args, response)
  65. table.insert(self._pending_messages, { 'request', method, args, response })
  66. uv.stop()
  67. end
  68. local function on_notification(method, args)
  69. table.insert(self._pending_messages, { 'notification', method, args })
  70. uv.stop()
  71. end
  72. if self._is_running then
  73. error('Event loop already running')
  74. end
  75. if #self._pending_messages > 0 then
  76. return table.remove(self._pending_messages, 1)
  77. end
  78. -- if closed, only return pending messages
  79. if self.closed then
  80. return nil
  81. end
  82. self:_run(on_request, on_notification, timeout)
  83. return table.remove(self._pending_messages, 1)
  84. end
  85. --- Sends a notification to the RPC endpoint.
  86. function Session:notify(method, ...)
  87. self._rpc_stream:write(method, { ... })
  88. end
  89. --- Sends a request to the RPC endpoint.
  90. ---
  91. --- @param method string
  92. --- @param ... any
  93. --- @return boolean, table
  94. function Session:request(method, ...)
  95. local args = { ... }
  96. local err, result
  97. if self._is_running then
  98. err, result = self:_yielding_request(method, args)
  99. else
  100. err, result = self:_blocking_request(method, args)
  101. end
  102. if err then
  103. return false, err
  104. end
  105. return true, result
  106. end
  107. --- Processes incoming RPC requests/notifications until exhausted.
  108. ---
  109. --- TODO(justinmk): luaclient2 avoids this via uvutil.cb_wait() + uvutil.add_idle_call()?
  110. ---
  111. --- @param request_cb function Handles requests from the sever to the local end.
  112. --- @param notification_cb function Handles notifications from the sever to the local end.
  113. --- @param setup_cb function
  114. --- @param timeout number
  115. function Session:run(request_cb, notification_cb, setup_cb, timeout)
  116. --- Handles an incoming request.
  117. local function on_request(method, args, response)
  118. coroutine_exec(request_cb, method, args, function(status, result, flag)
  119. if status then
  120. response:send(result, flag)
  121. else
  122. response:send(result, true)
  123. end
  124. end)
  125. end
  126. --- Handles an incoming notification.
  127. local function on_notification(method, args)
  128. coroutine_exec(notification_cb, method, args)
  129. end
  130. self._is_running = true
  131. if setup_cb then
  132. coroutine_exec(setup_cb)
  133. end
  134. while #self._pending_messages > 0 do
  135. local msg = table.remove(self._pending_messages, 1)
  136. if msg[1] == 'request' then
  137. on_request(msg[2], msg[3], msg[4])
  138. else
  139. on_notification(msg[2], msg[3])
  140. end
  141. end
  142. self:_run(on_request, on_notification, timeout)
  143. self._is_running = false
  144. end
  145. function Session:stop()
  146. uv.stop()
  147. end
  148. function Session:close(signal)
  149. if not self._timer:is_closing() then
  150. self._timer:close()
  151. end
  152. if not self._prepare:is_closing() then
  153. self._prepare:close()
  154. end
  155. self._rpc_stream:close(signal)
  156. self.closed = true
  157. end
  158. --- Sends a request to the RPC endpoint, without blocking (schedules a coroutine).
  159. function Session:_yielding_request(method, args)
  160. return coroutine.yield(function(co)
  161. self._rpc_stream:write(method, args, function(err, result)
  162. resume(co, err, result)
  163. end)
  164. end)
  165. end
  166. --- Sends a request to the RPC endpoint, and blocks (polls event loop) until a response is received.
  167. function Session:_blocking_request(method, args)
  168. local err, result
  169. -- Invoked when a request is received from the remote end.
  170. local function on_request(method_, args_, response)
  171. table.insert(self._pending_messages, { 'request', method_, args_, response })
  172. end
  173. -- Invoked when a notification is received from the remote end.
  174. local function on_notification(method_, args_)
  175. table.insert(self._pending_messages, { 'notification', method_, args_ })
  176. end
  177. self._rpc_stream:write(method, args, function(e, r)
  178. err = e
  179. result = r
  180. uv.stop()
  181. end)
  182. -- Poll for incoming requests/notifications received from the remote end.
  183. self:_run(on_request, on_notification)
  184. return (err or self.eof_err), result
  185. end
  186. --- Polls for incoming requests/notifications received from the remote end.
  187. function Session:_run(request_cb, notification_cb, timeout)
  188. if type(timeout) == 'number' then
  189. self._prepare:start(function()
  190. self._timer:start(timeout, 0, function()
  191. uv.stop()
  192. end)
  193. self._prepare:stop()
  194. end)
  195. end
  196. self._rpc_stream:read_start(request_cb, notification_cb, function()
  197. uv.stop()
  198. --- @diagnostic disable-next-line: invisible
  199. local stderr = self._rpc_stream._stream.stderr --[[@as string?]]
  200. -- See if `ProcStream.stderr` has anything useful.
  201. stderr = '' ~= ((stderr or ''):match('^%s*(.*%S)') or '') and ' stderr:\n' .. stderr or ''
  202. self.eof_err = { 1, 'EOF was received from Nvim. Likely the Nvim process crashed.' .. stderr }
  203. end)
  204. uv.run()
  205. self._prepare:stop()
  206. self._timer:stop()
  207. self._rpc_stream:read_stop()
  208. end
  209. --- Nvim msgpack-RPC session.
  210. return Session