__init__.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import random
  5. import string
  6. import asyncio
  7. import base64
  8. from typing import Union, AsyncIterator, Iterator, Coroutine, Optional
  9. from ..providers.base_provider import AsyncGeneratorProvider
  10. from ..image import ImageResponse, copy_images, images_dir
  11. from ..typing import Messages, Image, ImageType
  12. from ..providers.types import ProviderType
  13. from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData
  14. from ..errors import NoImageResponseError, ModelNotFoundError
  15. from ..providers.retry_provider import IterListProvider
  16. from ..providers.asyncio import get_running_loop, to_sync_generator, async_generator_to_list
  17. from ..Provider.needs_auth.BingCreateImages import BingCreateImages
  18. from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
  19. from .image_models import ImageModels
  20. from .types import IterResponse, ImageProvider, Client as BaseClient
  21. from .service import get_model_and_provider, get_last_provider, convert_to_provider
  22. from .helper import find_stop, filter_json, filter_none, safe_aclose, to_async_iterator
  23. ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  24. AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  25. try:
  26. anext # Python 3.8+
  27. except NameError:
  28. async def anext(aiter):
  29. try:
  30. return await aiter.__anext__()
  31. except StopAsyncIteration:
  32. raise StopIteration
  33. # Synchronous iter_response function
  34. def iter_response(
  35. response: Union[Iterator[Union[str, ResponseType]]],
  36. stream: bool,
  37. response_format: Optional[dict] = None,
  38. max_tokens: Optional[int] = None,
  39. stop: Optional[list[str]] = None
  40. ) -> ChatCompletionResponseType:
  41. content = ""
  42. finish_reason = None
  43. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  44. idx = 0
  45. if hasattr(response, '__aiter__'):
  46. response = to_sync_generator(response)
  47. for chunk in response:
  48. if isinstance(chunk, FinishReason):
  49. finish_reason = chunk.reason
  50. break
  51. elif isinstance(chunk, BaseConversation):
  52. yield chunk
  53. continue
  54. elif isinstance(chunk, SynthesizeData):
  55. continue
  56. chunk = str(chunk)
  57. content += chunk
  58. if max_tokens is not None and idx + 1 >= max_tokens:
  59. finish_reason = "length"
  60. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  61. if first != -1:
  62. finish_reason = "stop"
  63. if stream:
  64. yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
  65. if finish_reason is not None:
  66. break
  67. idx += 1
  68. finish_reason = "stop" if finish_reason is None else finish_reason
  69. if stream:
  70. yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
  71. else:
  72. if response_format is not None and "type" in response_format:
  73. if response_format["type"] == "json_object":
  74. content = filter_json(content)
  75. yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
  76. # Synchronous iter_append_model_and_provider function
  77. def iter_append_model_and_provider(response: ChatCompletionResponseType) -> ChatCompletionResponseType:
  78. last_provider = None
  79. for chunk in response:
  80. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  81. last_provider = get_last_provider(True) if last_provider is None else last_provider
  82. chunk.model = last_provider.get("model")
  83. chunk.provider = last_provider.get("name")
  84. yield chunk
  85. async def async_iter_response(
  86. response: AsyncIterator[Union[str, ResponseType]],
  87. stream: bool,
  88. response_format: Optional[dict] = None,
  89. max_tokens: Optional[int] = None,
  90. stop: Optional[list[str]] = None
  91. ) -> AsyncChatCompletionResponseType:
  92. content = ""
  93. finish_reason = None
  94. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  95. idx = 0
  96. try:
  97. async for chunk in response:
  98. if isinstance(chunk, FinishReason):
  99. finish_reason = chunk.reason
  100. break
  101. elif isinstance(chunk, BaseConversation):
  102. yield chunk
  103. continue
  104. elif isinstance(chunk, SynthesizeData):
  105. continue
  106. chunk = str(chunk)
  107. content += chunk
  108. idx += 1
  109. if max_tokens is not None and idx >= max_tokens:
  110. finish_reason = "length"
  111. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  112. if first != -1:
  113. finish_reason = "stop"
  114. if stream:
  115. yield ChatCompletionChunk(chunk, None, completion_id, int(time.time()))
  116. if finish_reason is not None:
  117. break
  118. finish_reason = "stop" if finish_reason is None else finish_reason
  119. if stream:
  120. yield ChatCompletionChunk(None, finish_reason, completion_id, int(time.time()))
  121. else:
  122. if response_format is not None and "type" in response_format:
  123. if response_format["type"] == "json_object":
  124. content = filter_json(content)
  125. yield ChatCompletion(content, finish_reason, completion_id, int(time.time()))
  126. finally:
  127. if hasattr(response, 'aclose'):
  128. await safe_aclose(response)
  129. async def async_iter_append_model_and_provider(
  130. response: AsyncChatCompletionResponseType
  131. ) -> AsyncChatCompletionResponseType:
  132. last_provider = None
  133. try:
  134. async for chunk in response:
  135. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  136. last_provider = get_last_provider(True) if last_provider is None else last_provider
  137. chunk.model = last_provider.get("model")
  138. chunk.provider = last_provider.get("name")
  139. yield chunk
  140. finally:
  141. if hasattr(response, 'aclose'):
  142. await safe_aclose(response)
  143. class Client(BaseClient):
  144. def __init__(
  145. self,
  146. provider: Optional[ProviderType] = None,
  147. image_provider: Optional[ImageProvider] = None,
  148. **kwargs
  149. ) -> None:
  150. super().__init__(**kwargs)
  151. self.chat: Chat = Chat(self, provider)
  152. self.images: Images = Images(self, image_provider)
  153. class Completions:
  154. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  155. self.client: Client = client
  156. self.provider: ProviderType = provider
  157. def create(
  158. self,
  159. messages: Messages,
  160. model: str,
  161. provider: Optional[ProviderType] = None,
  162. stream: Optional[bool] = False,
  163. proxy: Optional[str] = None,
  164. response_format: Optional[dict] = None,
  165. max_tokens: Optional[int] = None,
  166. stop: Optional[Union[list[str], str]] = None,
  167. api_key: Optional[str] = None,
  168. ignored: Optional[list[str]] = None,
  169. ignore_working: Optional[bool] = False,
  170. ignore_stream: Optional[bool] = False,
  171. **kwargs
  172. ) -> IterResponse:
  173. model, provider = get_model_and_provider(
  174. model,
  175. self.provider if provider is None else provider,
  176. stream,
  177. ignored,
  178. ignore_working,
  179. ignore_stream,
  180. )
  181. stop = [stop] if isinstance(stop, str) else stop
  182. response = provider.create_completion(
  183. model,
  184. messages,
  185. stream=stream,
  186. **filter_none(
  187. proxy=self.client.proxy if proxy is None else proxy,
  188. max_tokens=max_tokens,
  189. stop=stop,
  190. api_key=self.client.api_key if api_key is None else api_key
  191. ),
  192. **kwargs
  193. )
  194. if asyncio.iscoroutinefunction(provider.create_completion):
  195. # Run the asynchronous function in an event loop
  196. response = asyncio.run(response)
  197. if stream and hasattr(response, '__aiter__'):
  198. # It's an async generator, wrap it into a sync iterator
  199. response = to_sync_generator(response)
  200. elif hasattr(response, '__aiter__'):
  201. # If response is an async generator, collect it into a list
  202. response = asyncio.run(async_generator_to_list(response))
  203. response = iter_response(response, stream, response_format, max_tokens, stop)
  204. response = iter_append_model_and_provider(response)
  205. if stream:
  206. return response
  207. else:
  208. return next(response)
  209. class Chat:
  210. completions: Completions
  211. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  212. self.completions = Completions(client, provider)
  213. class Images:
  214. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  215. self.client: Client = client
  216. self.provider: Optional[ProviderType] = provider
  217. self.models: ImageModels = ImageModels(client)
  218. def generate(
  219. self,
  220. prompt: str,
  221. model: str = None,
  222. provider: Optional[ProviderType] = None,
  223. response_format: str = "url",
  224. proxy: Optional[str] = None,
  225. **kwargs
  226. ) -> ImagesResponse:
  227. """
  228. Synchronous generate method that runs the async_generate method in an event loop.
  229. """
  230. return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
  231. async def async_generate(
  232. self,
  233. prompt: str,
  234. model: Optional[str] = None,
  235. provider: Optional[ProviderType] = None,
  236. response_format: Optional[str] = "url",
  237. proxy: Optional[str] = None,
  238. **kwargs
  239. ) -> ImagesResponse:
  240. if provider is None:
  241. provider_handler = self.models.get(model, provider or self.provider or BingCreateImages)
  242. elif isinstance(provider, str):
  243. provider_handler = convert_to_provider(provider)
  244. else:
  245. provider_handler = provider
  246. if provider_handler is None:
  247. raise ModelNotFoundError(f"Unknown model: {model}")
  248. if isinstance(provider_handler, IterListProvider):
  249. if provider_handler.providers:
  250. provider_handler = provider_handler.providers[0]
  251. else:
  252. raise ModelNotFoundError(f"IterListProvider for model {model} has no providers")
  253. if proxy is None:
  254. proxy = self.client.proxy
  255. response = None
  256. if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
  257. messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
  258. async for item in provider_handler.create_async_generator(model, messages, prompt=prompt, **kwargs):
  259. if isinstance(item, ImageResponse):
  260. response = item
  261. break
  262. elif hasattr(provider_handler, 'create'):
  263. if asyncio.iscoroutinefunction(provider_handler.create):
  264. response = await provider_handler.create(prompt)
  265. else:
  266. response = provider_handler.create(prompt)
  267. if isinstance(response, str):
  268. response = ImageResponse([response], prompt)
  269. elif hasattr(provider_handler, "create_completion"):
  270. get_running_loop(check_nested=True)
  271. messages = [{"role": "user", "content": f"Generate a image: {prompt}"}]
  272. for item in provider_handler.create_completion(model, messages, prompt=prompt, **kwargs):
  273. if isinstance(item, ImageResponse):
  274. response = item
  275. break
  276. else:
  277. raise ValueError(f"Provider {provider} does not support image generation")
  278. if isinstance(response, ImageResponse):
  279. return await self._process_image_response(
  280. response,
  281. response_format,
  282. proxy,
  283. model,
  284. getattr(provider_handler, "__name__", None)
  285. )
  286. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  287. def create_variation(
  288. self,
  289. image: Union[str, bytes],
  290. model: str = None,
  291. provider: Optional[ProviderType] = None,
  292. response_format: str = "url",
  293. **kwargs
  294. ) -> ImagesResponse:
  295. return asyncio.run(self.async_create_variation(
  296. image, model, provider, response_format, **kwargs
  297. ))
  298. async def async_create_variation(
  299. self,
  300. image: ImageType,
  301. model: Optional[str] = None,
  302. provider: Optional[ProviderType] = None,
  303. response_format: str = "url",
  304. proxy: Optional[str] = None,
  305. **kwargs
  306. ) -> ImagesResponse:
  307. if provider is None:
  308. provider = self.models.get(model, provider or self.provider or BingCreateImages)
  309. if provider is None:
  310. raise ModelNotFoundError(f"Unknown model: {model}")
  311. if isinstance(provider, str):
  312. provider = convert_to_provider(provider)
  313. if proxy is None:
  314. proxy = self.client.proxy
  315. if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
  316. messages = [{"role": "user", "content": "create a variation of this image"}]
  317. generator = None
  318. try:
  319. generator = provider.create_async_generator(model, messages, image=image, response_format=response_format, proxy=proxy, **kwargs)
  320. async for chunk in generator:
  321. if isinstance(chunk, ImageResponse):
  322. response = chunk
  323. break
  324. finally:
  325. if generator and hasattr(generator, 'aclose'):
  326. await safe_aclose(generator)
  327. elif hasattr(provider, 'create_variation'):
  328. if asyncio.iscoroutinefunction(provider.create_variation):
  329. response = await provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
  330. else:
  331. response = provider.create_variation(image, model=model, response_format=response_format, proxy=proxy, **kwargs)
  332. else:
  333. raise NoImageResponseError(f"Provider {provider} does not support image variation")
  334. if isinstance(response, str):
  335. response = ImageResponse([response])
  336. if isinstance(response, ImageResponse):
  337. return self._process_image_response(response, response_format, proxy, model, getattr(provider, "__name__", None))
  338. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  339. async def _process_image_response(
  340. self,
  341. response: ImageResponse,
  342. response_format: str,
  343. proxy: str = None,
  344. model: Optional[str] = None,
  345. provider: Optional[str] = None
  346. ) -> list[Image]:
  347. if response_format in ("url", "b64_json"):
  348. images = await copy_images(response.get_list(), response.options.get("cookies"), proxy)
  349. async def process_image_item(image_file: str) -> Image:
  350. if response_format == "b64_json":
  351. with open(os.path.join(images_dir, os.path.basename(image_file)), "rb") as file:
  352. image_data = base64.b64encode(file.read()).decode()
  353. return Image(url=image_file, b64_json=image_data, revised_prompt=response.alt)
  354. return Image(url=image_file, revised_prompt=response.alt)
  355. images = await asyncio.gather(*[process_image_item(image) for image in images])
  356. else:
  357. images = [Image(url=image, revised_prompt=response.alt) for image in response.get_list()]
  358. last_provider = get_last_provider(True)
  359. return ImagesResponse(
  360. images,
  361. model=last_provider.get("model") if model is None else model,
  362. provider=last_provider.get("name") if provider is None else provider
  363. )
  364. class AsyncClient(BaseClient):
  365. def __init__(
  366. self,
  367. provider: Optional[ProviderType] = None,
  368. image_provider: Optional[ImageProvider] = None,
  369. **kwargs
  370. ) -> None:
  371. super().__init__(**kwargs)
  372. self.chat: AsyncChat = AsyncChat(self, provider)
  373. self.images: AsyncImages = AsyncImages(self, image_provider)
  374. class AsyncChat:
  375. completions: AsyncCompletions
  376. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  377. self.completions = AsyncCompletions(client, provider)
  378. class AsyncCompletions:
  379. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  380. self.client: AsyncClient = client
  381. self.provider: ProviderType = provider
  382. def create(
  383. self,
  384. messages: Messages,
  385. model: str,
  386. provider: Optional[ProviderType] = None,
  387. stream: Optional[bool] = False,
  388. proxy: Optional[str] = None,
  389. response_format: Optional[dict] = None,
  390. max_tokens: Optional[int] = None,
  391. stop: Optional[Union[list[str], str]] = None,
  392. api_key: Optional[str] = None,
  393. ignored: Optional[list[str]] = None,
  394. ignore_working: Optional[bool] = False,
  395. ignore_stream: Optional[bool] = False,
  396. **kwargs
  397. ) -> Union[Coroutine[ChatCompletion], AsyncIterator[ChatCompletionChunk, BaseConversation]]:
  398. model, provider = get_model_and_provider(
  399. model,
  400. self.provider if provider is None else provider,
  401. stream,
  402. ignored,
  403. ignore_working,
  404. ignore_stream,
  405. )
  406. stop = [stop] if isinstance(stop, str) else stop
  407. response = provider.create_completion(
  408. model,
  409. messages,
  410. stream=stream,
  411. **filter_none(
  412. proxy=self.client.proxy if proxy is None else proxy,
  413. max_tokens=max_tokens,
  414. stop=stop,
  415. api_key=self.client.api_key if api_key is None else api_key
  416. ),
  417. **kwargs
  418. )
  419. if not isinstance(response, AsyncIterator):
  420. response = to_async_iterator(response)
  421. response = async_iter_response(response, stream, response_format, max_tokens, stop)
  422. response = async_iter_append_model_and_provider(response)
  423. return response if stream else anext(response)
  424. class AsyncImages(Images):
  425. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  426. self.client: AsyncClient = client
  427. self.provider: Optional[ProviderType] = provider
  428. self.models: ImageModels = ImageModels(client)
  429. async def generate(
  430. self,
  431. prompt: str,
  432. model: Optional[str] = None,
  433. provider: Optional[ProviderType] = None,
  434. response_format: str = "url",
  435. **kwargs
  436. ) -> ImagesResponse:
  437. return await self.async_generate(prompt, model, provider, response_format, **kwargs)
  438. async def create_variation(
  439. self,
  440. image: ImageType,
  441. model: str = None,
  442. provider: ProviderType = None,
  443. response_format: str = "url",
  444. **kwargs
  445. ) -> ImagesResponse:
  446. return await self.async_create_variation(
  447. image, model, provider, response_format, **kwargs
  448. )