OpenaiAPI.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. from __future__ import annotations
  2. import json
  3. import requests
  4. from ..helper import filter_none
  5. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
  6. from ...typing import Union, Optional, AsyncResult, Messages, ImagesType
  7. from ...requests import StreamSession, raise_for_status
  8. from ...providers.response import FinishReason, ToolCalls, Usage
  9. from ...errors import MissingAuthError, ResponseError
  10. from ...image import to_data_uri
  11. from ... import debug
  12. class OpenaiAPI(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin):
  13. label = "OpenAI API"
  14. url = "https://platform.openai.com"
  15. login_url = "https://platform.openai.com/settings/organization/api-keys"
  16. api_base = "https://api.openai.com/v1"
  17. working = True
  18. needs_auth = True
  19. supports_message_history = True
  20. supports_system_message = True
  21. default_model = ""
  22. fallback_models = []
  23. @classmethod
  24. def get_models(cls, api_key: str = None, api_base: str = None) -> list[str]:
  25. if not cls.models:
  26. try:
  27. headers = {}
  28. if api_base is None:
  29. api_base = cls.api_base
  30. if api_key is not None:
  31. headers["authorization"] = f"Bearer {api_key}"
  32. response = requests.get(f"{api_base}/models", headers=headers)
  33. raise_for_status(response)
  34. data = response.json()
  35. cls.models = [model.get("id") for model in data.get("data")]
  36. cls.models.sort()
  37. except Exception as e:
  38. debug.log(e)
  39. cls.models = cls.fallback_models
  40. return cls.models
  41. @classmethod
  42. async def create_async_generator(
  43. cls,
  44. model: str,
  45. messages: Messages,
  46. proxy: str = None,
  47. timeout: int = 120,
  48. images: ImagesType = None,
  49. api_key: str = None,
  50. api_endpoint: str = None,
  51. api_base: str = None,
  52. temperature: float = None,
  53. max_tokens: int = None,
  54. top_p: float = None,
  55. stop: Union[str, list[str]] = None,
  56. stream: bool = False,
  57. headers: dict = None,
  58. impersonate: str = None,
  59. tools: Optional[list] = None,
  60. extra_data: dict = {},
  61. **kwargs
  62. ) -> AsyncResult:
  63. if cls.needs_auth and api_key is None:
  64. raise MissingAuthError('Add a "api_key"')
  65. if api_base is None:
  66. api_base = cls.api_base
  67. if images is not None:
  68. if not model and hasattr(cls, "default_vision_model"):
  69. model = cls.default_vision_model
  70. messages[-1]["content"] = [
  71. *[{
  72. "type": "image_url",
  73. "image_url": {"url": to_data_uri(image)}
  74. } for image, _ in images],
  75. {
  76. "type": "text",
  77. "text": messages[-1]["content"]
  78. }
  79. ]
  80. async with StreamSession(
  81. proxy=proxy,
  82. headers=cls.get_headers(stream, api_key, headers),
  83. timeout=timeout,
  84. impersonate=impersonate,
  85. ) as session:
  86. data = filter_none(
  87. messages=messages,
  88. model=cls.get_model(model, api_key=api_key, api_base=api_base),
  89. temperature=temperature,
  90. max_tokens=max_tokens,
  91. top_p=top_p,
  92. stop=stop,
  93. stream=stream,
  94. tools=tools,
  95. **extra_data
  96. )
  97. if api_endpoint is None:
  98. api_endpoint = f"{api_base.rstrip('/')}/chat/completions"
  99. async with session.post(api_endpoint, json=data) as response:
  100. await raise_for_status(response)
  101. if not stream:
  102. data = await response.json()
  103. cls.raise_error(data)
  104. choice = data["choices"][0]
  105. if "content" in choice["message"] and choice["message"]["content"]:
  106. yield choice["message"]["content"].strip()
  107. elif "tool_calls" in choice["message"]:
  108. yield ToolCalls(choice["message"]["tool_calls"])
  109. if "usage" in data:
  110. yield Usage(**data["usage"])
  111. finish = cls.read_finish_reason(choice)
  112. if finish is not None:
  113. yield finish
  114. else:
  115. first = True
  116. async for line in response.iter_lines():
  117. if line.startswith(b"data: "):
  118. chunk = line[6:]
  119. if chunk == b"[DONE]":
  120. break
  121. data = json.loads(chunk)
  122. cls.raise_error(data)
  123. choice = data["choices"][0]
  124. if "content" in choice["delta"] and choice["delta"]["content"]:
  125. delta = choice["delta"]["content"]
  126. if first:
  127. delta = delta.lstrip()
  128. if delta:
  129. first = False
  130. yield delta
  131. finish = cls.read_finish_reason(choice)
  132. if finish is not None:
  133. yield finish
  134. @staticmethod
  135. def read_finish_reason(choice: dict) -> Optional[FinishReason]:
  136. if "finish_reason" in choice and choice["finish_reason"] is not None:
  137. return FinishReason(choice["finish_reason"])
  138. @classmethod
  139. def get_headers(cls, stream: bool, api_key: str = None, headers: dict = None) -> dict:
  140. return {
  141. "Accept": "text/event-stream" if stream else "application/json",
  142. "Content-Type": "application/json",
  143. **(
  144. {"Authorization": f"Bearer {api_key}"}
  145. if api_key is not None else {}
  146. ),
  147. **({} if headers is None else headers)
  148. }