CohereForAI.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from __future__ import annotations
  2. import json
  3. import uuid
  4. from aiohttp import ClientSession, FormData
  5. from ...typing import AsyncResult, Messages
  6. from ...requests import raise_for_status
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. from ..helper import format_prompt
  9. from ...providers.response import JsonConversation, TitleGeneration
  10. class CohereForAI(AsyncGeneratorProvider, ProviderModelMixin):
  11. url = "https://cohereforai-c4ai-command.hf.space"
  12. conversation_url = f"{url}/conversation"
  13. working = True
  14. default_model = "command-r-plus-08-2024"
  15. models = [
  16. default_model,
  17. "command-r-08-2024",
  18. "command-r-plus",
  19. "command-r",
  20. "command-r7b-12-2024",
  21. ]
  22. model_aliases = {
  23. "command-r-plus": "command-r-plus-08-2024",
  24. "command-r": "command-r-08-2024",
  25. "command-r7b": "command-r7b-12-2024",
  26. }
  27. @classmethod
  28. async def create_async_generator(
  29. cls, model: str, messages: Messages,
  30. api_key: str = None,
  31. proxy: str = None,
  32. conversation: JsonConversation = None,
  33. return_conversation: bool = False,
  34. **kwargs
  35. ) -> AsyncResult:
  36. model = cls.get_model(model)
  37. headers = {
  38. "Origin": cls.url,
  39. "User-Agent": "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:133.0) Gecko/20100101 Firefox/133.0",
  40. "Accept": "*/*",
  41. "Accept-Language": "en-US,en;q=0.5",
  42. "Referer": "https://cohereforai-c4ai-command.hf.space/",
  43. "Sec-Fetch-Dest": "empty",
  44. "Sec-Fetch-Mode": "cors",
  45. "Sec-Fetch-Site": "same-origin",
  46. "Priority": "u=4",
  47. }
  48. if api_key is not None:
  49. headers["Authorization"] = f"Bearer {api_key}"
  50. async with ClientSession(
  51. headers=headers,
  52. cookies=None if conversation is None else conversation.cookies
  53. ) as session:
  54. system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
  55. messages = [message for message in messages if message["role"] != "system"]
  56. inputs = format_prompt(messages) if conversation is None else messages[-1]["content"]
  57. if conversation is None or conversation.model != model or conversation.preprompt != system_prompt:
  58. data = {"model": model, "preprompt": system_prompt}
  59. async with session.post(cls.conversation_url, json=data, proxy=proxy) as response:
  60. await raise_for_status(response)
  61. conversation = JsonConversation(
  62. **await response.json(),
  63. **data,
  64. cookies={n: c.value for n, c in response.cookies.items()}
  65. )
  66. if return_conversation:
  67. yield conversation
  68. async with session.get(f"{cls.conversation_url}/{conversation.conversationId}/__data.json?x-sveltekit-invalidated=11", proxy=proxy) as response:
  69. await raise_for_status(response)
  70. node = json.loads((await response.text()).splitlines()[0])["nodes"][1]
  71. if node["type"] == "error":
  72. raise RuntimeError(node["error"])
  73. data = node["data"]
  74. message_id = data[data[data[data[0]["messages"]][-1]]["id"]]
  75. data = FormData()
  76. inputs = messages[-1]["content"]
  77. data.add_field(
  78. "data",
  79. json.dumps({"inputs": inputs, "id": message_id, "is_retry": False, "is_continue": False, "web_search": False, "tools": []}),
  80. content_type="application/json"
  81. )
  82. async with session.post(f"{cls.conversation_url}/{conversation.conversationId}", data=data, proxy=proxy) as response:
  83. await raise_for_status(response)
  84. async for chunk in response.content:
  85. try:
  86. data = json.loads(chunk)
  87. except (json.JSONDecodeError) as e:
  88. raise RuntimeError(f"Failed to read response: {chunk.decode(errors='replace')}", e)
  89. if data["type"] == "stream":
  90. yield data["token"].replace("\u0000", "")
  91. elif data["type"] == "title":
  92. yield TitleGeneration(data["title"])
  93. elif data["type"] == "finalAnswer":
  94. break