TraversalClient.cpp 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. // This file is public domain, in case it's useful to anyone. -comex
  2. #include "Common/Timer.h"
  3. #include "Common/TraversalClient.h"
  4. static void GetRandomishBytes(u8* buf, size_t size)
  5. {
  6. // We don't need high quality random numbers (which might not be available),
  7. // just non-repeating numbers!
  8. static std::mt19937 prng(enet_time_get());
  9. static std::uniform_int_distribution<unsigned int> u8_distribution(0, 255);
  10. for (size_t i = 0; i < size; i++)
  11. buf[i] = u8_distribution(prng);
  12. }
  13. TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port)
  14. : m_NetHost(netHost)
  15. , m_Client(nullptr)
  16. , m_FailureReason(0)
  17. , m_ConnectRequestId(0)
  18. , m_PendingConnect(false)
  19. , m_Server(server)
  20. , m_port(port)
  21. , m_PingTime(0)
  22. {
  23. netHost->intercept = TraversalClient::InterceptCallback;
  24. Reset();
  25. ReconnectToServer();
  26. }
  27. TraversalClient::~TraversalClient()
  28. {
  29. }
  30. void TraversalClient::ReconnectToServer()
  31. {
  32. if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
  33. {
  34. OnFailure(BadHost);
  35. return;
  36. }
  37. m_ServerAddress.port = m_port;
  38. m_State = Connecting;
  39. TraversalPacket hello = {};
  40. hello.type = TraversalPacketHelloFromClient;
  41. hello.helloFromClient.protoVersion = TraversalProtoVersion;
  42. SendTraversalPacket(hello);
  43. if (m_Client)
  44. m_Client->OnTraversalStateChanged();
  45. }
  46. static ENetAddress MakeENetAddress(TraversalInetAddress* address)
  47. {
  48. ENetAddress eaddr;
  49. if (address->isIPV6)
  50. {
  51. eaddr.port = 0; // no support yet :(
  52. }
  53. else
  54. {
  55. eaddr.host = address->address[0];
  56. eaddr.port = ntohs(address->port);
  57. }
  58. return eaddr;
  59. }
  60. void TraversalClient::ConnectToClient(const std::string& host)
  61. {
  62. if (host.size() > sizeof(TraversalHostId))
  63. {
  64. PanicAlert("host too long");
  65. return;
  66. }
  67. TraversalPacket packet = {};
  68. packet.type = TraversalPacketConnectPlease;
  69. memcpy(packet.connectPlease.hostId.data(), host.c_str(), host.size());
  70. m_ConnectRequestId = SendTraversalPacket(packet);
  71. m_PendingConnect = true;
  72. }
  73. bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
  74. {
  75. if (from->host == m_ServerAddress.host &&
  76. from->port == m_ServerAddress.port)
  77. {
  78. if (size < sizeof(TraversalPacket))
  79. {
  80. ERROR_LOG(NETPLAY, "Received too-short traversal packet.");
  81. }
  82. else
  83. {
  84. HandleServerPacket((TraversalPacket*) data);
  85. return true;
  86. }
  87. }
  88. return false;
  89. }
  90. //--Temporary until more of the old netplay branch is moved over
  91. void TraversalClient::Update()
  92. {
  93. ENetEvent netEvent;
  94. if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
  95. {
  96. switch (netEvent.type)
  97. {
  98. case ENET_EVENT_TYPE_RECEIVE:
  99. TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
  100. enet_packet_destroy(netEvent.packet);
  101. break;
  102. default:
  103. break;
  104. }
  105. }
  106. HandleResends();
  107. }
  108. void TraversalClient::HandleServerPacket(TraversalPacket* packet)
  109. {
  110. u8 ok = 1;
  111. switch (packet->type)
  112. {
  113. case TraversalPacketAck:
  114. if (!packet->ack.ok)
  115. {
  116. OnFailure(ServerForgotAboutUs);
  117. break;
  118. }
  119. for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
  120. {
  121. if (it->packet.requestId == packet->requestId)
  122. {
  123. m_OutgoingTraversalPackets.erase(it);
  124. break;
  125. }
  126. }
  127. break;
  128. case TraversalPacketHelloFromServer:
  129. if (m_State != Connecting)
  130. break;
  131. if (!packet->helloFromServer.ok)
  132. {
  133. OnFailure(VersionTooOld);
  134. break;
  135. }
  136. m_HostId = packet->helloFromServer.yourHostId;
  137. m_State = Connected;
  138. if (m_Client)
  139. m_Client->OnTraversalStateChanged();
  140. break;
  141. case TraversalPacketPleaseSendPacket:
  142. {
  143. // security is overrated.
  144. ENetAddress addr = MakeENetAddress(&packet->pleaseSendPacket.address);
  145. if (addr.port != 0)
  146. {
  147. char message[] = "Hello from Dolphin Netplay...";
  148. ENetBuffer buf;
  149. buf.data = message;
  150. buf.dataLength = sizeof(message) - 1;
  151. enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
  152. }
  153. else
  154. {
  155. // invalid IPV6
  156. ok = 0;
  157. }
  158. break;
  159. }
  160. case TraversalPacketConnectReady:
  161. case TraversalPacketConnectFailed:
  162. {
  163. if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
  164. break;
  165. m_PendingConnect = false;
  166. if (!m_Client)
  167. break;
  168. if (packet->type == TraversalPacketConnectReady)
  169. m_Client->OnConnectReady(MakeENetAddress(&packet->connectReady.address));
  170. else
  171. m_Client->OnConnectFailed(packet->connectFailed.reason);
  172. break;
  173. }
  174. default:
  175. WARN_LOG(NETPLAY, "Received unknown packet with type %d", packet->type);
  176. break;
  177. }
  178. if (packet->type != TraversalPacketAck)
  179. {
  180. TraversalPacket ack = {};
  181. ack.type = TraversalPacketAck;
  182. ack.requestId = packet->requestId;
  183. ack.ack.ok = ok;
  184. ENetBuffer buf;
  185. buf.data = &ack;
  186. buf.dataLength = sizeof(ack);
  187. if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
  188. OnFailure(SocketSendError);
  189. }
  190. }
  191. void TraversalClient::OnFailure(FailureReason reason)
  192. {
  193. m_State = Failure;
  194. m_FailureReason = reason;
  195. switch (reason)
  196. {
  197. case TraversalClient::BadHost:
  198. PanicAlertT("Couldn't look up central server %s", m_Server.c_str());
  199. break;
  200. case TraversalClient::VersionTooOld:
  201. PanicAlertT("Dolphin too old for traversal server");
  202. break;
  203. case TraversalClient::ServerForgotAboutUs:
  204. PanicAlertT("Disconnected from traversal server");
  205. break;
  206. case TraversalClient::SocketSendError:
  207. PanicAlertT("Socket error sending to traversal server");
  208. break;
  209. case TraversalClient::ResendTimeout:
  210. PanicAlertT("Timeout connecting to traversal server");
  211. break;
  212. }
  213. if (m_Client)
  214. m_Client->OnTraversalStateChanged();
  215. }
  216. void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
  217. {
  218. info->sendTime = enet_time_get();
  219. info->tries++;
  220. ENetBuffer buf;
  221. buf.data = &info->packet;
  222. buf.dataLength = sizeof(info->packet);
  223. if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
  224. OnFailure(SocketSendError);
  225. }
  226. void TraversalClient::HandleResends()
  227. {
  228. enet_uint32 now = enet_time_get();
  229. for (auto& tpi : m_OutgoingTraversalPackets)
  230. {
  231. if (now - tpi.sendTime >= (u32) (300 * tpi.tries))
  232. {
  233. if (tpi.tries >= 5)
  234. {
  235. OnFailure(ResendTimeout);
  236. m_OutgoingTraversalPackets.clear();
  237. break;
  238. }
  239. else
  240. {
  241. ResendPacket(&tpi);
  242. }
  243. }
  244. }
  245. HandlePing();
  246. }
  247. void TraversalClient::HandlePing()
  248. {
  249. enet_uint32 now = enet_time_get();
  250. if (m_State == Connected && now - m_PingTime >= 500)
  251. {
  252. TraversalPacket ping = {};
  253. ping.type = TraversalPacketPing;
  254. ping.ping.hostId = m_HostId;
  255. SendTraversalPacket(ping);
  256. m_PingTime = now;
  257. }
  258. }
  259. TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
  260. {
  261. OutgoingTraversalPacketInfo info;
  262. info.packet = packet;
  263. GetRandomishBytes((u8*) &info.packet.requestId, sizeof(info.packet.requestId));
  264. info.tries = 0;
  265. m_OutgoingTraversalPackets.push_back(info);
  266. ResendPacket(&m_OutgoingTraversalPackets.back());
  267. return info.packet.requestId;
  268. }
  269. void TraversalClient::Reset()
  270. {
  271. m_PendingConnect = false;
  272. m_Client = nullptr;
  273. }
  274. int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
  275. {
  276. auto traversalClient = g_TraversalClient.get();
  277. if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength, &host->receivedAddress)
  278. || (host->receivedDataLength == 1 && host->receivedData[0] == 0))
  279. {
  280. event->type = (ENetEventType)42;
  281. return 1;
  282. }
  283. return 0;
  284. }
  285. std::unique_ptr<TraversalClient> g_TraversalClient;
  286. std::unique_ptr<ENetHost> g_MainNetHost;
  287. // The settings at the previous TraversalClient reset - notably, we
  288. // need to know not just what port it's on, but whether it was
  289. // explicitly requested.
  290. static std::string g_OldServer;
  291. static u16 g_OldPort;
  292. bool EnsureTraversalClient(const std::string& server, u16 port)
  293. {
  294. if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer || port != g_OldPort)
  295. {
  296. g_OldServer = server;
  297. g_OldPort = port ;
  298. ENetAddress addr = { ENET_HOST_ANY, 0 };
  299. ENetHost* host = enet_host_create(
  300. &addr, // address
  301. 50, // peerCount
  302. 1, // channelLimit
  303. 0, // incomingBandwidth
  304. 0); // outgoingBandwidth
  305. if (!host)
  306. {
  307. g_MainNetHost.reset();
  308. return false;
  309. }
  310. g_MainNetHost.reset(host);
  311. g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server, port));
  312. }
  313. return true;
  314. }
  315. void ReleaseTraversalClient()
  316. {
  317. if (!g_TraversalClient)
  318. return;
  319. g_TraversalClient.release();
  320. g_MainNetHost.release();
  321. }