DeepInfraChat.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ...typing import AsyncResult, Messages
  5. from ...requests.raise_for_status import raise_for_status
  6. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin):
  8. url = "https://deepinfra.com/chat"
  9. api_endpoint = "https://api.deepinfra.com/v1/openai/chat/completions"
  10. working = False
  11. supports_stream = True
  12. supports_system_message = True
  13. supports_message_history = True
  14. default_model = 'meta-llama/Llama-3.3-70B-Instruct-Turbo'
  15. models = [
  16. 'meta-llama/Llama-3.3-70B-Instruct',
  17. 'meta-llama/Meta-Llama-3.1-8B-Instruct',
  18. default_model,
  19. 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo',
  20. 'Qwen/QwQ-32B-Preview',
  21. 'microsoft/WizardLM-2-8x22B',
  22. 'Qwen/Qwen2.5-72B-Instruct',
  23. 'Qwen/Qwen2.5-Coder-32B-Instruct',
  24. 'nvidia/Llama-3.1-Nemotron-70B-Instruct',
  25. ]
  26. model_aliases = {
  27. "llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
  28. "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
  29. "llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
  30. "llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
  31. "qwq-32b": "Qwen/QwQ-32B-Preview",
  32. "wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
  33. "qwen-2-72b": "Qwen/Qwen2.5-72B-Instruct",
  34. "qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
  35. "nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct",
  36. }
  37. @classmethod
  38. async def create_async_generator(
  39. cls,
  40. model: str,
  41. messages: Messages,
  42. proxy: str = None,
  43. **kwargs
  44. ) -> AsyncResult:
  45. model = cls.get_model(model)
  46. headers = {
  47. 'Accept-Language': 'en-US,en;q=0.9',
  48. 'Content-Type': 'application/json',
  49. 'Origin': 'https://deepinfra.com',
  50. 'Referer': 'https://deepinfra.com/',
  51. 'X-Deepinfra-Source': 'web-page',
  52. 'accept': 'text/event-stream',
  53. }
  54. async with ClientSession(headers=headers) as session:
  55. data = {
  56. "model": model,
  57. "messages": messages,
  58. "stream": True
  59. }
  60. async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
  61. await raise_for_status(response)
  62. async for chunk in response.content:
  63. if chunk:
  64. chunk_text = chunk.decode(errors="ignore")
  65. try:
  66. # Handle streaming response
  67. if chunk_text.startswith("data: "):
  68. if chunk_text.strip() == "data: [DONE]":
  69. continue
  70. chunk_data = json.loads(chunk_text[6:])
  71. content = chunk_data["choices"][0]["delta"].get("content")
  72. if content:
  73. yield content
  74. # Handle non-streaming response
  75. else:
  76. chunk_data = json.loads(chunk_text)
  77. content = chunk_data["choices"][0]["message"].get("content")
  78. if content:
  79. yield content
  80. except (json.JSONDecodeError, KeyError):
  81. continue