CohereForAI.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession, FormData
  4. from ...typing import AsyncResult, Messages
  5. from ...requests import raise_for_status
  6. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. from ..helper import format_prompt
  8. from ...providers.response import JsonConversation, TitleGeneration
  9. class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
  10. url = "https://cohereforai-c4ai-command.hf.space"
  11. conversation_url = f"{url}/conversation"
  12. working = True
  13. default_model = "command-r-plus-08-2024"
  14. models = [
  15. default_model,
  16. "command-r-08-2024",
  17. "command-r-plus",
  18. "command-r",
  19. "command-r7b-12-2024",
  20. ]
  21. model_aliases = {
  22. "command-r-plus": "command-r-plus-08-2024",
  23. "command-r": "command-r-08-2024",
  24. "command-r7b": "command-r7b-12-2024",
  25. }
  26. @classmethod
  27. async def create_async_generator(
  28. cls, model: str, messages: Messages,
  29. api_key: str = None,
  30. proxy: str = None,
  31. conversation: JsonConversation = None,
  32. return_conversation: bool = False,
  33. **kwargs
  34. ) -> AsyncResult:
  35. model = cls.get_model(model)
  36. headers = {
  37. "Origin": cls.url,
  38. "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:133.0) Gecko/20100101 Firefox/133.0",
  39. "Accept": "*/*",
  40. "Accept-Language": "en-US,en;q=0.5",
  41. "Referer": "https://cohereforai-c4ai-command.hf.space/",
  42. "Sec-Fetch-Dest": "empty",
  43. "Sec-Fetch-Mode": "cors",
  44. "Sec-Fetch-Site": "same-origin",
  45. "Priority": "u=4",
  46. }
  47. if api_key is not None:
  48. headers["Authorization"] = f"Bearer {api_key}"
  49. async with ClientSession(
  50. headers=headers,
  51. cookies=None if conversation is None else conversation.cookies
  52. ) as session:
  53. system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
  54. messages = [message for message in messages if message["role"] != "system"]
  55. inputs = format_prompt(messages) if conversation is None else messages[-1]["content"]
  56. if conversation is None or conversation.model != model or conversation.preprompt != system_prompt:
  57. data = {"model": model, "preprompt": system_prompt}
  58. async with session.post(cls.conversation_url, json=data, proxy=proxy) as response:
  59. await raise_for_status(response)
  60. conversation = JsonConversation(
  61. **await response.json(),
  62. **data,
  63. cookies={n: c.value for n, c in response.cookies.items()}
  64. )
  65. if return_conversation:
  66. yield conversation
  67. async with session.get(f"{cls.conversation_url}/{conversation.conversationId}/__data.json?x-sveltekit-invalidated=11", proxy=proxy) as response:
  68. await raise_for_status(response)
  69. node = json.loads((await response.text()).splitlines()[0])["nodes"][1]
  70. if node["type"] == "error":
  71. raise RuntimeError(node["error"])
  72. data = node["data"]
  73. message_id = data[data[data[data[0]["messages"]][-1]]["id"]]
  74. data = FormData()
  75. inputs = messages[-1]["content"]
  76. data.add_field(
  77. "data",
  78. json.dumps({"inputs": inputs, "id": message_id, "is_retry": False, "is_continue": False, "web_search": False, "tools": []}),
  79. content_type="application/json"
  80. )
  81. async with session.post(f"{cls.conversation_url}/{conversation.conversationId}", data=data, proxy=proxy) as response:
  82. await raise_for_status(response)
  83. async for chunk in response.content:
  84. try:
  85. data = json.loads(chunk)
  86. except (json.JSONDecodeError) as e:
  87. raise RuntimeError(f"Failed to read response: {chunk.decode(errors='replace')}", e)
  88. if data["type"] == "stream":
  89. yield data["token"].replace("\u0000", "")
  90. elif data["type"] == "title":
  91. yield TitleGeneration(data["title"])
  92. elif data["type"] == "finalAnswer":
  93. break