1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- from threading import Thread
- from typing import Any, Callable
- import base64
- import hashlib
- import json
- import socket
- import struct
- class Server:
- def __init__(self, message_handler: Callable[[str], Any]):
- self._message_handler = message_handler
- self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self._connection = None
- self._active = True
- def start(self, host: str, port: int):
- self._socket.bind((host, port))
- self._socket.listen(1)
- self._thread = Thread(target=self._handle, daemon=True)
- self._thread.start()
- def stop(self):
- self._active = False
- def send(self, msg: dict):
- if not self._connection:
- return
- payload = json.dumps(msg, ensure_ascii=False).encode()
- payload_length = len(payload)
- data = bytearray([0b1000_0001])
- if payload_length <= 125:
- data.append(payload_length)
- data.extend(payload)
- else:
- data.append(126)
- data.extend(struct.pack(">H", ))
- data.extend(payload)
- self._connection.send(data)
- def _handle(self):
- try:
- self._socket.settimeout(10)
- self._connection, _ = self._socket.accept()
- self._connection.settimeout(3)
- self._handle_handshake(self._connection)
- self._connection.settimeout(1)
- while self._active:
- try:
- msg = self._connection.recv(65536)
- code = msg[0] & 0b0000_1111
- if code != 1:
- break
- payload_index = 6
- payload_length = msg[1] & 0b0111_1111
- if payload_length == 126:
- payload_index = 8
- payload_length = msg[2] << 8 & msg[3]
- msg = bytes(value ^ msg[payload_index-4:payload_index][index % 4] for index, value in enumerate(msg[payload_index:]))
- self._message_handler(json.loads(msg.decode()))
- except socket.timeout:
- pass
- except socket.timeout:
- pass
- finally:
- self._socket.close()
- if self._connection:
- self._connection.close()
- def _handle_handshake(self, connection):
- req = connection.recv(65536)
- headers = {}
- for line in req.split(b"\r\n")[1:]:
- if line == b"":
- break
- parts = line.split(b": ", 1)
- headers[parts[0]] = parts[1]
- accept = base64.b64encode(hashlib.sha1(headers[b"Sec-WebSocket-Key"] + b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11").digest())
- connection.send(b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + b"\r\n\r\n")
|