DeepInfra.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from __future__ import annotations
  2. import requests
  3. from ...typing import AsyncResult, Messages
  4. from .OpenaiAPI import OpenaiAPI
  5. from ...requests import StreamSession, raise_for_status
  6. from ...image import ImageResponse
  7. class DeepInfra(OpenaiAPI):
  8. label = "DeepInfra"
  9. url = "https://deepinfra.com"
  10. login_url = "https://deepinfra.com/dash/api_keys"
  11. working = True
  12. api_base = "https://api.deepinfra.com/v1/openai",
  13. needs_auth = True
  14. supports_stream = True
  15. supports_message_history = True
  16. default_model = "meta-llama/Meta-Llama-3.1-70B-Instruct"
  17. default_image_model = ''
  18. image_models = [default_image_model]
  19. @classmethod
  20. def get_models(cls, **kwargs):
  21. if not cls.models:
  22. url = 'https://api.deepinfra.com/models/featured'
  23. models = requests.get(url).json()
  24. cls.models = [model['model_name'] for model in models if model["type"] == "text-generation"]
  25. cls.image_models = [model['model_name'] for model in models if model["reported_type"] == "text-to-image"]
  26. return cls.models
  27. @classmethod
  28. def create_async_generator(
  29. cls,
  30. model: str,
  31. messages: Messages,
  32. stream: bool = True,
  33. temperature: float = 0.7,
  34. max_tokens: int = 1028,
  35. prompt: str = None,
  36. **kwargs
  37. ) -> AsyncResult:
  38. headers = {
  39. 'Accept-Encoding': 'gzip, deflate, br',
  40. 'Accept-Language': 'en-US',
  41. 'Origin': 'https://deepinfra.com',
  42. 'Referer': 'https://deepinfra.com/',
  43. 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36',
  44. 'X-Deepinfra-Source': 'web-embed',
  45. }
  46. # Check if the model is an image model
  47. if model in cls.image_models:
  48. return cls.create_image_generator(messages[-1]["content"] if prompt is None else prompt, model, headers=headers, **kwargs)
  49. # Text generation
  50. return super().create_async_generator(
  51. model, messages,
  52. stream=stream,
  53. temperature=temperature,
  54. max_tokens=max_tokens,
  55. headers=headers,
  56. **kwargs
  57. )
  58. @classmethod
  59. async def create_image_generator(
  60. cls,
  61. prompt: str,
  62. model: str,
  63. api_key: str = None,
  64. api_base: str = "https://api.deepinfra.com/v1/inference",
  65. proxy: str = None,
  66. timeout: int = 180,
  67. headers: dict = None,
  68. extra_data: dict = {},
  69. **kwargs
  70. ) -> AsyncResult:
  71. if api_key is not None and headers is not None:
  72. headers["Authorization"] = f"Bearer {api_key}"
  73. async with StreamSession(
  74. proxies={"all": proxy},
  75. headers=headers,
  76. timeout=timeout
  77. ) as session:
  78. model = cls.get_model(model)
  79. data = {"prompt": prompt, **extra_data}
  80. data = {"input": data} if model == cls.default_image_model else data
  81. async with session.post(f"{api_base.rstrip('/')}/{model}", json=data) as response:
  82. await raise_for_status(response)
  83. data = await response.json()
  84. images = data.get("output", data.get("images", data.get("image_url")))
  85. if not images:
  86. raise RuntimeError(f"Response: {data}")
  87. images = images[0] if len(images) == 1 else images
  88. yield ImageResponse(images, prompt)