123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371 |
- #! /usr/bin/env python
- # -*- coding: utf-8 -*-
- # SPDX-FileCopyrightText: Copyright (C) 2022-2023 MH3SP Server Project
- # SPDX-License-Identifier: AGPL-3.0-or-later
- """Monster Hunter PAT Server module."""
- import multiprocessing
- import random
- import socket
- import struct
- import threading
- from mh.time_utils import Timer
- from other.utils import wii_ssl_wrap_socket
- try:
- # Python 3
- import queue
- import selectors
- except ImportError:
- # Python 2
- import Queue as queue
- import externals.selectors2 as selectors
- try:
- from typing import List, Tuple # noqa: F401
- except ImportError:
- pass
- class BasicPatHandler(object):
- def __init__(self, socket, client_address, server):
- # type: (socket.socket, Tuple[str, int], BasicPatServer) -> None
- self.socket = socket
- self.client_address = client_address
- self.server = server
- self.finished = False
- self.rw = threading.Lock()
- self.setup()
- def fileno(self):
- # type: () -> int
- return self.socket.fileno()
- def setup(self):
- self.rfile = self.socket.makefile('rb', -1)
- self.wfile = self.socket.makefile('wb', 0)
- self.on_init()
- def on_init(self):
- """Called after setup"""
- pass
- def on_exception(self, e):
- # type: (Exception) -> None
- """Called when during recv/write an exception ocurred"""
- pass
- def on_recv(self):
- """Called when the socket have bytes to be readed
- ** This method would be called by the server thread
- """
- header = self.rfile.read(8)
- if not len(header):
- # The socket was closed by externally
- return None
- if len(header) < 8:
- # Invalid packet header
- return None
- return self.recv_packet(header)
- def on_packet(self, data):
- """ Called when there is a packet to be handled
- ** This method would be called from a worker thread (Not Thread Safe)
- """
- def recv_packet(self, header):
- """Receive PAT packet."""
- size, seq, packet_id = struct.unpack(">HHI", header)
- data = self.rfile.read(size)
- return packet_id, data, seq
- def send_packet(self, packet_id=0, data=b'', seq=0):
- """Send PAT packet."""
- self.wfile.write(struct.pack(
- ">HHI",
- len(data), seq, packet_id
- ))
- if data:
- self.wfile.write(data)
- def on_tick(self):
- """Called every time the server tick
- ** Currently executed from the server thread
- """
- pass
- def on_finish(self):
- """Called before finish"""
- pass
- def is_finished(self):
- return self.finished
- def finish(self):
- """Called when the handler is being disposed"""
- if self.finished:
- return
- try:
- self.on_finish()
- except Exception:
- pass
- self.finished = True
- try:
- self.wfile.close()
- except Exception:
- pass
- try:
- self.rfile.close()
- except Exception:
- pass
- class BasicPatServer(object):
- socket_queue_size = 5
- address_family = socket.AF_INET
- socket_type = socket.SOCK_STREAM
- def __init__(self, server_address, RequestHandlerClass, max_threads,
- bind_and_activate=True, ssl_cert=None, ssl_key=None):
- # type: (Tuple[str, int], BasicPatHandler, int, bool, str|None, str|None) -> None
- """Constructor. May be extended, do not override."""
- self.server_address = server_address
- self.RequestHandlerClass = RequestHandlerClass
- self.__is_shut_down = threading.Event()
- self.__is_shut_down.set()
- self.__shutdown_request = False
- self.socket = socket.socket(self.address_family, self.socket_type)
- self._random = random.SystemRandom() # type: random.SystemRandom
- self.handlers = [] # type: List[BasicPatHandler]
- self.worker_threads = [] # type: List[threading.Thread]
- self.worker_queues = [] # type: list[queue.queue]
- self.selector = selectors.DefaultSelector()
- self.max_threads = max_threads or multiprocessing.cpu_count()
- self.ssl_cert = ssl_cert
- self.ssl_key = ssl_key
- if bind_and_activate:
- try:
- self.server_bind()
- self.server_activate()
- except Exception:
- self.close()
- raise
- def server_bind(self):
- self.socket.bind(self.server_address)
- self.server_address = self.socket.getsockname()
- def server_activate(self):
- self.socket.listen(0)
- def fileno(self):
- """Return server socket file descriptor.
- Interface required by selector.
- """
- return self.socket.fileno()
- def initialize_workers(self):
- """Initialize workers queues/threads.
- This needs to be deferred, otherwise the close method might try to
- join threads that aren't started yet when an error occurs early.
- """
- for n in range(self.max_threads):
- thread_queue = queue.Queue()
- thread = threading.Thread(
- target=self._worker_target,
- args=(thread_queue,),
- name="{}.Worker-{}".format(self.__class__.__name__, n)
- )
- self.worker_queues.append(thread_queue)
- self.worker_threads.append(thread)
- thread.start()
- def serve_forever(self):
- self.__is_shut_down.clear()
- try:
- self.initialize_workers()
- with self.selector as selector:
- selector.register(self, selectors.EVENT_READ)
- write_watch = Timer()
- write_timeout = 1 # Seconds
- while not self.__shutdown_request:
- ready = selector.select(write_timeout)
- if self.__shutdown_request:
- break
- for (key, event) in ready:
- selected = key.fileobj
- if selected == self:
- self.accept_new_connection()
- else:
- assert event == selectors.EVENT_READ
- try:
- packet = selected.on_recv()
- if packet is None:
- if selected.is_finished():
- self.remove_handler(selected)
- continue
- self._queue_work(selected, packet, event)
- except Exception as e:
- selected.on_exception(e)
- if selected.is_finished():
- self.remove_handler(selected)
- if write_watch.elapsed() >= write_timeout:
- try:
- for handler in self.handlers:
- try:
- handler.on_tick()
- except Exception as e:
- handler.on_exception(e)
- if handler.is_finished():
- self.remove_handler(handler)
- finally:
- write_watch.restart()
- finally:
- self.__is_shut_down.set()
- def _worker_target(self, work_queue):
- # type: (queue.Queue) -> None
- while not self.__shutdown_request:
- try:
- handler, packet, event = work_queue.get(block=True)
- except queue.Empty:
- continue
- if self.__shutdown_request:
- break
- if handler.is_finished():
- continue
- assert event == selectors.EVENT_READ
- try:
- handler.on_packet(packet)
- except Exception as e:
- handler.on_exception(e)
- if handler.is_finished():
- self.remove_handler(handler)
- def accept_new_connection(self):
- # type: () -> None
- try:
- client_socket, client_address = self.socket.accept()
- except Exception as e:
- self.error('Error accepting connection (1). {}'.format(e))
- return
- try:
- # TODO: Find a cleaner way to process ill-formed packets.
- # Currently, they get stuck on `packet = selected.on_recv()`,
- # thus blocking the `serve_forever` method.
- client_socket.settimeout(2.0)
- # TODO: Ensure this is the correct way to fix the server not
- # accepting SSL connection anymore.
- #
- # See https://stackoverflow.com/a/68214507
- if self.ssl_cert and self.ssl_key:
- client_socket = wii_ssl_wrap_socket(
- client_socket, self.ssl_cert, self.ssl_key
- )
- handler = self.RequestHandlerClass(client_socket, client_address,
- self)
- except Exception as e:
- self.error('Error accepting connection (2). {}'.format(e))
- return
- handler.__worker_thread = \
- self._random.randint(0, len(self.worker_queues)-1)
- self.selector.register(handler, selectors.EVENT_READ)
- self.handlers.append(handler)
- def _queue_work(self, handler, work_data, event):
- # type: (BasicPatHandler, any, int) -> None
- if handler.is_finished():
- return
- thread_queue = self.worker_queues[handler.__worker_thread]
- thread_queue.put((handler, work_data, event), block=True)
- def remove_handler(self, handler):
- # type: (BasicPatHandler) -> None
- try:
- self.handlers.remove(handler)
- except Exception:
- pass
- try:
- self.selector.unregister(handler)
- except Exception:
- pass
- try:
- handler.finish()
- except Exception:
- pass
- try:
- handler.socket.close()
- except Exception:
- pass
- def close(self):
- """Called to clean-up the server.
- May be overridden.
- """
- self.__shutdown_request = True
- self.socket.close()
- self.__is_shut_down.wait()
- for h in self.handlers:
- try:
- h.finish()
- except Exception:
- pass
- for q in self.worker_queues:
- q.put((None, None, None), block=True)
- for t in self.worker_threads:
- t.join()
- self.worker_queues = []
- self.selector = None
- self.worker_threads = []
- self.__shutdown_request = False
- self.info('Server Closed')
|