DeepInfra.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. from __future__ import annotations
  2. import requests
  3. from ...typing import AsyncResult, Messages
  4. from ...requests import StreamSession, raise_for_status
  5. from ...image import ImageResponse
  6. from ..template import OpenaiTemplate
  7. class DeepInfra(OpenaiTemplate):
  8. url = "https://deepinfra.com"
  9. login_url = "https://deepinfra.com/dash/api_keys"
  10. api_base = "https://api.deepinfra.com/v1/openai"
  11. working = True
  12. needs_auth = True
  13. default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
  14. default_image_model = "stabilityai/sd3.5"
  15. @classmethod
  16. def get_models(cls, **kwargs):
  17. if not cls.models:
  18. url = 'https://api.deepinfra.com/models/featured'
  19. response = requests.get(url)
  20. models = response.json()
  21. cls.models = []
  22. cls.image_models = []
  23. for model in models:
  24. if model["type"] == "text-generation":
  25. cls.models.append(model['model_name'])
  26. elif model["reported_type"] == "text-to-image":
  27. cls.image_models.append(model['model_name'])
  28. cls.models.extend(cls.image_models)
  29. return cls.models
  30. @classmethod
  31. def get_image_models(cls, **kwargs):
  32. if not cls.image_models:
  33. cls.get_models()
  34. return cls.image_models
  35. @classmethod
  36. async def create_async_generator(
  37. cls,
  38. model: str,
  39. messages: Messages,
  40. stream: bool,
  41. prompt: str = None,
  42. temperature: float = 0.7,
  43. max_tokens: int = 1028,
  44. **kwargs
  45. ) -> AsyncResult:
  46. if model in cls.get_image_models():
  47. yield cls.create_async_image(
  48. messages[-1]["content"] if prompt is None else prompt,
  49. model,
  50. **kwargs
  51. )
  52. return
  53. headers = {
  54. 'Accept-Encoding': 'gzip, deflate, br',
  55. 'Accept-Language': 'en-US',
  56. 'Origin': 'https://deepinfra.com',
  57. 'Referer': 'https://deepinfra.com/',
  58. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
  59. 'X-Deepinfra-Source': 'web-embed',
  60. }
  61. async for chunk in super().create_async_generator(
  62. model, messages,
  63. stream=stream,
  64. temperature=temperature,
  65. max_tokens=max_tokens,
  66. headers=headers,
  67. **kwargs
  68. ):
  69. yield chunk
  70. @classmethod
  71. async def create_async_image(
  72. cls,
  73. prompt: str,
  74. model: str,
  75. api_key: str = None,
  76. api_base: str = "https://api.deepinfra.com/v1/inference",
  77. proxy: str = None,
  78. timeout: int = 180,
  79. extra_data: dict = {},
  80. **kwargs
  81. ) -> ImageResponse:
  82. headers = {
  83. 'Accept-Encoding': 'gzip, deflate, br',
  84. 'Accept-Language': 'en-US',
  85. 'Connection': 'keep-alive',
  86. 'Origin': 'https://deepinfra.com',
  87. 'Referer': 'https://deepinfra.com/',
  88. 'Sec-Fetch-Dest': 'empty',
  89. 'Sec-Fetch-Mode': 'cors',
  90. 'Sec-Fetch-Site': 'same-site',
  91. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
  92. 'X-Deepinfra-Source': 'web-embed',
  93. 'sec-ch-ua': '"Google Chrome";v="119", "Chromium";v="119", "Not?A_Brand";v="24"',
  94. 'sec-ch-ua-mobile': '?0',
  95. 'sec-ch-ua-platform': '"macOS"',
  96. }
  97. if api_key is not None:
  98. headers["Authorization"] = f"Bearer {api_key}"
  99. async with StreamSession(
  100. proxies={"all": proxy},
  101. headers=headers,
  102. timeout=timeout
  103. ) as session:
  104. model = cls.get_model(model)
  105. data = {"prompt": prompt, **extra_data}
  106. data = {"input": data} if model == cls.default_model else data
  107. async with session.post(f"{api_base.rstrip('/')}/{model}", json=data) as response:
  108. await raise_for_status(response)
  109. data = await response.json()
  110. images = data.get("output", data.get("images", data.get("image_url")))
  111. if not images:
  112. raise RuntimeError(f"Response: {data}")
  113. images = images[0] if len(images) == 1 else images
  114. return ImageResponse(images, prompt)