__init__.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. # SPDX-License-Identifier: AGPL-3.0-or-later
  2. # pylint: disable=missing-module-docstring, global-statement
  3. import asyncio
  4. import threading
  5. import concurrent.futures
  6. from queue import SimpleQueue
  7. from types import MethodType
  8. from timeit import default_timer
  9. from typing import Iterable, NamedTuple, Tuple, List, Dict, Union
  10. from contextlib import contextmanager
  11. import httpx
  12. import anyio
  13. from searx.extended_types import SXNG_Response
  14. from .network import get_network, initialize, check_network_configuration # pylint:disable=cyclic-import
  15. from .client import get_loop
  16. from .raise_for_httperror import raise_for_httperror
  17. THREADLOCAL = threading.local()
  18. """Thread-local data is data for thread specific values."""
  19. def reset_time_for_thread():
  20. THREADLOCAL.total_time = 0
  21. def get_time_for_thread():
  22. """returns thread's total time or None"""
  23. return THREADLOCAL.__dict__.get('total_time')
  24. def set_timeout_for_thread(timeout, start_time=None):
  25. THREADLOCAL.timeout = timeout
  26. THREADLOCAL.start_time = start_time
  27. def set_context_network_name(network_name):
  28. THREADLOCAL.network = get_network(network_name)
  29. def get_context_network():
  30. """If set return thread's network.
  31. If unset, return value from :py:obj:`get_network`.
  32. """
  33. return THREADLOCAL.__dict__.get('network') or get_network()
  34. @contextmanager
  35. def _record_http_time():
  36. # pylint: disable=too-many-branches
  37. time_before_request = default_timer()
  38. start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
  39. try:
  40. yield start_time
  41. finally:
  42. # update total_time.
  43. # See get_time_for_thread() and reset_time_for_thread()
  44. if hasattr(THREADLOCAL, 'total_time'):
  45. time_after_request = default_timer()
  46. THREADLOCAL.total_time += time_after_request - time_before_request
  47. def _get_timeout(start_time, kwargs):
  48. # pylint: disable=too-many-branches
  49. # timeout (httpx)
  50. if 'timeout' in kwargs:
  51. timeout = kwargs['timeout']
  52. else:
  53. timeout = getattr(THREADLOCAL, 'timeout', None)
  54. if timeout is not None:
  55. kwargs['timeout'] = timeout
  56. # 2 minutes timeout for the requests without timeout
  57. timeout = timeout or 120
  58. # adjust actual timeout
  59. timeout += 0.2 # overhead
  60. if start_time:
  61. timeout -= default_timer() - start_time
  62. return timeout
  63. def request(method, url, **kwargs) -> SXNG_Response:
  64. """same as requests/requests/api.py request(...)"""
  65. with _record_http_time() as start_time:
  66. network = get_context_network()
  67. timeout = _get_timeout(start_time, kwargs)
  68. future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
  69. try:
  70. return future.result(timeout)
  71. except concurrent.futures.TimeoutError as e:
  72. raise httpx.TimeoutException('Timeout', request=None) from e
  73. def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]:
  74. """send multiple HTTP requests in parallel. Wait for all requests to finish."""
  75. with _record_http_time() as start_time:
  76. # send the requests
  77. network = get_context_network()
  78. loop = get_loop()
  79. future_list = []
  80. for request_desc in request_list:
  81. timeout = _get_timeout(start_time, request_desc.kwargs)
  82. future = asyncio.run_coroutine_threadsafe(
  83. network.request(request_desc.method, request_desc.url, **request_desc.kwargs), loop
  84. )
  85. future_list.append((future, timeout))
  86. # read the responses
  87. responses = []
  88. for future, timeout in future_list:
  89. try:
  90. responses.append(future.result(timeout))
  91. except concurrent.futures.TimeoutError:
  92. responses.append(httpx.TimeoutException('Timeout', request=None))
  93. except Exception as e: # pylint: disable=broad-except
  94. responses.append(e)
  95. return responses
  96. class Request(NamedTuple):
  97. """Request description for the multi_requests function"""
  98. method: str
  99. url: str
  100. kwargs: Dict[str, str] = {}
  101. @staticmethod
  102. def get(url, **kwargs):
  103. return Request('GET', url, kwargs)
  104. @staticmethod
  105. def options(url, **kwargs):
  106. return Request('OPTIONS', url, kwargs)
  107. @staticmethod
  108. def head(url, **kwargs):
  109. return Request('HEAD', url, kwargs)
  110. @staticmethod
  111. def post(url, **kwargs):
  112. return Request('POST', url, kwargs)
  113. @staticmethod
  114. def put(url, **kwargs):
  115. return Request('PUT', url, kwargs)
  116. @staticmethod
  117. def patch(url, **kwargs):
  118. return Request('PATCH', url, kwargs)
  119. @staticmethod
  120. def delete(url, **kwargs):
  121. return Request('DELETE', url, kwargs)
  122. def get(url, **kwargs) -> SXNG_Response:
  123. kwargs.setdefault('allow_redirects', True)
  124. return request('get', url, **kwargs)
  125. def options(url, **kwargs) -> SXNG_Response:
  126. kwargs.setdefault('allow_redirects', True)
  127. return request('options', url, **kwargs)
  128. def head(url, **kwargs) -> SXNG_Response:
  129. kwargs.setdefault('allow_redirects', False)
  130. return request('head', url, **kwargs)
  131. def post(url, data=None, **kwargs) -> SXNG_Response:
  132. return request('post', url, data=data, **kwargs)
  133. def put(url, data=None, **kwargs) -> SXNG_Response:
  134. return request('put', url, data=data, **kwargs)
  135. def patch(url, data=None, **kwargs) -> SXNG_Response:
  136. return request('patch', url, data=data, **kwargs)
  137. def delete(url, **kwargs) -> SXNG_Response:
  138. return request('delete', url, **kwargs)
  139. async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
  140. try:
  141. async with await network.stream(method, url, **kwargs) as response:
  142. queue.put(response)
  143. # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
  144. # https://www.python-httpx.org/quickstart/#streaming-responses
  145. async for chunk in response.aiter_raw(65536):
  146. if len(chunk) > 0:
  147. queue.put(chunk)
  148. except (httpx.StreamClosed, anyio.ClosedResourceError):
  149. # the response was queued before the exception.
  150. # the exception was raised on aiter_raw.
  151. # we do nothing here: in the finally block, None will be queued
  152. # so stream(method, url, **kwargs) generator can stop
  153. pass
  154. except Exception as e: # pylint: disable=broad-except
  155. # broad except to avoid this scenario:
  156. # exception in network.stream(method, url, **kwargs)
  157. # -> the exception is not catch here
  158. # -> queue None (in finally)
  159. # -> the function below steam(method, url, **kwargs) has nothing to return
  160. queue.put(e)
  161. finally:
  162. queue.put(None)
  163. def _stream_generator(method, url, **kwargs):
  164. queue = SimpleQueue()
  165. network = get_context_network()
  166. future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(network, queue, method, url, **kwargs), get_loop())
  167. # yield chunks
  168. obj_or_exception = queue.get()
  169. while obj_or_exception is not None:
  170. if isinstance(obj_or_exception, Exception):
  171. raise obj_or_exception
  172. yield obj_or_exception
  173. obj_or_exception = queue.get()
  174. future.result()
  175. def _close_response_method(self):
  176. asyncio.run_coroutine_threadsafe(self.aclose(), get_loop())
  177. # reach the end of _self.generator ( _stream_generator ) to an avoid memory leak.
  178. # it makes sure that :
  179. # * the httpx response is closed (see the stream_chunk_to_queue function)
  180. # * to call future.result() in _stream_generator
  181. for _ in self._generator: # pylint: disable=protected-access
  182. continue
  183. def stream(method, url, **kwargs) -> Tuple[httpx.Response, Iterable[bytes]]:
  184. """Replace httpx.stream.
  185. Usage:
  186. response, stream = poolrequests.stream(...)
  187. for chunk in stream:
  188. ...
  189. httpx.Client.stream requires to write the httpx.HTTPTransport version of the
  190. the httpx.AsyncHTTPTransport declared above.
  191. """
  192. generator = _stream_generator(method, url, **kwargs)
  193. # yield response
  194. response = next(generator) # pylint: disable=stop-iteration-return
  195. if isinstance(response, Exception):
  196. raise response
  197. response._generator = generator # pylint: disable=protected-access
  198. response.close = MethodType(_close_response_method, response)
  199. return response, generator