server.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from threading import Thread
  2. from typing import Any, Callable
  3. import base64
  4. import hashlib
  5. import json
  6. import socket
  7. import struct
  8. class Server:
  9. def __init__(self, message_handler: Callable[[str], Any]):
  10. self._message_handler = message_handler
  11. self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  12. self._connection = None
  13. self._active = True
  14. def start(self, host: str, port: int):
  15. self._socket.bind((host, port))
  16. self._socket.listen(1)
  17. self._thread = Thread(target=self._handle, daemon=True)
  18. self._thread.start()
  19. def stop(self):
  20. self._active = False
  21. def send(self, msg: dict):
  22. if not self._connection:
  23. return
  24. payload = json.dumps(msg, ensure_ascii=False).encode()
  25. payload_length = len(payload)
  26. data = bytearray([0b1000_0001])
  27. if payload_length <= 125:
  28. data.append(payload_length)
  29. data.extend(payload)
  30. else:
  31. data.append(126)
  32. data.extend(struct.pack(">H", ))
  33. data.extend(payload)
  34. self._connection.send(data)
  35. def _handle(self):
  36. try:
  37. self._socket.settimeout(10)
  38. self._connection, _ = self._socket.accept()
  39. self._connection.settimeout(3)
  40. self._handle_handshake(self._connection)
  41. self._connection.settimeout(1)
  42. while self._active:
  43. try:
  44. msg = self._connection.recv(65536)
  45. code = msg[0] & 0b0000_1111
  46. if code != 1:
  47. break
  48. payload_index = 6
  49. payload_length = msg[1] & 0b0111_1111
  50. if payload_length == 126:
  51. payload_index = 8
  52. payload_length = msg[2] << 8 & msg[3]
  53. msg = bytes(value ^ msg[payload_index-4:payload_index][index % 4] for index, value in enumerate(msg[payload_index:]))
  54. self._message_handler(json.loads(msg.decode()))
  55. except socket.timeout:
  56. pass
  57. except socket.timeout:
  58. pass
  59. finally:
  60. self._socket.close()
  61. if self._connection:
  62. self._connection.close()
  63. def _handle_handshake(self, connection):
  64. req = connection.recv(65536)
  65. headers = {}
  66. for line in req.split(b"\r\n")[1:]:
  67. if line == b"":
  68. break
  69. parts = line.split(b": ", 1)
  70. headers[parts[0]] = parts[1]
  71. accept = base64.b64encode(hashlib.sha1(headers[b"Sec-WebSocket-Key"] + b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11").digest())
  72. connection.send(b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + accept + b"\r\n\r\n")