Koala.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. from __future__ import annotations
  2. import json
  3. from typing import AsyncGenerator, Optional, List, Dict, Union, Any
  4. from aiohttp import ClientSession, BaseConnector, ClientResponse
  5. from ...typing import AsyncResult, Messages
  6. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. from ..helper import get_random_string, get_connector
  8. from ...requests import raise_for_status
  9. class Koala(AsyncGeneratorProvider, ProviderModelMixin):
  10. url = "https://koala.sh/chat"
  11. api_endpoint = "https://koala.sh/api/gpt/"
  12. working = False
  13. supports_message_history = True
  14. default_model = 'gpt-4o-mini'
  15. @classmethod
  16. async def create_async_generator(
  17. cls,
  18. model: str,
  19. messages: Messages,
  20. proxy: Optional[str] = None,
  21. connector: Optional[BaseConnector] = None,
  22. **kwargs: Any
  23. ) -> AsyncGenerator[Dict[str, Union[str, int, float, List[Dict[str, Any]], None]], None]:
  24. if not model:
  25. model = "gpt-4o-mini"
  26. headers = {
  27. "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:122.0) Gecko/20100101 Firefox/122.0",
  28. "Accept": "text/event-stream",
  29. "Accept-Language": "de,en-US;q=0.7,en;q=0.3",
  30. "Accept-Encoding": "gzip, deflate, br",
  31. "Referer": f"{cls.url}",
  32. "Flag-Real-Time-Data": "false",
  33. "Visitor-ID": get_random_string(20),
  34. "Origin": "https://koala.sh",
  35. "Alt-Used": "koala.sh",
  36. "Sec-Fetch-Dest": "empty",
  37. "Sec-Fetch-Mode": "cors",
  38. "Sec-Fetch-Site": "same-origin",
  39. "TE": "trailers",
  40. }
  41. async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
  42. input_text = messages[-1]["content"]
  43. system_messages = " ".join(
  44. message["content"] for message in messages if message["role"] == "system"
  45. )
  46. if system_messages:
  47. input_text += f" {system_messages}"
  48. data = {
  49. "input": input_text,
  50. "inputHistory": [
  51. message["content"]
  52. for message in messages[:-1]
  53. if message["role"] == "user"
  54. ],
  55. "outputHistory": [
  56. message["content"]
  57. for message in messages
  58. if message["role"] == "assistant"
  59. ],
  60. "model": model,
  61. }
  62. async with session.post(f"{cls.api_endpoint}", json=data, proxy=proxy) as response:
  63. await raise_for_status(response)
  64. async for chunk in cls._parse_event_stream(response):
  65. yield chunk
  66. @staticmethod
  67. async def _parse_event_stream(response: ClientResponse) -> AsyncGenerator[Dict[str, Any], None]:
  68. async for chunk in response.content:
  69. if chunk.startswith(b"data: "):
  70. yield json.loads(chunk[6:])