123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357 |
- // This file is public domain, in case it's useful to anyone. -comex
- #include "Common/Timer.h"
- #include "Common/TraversalClient.h"
- static void GetRandomishBytes(u8* buf, size_t size)
- {
- // We don't need high quality random numbers (which might not be available),
- // just non-repeating numbers!
- static std::mt19937 prng(enet_time_get());
- static std::uniform_int_distribution<unsigned int> u8_distribution(0, 255);
- for (size_t i = 0; i < size; i++)
- buf[i] = u8_distribution(prng);
- }
- TraversalClient::TraversalClient(ENetHost* netHost, const std::string& server, const u16 port)
- : m_NetHost(netHost)
- , m_Client(nullptr)
- , m_FailureReason(0)
- , m_ConnectRequestId(0)
- , m_PendingConnect(false)
- , m_Server(server)
- , m_port(port)
- , m_PingTime(0)
- {
- netHost->intercept = TraversalClient::InterceptCallback;
- Reset();
- ReconnectToServer();
- }
- TraversalClient::~TraversalClient()
- {
- }
- void TraversalClient::ReconnectToServer()
- {
- if (enet_address_set_host(&m_ServerAddress, m_Server.c_str()))
- {
- OnFailure(BadHost);
- return;
- }
- m_ServerAddress.port = m_port;
- m_State = Connecting;
- TraversalPacket hello = {};
- hello.type = TraversalPacketHelloFromClient;
- hello.helloFromClient.protoVersion = TraversalProtoVersion;
- SendTraversalPacket(hello);
- if (m_Client)
- m_Client->OnTraversalStateChanged();
- }
- static ENetAddress MakeENetAddress(TraversalInetAddress* address)
- {
- ENetAddress eaddr;
- if (address->isIPV6)
- {
- eaddr.port = 0; // no support yet :(
- }
- else
- {
- eaddr.host = address->address[0];
- eaddr.port = ntohs(address->port);
- }
- return eaddr;
- }
- void TraversalClient::ConnectToClient(const std::string& host)
- {
- if (host.size() > sizeof(TraversalHostId))
- {
- PanicAlert("host too long");
- return;
- }
- TraversalPacket packet = {};
- packet.type = TraversalPacketConnectPlease;
- memcpy(packet.connectPlease.hostId.data(), host.c_str(), host.size());
- m_ConnectRequestId = SendTraversalPacket(packet);
- m_PendingConnect = true;
- }
- bool TraversalClient::TestPacket(u8* data, size_t size, ENetAddress* from)
- {
- if (from->host == m_ServerAddress.host &&
- from->port == m_ServerAddress.port)
- {
- if (size < sizeof(TraversalPacket))
- {
- ERROR_LOG(NETPLAY, "Received too-short traversal packet.");
- }
- else
- {
- HandleServerPacket((TraversalPacket*) data);
- return true;
- }
- }
- return false;
- }
- //--Temporary until more of the old netplay branch is moved over
- void TraversalClient::Update()
- {
- ENetEvent netEvent;
- if (enet_host_service(m_NetHost, &netEvent, 4) > 0)
- {
- switch (netEvent.type)
- {
- case ENET_EVENT_TYPE_RECEIVE:
- TestPacket(netEvent.packet->data, netEvent.packet->dataLength, &netEvent.peer->address);
- enet_packet_destroy(netEvent.packet);
- break;
- default:
- break;
- }
- }
- HandleResends();
- }
- void TraversalClient::HandleServerPacket(TraversalPacket* packet)
- {
- u8 ok = 1;
- switch (packet->type)
- {
- case TraversalPacketAck:
- if (!packet->ack.ok)
- {
- OnFailure(ServerForgotAboutUs);
- break;
- }
- for (auto it = m_OutgoingTraversalPackets.begin(); it != m_OutgoingTraversalPackets.end(); ++it)
- {
- if (it->packet.requestId == packet->requestId)
- {
- m_OutgoingTraversalPackets.erase(it);
- break;
- }
- }
- break;
- case TraversalPacketHelloFromServer:
- if (m_State != Connecting)
- break;
- if (!packet->helloFromServer.ok)
- {
- OnFailure(VersionTooOld);
- break;
- }
- m_HostId = packet->helloFromServer.yourHostId;
- m_State = Connected;
- if (m_Client)
- m_Client->OnTraversalStateChanged();
- break;
- case TraversalPacketPleaseSendPacket:
- {
- // security is overrated.
- ENetAddress addr = MakeENetAddress(&packet->pleaseSendPacket.address);
- if (addr.port != 0)
- {
- char message[] = "Hello from Dolphin Netplay...";
- ENetBuffer buf;
- buf.data = message;
- buf.dataLength = sizeof(message) - 1;
- enet_socket_send(m_NetHost->socket, &addr, &buf, 1);
- }
- else
- {
- // invalid IPV6
- ok = 0;
- }
- break;
- }
- case TraversalPacketConnectReady:
- case TraversalPacketConnectFailed:
- {
- if (!m_PendingConnect || packet->connectReady.requestId != m_ConnectRequestId)
- break;
- m_PendingConnect = false;
- if (!m_Client)
- break;
- if (packet->type == TraversalPacketConnectReady)
- m_Client->OnConnectReady(MakeENetAddress(&packet->connectReady.address));
- else
- m_Client->OnConnectFailed(packet->connectFailed.reason);
- break;
- }
- default:
- WARN_LOG(NETPLAY, "Received unknown packet with type %d", packet->type);
- break;
- }
- if (packet->type != TraversalPacketAck)
- {
- TraversalPacket ack = {};
- ack.type = TraversalPacketAck;
- ack.requestId = packet->requestId;
- ack.ack.ok = ok;
- ENetBuffer buf;
- buf.data = &ack;
- buf.dataLength = sizeof(ack);
- if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
- OnFailure(SocketSendError);
- }
- }
- void TraversalClient::OnFailure(FailureReason reason)
- {
- m_State = Failure;
- m_FailureReason = reason;
- switch (reason)
- {
- case TraversalClient::BadHost:
- PanicAlertT("Couldn't look up central server %s", m_Server.c_str());
- break;
- case TraversalClient::VersionTooOld:
- PanicAlertT("Dolphin too old for traversal server");
- break;
- case TraversalClient::ServerForgotAboutUs:
- PanicAlertT("Disconnected from traversal server");
- break;
- case TraversalClient::SocketSendError:
- PanicAlertT("Socket error sending to traversal server");
- break;
- case TraversalClient::ResendTimeout:
- PanicAlertT("Timeout connecting to traversal server");
- break;
- }
- if (m_Client)
- m_Client->OnTraversalStateChanged();
- }
- void TraversalClient::ResendPacket(OutgoingTraversalPacketInfo* info)
- {
- info->sendTime = enet_time_get();
- info->tries++;
- ENetBuffer buf;
- buf.data = &info->packet;
- buf.dataLength = sizeof(info->packet);
- if (enet_socket_send(m_NetHost->socket, &m_ServerAddress, &buf, 1) == -1)
- OnFailure(SocketSendError);
- }
- void TraversalClient::HandleResends()
- {
- enet_uint32 now = enet_time_get();
- for (auto& tpi : m_OutgoingTraversalPackets)
- {
- if (now - tpi.sendTime >= (u32) (300 * tpi.tries))
- {
- if (tpi.tries >= 5)
- {
- OnFailure(ResendTimeout);
- m_OutgoingTraversalPackets.clear();
- break;
- }
- else
- {
- ResendPacket(&tpi);
- }
- }
- }
- HandlePing();
- }
- void TraversalClient::HandlePing()
- {
- enet_uint32 now = enet_time_get();
- if (m_State == Connected && now - m_PingTime >= 500)
- {
- TraversalPacket ping = {};
- ping.type = TraversalPacketPing;
- ping.ping.hostId = m_HostId;
- SendTraversalPacket(ping);
- m_PingTime = now;
- }
- }
- TraversalRequestId TraversalClient::SendTraversalPacket(const TraversalPacket& packet)
- {
- OutgoingTraversalPacketInfo info;
- info.packet = packet;
- GetRandomishBytes((u8*) &info.packet.requestId, sizeof(info.packet.requestId));
- info.tries = 0;
- m_OutgoingTraversalPackets.push_back(info);
- ResendPacket(&m_OutgoingTraversalPackets.back());
- return info.packet.requestId;
- }
- void TraversalClient::Reset()
- {
- m_PendingConnect = false;
- m_Client = nullptr;
- }
- int ENET_CALLBACK TraversalClient::InterceptCallback(ENetHost* host, ENetEvent* event)
- {
- auto traversalClient = g_TraversalClient.get();
- if (traversalClient->TestPacket(host->receivedData, host->receivedDataLength, &host->receivedAddress)
- || (host->receivedDataLength == 1 && host->receivedData[0] == 0))
- {
- event->type = (ENetEventType)42;
- return 1;
- }
- return 0;
- }
- std::unique_ptr<TraversalClient> g_TraversalClient;
- std::unique_ptr<ENetHost> g_MainNetHost;
- // The settings at the previous TraversalClient reset - notably, we
- // need to know not just what port it's on, but whether it was
- // explicitly requested.
- static std::string g_OldServer;
- static u16 g_OldPort;
- bool EnsureTraversalClient(const std::string& server, u16 port)
- {
- if (!g_MainNetHost || !g_TraversalClient || server != g_OldServer || port != g_OldPort)
- {
- g_OldServer = server;
- g_OldPort = port ;
- ENetAddress addr = { ENET_HOST_ANY, 0 };
- ENetHost* host = enet_host_create(
- &addr, // address
- 50, // peerCount
- 1, // channelLimit
- 0, // incomingBandwidth
- 0); // outgoingBandwidth
- if (!host)
- {
- g_MainNetHost.reset();
- return false;
- }
- g_MainNetHost.reset(host);
- g_TraversalClient.reset(new TraversalClient(g_MainNetHost.get(), server, port));
- }
- return true;
- }
- void ReleaseTraversalClient()
- {
- if (!g_TraversalClient)
- return;
- g_TraversalClient.release();
- g_MainNetHost.release();
- }
|