123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # SPDX-License-Identifier: AGPL-3.0-or-later
- # pylint: disable=missing-module-docstring, global-statement
- import asyncio
- import threading
- import concurrent.futures
- from queue import SimpleQueue
- from types import MethodType
- from timeit import default_timer
- from typing import Iterable, NamedTuple, Tuple, List, Dict, Union
- from contextlib import contextmanager
- import httpx
- import anyio
- from .network import get_network, initialize, check_network_configuration # pylint:disable=cyclic-import
- from .client import get_loop
- from .raise_for_httperror import raise_for_httperror
- THREADLOCAL = threading.local()
- """Thread-local data is data for thread specific values."""
- def reset_time_for_thread():
- THREADLOCAL.total_time = 0
- def get_time_for_thread():
- """returns thread's total time or None"""
- return THREADLOCAL.__dict__.get('total_time')
- def set_timeout_for_thread(timeout, start_time=None):
- THREADLOCAL.timeout = timeout
- THREADLOCAL.start_time = start_time
- def set_context_network_name(network_name):
- THREADLOCAL.network = get_network(network_name)
- def get_context_network():
- """If set return thread's network.
- If unset, return value from :py:obj:`get_network`.
- """
- return THREADLOCAL.__dict__.get('network') or get_network()
- @contextmanager
- def _record_http_time():
- # pylint: disable=too-many-branches
- time_before_request = default_timer()
- start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
- try:
- yield start_time
- finally:
- # update total_time.
- # See get_time_for_thread() and reset_time_for_thread()
- if hasattr(THREADLOCAL, 'total_time'):
- time_after_request = default_timer()
- THREADLOCAL.total_time += time_after_request - time_before_request
- def _get_timeout(start_time, kwargs):
- # pylint: disable=too-many-branches
- # timeout (httpx)
- if 'timeout' in kwargs:
- timeout = kwargs['timeout']
- else:
- timeout = getattr(THREADLOCAL, 'timeout', None)
- if timeout is not None:
- kwargs['timeout'] = timeout
- # 2 minutes timeout for the requests without timeout
- timeout = timeout or 120
- # adjust actual timeout
- timeout += 0.2 # overhead
- if start_time:
- timeout -= default_timer() - start_time
- return timeout
- def request(method, url, **kwargs):
- """same as requests/requests/api.py request(...)"""
- with _record_http_time() as start_time:
- network = get_context_network()
- timeout = _get_timeout(start_time, kwargs)
- future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
- try:
- return future.result(timeout)
- except concurrent.futures.TimeoutError as e:
- raise httpx.TimeoutException('Timeout', request=None) from e
- def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]:
- """send multiple HTTP requests in parallel. Wait for all requests to finish."""
- with _record_http_time() as start_time:
- # send the requests
- network = get_context_network()
- loop = get_loop()
- future_list = []
- for request_desc in request_list:
- timeout = _get_timeout(start_time, request_desc.kwargs)
- future = asyncio.run_coroutine_threadsafe(
- network.request(request_desc.method, request_desc.url, **request_desc.kwargs), loop
- )
- future_list.append((future, timeout))
- # read the responses
- responses = []
- for future, timeout in future_list:
- try:
- responses.append(future.result(timeout))
- except concurrent.futures.TimeoutError:
- responses.append(httpx.TimeoutException('Timeout', request=None))
- except Exception as e: # pylint: disable=broad-except
- responses.append(e)
- return responses
- class Request(NamedTuple):
- """Request description for the multi_requests function"""
- method: str
- url: str
- kwargs: Dict[str, str] = {}
- @staticmethod
- def get(url, **kwargs):
- return Request('GET', url, kwargs)
- @staticmethod
- def options(url, **kwargs):
- return Request('OPTIONS', url, kwargs)
- @staticmethod
- def head(url, **kwargs):
- return Request('HEAD', url, kwargs)
- @staticmethod
- def post(url, **kwargs):
- return Request('POST', url, kwargs)
- @staticmethod
- def put(url, **kwargs):
- return Request('PUT', url, kwargs)
- @staticmethod
- def patch(url, **kwargs):
- return Request('PATCH', url, kwargs)
- @staticmethod
- def delete(url, **kwargs):
- return Request('DELETE', url, kwargs)
- def get(url, **kwargs):
- kwargs.setdefault('allow_redirects', True)
- return request('get', url, **kwargs)
- def options(url, **kwargs):
- kwargs.setdefault('allow_redirects', True)
- return request('options', url, **kwargs)
- def head(url, **kwargs):
- kwargs.setdefault('allow_redirects', False)
- return request('head', url, **kwargs)
- def post(url, data=None, **kwargs):
- return request('post', url, data=data, **kwargs)
- def put(url, data=None, **kwargs):
- return request('put', url, data=data, **kwargs)
- def patch(url, data=None, **kwargs):
- return request('patch', url, data=data, **kwargs)
- def delete(url, **kwargs):
- return request('delete', url, **kwargs)
- async def stream_chunk_to_queue(network, queue, method, url, **kwargs):
- try:
- async with await network.stream(method, url, **kwargs) as response:
- queue.put(response)
- # aiter_raw: access the raw bytes on the response without applying any HTTP content decoding
- # https://www.python-httpx.org/quickstart/#streaming-responses
- async for chunk in response.aiter_raw(65536):
- if len(chunk) > 0:
- queue.put(chunk)
- except (httpx.StreamClosed, anyio.ClosedResourceError):
- # the response was queued before the exception.
- # the exception was raised on aiter_raw.
- # we do nothing here: in the finally block, None will be queued
- # so stream(method, url, **kwargs) generator can stop
- pass
- except Exception as e: # pylint: disable=broad-except
- # broad except to avoid this scenario:
- # exception in network.stream(method, url, **kwargs)
- # -> the exception is not catch here
- # -> queue None (in finally)
- # -> the function below steam(method, url, **kwargs) has nothing to return
- queue.put(e)
- finally:
- queue.put(None)
- def _stream_generator(method, url, **kwargs):
- queue = SimpleQueue()
- network = get_context_network()
- future = asyncio.run_coroutine_threadsafe(stream_chunk_to_queue(network, queue, method, url, **kwargs), get_loop())
- # yield chunks
- obj_or_exception = queue.get()
- while obj_or_exception is not None:
- if isinstance(obj_or_exception, Exception):
- raise obj_or_exception
- yield obj_or_exception
- obj_or_exception = queue.get()
- future.result()
- def _close_response_method(self):
- asyncio.run_coroutine_threadsafe(self.aclose(), get_loop())
- # reach the end of _self.generator ( _stream_generator ) to an avoid memory leak.
- # it makes sure that :
- # * the httpx response is closed (see the stream_chunk_to_queue function)
- # * to call future.result() in _stream_generator
- for _ in self._generator: # pylint: disable=protected-access
- continue
- def stream(method, url, **kwargs) -> Tuple[httpx.Response, Iterable[bytes]]:
- """Replace httpx.stream.
- Usage:
- response, stream = poolrequests.stream(...)
- for chunk in stream:
- ...
- httpx.Client.stream requires to write the httpx.HTTPTransport version of the
- the httpx.AsyncHTTPTransport declared above.
- """
- generator = _stream_generator(method, url, **kwargs)
- # yield response
- response = next(generator) # pylint: disable=stop-iteration-return
- if isinstance(response, Exception):
- raise response
- response._generator = generator # pylint: disable=protected-access
- response.close = MethodType(_close_response_method, response)
- return response, generator
|