client.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. package rpc
  2. import (
  3. "io"
  4. "fmt"
  5. "net"
  6. "time"
  7. "errors"
  8. "kumachan/standalone/rx"
  9. "kumachan/standalone/rpc/kmd"
  10. )
  11. type ClientOptions struct {
  12. Connection net.Conn
  13. DebugOutput io.Writer
  14. ConstructorArgument kmd.Object
  15. InstanceConsumer func(*ClientInstance) rx.Observable
  16. Limits
  17. StreamKmdApi
  18. }
  19. func Client(service ServiceInterface, opts *ClientOptions) rx.Observable {
  20. var raw_conn = opts.Connection
  21. var logger = ClientLogger {
  22. LocalAddr: raw_conn.LocalAddr(),
  23. RemoteAddr: raw_conn.RemoteAddr(),
  24. Output: opts.DebugOutput,
  25. }
  26. var handle = func(conn *rx.WrappedConnection) struct{} {
  27. var fatal = func(err error) struct{} {
  28. conn.Fatal(err)
  29. return struct{}{}
  30. }
  31. err := sendServiceConfirmation(conn, service)
  32. if err != nil { return fatal(err) }
  33. err = sendConstructorArgument(conn, service, opts)
  34. if err != nil { return fatal(err) }
  35. err = receiveInstanceCreated(conn)
  36. if err != nil { return fatal(err) }
  37. var instance = createClientInstance(conn, logger, service, opts)
  38. consumeClientInstance(instance, conn, opts)
  39. err = clientProcessMessages(instance, conn, opts)
  40. if err != nil { return fatal(err) }
  41. return struct{}{}
  42. }
  43. var timeout = rx.TimeoutPair {
  44. ReadTimeout: opts.RecvTimeout,
  45. WriteTimeout: opts.SendTimeout,
  46. }
  47. return rx.NewConnectionHandler(raw_conn, timeout, func(conn *rx.WrappedConnection) {
  48. handle(conn)
  49. }).Catch(func(err rx.Object) rx.Observable {
  50. logger.LogError(err.(error))
  51. return rx.Throw(err)
  52. })
  53. }
  54. type ClientInstance struct {
  55. connection *rx.WrappedConnection
  56. requester *rx.Worker
  57. logger ClientLogger
  58. service ServiceInterface
  59. options *ClientOptions
  60. state ClientInstanceState
  61. }
  62. type ClientInstanceState struct {
  63. mutator *rx.Worker
  64. calls map[uint64] Call
  65. nextCallId uint64
  66. }
  67. type Call struct {
  68. sender rx.Sender
  69. retType *kmd.Type
  70. }
  71. func createClientInstance(conn *rx.WrappedConnection, logger ClientLogger, service ServiceInterface, opts *ClientOptions) *ClientInstance {
  72. return &ClientInstance {
  73. connection: conn,
  74. requester: rx.CreateWorker(),
  75. logger: logger,
  76. service: service,
  77. options: opts,
  78. state: ClientInstanceState {
  79. mutator: rx.CreateWorker(),
  80. calls: make(map[uint64] Call),
  81. nextCallId: 0,
  82. },
  83. }
  84. }
  85. func (instance *ClientInstance) Call(method_name string, arg kmd.Object) rx.Observable {
  86. var method, exists = instance.service.Methods[method_name]
  87. if !(exists) { panic("something went wrong") }
  88. return rx.NewSyncWithSender(func(sender rx.Sender) {
  89. instance.state.mutator.Do(func() {
  90. var id = instance.state.nextCallId
  91. instance.state.nextCallId += 1
  92. instance.state.calls[id] = Call {
  93. sender: sender,
  94. retType: method.RetType,
  95. }
  96. var send_request = func() struct{} {
  97. var conn = instance.connection
  98. var fatal = func(err error) struct{} {
  99. var wrapped = fmt.Errorf("error sending call request: %w", err)
  100. conn.Fatal(wrapped)
  101. instance.logger.LogError(wrapped)
  102. return struct{}{}
  103. }
  104. var method_name_bin = ([] byte)(method_name)
  105. var msg_kind = (func() string {
  106. if method.MultiValue {
  107. return MSG_CALL_MULTI
  108. } else {
  109. return MSG_CALL
  110. }
  111. })()
  112. err := sendMessage(msg_kind, id, method_name_bin, conn)
  113. if err != nil { return fatal(err) }
  114. err = sendCallArgument(arg, method, conn, instance.options)
  115. if err != nil { return fatal(err) }
  116. return struct{}{}
  117. }
  118. instance.requester.Do(func() {
  119. send_request()
  120. })
  121. })
  122. })
  123. }
  124. func (instance *ClientInstance) lookupCall(id uint64) (Call, bool) {
  125. var call, exists = instance.state.calls[id]
  126. if !(exists) {
  127. var err = errors.New(fmt.Sprintf(
  128. "inconsistent server message: call %d does not exist", id))
  129. instance.connection.Fatal(err)
  130. instance.logger.LogError(err)
  131. return Call{}, false
  132. }
  133. return call, true
  134. }
  135. func (instance *ClientInstance) getCallReturnValueType(id uint64) *kmd.Type {
  136. var wait = make(chan *kmd.Type)
  137. instance.state.mutator.Do(func() {
  138. var call, ok = instance.lookupCall(id)
  139. if !(ok) { return }
  140. wait <- call.retType
  141. })
  142. return <- wait
  143. }
  144. func (instance *ClientInstance) next(id uint64, value kmd.Object) {
  145. instance.state.mutator.Do(func() {
  146. var call, ok = instance.lookupCall(id)
  147. if !(ok) { return }
  148. call.sender.Next(value)
  149. })
  150. }
  151. func (instance *ClientInstance) error(id uint64, e error) {
  152. instance.state.mutator.Do(func() {
  153. var call, ok = instance.lookupCall(id)
  154. if !(ok) { return }
  155. call.sender.Error(e)
  156. })
  157. }
  158. func (instance *ClientInstance) complete(id uint64) {
  159. instance.state.mutator.Do(func() {
  160. var call, ok = instance.lookupCall(id)
  161. if !(ok) { return }
  162. delete(instance.state.calls, id)
  163. call.sender.Complete()
  164. })
  165. }
  166. func sendServiceConfirmation(conn io.Writer, service ServiceInterface) error {
  167. var service_id = service.Identifier.String()
  168. return sendMessage(MSG_SERVICE, ^uint64(0), ([] byte)(service_id), conn)
  169. }
  170. func sendConstructorArgument(conn io.Writer, service ServiceInterface, opts *ClientOptions) error {
  171. var ctor = service.Constructor
  172. var arg = opts.ConstructorArgument
  173. return sendObject(arg, ctor.ArgType, conn, opts.StreamKmdApi)
  174. }
  175. func receiveInstanceCreated(conn io.Reader) error {
  176. kind, _, payload, err := receiveMessage(conn)
  177. if err != nil {
  178. return fmt.Errorf("failed to receive instance created notification: %w", err)
  179. }
  180. if kind != MSG_CREATED {
  181. if kind == MSG_ERROR {
  182. return deserializeError(payload)
  183. } else {
  184. return errors.New(fmt.Sprintf("unexpected message kind: %s", kind))
  185. }
  186. }
  187. return nil
  188. }
  189. func consumeClientInstance(instance *ClientInstance, conn *rx.WrappedConnection, opts *ClientOptions) {
  190. var consume = opts.InstanceConsumer(instance)
  191. var consume_and_dispose = consume.WaitComplete().Then(func(_ rx.Object) rx.Observable {
  192. _ = conn.Close()
  193. return rx.Noop()
  194. })
  195. rx.Schedule(consume_and_dispose, conn.Scheduler(), rx.Receiver {
  196. Context: conn.Context(),
  197. })
  198. }
  199. func sendCallArgument(arg kmd.Object, method ServiceMethodInterface, conn *rx.WrappedConnection, opts *ClientOptions) error {
  200. return sendObject(arg, method.ArgType, conn, opts.StreamKmdApi)
  201. }
  202. func clientProcessMessages(instance *ClientInstance, conn *rx.WrappedConnection, opts *ClientOptions) error {
  203. var interval = opts.RecvInterval
  204. for {
  205. if interval != 0 {
  206. <- time.After(interval)
  207. }
  208. var kind, id, payload, err = receiveMessage(conn)
  209. if err != nil { return fmt.Errorf("error receving server message: %w", err) }
  210. switch kind {
  211. case MSG_VALUE:
  212. var ret_type = instance.getCallReturnValueType(id)
  213. var limit = opts.RecvMaxObjectSize
  214. value, err := receiveObject(ret_type, conn, limit, opts.StreamKmdApi)
  215. if err != nil { return fmt.Errorf("error receiving value object: %w", err) }
  216. instance.next(id, value)
  217. case MSG_ERROR:
  218. var e = deserializeError(payload)
  219. instance.error(id, e)
  220. case MSG_COMPLETE:
  221. instance.complete(id)
  222. default:
  223. return errors.New(fmt.Sprintf("unknown message kind: %s", kind))
  224. }
  225. }
  226. }