virtio_transport_common.c 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008
  1. /*
  2. * common code for virtio vsock
  3. *
  4. * Copyright (C) 2013-2015 Red Hat, Inc.
  5. * Author: Asias He <asias@redhat.com>
  6. * Stefan Hajnoczi <stefanha@redhat.com>
  7. *
  8. * This work is licensed under the terms of the GNU GPL, version 2.
  9. */
  10. #include <linux/spinlock.h>
  11. #include <linux/module.h>
  12. #include <linux/ctype.h>
  13. #include <linux/list.h>
  14. #include <linux/virtio.h>
  15. #include <linux/virtio_ids.h>
  16. #include <linux/virtio_config.h>
  17. #include <linux/virtio_vsock.h>
  18. #include <net/sock.h>
  19. #include <net/af_vsock.h>
  20. #define CREATE_TRACE_POINTS
  21. #include <trace/events/vsock_virtio_transport_common.h>
  22. /* How long to wait for graceful shutdown of a connection */
  23. #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
  24. static const struct virtio_transport *virtio_transport_get_ops(void)
  25. {
  26. const struct vsock_transport *t = vsock_core_get_transport();
  27. return container_of(t, struct virtio_transport, transport);
  28. }
  29. struct virtio_vsock_pkt *
  30. virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
  31. size_t len,
  32. u32 src_cid,
  33. u32 src_port,
  34. u32 dst_cid,
  35. u32 dst_port)
  36. {
  37. struct virtio_vsock_pkt *pkt;
  38. int err;
  39. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  40. if (!pkt)
  41. return NULL;
  42. pkt->hdr.type = cpu_to_le16(info->type);
  43. pkt->hdr.op = cpu_to_le16(info->op);
  44. pkt->hdr.src_cid = cpu_to_le64(src_cid);
  45. pkt->hdr.dst_cid = cpu_to_le64(dst_cid);
  46. pkt->hdr.src_port = cpu_to_le32(src_port);
  47. pkt->hdr.dst_port = cpu_to_le32(dst_port);
  48. pkt->hdr.flags = cpu_to_le32(info->flags);
  49. pkt->len = len;
  50. pkt->hdr.len = cpu_to_le32(len);
  51. pkt->reply = info->reply;
  52. pkt->vsk = info->vsk;
  53. if (info->msg && len > 0) {
  54. pkt->buf = kmalloc(len, GFP_KERNEL);
  55. if (!pkt->buf)
  56. goto out_pkt;
  57. err = memcpy_from_msg(pkt->buf, info->msg, len);
  58. if (err)
  59. goto out;
  60. }
  61. trace_virtio_transport_alloc_pkt(src_cid, src_port,
  62. dst_cid, dst_port,
  63. len,
  64. info->type,
  65. info->op,
  66. info->flags);
  67. return pkt;
  68. out:
  69. kfree(pkt->buf);
  70. out_pkt:
  71. kfree(pkt);
  72. return NULL;
  73. }
  74. EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt);
  75. static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
  76. struct virtio_vsock_pkt_info *info)
  77. {
  78. u32 src_cid, src_port, dst_cid, dst_port;
  79. struct virtio_vsock_sock *vvs;
  80. struct virtio_vsock_pkt *pkt;
  81. u32 pkt_len = info->pkt_len;
  82. src_cid = vm_sockets_get_local_cid();
  83. src_port = vsk->local_addr.svm_port;
  84. if (!info->remote_cid) {
  85. dst_cid = vsk->remote_addr.svm_cid;
  86. dst_port = vsk->remote_addr.svm_port;
  87. } else {
  88. dst_cid = info->remote_cid;
  89. dst_port = info->remote_port;
  90. }
  91. vvs = vsk->trans;
  92. /* we can send less than pkt_len bytes */
  93. if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
  94. pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
  95. /* virtio_transport_get_credit might return less than pkt_len credit */
  96. pkt_len = virtio_transport_get_credit(vvs, pkt_len);
  97. /* Do not send zero length OP_RW pkt */
  98. if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
  99. return pkt_len;
  100. pkt = virtio_transport_alloc_pkt(info, pkt_len,
  101. src_cid, src_port,
  102. dst_cid, dst_port);
  103. if (!pkt) {
  104. virtio_transport_put_credit(vvs, pkt_len);
  105. return -ENOMEM;
  106. }
  107. virtio_transport_inc_tx_pkt(vvs, pkt);
  108. return virtio_transport_get_ops()->send_pkt(pkt);
  109. }
  110. static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
  111. struct virtio_vsock_pkt *pkt)
  112. {
  113. vvs->rx_bytes += pkt->len;
  114. }
  115. static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
  116. struct virtio_vsock_pkt *pkt)
  117. {
  118. vvs->rx_bytes -= pkt->len;
  119. vvs->fwd_cnt += pkt->len;
  120. }
  121. void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
  122. {
  123. spin_lock_bh(&vvs->tx_lock);
  124. pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
  125. pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
  126. spin_unlock_bh(&vvs->tx_lock);
  127. }
  128. EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
  129. u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
  130. {
  131. u32 ret;
  132. spin_lock_bh(&vvs->tx_lock);
  133. ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  134. if (ret > credit)
  135. ret = credit;
  136. vvs->tx_cnt += ret;
  137. spin_unlock_bh(&vvs->tx_lock);
  138. return ret;
  139. }
  140. EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
  141. void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
  142. {
  143. spin_lock_bh(&vvs->tx_lock);
  144. vvs->tx_cnt -= credit;
  145. spin_unlock_bh(&vvs->tx_lock);
  146. }
  147. EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
  148. static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
  149. int type,
  150. struct virtio_vsock_hdr *hdr)
  151. {
  152. struct virtio_vsock_pkt_info info = {
  153. .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
  154. .type = type,
  155. .vsk = vsk,
  156. };
  157. return virtio_transport_send_pkt_info(vsk, &info);
  158. }
  159. static ssize_t
  160. virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
  161. struct msghdr *msg,
  162. size_t len)
  163. {
  164. struct virtio_vsock_sock *vvs = vsk->trans;
  165. struct virtio_vsock_pkt *pkt;
  166. size_t bytes, total = 0;
  167. int err = -EFAULT;
  168. spin_lock_bh(&vvs->rx_lock);
  169. while (total < len && !list_empty(&vvs->rx_queue)) {
  170. pkt = list_first_entry(&vvs->rx_queue,
  171. struct virtio_vsock_pkt, list);
  172. bytes = len - total;
  173. if (bytes > pkt->len - pkt->off)
  174. bytes = pkt->len - pkt->off;
  175. /* sk_lock is held by caller so no one else can dequeue.
  176. * Unlock rx_lock since memcpy_to_msg() may sleep.
  177. */
  178. spin_unlock_bh(&vvs->rx_lock);
  179. err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
  180. if (err)
  181. goto out;
  182. spin_lock_bh(&vvs->rx_lock);
  183. total += bytes;
  184. pkt->off += bytes;
  185. if (pkt->off == pkt->len) {
  186. virtio_transport_dec_rx_pkt(vvs, pkt);
  187. list_del(&pkt->list);
  188. virtio_transport_free_pkt(pkt);
  189. }
  190. }
  191. spin_unlock_bh(&vvs->rx_lock);
  192. /* Send a credit pkt to peer */
  193. virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
  194. NULL);
  195. return total;
  196. out:
  197. if (total)
  198. err = total;
  199. return err;
  200. }
  201. ssize_t
  202. virtio_transport_stream_dequeue(struct vsock_sock *vsk,
  203. struct msghdr *msg,
  204. size_t len, int flags)
  205. {
  206. if (flags & MSG_PEEK)
  207. return -EOPNOTSUPP;
  208. return virtio_transport_stream_do_dequeue(vsk, msg, len);
  209. }
  210. EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
  211. int
  212. virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
  213. struct msghdr *msg,
  214. size_t len, int flags)
  215. {
  216. return -EOPNOTSUPP;
  217. }
  218. EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
  219. s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
  220. {
  221. struct virtio_vsock_sock *vvs = vsk->trans;
  222. s64 bytes;
  223. spin_lock_bh(&vvs->rx_lock);
  224. bytes = vvs->rx_bytes;
  225. spin_unlock_bh(&vvs->rx_lock);
  226. return bytes;
  227. }
  228. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
  229. static s64 virtio_transport_has_space(struct vsock_sock *vsk)
  230. {
  231. struct virtio_vsock_sock *vvs = vsk->trans;
  232. s64 bytes;
  233. bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
  234. if (bytes < 0)
  235. bytes = 0;
  236. return bytes;
  237. }
  238. s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
  239. {
  240. struct virtio_vsock_sock *vvs = vsk->trans;
  241. s64 bytes;
  242. spin_lock_bh(&vvs->tx_lock);
  243. bytes = virtio_transport_has_space(vsk);
  244. spin_unlock_bh(&vvs->tx_lock);
  245. return bytes;
  246. }
  247. EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
  248. int virtio_transport_do_socket_init(struct vsock_sock *vsk,
  249. struct vsock_sock *psk)
  250. {
  251. struct virtio_vsock_sock *vvs;
  252. vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
  253. if (!vvs)
  254. return -ENOMEM;
  255. vsk->trans = vvs;
  256. vvs->vsk = vsk;
  257. if (psk) {
  258. struct virtio_vsock_sock *ptrans = psk->trans;
  259. vvs->buf_size = ptrans->buf_size;
  260. vvs->buf_size_min = ptrans->buf_size_min;
  261. vvs->buf_size_max = ptrans->buf_size_max;
  262. vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
  263. } else {
  264. vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
  265. vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
  266. vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
  267. }
  268. vvs->buf_alloc = vvs->buf_size;
  269. spin_lock_init(&vvs->rx_lock);
  270. spin_lock_init(&vvs->tx_lock);
  271. INIT_LIST_HEAD(&vvs->rx_queue);
  272. return 0;
  273. }
  274. EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
  275. u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
  276. {
  277. struct virtio_vsock_sock *vvs = vsk->trans;
  278. return vvs->buf_size;
  279. }
  280. EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
  281. u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
  282. {
  283. struct virtio_vsock_sock *vvs = vsk->trans;
  284. return vvs->buf_size_min;
  285. }
  286. EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
  287. u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
  288. {
  289. struct virtio_vsock_sock *vvs = vsk->trans;
  290. return vvs->buf_size_max;
  291. }
  292. EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
  293. void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
  294. {
  295. struct virtio_vsock_sock *vvs = vsk->trans;
  296. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  297. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  298. if (val < vvs->buf_size_min)
  299. vvs->buf_size_min = val;
  300. if (val > vvs->buf_size_max)
  301. vvs->buf_size_max = val;
  302. vvs->buf_size = val;
  303. vvs->buf_alloc = val;
  304. }
  305. EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
  306. void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
  307. {
  308. struct virtio_vsock_sock *vvs = vsk->trans;
  309. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  310. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  311. if (val > vvs->buf_size)
  312. vvs->buf_size = val;
  313. vvs->buf_size_min = val;
  314. }
  315. EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
  316. void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
  317. {
  318. struct virtio_vsock_sock *vvs = vsk->trans;
  319. if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
  320. val = VIRTIO_VSOCK_MAX_BUF_SIZE;
  321. if (val < vvs->buf_size)
  322. vvs->buf_size = val;
  323. vvs->buf_size_max = val;
  324. }
  325. EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
  326. int
  327. virtio_transport_notify_poll_in(struct vsock_sock *vsk,
  328. size_t target,
  329. bool *data_ready_now)
  330. {
  331. if (vsock_stream_has_data(vsk))
  332. *data_ready_now = true;
  333. else
  334. *data_ready_now = false;
  335. return 0;
  336. }
  337. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
  338. int
  339. virtio_transport_notify_poll_out(struct vsock_sock *vsk,
  340. size_t target,
  341. bool *space_avail_now)
  342. {
  343. s64 free_space;
  344. free_space = vsock_stream_has_space(vsk);
  345. if (free_space > 0)
  346. *space_avail_now = true;
  347. else if (free_space == 0)
  348. *space_avail_now = false;
  349. return 0;
  350. }
  351. EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
  352. int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
  353. size_t target, struct vsock_transport_recv_notify_data *data)
  354. {
  355. return 0;
  356. }
  357. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
  358. int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
  359. size_t target, struct vsock_transport_recv_notify_data *data)
  360. {
  361. return 0;
  362. }
  363. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
  364. int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
  365. size_t target, struct vsock_transport_recv_notify_data *data)
  366. {
  367. return 0;
  368. }
  369. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
  370. int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
  371. size_t target, ssize_t copied, bool data_read,
  372. struct vsock_transport_recv_notify_data *data)
  373. {
  374. return 0;
  375. }
  376. EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
  377. int virtio_transport_notify_send_init(struct vsock_sock *vsk,
  378. struct vsock_transport_send_notify_data *data)
  379. {
  380. return 0;
  381. }
  382. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
  383. int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
  384. struct vsock_transport_send_notify_data *data)
  385. {
  386. return 0;
  387. }
  388. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
  389. int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
  390. struct vsock_transport_send_notify_data *data)
  391. {
  392. return 0;
  393. }
  394. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
  395. int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
  396. ssize_t written, struct vsock_transport_send_notify_data *data)
  397. {
  398. return 0;
  399. }
  400. EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
  401. u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
  402. {
  403. struct virtio_vsock_sock *vvs = vsk->trans;
  404. return vvs->buf_size;
  405. }
  406. EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
  407. bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
  408. {
  409. return true;
  410. }
  411. EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
  412. bool virtio_transport_stream_allow(u32 cid, u32 port)
  413. {
  414. return true;
  415. }
  416. EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
  417. int virtio_transport_dgram_bind(struct vsock_sock *vsk,
  418. struct sockaddr_vm *addr)
  419. {
  420. return -EOPNOTSUPP;
  421. }
  422. EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
  423. bool virtio_transport_dgram_allow(u32 cid, u32 port)
  424. {
  425. return false;
  426. }
  427. EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
  428. int virtio_transport_connect(struct vsock_sock *vsk)
  429. {
  430. struct virtio_vsock_pkt_info info = {
  431. .op = VIRTIO_VSOCK_OP_REQUEST,
  432. .type = VIRTIO_VSOCK_TYPE_STREAM,
  433. .vsk = vsk,
  434. };
  435. return virtio_transport_send_pkt_info(vsk, &info);
  436. }
  437. EXPORT_SYMBOL_GPL(virtio_transport_connect);
  438. int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
  439. {
  440. struct virtio_vsock_pkt_info info = {
  441. .op = VIRTIO_VSOCK_OP_SHUTDOWN,
  442. .type = VIRTIO_VSOCK_TYPE_STREAM,
  443. .flags = (mode & RCV_SHUTDOWN ?
  444. VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
  445. (mode & SEND_SHUTDOWN ?
  446. VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
  447. .vsk = vsk,
  448. };
  449. return virtio_transport_send_pkt_info(vsk, &info);
  450. }
  451. EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
  452. int
  453. virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
  454. struct sockaddr_vm *remote_addr,
  455. struct msghdr *msg,
  456. size_t dgram_len)
  457. {
  458. return -EOPNOTSUPP;
  459. }
  460. EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
  461. ssize_t
  462. virtio_transport_stream_enqueue(struct vsock_sock *vsk,
  463. struct msghdr *msg,
  464. size_t len)
  465. {
  466. struct virtio_vsock_pkt_info info = {
  467. .op = VIRTIO_VSOCK_OP_RW,
  468. .type = VIRTIO_VSOCK_TYPE_STREAM,
  469. .msg = msg,
  470. .pkt_len = len,
  471. .vsk = vsk,
  472. };
  473. return virtio_transport_send_pkt_info(vsk, &info);
  474. }
  475. EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
  476. void virtio_transport_destruct(struct vsock_sock *vsk)
  477. {
  478. struct virtio_vsock_sock *vvs = vsk->trans;
  479. kfree(vvs);
  480. }
  481. EXPORT_SYMBOL_GPL(virtio_transport_destruct);
  482. static int virtio_transport_reset(struct vsock_sock *vsk,
  483. struct virtio_vsock_pkt *pkt)
  484. {
  485. struct virtio_vsock_pkt_info info = {
  486. .op = VIRTIO_VSOCK_OP_RST,
  487. .type = VIRTIO_VSOCK_TYPE_STREAM,
  488. .reply = !!pkt,
  489. .vsk = vsk,
  490. };
  491. /* Send RST only if the original pkt is not a RST pkt */
  492. if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  493. return 0;
  494. return virtio_transport_send_pkt_info(vsk, &info);
  495. }
  496. /* Normally packets are associated with a socket. There may be no socket if an
  497. * attempt was made to connect to a socket that does not exist.
  498. */
  499. static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
  500. {
  501. const struct virtio_transport *t;
  502. struct virtio_vsock_pkt *reply;
  503. struct virtio_vsock_pkt_info info = {
  504. .op = VIRTIO_VSOCK_OP_RST,
  505. .type = le16_to_cpu(pkt->hdr.type),
  506. .reply = true,
  507. };
  508. /* Send RST only if the original pkt is not a RST pkt */
  509. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  510. return 0;
  511. reply = virtio_transport_alloc_pkt(&info, 0,
  512. le64_to_cpu(pkt->hdr.dst_cid),
  513. le32_to_cpu(pkt->hdr.dst_port),
  514. le64_to_cpu(pkt->hdr.src_cid),
  515. le32_to_cpu(pkt->hdr.src_port));
  516. if (!reply)
  517. return -ENOMEM;
  518. t = virtio_transport_get_ops();
  519. if (!t) {
  520. virtio_transport_free_pkt(reply);
  521. return -ENOTCONN;
  522. }
  523. return t->send_pkt(reply);
  524. }
  525. static void virtio_transport_wait_close(struct sock *sk, long timeout)
  526. {
  527. if (timeout) {
  528. DEFINE_WAIT(wait);
  529. do {
  530. prepare_to_wait(sk_sleep(sk), &wait,
  531. TASK_INTERRUPTIBLE);
  532. if (sk_wait_event(sk, &timeout,
  533. sock_flag(sk, SOCK_DONE)))
  534. break;
  535. } while (!signal_pending(current) && timeout);
  536. finish_wait(sk_sleep(sk), &wait);
  537. }
  538. }
  539. static void virtio_transport_do_close(struct vsock_sock *vsk,
  540. bool cancel_timeout)
  541. {
  542. struct sock *sk = sk_vsock(vsk);
  543. sock_set_flag(sk, SOCK_DONE);
  544. vsk->peer_shutdown = SHUTDOWN_MASK;
  545. if (vsock_stream_has_data(vsk) <= 0)
  546. sk->sk_state = SS_DISCONNECTING;
  547. sk->sk_state_change(sk);
  548. if (vsk->close_work_scheduled &&
  549. (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
  550. vsk->close_work_scheduled = false;
  551. vsock_remove_sock(vsk);
  552. /* Release refcnt obtained when we scheduled the timeout */
  553. sock_put(sk);
  554. }
  555. }
  556. static void virtio_transport_close_timeout(struct work_struct *work)
  557. {
  558. struct vsock_sock *vsk =
  559. container_of(work, struct vsock_sock, close_work.work);
  560. struct sock *sk = sk_vsock(vsk);
  561. sock_hold(sk);
  562. lock_sock(sk);
  563. if (!sock_flag(sk, SOCK_DONE)) {
  564. (void)virtio_transport_reset(vsk, NULL);
  565. virtio_transport_do_close(vsk, false);
  566. }
  567. vsk->close_work_scheduled = false;
  568. release_sock(sk);
  569. sock_put(sk);
  570. }
  571. /* User context, vsk->sk is locked */
  572. static bool virtio_transport_close(struct vsock_sock *vsk)
  573. {
  574. struct sock *sk = &vsk->sk;
  575. if (!(sk->sk_state == SS_CONNECTED ||
  576. sk->sk_state == SS_DISCONNECTING))
  577. return true;
  578. /* Already received SHUTDOWN from peer, reply with RST */
  579. if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
  580. (void)virtio_transport_reset(vsk, NULL);
  581. return true;
  582. }
  583. if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
  584. (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
  585. if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
  586. virtio_transport_wait_close(sk, sk->sk_lingertime);
  587. if (sock_flag(sk, SOCK_DONE)) {
  588. return true;
  589. }
  590. sock_hold(sk);
  591. INIT_DELAYED_WORK(&vsk->close_work,
  592. virtio_transport_close_timeout);
  593. vsk->close_work_scheduled = true;
  594. schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
  595. return false;
  596. }
  597. void virtio_transport_release(struct vsock_sock *vsk)
  598. {
  599. struct sock *sk = &vsk->sk;
  600. bool remove_sock = true;
  601. lock_sock(sk);
  602. if (sk->sk_type == SOCK_STREAM)
  603. remove_sock = virtio_transport_close(vsk);
  604. release_sock(sk);
  605. if (remove_sock)
  606. vsock_remove_sock(vsk);
  607. }
  608. EXPORT_SYMBOL_GPL(virtio_transport_release);
  609. static int
  610. virtio_transport_recv_connecting(struct sock *sk,
  611. struct virtio_vsock_pkt *pkt)
  612. {
  613. struct vsock_sock *vsk = vsock_sk(sk);
  614. int err;
  615. int skerr;
  616. switch (le16_to_cpu(pkt->hdr.op)) {
  617. case VIRTIO_VSOCK_OP_RESPONSE:
  618. sk->sk_state = SS_CONNECTED;
  619. sk->sk_socket->state = SS_CONNECTED;
  620. vsock_insert_connected(vsk);
  621. sk->sk_state_change(sk);
  622. break;
  623. case VIRTIO_VSOCK_OP_INVALID:
  624. break;
  625. case VIRTIO_VSOCK_OP_RST:
  626. skerr = ECONNRESET;
  627. err = 0;
  628. goto destroy;
  629. default:
  630. skerr = EPROTO;
  631. err = -EINVAL;
  632. goto destroy;
  633. }
  634. return 0;
  635. destroy:
  636. virtio_transport_reset(vsk, pkt);
  637. sk->sk_state = SS_UNCONNECTED;
  638. sk->sk_err = skerr;
  639. sk->sk_error_report(sk);
  640. return err;
  641. }
  642. static int
  643. virtio_transport_recv_connected(struct sock *sk,
  644. struct virtio_vsock_pkt *pkt)
  645. {
  646. struct vsock_sock *vsk = vsock_sk(sk);
  647. struct virtio_vsock_sock *vvs = vsk->trans;
  648. int err = 0;
  649. switch (le16_to_cpu(pkt->hdr.op)) {
  650. case VIRTIO_VSOCK_OP_RW:
  651. pkt->len = le32_to_cpu(pkt->hdr.len);
  652. pkt->off = 0;
  653. spin_lock_bh(&vvs->rx_lock);
  654. virtio_transport_inc_rx_pkt(vvs, pkt);
  655. list_add_tail(&pkt->list, &vvs->rx_queue);
  656. spin_unlock_bh(&vvs->rx_lock);
  657. sk->sk_data_ready(sk);
  658. return err;
  659. case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
  660. sk->sk_write_space(sk);
  661. break;
  662. case VIRTIO_VSOCK_OP_SHUTDOWN:
  663. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
  664. vsk->peer_shutdown |= RCV_SHUTDOWN;
  665. if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
  666. vsk->peer_shutdown |= SEND_SHUTDOWN;
  667. if (vsk->peer_shutdown == SHUTDOWN_MASK &&
  668. vsock_stream_has_data(vsk) <= 0)
  669. sk->sk_state = SS_DISCONNECTING;
  670. if (le32_to_cpu(pkt->hdr.flags))
  671. sk->sk_state_change(sk);
  672. break;
  673. case VIRTIO_VSOCK_OP_RST:
  674. virtio_transport_do_close(vsk, true);
  675. break;
  676. default:
  677. err = -EINVAL;
  678. break;
  679. }
  680. virtio_transport_free_pkt(pkt);
  681. return err;
  682. }
  683. static void
  684. virtio_transport_recv_disconnecting(struct sock *sk,
  685. struct virtio_vsock_pkt *pkt)
  686. {
  687. struct vsock_sock *vsk = vsock_sk(sk);
  688. if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
  689. virtio_transport_do_close(vsk, true);
  690. }
  691. static int
  692. virtio_transport_send_response(struct vsock_sock *vsk,
  693. struct virtio_vsock_pkt *pkt)
  694. {
  695. struct virtio_vsock_pkt_info info = {
  696. .op = VIRTIO_VSOCK_OP_RESPONSE,
  697. .type = VIRTIO_VSOCK_TYPE_STREAM,
  698. .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
  699. .remote_port = le32_to_cpu(pkt->hdr.src_port),
  700. .reply = true,
  701. .vsk = vsk,
  702. };
  703. return virtio_transport_send_pkt_info(vsk, &info);
  704. }
  705. /* Handle server socket */
  706. static int
  707. virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
  708. {
  709. struct vsock_sock *vsk = vsock_sk(sk);
  710. struct vsock_sock *vchild;
  711. struct sock *child;
  712. if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
  713. virtio_transport_reset(vsk, pkt);
  714. return -EINVAL;
  715. }
  716. if (sk_acceptq_is_full(sk)) {
  717. virtio_transport_reset(vsk, pkt);
  718. return -ENOMEM;
  719. }
  720. child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
  721. sk->sk_type, 0);
  722. if (!child) {
  723. virtio_transport_reset(vsk, pkt);
  724. return -ENOMEM;
  725. }
  726. sk->sk_ack_backlog++;
  727. lock_sock_nested(child, SINGLE_DEPTH_NESTING);
  728. child->sk_state = SS_CONNECTED;
  729. vchild = vsock_sk(child);
  730. vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
  731. le32_to_cpu(pkt->hdr.dst_port));
  732. vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
  733. le32_to_cpu(pkt->hdr.src_port));
  734. vsock_insert_connected(vchild);
  735. vsock_enqueue_accept(sk, child);
  736. virtio_transport_send_response(vchild, pkt);
  737. release_sock(child);
  738. sk->sk_data_ready(sk);
  739. return 0;
  740. }
  741. static bool virtio_transport_space_update(struct sock *sk,
  742. struct virtio_vsock_pkt *pkt)
  743. {
  744. struct vsock_sock *vsk = vsock_sk(sk);
  745. struct virtio_vsock_sock *vvs = vsk->trans;
  746. bool space_available;
  747. /* buf_alloc and fwd_cnt is always included in the hdr */
  748. spin_lock_bh(&vvs->tx_lock);
  749. vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
  750. vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
  751. space_available = virtio_transport_has_space(vsk);
  752. spin_unlock_bh(&vvs->tx_lock);
  753. return space_available;
  754. }
  755. /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
  756. * lock.
  757. */
  758. void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
  759. {
  760. struct sockaddr_vm src, dst;
  761. struct vsock_sock *vsk;
  762. struct sock *sk;
  763. bool space_available;
  764. vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
  765. le32_to_cpu(pkt->hdr.src_port));
  766. vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
  767. le32_to_cpu(pkt->hdr.dst_port));
  768. trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
  769. dst.svm_cid, dst.svm_port,
  770. le32_to_cpu(pkt->hdr.len),
  771. le16_to_cpu(pkt->hdr.type),
  772. le16_to_cpu(pkt->hdr.op),
  773. le32_to_cpu(pkt->hdr.flags),
  774. le32_to_cpu(pkt->hdr.buf_alloc),
  775. le32_to_cpu(pkt->hdr.fwd_cnt));
  776. if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
  777. (void)virtio_transport_reset_no_sock(pkt);
  778. goto free_pkt;
  779. }
  780. /* The socket must be in connected or bound table
  781. * otherwise send reset back
  782. */
  783. sk = vsock_find_connected_socket(&src, &dst);
  784. if (!sk) {
  785. sk = vsock_find_bound_socket(&dst);
  786. if (!sk) {
  787. (void)virtio_transport_reset_no_sock(pkt);
  788. goto free_pkt;
  789. }
  790. }
  791. vsk = vsock_sk(sk);
  792. space_available = virtio_transport_space_update(sk, pkt);
  793. lock_sock(sk);
  794. /* Update CID in case it has changed after a transport reset event */
  795. vsk->local_addr.svm_cid = dst.svm_cid;
  796. if (space_available)
  797. sk->sk_write_space(sk);
  798. switch (sk->sk_state) {
  799. case VSOCK_SS_LISTEN:
  800. virtio_transport_recv_listen(sk, pkt);
  801. virtio_transport_free_pkt(pkt);
  802. break;
  803. case SS_CONNECTING:
  804. virtio_transport_recv_connecting(sk, pkt);
  805. virtio_transport_free_pkt(pkt);
  806. break;
  807. case SS_CONNECTED:
  808. virtio_transport_recv_connected(sk, pkt);
  809. break;
  810. case SS_DISCONNECTING:
  811. virtio_transport_recv_disconnecting(sk, pkt);
  812. virtio_transport_free_pkt(pkt);
  813. break;
  814. default:
  815. virtio_transport_free_pkt(pkt);
  816. break;
  817. }
  818. release_sock(sk);
  819. /* Release refcnt obtained when we fetched this socket out of the
  820. * bound or connected list.
  821. */
  822. sock_put(sk);
  823. return;
  824. free_pkt:
  825. virtio_transport_free_pkt(pkt);
  826. }
  827. EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
  828. void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
  829. {
  830. kfree(pkt->buf);
  831. kfree(pkt);
  832. }
  833. EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
  834. MODULE_LICENSE("GPL v2");
  835. MODULE_AUTHOR("Asias He");
  836. MODULE_DESCRIPTION("common code for virtio vsock");