AllenAI.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from __future__ import annotations
  2. import json
  3. from uuid import uuid4
  4. from aiohttp import ClientSession
  5. from ..typing import AsyncResult, Messages
  6. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. from ..requests.raise_for_status import raise_for_status
  8. from ..providers.response import FinishReason, JsonConversation
  9. from .helper import format_prompt, get_last_user_message
  10. class Conversation(JsonConversation):
  11. parent: str = None
  12. x_anonymous_user_id: str = None
  13. def __init__(self, model: str):
  14. super().__init__() # Ensure parent class is initialized
  15. self.model = model
  16. self.messages = [] # Instance-specific list
  17. if not self.x_anonymous_user_id:
  18. self.x_anonymous_user_id = str(uuid4())
  19. class AllenAI(AsyncGeneratorProvider, ProviderModelMixin):
  20. label = "Ai2 Playground"
  21. url = "https://playground.allenai.org"
  22. login_url = None
  23. api_endpoint = "https://olmo-api.allen.ai/v4/message/stream"
  24. working = True
  25. needs_auth = False
  26. use_nodriver = False
  27. supports_stream = True
  28. supports_system_message = False
  29. supports_message_history = True
  30. default_model = 'tulu3-405b'
  31. models = [
  32. default_model,
  33. 'OLMo-2-1124-13B-Instruct',
  34. 'tulu-3-1-8b',
  35. 'Llama-3-1-Tulu-3-70B',
  36. 'olmoe-0125'
  37. ]
  38. model_aliases = {
  39. "tulu-3-405b": default_model,
  40. "olmo-2-13b": "OLMo-2-1124-13B-Instruct",
  41. "tulu-3-1-8b": "tulu-3-1-8b",
  42. "tulu-3-70b": "Llama-3-1-Tulu-3-70B",
  43. }
  44. @classmethod
  45. async def create_async_generator(
  46. cls,
  47. model: str,
  48. messages: Messages,
  49. proxy: str = None,
  50. host: str = "inferd",
  51. private: bool = True,
  52. top_p: float = None,
  53. temperature: float = None,
  54. conversation: Conversation = None,
  55. return_conversation: bool = False,
  56. **kwargs
  57. ) -> AsyncResult:
  58. prompt = format_prompt(messages) if conversation is None else get_last_user_message(messages)
  59. # Initialize or update conversation
  60. if conversation is None:
  61. conversation = Conversation(model)
  62. # Generate new boundary for each request
  63. boundary = f"----WebKitFormBoundary{uuid4().hex}"
  64. headers = {
  65. "accept": "*/*",
  66. "accept-language": "en-US,en;q=0.9",
  67. "content-type": f"multipart/form-data; boundary={boundary}",
  68. "origin": cls.url,
  69. "referer": f"{cls.url}/",
  70. "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
  71. "x-anonymous-user-id": conversation.x_anonymous_user_id,
  72. }
  73. # Build multipart form data
  74. form_data = [
  75. f'--{boundary}\r\n'
  76. f'Content-Disposition: form-data; name="model"\r\n\r\n{cls.get_model(model)}\r\n',
  77. f'--{boundary}\r\n'
  78. f'Content-Disposition: form-data; name="host"\r\n\r\n{host}\r\n',
  79. f'--{boundary}\r\n'
  80. f'Content-Disposition: form-data; name="content"\r\n\r\n{prompt}\r\n',
  81. f'--{boundary}\r\n'
  82. f'Content-Disposition: form-data; name="private"\r\n\r\n{str(private).lower()}\r\n'
  83. ]
  84. # Add parent if exists in conversation
  85. if conversation.parent:
  86. form_data.append(
  87. f'--{boundary}\r\n'
  88. f'Content-Disposition: form-data; name="parent"\r\n\r\n{conversation.parent}\r\n'
  89. )
  90. # Add optional parameters
  91. if temperature is not None:
  92. form_data.append(
  93. f'--{boundary}\r\n'
  94. f'Content-Disposition: form-data; name="temperature"\r\n\r\n{temperature}\r\n'
  95. )
  96. if top_p is not None:
  97. form_data.append(
  98. f'--{boundary}\r\n'
  99. f'Content-Disposition: form-data; name="top_p"\r\n\r\n{top_p}\r\n'
  100. )
  101. form_data.append(f'--{boundary}--\r\n')
  102. data = "".join(form_data).encode()
  103. async with ClientSession(headers=headers) as session:
  104. async with session.post(
  105. cls.api_endpoint,
  106. data=data,
  107. proxy=proxy,
  108. ) as response:
  109. await raise_for_status(response)
  110. current_parent = None
  111. async for chunk in response.content:
  112. if not chunk:
  113. continue
  114. decoded = chunk.decode(errors="ignore")
  115. for line in decoded.splitlines():
  116. line = line.strip()
  117. if not line:
  118. continue
  119. try:
  120. data = json.loads(line)
  121. except json.JSONDecodeError:
  122. continue
  123. if isinstance(data, dict):
  124. # Update the parental ID
  125. if data.get("children"):
  126. for child in data["children"]:
  127. if child.get("role") == "assistant":
  128. current_parent = child.get("id")
  129. break
  130. # We process content only from the assistant
  131. if "message" in data and data.get("content"):
  132. content = data["content"]
  133. # Skip empty content blocks
  134. if content.strip():
  135. yield content
  136. # Processing the final response
  137. if data.get("final") or data.get("finish_reason") == "stop":
  138. if current_parent:
  139. conversation.parent = current_parent
  140. # Add a message to the story
  141. conversation.messages.extend([
  142. {"role": "user", "content": prompt},
  143. {"role": "assistant", "content": content}
  144. ])
  145. if return_conversation:
  146. yield conversation
  147. yield FinishReason("stop")
  148. return