test_websockets.py 16 KB


  1. #!/usr/bin/env python3
  2. # Allow direct execution
  3. import os
  4. import sys
  5. import pytest
  6. from test.helper import verify_address_availability
  7. sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
  8. import http.client
  9. import http.cookiejar
  10. import http.server
  11. import json
  12. import random
  13. import ssl
  14. import threading
  15. from yt_dlp import socks
  16. from yt_dlp.cookies import YoutubeDLCookieJar
  17. from yt_dlp.dependencies import websockets
  18. from yt_dlp.networking import Request
  19. from yt_dlp.networking.exceptions import (
  20. CertificateVerifyError,
  21. HTTPError,
  22. ProxyError,
  23. RequestError,
  24. SSLError,
  25. TransportError,
  26. )
  27. from yt_dlp.utils.networking import HTTPHeaderDict
  28. TEST_DIR = os.path.dirname(os.path.abspath(__file__))
  29. def websocket_handler(websocket):
  30. for message in websocket:
  31. if isinstance(message, bytes):
  32. if message == b'bytes':
  33. return websocket.send('2')
  34. elif isinstance(message, str):
  35. if message == 'headers':
  36. return websocket.send(json.dumps(dict(websocket.request.headers)))
  37. elif message == 'path':
  38. return websocket.send(websocket.request.path)
  39. elif message == 'source_address':
  40. return websocket.send(websocket.remote_address[0])
  41. elif message == 'str':
  42. return websocket.send('1')
  43. return websocket.send(message)
  44. def process_request(self, request):
  45. if request.path.startswith('/gen_'):
  46. status = http.HTTPStatus(int(request.path[5:]))
  47. if 300 <= status.value <= 300:
  48. return websockets.http11.Response(
  49. status.value, status.phrase, websockets.datastructures.Headers([('Location', '/')]), b'')
  50. return self.protocol.reject(status.value, status.phrase)
  51. return self.protocol.accept(request)
  52. def create_websocket_server(**ws_kwargs):
  53. import websockets.sync.server
  54. wsd = websockets.sync.server.serve(
  55. websocket_handler, '127.0.0.1', 0,
  56. process_request=process_request, open_timeout=2, **ws_kwargs)
  57. ws_port = wsd.socket.getsockname()[1]
  58. ws_server_thread = threading.Thread(target=wsd.serve_forever)
  59. ws_server_thread.daemon = True
  60. ws_server_thread.start()
  61. return ws_server_thread, ws_port
  62. def create_ws_websocket_server():
  63. return create_websocket_server()
  64. def create_wss_websocket_server():
  65. certfn = os.path.join(TEST_DIR, 'testcert.pem')
  66. sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  67. sslctx.load_cert_chain(certfn, None)
  68. return create_websocket_server(ssl_context=sslctx)
  69. MTLS_CERT_DIR = os.path.join(TEST_DIR, 'testdata', 'certificate')
  70. def create_mtls_wss_websocket_server():
  71. certfn = os.path.join(TEST_DIR, 'testcert.pem')
  72. cacertfn = os.path.join(MTLS_CERT_DIR, 'ca.crt')
  73. sslctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
  74. sslctx.verify_mode = ssl.CERT_REQUIRED
  75. sslctx.load_verify_locations(cafile=cacertfn)
  76. sslctx.load_cert_chain(certfn, None)
  77. return create_websocket_server(ssl_context=sslctx)
  78. def ws_validate_and_send(rh, req):
  79. rh.validate(req)
  80. max_tries = 3
  81. for i in range(max_tries):
  82. try:
  83. return rh.send(req)
  84. except TransportError as e:
  85. if i < (max_tries - 1) and 'connection closed during handshake' in str(e):
  86. # websockets server sometimes hangs on new connections
  87. continue
  88. raise
  89. @pytest.mark.skipif(not websockets, reason='websockets must be installed to test websocket request handlers')
  90. class TestWebsSocketRequestHandlerConformance:
  91. @classmethod
  92. def setup_class(cls):
  93. cls.ws_thread, cls.ws_port = create_ws_websocket_server()
  94. cls.ws_base_url = f'ws://127.0.0.1:{cls.ws_port}'
  95. cls.wss_thread, cls.wss_port = create_wss_websocket_server()
  96. cls.wss_base_url = f'wss://127.0.0.1:{cls.wss_port}'
  97. cls.bad_wss_thread, cls.bad_wss_port = create_websocket_server(ssl_context=ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER))
  98. cls.bad_wss_host = f'wss://127.0.0.1:{cls.bad_wss_port}'
  99. cls.mtls_wss_thread, cls.mtls_wss_port = create_mtls_wss_websocket_server()
  100. cls.mtls_wss_base_url = f'wss://127.0.0.1:{cls.mtls_wss_port}'
  101. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  102. def test_basic_websockets(self, handler):
  103. with handler() as rh:
  104. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  105. assert 'upgrade' in ws.headers
  106. assert ws.status == 101
  107. ws.send('foo')
  108. assert ws.recv() == 'foo'
  109. ws.close()
  110. # https://www.rfc-editor.org/rfc/rfc6455.html#section-5.6
  111. @pytest.mark.parametrize('msg,opcode', [('str', 1), (b'bytes', 2)])
  112. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  113. def test_send_types(self, handler, msg, opcode):
  114. with handler() as rh:
  115. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  116. ws.send(msg)
  117. assert int(ws.recv()) == opcode
  118. ws.close()
  119. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  120. def test_verify_cert(self, handler):
  121. with handler() as rh:
  122. with pytest.raises(CertificateVerifyError):
  123. ws_validate_and_send(rh, Request(self.wss_base_url))
  124. with handler(verify=False) as rh:
  125. ws = ws_validate_and_send(rh, Request(self.wss_base_url))
  126. assert ws.status == 101
  127. ws.close()
  128. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  129. def test_ssl_error(self, handler):
  130. with handler(verify=False) as rh:
  131. with pytest.raises(SSLError, match=r'ssl(?:v3|/tls) alert handshake failure') as exc_info:
  132. ws_validate_and_send(rh, Request(self.bad_wss_host))
  133. assert not issubclass(exc_info.type, CertificateVerifyError)
  134. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  135. @pytest.mark.parametrize('path,expected', [
  136. # Unicode characters should be encoded with uppercase percent-encoding
  137. ('/中文', '/%E4%B8%AD%E6%96%87'),
  138. # don't normalize existing percent encodings
  139. ('/%c7%9f', '/%c7%9f'),
  140. ])
  141. def test_percent_encode(self, handler, path, expected):
  142. with handler() as rh:
  143. ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}{path}'))
  144. ws.send('path')
  145. assert ws.recv() == expected
  146. assert ws.status == 101
  147. ws.close()
  148. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  149. def test_remove_dot_segments(self, handler):
  150. with handler() as rh:
  151. # This isn't a comprehensive test,
  152. # but it should be enough to check whether the handler is removing dot segments
  153. ws = ws_validate_and_send(rh, Request(f'{self.ws_base_url}/a/b/./../../test'))
  154. assert ws.status == 101
  155. ws.send('path')
  156. assert ws.recv() == '/test'
  157. ws.close()
  158. # We are restricted to known HTTP status codes in http.HTTPStatus
  159. # Redirects are not supported for websockets
  160. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  161. @pytest.mark.parametrize('status', (200, 204, 301, 302, 303, 400, 500, 511))
  162. def test_raise_http_error(self, handler, status):
  163. with handler() as rh:
  164. with pytest.raises(HTTPError) as exc_info:
  165. ws_validate_and_send(rh, Request(f'{self.ws_base_url}/gen_{status}'))
  166. assert exc_info.value.status == status
  167. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  168. @pytest.mark.parametrize('params,extensions', [
  169. ({'timeout': sys.float_info.min}, {}),
  170. ({}, {'timeout': sys.float_info.min}),
  171. ])
  172. def test_timeout(self, handler, params, extensions):
  173. with handler(**params) as rh:
  174. with pytest.raises(TransportError):
  175. ws_validate_and_send(rh, Request(self.ws_base_url, extensions=extensions))
  176. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  177. def test_cookies(self, handler):
  178. cookiejar = YoutubeDLCookieJar()
  179. cookiejar.set_cookie(http.cookiejar.Cookie(
  180. version=0, name='test', value='ytdlp', port=None, port_specified=False,
  181. domain='127.0.0.1', domain_specified=True, domain_initial_dot=False, path='/',
  182. path_specified=True, secure=False, expires=None, discard=False, comment=None,
  183. comment_url=None, rest={}))
  184. with handler(cookiejar=cookiejar) as rh:
  185. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  186. ws.send('headers')
  187. assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
  188. ws.close()
  189. with handler() as rh:
  190. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  191. ws.send('headers')
  192. assert 'cookie' not in json.loads(ws.recv())
  193. ws.close()
  194. ws = ws_validate_and_send(rh, Request(self.ws_base_url, extensions={'cookiejar': cookiejar}))
  195. ws.send('headers')
  196. assert json.loads(ws.recv())['cookie'] == 'test=ytdlp'
  197. ws.close()
  198. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  199. def test_source_address(self, handler):
  200. source_address = f'127.0.0.{random.randint(5, 255)}'
  201. verify_address_availability(source_address)
  202. with handler(source_address=source_address) as rh:
  203. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  204. ws.send('source_address')
  205. assert source_address == ws.recv()
  206. ws.close()
  207. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  208. def test_response_url(self, handler):
  209. with handler() as rh:
  210. url = f'{self.ws_base_url}/something'
  211. ws = ws_validate_and_send(rh, Request(url))
  212. assert ws.url == url
  213. ws.close()
  214. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  215. def test_request_headers(self, handler):
  216. with handler(headers=HTTPHeaderDict({'test1': 'test', 'test2': 'test2'})) as rh:
  217. # Global Headers
  218. ws = ws_validate_and_send(rh, Request(self.ws_base_url))
  219. ws.send('headers')
  220. headers = HTTPHeaderDict(json.loads(ws.recv()))
  221. assert headers['test1'] == 'test'
  222. ws.close()
  223. # Per request headers, merged with global
  224. ws = ws_validate_and_send(rh, Request(
  225. self.ws_base_url, headers={'test2': 'changed', 'test3': 'test3'}))
  226. ws.send('headers')
  227. headers = HTTPHeaderDict(json.loads(ws.recv()))
  228. assert headers['test1'] == 'test'
  229. assert headers['test2'] == 'changed'
  230. assert headers['test3'] == 'test3'
  231. ws.close()
  232. @pytest.mark.parametrize('client_cert', (
  233. {'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithkey.crt')},
  234. {
  235. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
  236. 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'client.key'),
  237. },
  238. {
  239. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'clientwithencryptedkey.crt'),
  240. 'client_certificate_password': 'foobar',
  241. },
  242. {
  243. 'client_certificate': os.path.join(MTLS_CERT_DIR, 'client.crt'),
  244. 'client_certificate_key': os.path.join(MTLS_CERT_DIR, 'clientencrypted.key'),
  245. 'client_certificate_password': 'foobar',
  246. }
  247. ))
  248. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  249. def test_mtls(self, handler, client_cert):
  250. with handler(
  251. # Disable client-side validation of unacceptable self-signed testcert.pem
  252. # The test is of a check on the server side, so unaffected
  253. verify=False,
  254. client_cert=client_cert
  255. ) as rh:
  256. ws_validate_and_send(rh, Request(self.mtls_wss_base_url)).close()
  257. def create_fake_ws_connection(raised):
  258. import websockets.sync.client
  259. class FakeWsConnection(websockets.sync.client.ClientConnection):
  260. def __init__(self, *args, **kwargs):
  261. class FakeResponse:
  262. body = b''
  263. headers = {}
  264. status_code = 101
  265. reason_phrase = 'test'
  266. self.response = FakeResponse()
  267. def send(self, *args, **kwargs):
  268. raise raised()
  269. def recv(self, *args, **kwargs):
  270. raise raised()
  271. def close(self, *args, **kwargs):
  272. return
  273. return FakeWsConnection()
  274. @pytest.mark.parametrize('handler', ['Websockets'], indirect=True)
  275. class TestWebsocketsRequestHandler:
  276. @pytest.mark.parametrize('raised,expected', [
  277. # https://websockets.readthedocs.io/en/stable/reference/exceptions.html
  278. (lambda: websockets.exceptions.InvalidURI(msg='test', uri='test://'), RequestError),
  279. # Requires a response object. Should be covered by HTTP error tests.
  280. # (lambda: websockets.exceptions.InvalidStatus(), TransportError),
  281. (lambda: websockets.exceptions.InvalidHandshake(), TransportError),
  282. # These are subclasses of InvalidHandshake
  283. (lambda: websockets.exceptions.InvalidHeader(name='test'), TransportError),
  284. (lambda: websockets.exceptions.NegotiationError(), TransportError),
  285. # Catch-all
  286. (lambda: websockets.exceptions.WebSocketException(), TransportError),
  287. (lambda: TimeoutError(), TransportError),
  288. # These may be raised by our create_connection implementation, which should also be caught
  289. (lambda: OSError(), TransportError),
  290. (lambda: ssl.SSLError(), SSLError),
  291. (lambda: ssl.SSLCertVerificationError(), CertificateVerifyError),
  292. (lambda: socks.ProxyError(), ProxyError),
  293. ])
  294. def test_request_error_mapping(self, handler, monkeypatch, raised, expected):
  295. import websockets.sync.client
  296. import yt_dlp.networking._websockets
  297. with handler() as rh:
  298. def fake_connect(*args, **kwargs):
  299. raise raised()
  300. monkeypatch.setattr(yt_dlp.networking._websockets, 'create_connection', lambda *args, **kwargs: None)
  301. monkeypatch.setattr(websockets.sync.client, 'connect', fake_connect)
  302. with pytest.raises(expected) as exc_info:
  303. rh.send(Request('ws://fake-url'))
  304. assert exc_info.type is expected
  305. @pytest.mark.parametrize('raised,expected,match', [
  306. # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.send
  307. (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
  308. (lambda: RuntimeError(), TransportError, None),
  309. (lambda: TimeoutError(), TransportError, None),
  310. (lambda: TypeError(), RequestError, None),
  311. (lambda: socks.ProxyError(), ProxyError, None),
  312. # Catch-all
  313. (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
  314. ])
  315. def test_ws_send_error_mapping(self, handler, monkeypatch, raised, expected, match):
  316. from yt_dlp.networking._websockets import WebsocketsResponseAdapter
  317. ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
  318. with pytest.raises(expected, match=match) as exc_info:
  319. ws.send('test')
  320. assert exc_info.type is expected
  321. @pytest.mark.parametrize('raised,expected,match', [
  322. # https://websockets.readthedocs.io/en/stable/reference/sync/client.html#websockets.sync.client.ClientConnection.recv
  323. (lambda: websockets.exceptions.ConnectionClosed(None, None), TransportError, None),
  324. (lambda: RuntimeError(), TransportError, None),
  325. (lambda: TimeoutError(), TransportError, None),
  326. (lambda: socks.ProxyError(), ProxyError, None),
  327. # Catch-all
  328. (lambda: websockets.exceptions.WebSocketException(), TransportError, None),
  329. ])
  330. def test_ws_recv_error_mapping(self, handler, monkeypatch, raised, expected, match):
  331. from yt_dlp.networking._websockets import WebsocketsResponseAdapter
  332. ws = WebsocketsResponseAdapter(create_fake_ws_connection(raised), url='ws://fake-url')
  333. with pytest.raises(expected, match=match) as exc_info:
  334. ws.recv()
  335. assert exc_info.type is expected