__init__.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import random
  5. import string
  6. import asyncio
  7. import aiohttp
  8. import base64
  9. from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
  10. from ..image import ImageResponse, copy_images
  11. from ..typing import Messages, ImageType
  12. from ..providers.types import ProviderType, BaseRetryProvider
  13. from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
  14. from ..errors import NoImageResponseError
  15. from ..providers.retry_provider import IterListProvider
  16. from ..providers.asyncio import to_sync_generator
  17. from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
  18. from ..tools.run_tools import async_iter_run_tools, iter_run_tools
  19. from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
  20. from .image_models import ImageModels
  21. from .types import IterResponse, ImageProvider, Client as BaseClient
  22. from .service import get_model_and_provider, convert_to_provider
  23. from .helper import find_stop, filter_json, filter_none, safe_aclose
  24. from .. import debug
  25. ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  26. AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  27. try:
  28. anext # Python 3.8+
  29. except NameError:
  30. async def anext(aiter):
  31. try:
  32. return await aiter.__anext__()
  33. except StopAsyncIteration:
  34. raise StopIteration
  35. # Synchronous iter_response function
  36. def iter_response(
  37. response: Union[Iterator[Union[str, ResponseType]]],
  38. stream: bool,
  39. response_format: Optional[dict] = None,
  40. max_tokens: Optional[int] = None,
  41. stop: Optional[list[str]] = None
  42. ) -> ChatCompletionResponseType:
  43. content = ""
  44. finish_reason = None
  45. tool_calls = None
  46. usage = None
  47. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  48. idx = 0
  49. if hasattr(response, '__aiter__'):
  50. response = to_sync_generator(response)
  51. for chunk in response:
  52. if isinstance(chunk, FinishReason):
  53. finish_reason = chunk.reason
  54. break
  55. elif isinstance(chunk, ToolCalls):
  56. tool_calls = chunk.get_list()
  57. continue
  58. elif isinstance(chunk, Usage):
  59. usage = chunk
  60. continue
  61. elif isinstance(chunk, BaseConversation):
  62. yield chunk
  63. continue
  64. elif isinstance(chunk, SynthesizeData) or not chunk:
  65. continue
  66. elif isinstance(chunk, Exception):
  67. continue
  68. chunk = str(chunk)
  69. content += chunk
  70. if max_tokens is not None and idx + 1 >= max_tokens:
  71. finish_reason = "length"
  72. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  73. if first != -1:
  74. finish_reason = "stop"
  75. if stream:
  76. yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))
  77. if finish_reason is not None:
  78. break
  79. idx += 1
  80. if usage is None:
  81. usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
  82. finish_reason = "stop" if finish_reason is None else finish_reason
  83. if stream:
  84. yield ChatCompletionChunk.model_construct(
  85. None, finish_reason, completion_id, int(time.time()),
  86. usage=usage.get_dict()
  87. )
  88. else:
  89. if response_format is not None and "type" in response_format:
  90. if response_format["type"] == "json_object":
  91. content = filter_json(content)
  92. yield ChatCompletion.model_construct(
  93. content, finish_reason, completion_id, int(time.time()),
  94. usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
  95. )
  96. # Synchronous iter_append_model_and_provider function
  97. def iter_append_model_and_provider(response: ChatCompletionResponseType, last_model: str, last_provider: ProviderType) -> ChatCompletionResponseType:
  98. if isinstance(last_provider, BaseRetryProvider):
  99. last_provider = last_provider.last_provider
  100. for chunk in response:
  101. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  102. if last_provider is not None:
  103. chunk.model = getattr(last_provider, "last_model", last_model)
  104. chunk.provider = last_provider.__name__
  105. yield chunk
  106. async def async_iter_response(
  107. response: AsyncIterator[Union[str, ResponseType]],
  108. stream: bool,
  109. response_format: Optional[dict] = None,
  110. max_tokens: Optional[int] = None,
  111. stop: Optional[list[str]] = None
  112. ) -> AsyncChatCompletionResponseType:
  113. content = ""
  114. finish_reason = None
  115. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  116. idx = 0
  117. tool_calls = None
  118. usage = None
  119. try:
  120. async for chunk in response:
  121. if isinstance(chunk, FinishReason):
  122. finish_reason = chunk.reason
  123. break
  124. elif isinstance(chunk, BaseConversation):
  125. yield chunk
  126. continue
  127. elif isinstance(chunk, ToolCalls):
  128. tool_calls = chunk.get_list()
  129. continue
  130. elif isinstance(chunk, Usage):
  131. usage = chunk
  132. continue
  133. elif isinstance(chunk, SynthesizeData) or not chunk:
  134. continue
  135. elif isinstance(chunk, Exception):
  136. continue
  137. chunk = str(chunk)
  138. content += chunk
  139. idx += 1
  140. if max_tokens is not None and idx >= max_tokens:
  141. finish_reason = "length"
  142. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  143. if first != -1:
  144. finish_reason = "stop"
  145. if stream:
  146. yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))
  147. if finish_reason is not None:
  148. break
  149. finish_reason = "stop" if finish_reason is None else finish_reason
  150. if usage is None:
  151. usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
  152. if stream:
  153. yield ChatCompletionChunk.model_construct(
  154. None, finish_reason, completion_id, int(time.time()),
  155. usage=usage.get_dict()
  156. )
  157. else:
  158. if response_format is not None and "type" in response_format:
  159. if response_format["type"] == "json_object":
  160. content = filter_json(content)
  161. yield ChatCompletion.model_construct(
  162. content, finish_reason, completion_id, int(time.time()),
  163. usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
  164. )
  165. finally:
  166. await safe_aclose(response)
  167. async def async_iter_append_model_and_provider(
  168. response: AsyncChatCompletionResponseType,
  169. last_model: str,
  170. last_provider: ProviderType
  171. ) -> AsyncChatCompletionResponseType:
  172. last_provider = None
  173. try:
  174. if isinstance(last_provider, BaseRetryProvider):
  175. if last_provider is not None:
  176. last_provider = last_provider.last_provider
  177. async for chunk in response:
  178. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  179. if last_provider is not None:
  180. chunk.model = getattr(last_provider, "last_model", last_model)
  181. chunk.provider = last_provider.__name__
  182. yield chunk
  183. finally:
  184. await safe_aclose(response)
  185. class Client(BaseClient):
  186. def __init__(
  187. self,
  188. provider: Optional[ProviderType] = None,
  189. image_provider: Optional[ImageProvider] = None,
  190. **kwargs
  191. ) -> None:
  192. super().__init__(**kwargs)
  193. self.chat: Chat = Chat(self, provider)
  194. self.images: Images = Images(self, image_provider)
  195. class Completions:
  196. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  197. self.client: Client = client
  198. self.provider: ProviderType = provider
  199. def create(
  200. self,
  201. messages: Messages,
  202. model: str,
  203. provider: Optional[ProviderType] = None,
  204. stream: Optional[bool] = False,
  205. proxy: Optional[str] = None,
  206. image: Optional[ImageType] = None,
  207. image_name: Optional[str] = None,
  208. response_format: Optional[dict] = None,
  209. max_tokens: Optional[int] = None,
  210. stop: Optional[Union[list[str], str]] = None,
  211. api_key: Optional[str] = None,
  212. ignore_working: Optional[bool] = False,
  213. ignore_stream: Optional[bool] = False,
  214. **kwargs
  215. ) -> ChatCompletion:
  216. if image is not None:
  217. kwargs["images"] = [(image, image_name)]
  218. model, provider = get_model_and_provider(
  219. model,
  220. self.provider if provider is None else provider,
  221. stream,
  222. ignore_working,
  223. ignore_stream,
  224. has_images="images" in kwargs
  225. )
  226. stop = [stop] if isinstance(stop, str) else stop
  227. if ignore_stream:
  228. kwargs["ignore_stream"] = True
  229. response = iter_run_tools(
  230. provider.get_create_function(),
  231. model,
  232. messages,
  233. stream=stream,
  234. **filter_none(
  235. proxy=self.client.proxy if proxy is None else proxy,
  236. max_tokens=max_tokens,
  237. stop=stop,
  238. api_key=self.client.api_key if api_key is None else api_key
  239. ),
  240. **kwargs
  241. )
  242. response = iter_response(response, stream, response_format, max_tokens, stop)
  243. response = iter_append_model_and_provider(response, model, provider)
  244. if stream:
  245. return response
  246. else:
  247. return next(response)
  248. def stream(
  249. self,
  250. messages: Messages,
  251. model: str,
  252. **kwargs
  253. ) -> IterResponse:
  254. return self.create(messages, model, stream=True, **kwargs)
  255. class Chat:
  256. completions: Completions
  257. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  258. self.completions = Completions(client, provider)
  259. class Images:
  260. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  261. self.client: Client = client
  262. self.provider: Optional[ProviderType] = provider
  263. self.models: ImageModels = ImageModels(client)
  264. def generate(
  265. self,
  266. prompt: str,
  267. model: str = None,
  268. provider: Optional[ProviderType] = None,
  269. response_format: Optional[str] = None,
  270. proxy: Optional[str] = None,
  271. **kwargs
  272. ) -> ImagesResponse:
  273. """
  274. Synchronous generate method that runs the async_generate method in an event loop.
  275. """
  276. return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
  277. async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider:
  278. if provider is None:
  279. provider_handler = self.provider
  280. if provider_handler is None:
  281. provider_handler = self.models.get(model, default)
  282. elif isinstance(provider, str):
  283. provider_handler = convert_to_provider(provider)
  284. else:
  285. provider_handler = provider
  286. if provider_handler is None:
  287. return default
  288. return provider_handler
  289. async def async_generate(
  290. self,
  291. prompt: str,
  292. model: Optional[str] = None,
  293. provider: Optional[ProviderType] = None,
  294. response_format: Optional[str] = None,
  295. proxy: Optional[str] = None,
  296. **kwargs
  297. ) -> ImagesResponse:
  298. provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
  299. provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
  300. if proxy is None:
  301. proxy = self.client.proxy
  302. error = None
  303. response = None
  304. if isinstance(provider_handler, IterListProvider):
  305. for provider in provider_handler.providers:
  306. try:
  307. response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
  308. if response is not None:
  309. provider_name = provider.__name__
  310. break
  311. except Exception as e:
  312. error = e
  313. debug.log(f"Image provider {provider.__name__}: {e}")
  314. else:
  315. response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
  316. if isinstance(response, ImageResponse):
  317. return await self._process_image_response(
  318. response,
  319. model,
  320. provider_name,
  321. response_format,
  322. proxy
  323. )
  324. if response is None:
  325. if error is not None:
  326. raise error
  327. raise NoImageResponseError(f"No image response from {provider_name}")
  328. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  329. async def _generate_image_response(
  330. self,
  331. provider_handler,
  332. provider_name,
  333. model: str,
  334. prompt: str,
  335. prompt_prefix: str = "Generate a image: ",
  336. **kwargs
  337. ) -> ImageResponse:
  338. messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
  339. response = None
  340. if hasattr(provider_handler, "create_async_generator"):
  341. async for item in provider_handler.create_async_generator(
  342. model,
  343. messages,
  344. stream=True,
  345. prompt=prompt,
  346. **kwargs
  347. ):
  348. if isinstance(item, ImageResponse):
  349. response = item
  350. break
  351. elif hasattr(provider_handler, "create_completion"):
  352. for item in provider_handler.create_completion(
  353. model,
  354. messages,
  355. True,
  356. prompt=prompt,
  357. **kwargs
  358. ):
  359. if isinstance(item, ImageResponse):
  360. response = item
  361. break
  362. else:
  363. raise ValueError(f"Provider {provider_name} does not support image generation")
  364. return response
  365. def create_variation(
  366. self,
  367. image: ImageType,
  368. model: str = None,
  369. provider: Optional[ProviderType] = None,
  370. response_format: Optional[str] = None,
  371. **kwargs
  372. ) -> ImagesResponse:
  373. return asyncio.run(self.async_create_variation(
  374. image, model, provider, response_format, **kwargs
  375. ))
  376. async def async_create_variation(
  377. self,
  378. image: ImageType,
  379. model: Optional[str] = None,
  380. provider: Optional[ProviderType] = None,
  381. response_format: Optional[str] = None,
  382. proxy: Optional[str] = None,
  383. **kwargs
  384. ) -> ImagesResponse:
  385. provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
  386. provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
  387. if proxy is None:
  388. proxy = self.client.proxy
  389. prompt = "create a variation of this image"
  390. if image is not None:
  391. kwargs["images"] = [(image, None)]
  392. error = None
  393. response = None
  394. if isinstance(provider_handler, IterListProvider):
  395. for provider in provider_handler.providers:
  396. try:
  397. response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
  398. if response is not None:
  399. provider_name = provider.__name__
  400. break
  401. except Exception as e:
  402. error = e
  403. debug.log(f"Image provider {provider.__name__}: {e}")
  404. else:
  405. response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
  406. if isinstance(response, ImageResponse):
  407. return await self._process_image_response(response, model, provider_name, response_format, proxy)
  408. if response is None:
  409. if error is not None:
  410. raise error
  411. raise NoImageResponseError(f"No image response from {provider_name}")
  412. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  413. async def _process_image_response(
  414. self,
  415. response: ImageResponse,
  416. model: str,
  417. provider: str,
  418. response_format: Optional[str] = None,
  419. proxy: str = None
  420. ) -> ImagesResponse:
  421. if response_format == "url":
  422. # Return original URLs without saving locally
  423. images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()]
  424. elif response_format == "b64_json":
  425. # Convert URLs directly to base64 without saving
  426. async def get_b64_from_url(url: str) -> Image:
  427. async with aiohttp.ClientSession(cookies=response.get("cookies")) as session:
  428. async with session.get(url, proxy=proxy) as resp:
  429. if resp.status == 200:
  430. image_data = await resp.read()
  431. b64_data = base64.b64encode(image_data).decode()
  432. return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt)
  433. images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
  434. else:
  435. # Save locally for None (default) case
  436. images = await copy_images(response.get_list(), response.get("cookies"), proxy)
  437. images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
  438. return ImagesResponse.model_construct(
  439. created=int(time.time()),
  440. data=images,
  441. model=model,
  442. provider=provider
  443. )
  444. class AsyncClient(BaseClient):
  445. def __init__(
  446. self,
  447. provider: Optional[ProviderType] = None,
  448. image_provider: Optional[ImageProvider] = None,
  449. **kwargs
  450. ) -> None:
  451. super().__init__(**kwargs)
  452. self.chat: AsyncChat = AsyncChat(self, provider)
  453. self.images: AsyncImages = AsyncImages(self, image_provider)
  454. class AsyncChat:
  455. completions: AsyncCompletions
  456. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  457. self.completions = AsyncCompletions(client, provider)
  458. class AsyncCompletions:
  459. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  460. self.client: AsyncClient = client
  461. self.provider: ProviderType = provider
  462. def create(
  463. self,
  464. messages: Messages,
  465. model: str,
  466. provider: Optional[ProviderType] = None,
  467. stream: Optional[bool] = False,
  468. proxy: Optional[str] = None,
  469. image: Optional[ImageType] = None,
  470. image_name: Optional[str] = None,
  471. response_format: Optional[dict] = None,
  472. max_tokens: Optional[int] = None,
  473. stop: Optional[Union[list[str], str]] = None,
  474. api_key: Optional[str] = None,
  475. ignore_working: Optional[bool] = False,
  476. ignore_stream: Optional[bool] = False,
  477. **kwargs
  478. ) -> Awaitable[ChatCompletion]:
  479. if image is not None:
  480. kwargs["images"] = [(image, image_name)]
  481. model, provider = get_model_and_provider(
  482. model,
  483. self.provider if provider is None else provider,
  484. stream,
  485. ignore_working,
  486. ignore_stream,
  487. has_images="images" in kwargs,
  488. )
  489. stop = [stop] if isinstance(stop, str) else stop
  490. if ignore_stream:
  491. kwargs["ignore_stream"] = True
  492. response = async_iter_run_tools(
  493. provider,
  494. model,
  495. messages,
  496. stream=stream,
  497. **filter_none(
  498. proxy=self.client.proxy if proxy is None else proxy,
  499. max_tokens=max_tokens,
  500. stop=stop,
  501. api_key=self.client.api_key if api_key is None else api_key
  502. ),
  503. **kwargs
  504. )
  505. response = async_iter_response(response, stream, response_format, max_tokens, stop)
  506. response = async_iter_append_model_and_provider(response, model, provider)
  507. if stream:
  508. return response
  509. else:
  510. return anext(response)
  511. def stream(
  512. self,
  513. messages: Messages,
  514. model: str,
  515. **kwargs
  516. ) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
  517. return self.create(messages, model, stream=True, **kwargs)
  518. class AsyncImages(Images):
  519. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  520. self.client: AsyncClient = client
  521. self.provider: Optional[ProviderType] = provider
  522. self.models: ImageModels = ImageModels(client)
  523. async def generate(
  524. self,
  525. prompt: str,
  526. model: Optional[str] = None,
  527. provider: Optional[ProviderType] = None,
  528. response_format: Optional[str] = None,
  529. **kwargs
  530. ) -> ImagesResponse:
  531. return await self.async_generate(prompt, model, provider, response_format, **kwargs)
  532. async def create_variation(
  533. self,
  534. image: ImageType,
  535. model: str = None,
  536. provider: ProviderType = None,
  537. response_format: Optional[str] = None,
  538. **kwargs
  539. ) -> ImagesResponse:
  540. return await self.async_create_variation(
  541. image, model, provider, response_format, **kwargs
  542. )