session.lua 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. local uv = vim.uv
  2. local MsgpackRpcStream = require('test.client.msgpack_rpc_stream')
  3. --- @class test.Session
  4. --- @field private _pending_messages string[]
  5. --- @field private _msgpack_rpc_stream test.MsgpackRpcStream
  6. --- @field private _prepare uv.uv_prepare_t
  7. --- @field private _timer uv.uv_timer_t
  8. --- @field private _is_running boolean
  9. --- @field exec_lua_setup boolean
  10. local Session = {}
  11. Session.__index = Session
  12. if package.loaded['jit'] then
  13. -- luajit pcall is already coroutine safe
  14. Session.safe_pcall = pcall
  15. else
  16. Session.safe_pcall = require 'coxpcall'.pcall
  17. end
  18. local function resume(co, ...)
  19. local status, result = coroutine.resume(co, ...)
  20. if coroutine.status(co) == 'dead' then
  21. if not status then
  22. error(result)
  23. end
  24. return
  25. end
  26. assert(coroutine.status(co) == 'suspended')
  27. result(co)
  28. end
  29. local function coroutine_exec(func, ...)
  30. local args = { ... }
  31. local on_complete --- @type function?
  32. if #args > 0 and type(args[#args]) == 'function' then
  33. -- completion callback
  34. on_complete = table.remove(args)
  35. end
  36. resume(coroutine.create(function()
  37. local status, result, flag = Session.safe_pcall(func, unpack(args))
  38. if on_complete then
  39. coroutine.yield(function()
  40. -- run the completion callback on the main thread
  41. on_complete(status, result, flag)
  42. end)
  43. end
  44. end))
  45. end
  46. function Session.new(stream)
  47. return setmetatable({
  48. _msgpack_rpc_stream = MsgpackRpcStream.new(stream),
  49. _pending_messages = {},
  50. _prepare = uv.new_prepare(),
  51. _timer = uv.new_timer(),
  52. _is_running = false,
  53. }, Session)
  54. end
  55. --- @param timeout integer?
  56. --- @return string?
  57. function Session:next_message(timeout)
  58. local function on_request(method, args, response)
  59. table.insert(self._pending_messages, { 'request', method, args, response })
  60. uv.stop()
  61. end
  62. local function on_notification(method, args)
  63. table.insert(self._pending_messages, { 'notification', method, args })
  64. uv.stop()
  65. end
  66. if self._is_running then
  67. error('Event loop already running')
  68. end
  69. if #self._pending_messages > 0 then
  70. return table.remove(self._pending_messages, 1)
  71. end
  72. -- if closed, only return pending messages
  73. if self.closed then
  74. return nil
  75. end
  76. self:_run(on_request, on_notification, timeout)
  77. return table.remove(self._pending_messages, 1)
  78. end
  79. function Session:notify(method, ...)
  80. self._msgpack_rpc_stream:write(method, { ... })
  81. end
  82. --- @param method string
  83. --- @param ... any
  84. --- @return boolean, table
  85. function Session:request(method, ...)
  86. local args = { ... }
  87. local err, result
  88. if self._is_running then
  89. err, result = self:_yielding_request(method, args)
  90. else
  91. err, result = self:_blocking_request(method, args)
  92. end
  93. if err then
  94. return false, err
  95. end
  96. return true, result
  97. end
  98. --- Runs the event loop.
  99. function Session:run(request_cb, notification_cb, setup_cb, timeout)
  100. local function on_request(method, args, response)
  101. coroutine_exec(request_cb, method, args, function(status, result, flag)
  102. if status then
  103. response:send(result, flag)
  104. else
  105. response:send(result, true)
  106. end
  107. end)
  108. end
  109. local function on_notification(method, args)
  110. coroutine_exec(notification_cb, method, args)
  111. end
  112. self._is_running = true
  113. if setup_cb then
  114. coroutine_exec(setup_cb)
  115. end
  116. while #self._pending_messages > 0 do
  117. local msg = table.remove(self._pending_messages, 1)
  118. if msg[1] == 'request' then
  119. on_request(msg[2], msg[3], msg[4])
  120. else
  121. on_notification(msg[2], msg[3])
  122. end
  123. end
  124. self:_run(on_request, on_notification, timeout)
  125. self._is_running = false
  126. end
  127. function Session:stop()
  128. uv.stop()
  129. end
  130. function Session:close(signal)
  131. if not self._timer:is_closing() then
  132. self._timer:close()
  133. end
  134. if not self._prepare:is_closing() then
  135. self._prepare:close()
  136. end
  137. self._msgpack_rpc_stream:close(signal)
  138. self.closed = true
  139. end
  140. function Session:_yielding_request(method, args)
  141. return coroutine.yield(function(co)
  142. self._msgpack_rpc_stream:write(method, args, function(err, result)
  143. resume(co, err, result)
  144. end)
  145. end)
  146. end
  147. function Session:_blocking_request(method, args)
  148. local err, result
  149. local function on_request(method_, args_, response)
  150. table.insert(self._pending_messages, { 'request', method_, args_, response })
  151. end
  152. local function on_notification(method_, args_)
  153. table.insert(self._pending_messages, { 'notification', method_, args_ })
  154. end
  155. self._msgpack_rpc_stream:write(method, args, function(e, r)
  156. err = e
  157. result = r
  158. uv.stop()
  159. end)
  160. self:_run(on_request, on_notification)
  161. return (err or self.eof_err), result
  162. end
  163. function Session:_run(request_cb, notification_cb, timeout)
  164. if type(timeout) == 'number' then
  165. self._prepare:start(function()
  166. self._timer:start(timeout, 0, function()
  167. uv.stop()
  168. end)
  169. self._prepare:stop()
  170. end)
  171. end
  172. self._msgpack_rpc_stream:read_start(request_cb, notification_cb, function()
  173. uv.stop()
  174. self.eof_err = { 1, 'EOF was received from Nvim. Likely the Nvim process crashed.' }
  175. end)
  176. uv.run()
  177. self._prepare:stop()
  178. self._timer:stop()
  179. self._msgpack_rpc_stream:read_stop()
  180. end
  181. return Session