DeepInfraChat.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ..typing import AsyncResult, Messages
  5. from ..requests.raise_for_status import raise_for_status
  6. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. class DeepInfraChat(AsyncGeneratorProvider, ProviderModelMixin):
  8. url = "https://deepinfra.com/chat"
  9. api_endpoint = "https://api.deepinfra.com/v1/openai/chat/completions"
  10. working = True
  11. needs_auth = False
  12. supports_stream = True
  13. supports_system_message = True
  14. supports_message_history = True
  15. default_model = 'meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo'
  16. models = [
  17. 'meta-llama/Llama-3.3-70B-Instruct',
  18. 'meta-llama/Meta-Llama-3.1-8B-Instruct',
  19. 'meta-llama/Llama-3.3-70B-Instruct-Turbo',
  20. default_model,
  21. 'Qwen/QwQ-32B-Preview',
  22. 'microsoft/WizardLM-2-8x22B',
  23. 'Qwen/Qwen2.5-72B-Instruct',
  24. 'Qwen/Qwen2.5-Coder-32B-Instruct',
  25. 'nvidia/Llama-3.1-Nemotron-70B-Instruct',
  26. ]
  27. model_aliases = {
  28. "llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct",
  29. "llama-3.1-8b": "meta-llama/Meta-Llama-3.1-8B-Instruct",
  30. "llama-3.3-70b": "meta-llama/Llama-3.3-70B-Instruct-Turbo",
  31. "llama-3.1-70b": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
  32. "qwq-32b": "Qwen/QwQ-32B-Preview",
  33. "wizardlm-2-8x22b": "microsoft/WizardLM-2-8x22B",
  34. "qwen-2-72b": "Qwen/Qwen2.5-72B-Instruct",
  35. "qwen-2.5-coder-32b": "Qwen/Qwen2.5-Coder-32B-Instruct",
  36. "nemotron-70b": "nvidia/Llama-3.1-Nemotron-70B-Instruct",
  37. }
  38. @classmethod
  39. async def create_async_generator(
  40. cls,
  41. model: str,
  42. messages: Messages,
  43. proxy: str = None,
  44. **kwargs
  45. ) -> AsyncResult:
  46. model = cls.get_model(model)
  47. headers = {
  48. 'Accept-Language': 'en-US,en;q=0.9',
  49. 'Content-Type': 'application/json',
  50. 'Origin': 'https://deepinfra.com',
  51. 'Referer': 'https://deepinfra.com/',
  52. 'X-Deepinfra-Source': 'web-page',
  53. 'accept': 'text/event-stream',
  54. }
  55. async with ClientSession(headers=headers) as session:
  56. data = {
  57. "model": model,
  58. "messages": messages,
  59. "stream": True
  60. }
  61. async with session.post(cls.api_endpoint, json=data, proxy=proxy) as response:
  62. await raise_for_status(response)
  63. async for chunk in response.content:
  64. if chunk:
  65. chunk_text = chunk.decode(errors="ignore")
  66. try:
  67. # Handle streaming response
  68. if chunk_text.startswith("data: "):
  69. if chunk_text.strip() == "data: [DONE]":
  70. continue
  71. chunk_data = json.loads(chunk_text[6:])
  72. content = chunk_data["choices"][0]["delta"].get("content")
  73. if content:
  74. yield content
  75. # Handle non-streaming response
  76. else:
  77. chunk_data = json.loads(chunk_text)
  78. content = chunk_data["choices"][0]["message"].get("content")
  79. if content:
  80. yield content
  81. except (json.JSONDecodeError, KeyError):
  82. continue