__init__.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import random
  5. import string
  6. import asyncio
  7. import aiohttp
  8. import base64
  9. from typing import Union, AsyncIterator, Iterator, Awaitable, Optional
  10. from ..image import ImageResponse, copy_images
  11. from ..typing import Messages, ImageType
  12. from ..providers.types import ProviderType, BaseRetryProvider
  13. from ..providers.response import ResponseType, FinishReason, BaseConversation, SynthesizeData, ToolCalls, Usage
  14. from ..errors import NoImageResponseError
  15. from ..providers.retry_provider import IterListProvider
  16. from ..providers.asyncio import to_sync_generator
  17. from ..Provider.needs_auth import BingCreateImages, OpenaiAccount
  18. from ..tools.run_tools import async_iter_run_tools, iter_run_tools
  19. from .stubs import ChatCompletion, ChatCompletionChunk, Image, ImagesResponse
  20. from .image_models import ImageModels
  21. from .types import IterResponse, ImageProvider, Client as BaseClient
  22. from .service import get_model_and_provider, convert_to_provider
  23. from .helper import find_stop, filter_json, filter_none, safe_aclose
  24. from .. import debug
  25. ChatCompletionResponseType = Iterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  26. AsyncChatCompletionResponseType = AsyncIterator[Union[ChatCompletion, ChatCompletionChunk, BaseConversation]]
  27. try:
  28. anext # Python 3.8+
  29. except NameError:
  30. async def anext(aiter):
  31. try:
  32. return await aiter.__anext__()
  33. except StopAsyncIteration:
  34. raise StopIteration
  35. # Synchronous iter_response function
  36. def iter_response(
  37. response: Union[Iterator[Union[str, ResponseType]]],
  38. stream: bool,
  39. response_format: Optional[dict] = None,
  40. max_tokens: Optional[int] = None,
  41. stop: Optional[list[str]] = None
  42. ) -> ChatCompletionResponseType:
  43. content = ""
  44. finish_reason = None
  45. tool_calls = None
  46. usage = None
  47. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  48. idx = 0
  49. if hasattr(response, '__aiter__'):
  50. response = to_sync_generator(response)
  51. for chunk in response:
  52. if isinstance(chunk, FinishReason):
  53. finish_reason = chunk.reason
  54. break
  55. elif isinstance(chunk, ToolCalls):
  56. tool_calls = chunk.get_list()
  57. continue
  58. elif isinstance(chunk, Usage):
  59. usage = chunk
  60. continue
  61. elif isinstance(chunk, BaseConversation):
  62. yield chunk
  63. continue
  64. elif isinstance(chunk, SynthesizeData) or not chunk:
  65. continue
  66. chunk = str(chunk)
  67. content += chunk
  68. if max_tokens is not None and idx + 1 >= max_tokens:
  69. finish_reason = "length"
  70. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  71. if first != -1:
  72. finish_reason = "stop"
  73. if stream:
  74. yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))
  75. if finish_reason is not None:
  76. break
  77. idx += 1
  78. if usage is None:
  79. usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
  80. finish_reason = "stop" if finish_reason is None else finish_reason
  81. if stream:
  82. yield ChatCompletionChunk.model_construct(
  83. None, finish_reason, completion_id, int(time.time()),
  84. usage=usage.get_dict()
  85. )
  86. else:
  87. if response_format is not None and "type" in response_format:
  88. if response_format["type"] == "json_object":
  89. content = filter_json(content)
  90. yield ChatCompletion.model_construct(
  91. content, finish_reason, completion_id, int(time.time()),
  92. usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
  93. )
  94. # Synchronous iter_append_model_and_provider function
  95. def iter_append_model_and_provider(response: ChatCompletionResponseType, last_model: str, last_provider: ProviderType) -> ChatCompletionResponseType:
  96. if isinstance(last_provider, BaseRetryProvider):
  97. last_provider = last_provider.last_provider
  98. for chunk in response:
  99. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  100. if last_provider is not None:
  101. chunk.model = getattr(last_provider, "last_model", last_model)
  102. chunk.provider = last_provider.__name__
  103. yield chunk
  104. async def async_iter_response(
  105. response: AsyncIterator[Union[str, ResponseType]],
  106. stream: bool,
  107. response_format: Optional[dict] = None,
  108. max_tokens: Optional[int] = None,
  109. stop: Optional[list[str]] = None
  110. ) -> AsyncChatCompletionResponseType:
  111. content = ""
  112. finish_reason = None
  113. completion_id = ''.join(random.choices(string.ascii_letters + string.digits, k=28))
  114. idx = 0
  115. tool_calls = None
  116. usage = None
  117. try:
  118. async for chunk in response:
  119. if isinstance(chunk, FinishReason):
  120. finish_reason = chunk.reason
  121. break
  122. elif isinstance(chunk, BaseConversation):
  123. yield chunk
  124. continue
  125. elif isinstance(chunk, ToolCalls):
  126. tool_calls = chunk.get_list()
  127. continue
  128. elif isinstance(chunk, Usage):
  129. usage = chunk
  130. continue
  131. elif isinstance(chunk, SynthesizeData) or not chunk:
  132. continue
  133. chunk = str(chunk)
  134. content += chunk
  135. idx += 1
  136. if max_tokens is not None and idx >= max_tokens:
  137. finish_reason = "length"
  138. first, content, chunk = find_stop(stop, content, chunk if stream else None)
  139. if first != -1:
  140. finish_reason = "stop"
  141. if stream:
  142. yield ChatCompletionChunk.model_construct(chunk, None, completion_id, int(time.time()))
  143. if finish_reason is not None:
  144. break
  145. finish_reason = "stop" if finish_reason is None else finish_reason
  146. if usage is None:
  147. usage = Usage(prompt_tokens=0, completion_tokens=idx, total_tokens=idx)
  148. if stream:
  149. yield ChatCompletionChunk.model_construct(
  150. None, finish_reason, completion_id, int(time.time()),
  151. usage=usage.get_dict()
  152. )
  153. else:
  154. if response_format is not None and "type" in response_format:
  155. if response_format["type"] == "json_object":
  156. content = filter_json(content)
  157. yield ChatCompletion.model_construct(
  158. content, finish_reason, completion_id, int(time.time()),
  159. usage=usage.get_dict(), **filter_none(tool_calls=tool_calls)
  160. )
  161. finally:
  162. await safe_aclose(response)
  163. async def async_iter_append_model_and_provider(
  164. response: AsyncChatCompletionResponseType,
  165. last_model: str,
  166. last_provider: ProviderType
  167. ) -> AsyncChatCompletionResponseType:
  168. last_provider = None
  169. try:
  170. if isinstance(last_provider, BaseRetryProvider):
  171. if last_provider is not None:
  172. last_provider = last_provider.last_provider
  173. async for chunk in response:
  174. if isinstance(chunk, (ChatCompletion, ChatCompletionChunk)):
  175. if last_provider is not None:
  176. chunk.model = getattr(last_provider, "last_model", last_model)
  177. chunk.provider = last_provider.__name__
  178. yield chunk
  179. finally:
  180. await safe_aclose(response)
  181. class Client(BaseClient):
  182. def __init__(
  183. self,
  184. provider: Optional[ProviderType] = None,
  185. image_provider: Optional[ImageProvider] = None,
  186. **kwargs
  187. ) -> None:
  188. super().__init__(**kwargs)
  189. self.chat: Chat = Chat(self, provider)
  190. self.images: Images = Images(self, image_provider)
  191. class Completions:
  192. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  193. self.client: Client = client
  194. self.provider: ProviderType = provider
  195. def create(
  196. self,
  197. messages: Messages,
  198. model: str,
  199. provider: Optional[ProviderType] = None,
  200. stream: Optional[bool] = False,
  201. proxy: Optional[str] = None,
  202. image: Optional[ImageType] = None,
  203. image_name: Optional[str] = None,
  204. response_format: Optional[dict] = None,
  205. max_tokens: Optional[int] = None,
  206. stop: Optional[Union[list[str], str]] = None,
  207. api_key: Optional[str] = None,
  208. ignored: Optional[list[str]] = None,
  209. ignore_working: Optional[bool] = False,
  210. ignore_stream: Optional[bool] = False,
  211. **kwargs
  212. ) -> ChatCompletion:
  213. model, provider = get_model_and_provider(
  214. model,
  215. self.provider if provider is None else provider,
  216. stream,
  217. ignore_working,
  218. ignore_stream,
  219. )
  220. stop = [stop] if isinstance(stop, str) else stop
  221. if image is not None:
  222. kwargs["images"] = [(image, image_name)]
  223. if ignore_stream:
  224. kwargs["ignore_stream"] = True
  225. response = iter_run_tools(
  226. provider.get_create_function(),
  227. model,
  228. messages,
  229. stream=stream,
  230. **filter_none(
  231. proxy=self.client.proxy if proxy is None else proxy,
  232. max_tokens=max_tokens,
  233. stop=stop,
  234. api_key=self.client.api_key if api_key is None else api_key
  235. ),
  236. **kwargs
  237. )
  238. if not hasattr(response, '__iter__'):
  239. response = [response]
  240. response = iter_response(response, stream, response_format, max_tokens, stop)
  241. response = iter_append_model_and_provider(response, model, provider)
  242. if stream:
  243. return response
  244. else:
  245. return next(response)
  246. def stream(
  247. self,
  248. messages: Messages,
  249. model: str,
  250. **kwargs
  251. ) -> IterResponse:
  252. return self.create(messages, model, stream=True, **kwargs)
  253. class Chat:
  254. completions: Completions
  255. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  256. self.completions = Completions(client, provider)
  257. class Images:
  258. def __init__(self, client: Client, provider: Optional[ProviderType] = None):
  259. self.client: Client = client
  260. self.provider: Optional[ProviderType] = provider
  261. self.models: ImageModels = ImageModels(client)
  262. def generate(
  263. self,
  264. prompt: str,
  265. model: str = None,
  266. provider: Optional[ProviderType] = None,
  267. response_format: Optional[str] = None,
  268. proxy: Optional[str] = None,
  269. **kwargs
  270. ) -> ImagesResponse:
  271. """
  272. Synchronous generate method that runs the async_generate method in an event loop.
  273. """
  274. return asyncio.run(self.async_generate(prompt, model, provider, response_format, proxy, **kwargs))
  275. async def get_provider_handler(self, model: Optional[str], provider: Optional[ImageProvider], default: ImageProvider) -> ImageProvider:
  276. if provider is None:
  277. provider_handler = self.provider
  278. if provider_handler is None:
  279. provider_handler = self.models.get(model, default)
  280. elif isinstance(provider, str):
  281. provider_handler = convert_to_provider(provider)
  282. else:
  283. provider_handler = provider
  284. if provider_handler is None:
  285. return default
  286. return provider_handler
  287. async def async_generate(
  288. self,
  289. prompt: str,
  290. model: Optional[str] = None,
  291. provider: Optional[ProviderType] = None,
  292. response_format: Optional[str] = None,
  293. proxy: Optional[str] = None,
  294. **kwargs
  295. ) -> ImagesResponse:
  296. provider_handler = await self.get_provider_handler(model, provider, BingCreateImages)
  297. provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
  298. if proxy is None:
  299. proxy = self.client.proxy
  300. error = None
  301. response = None
  302. if isinstance(provider_handler, IterListProvider):
  303. for provider in provider_handler.providers:
  304. try:
  305. response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
  306. if response is not None:
  307. provider_name = provider.__name__
  308. break
  309. except Exception as e:
  310. error = e
  311. debug.log(f"Image provider {provider.__name__}: {e}")
  312. else:
  313. response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
  314. if isinstance(response, ImageResponse):
  315. return await self._process_image_response(
  316. response,
  317. model,
  318. provider_name,
  319. response_format,
  320. proxy
  321. )
  322. if response is None:
  323. if error is not None:
  324. raise error
  325. raise NoImageResponseError(f"No image response from {provider_name}")
  326. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  327. async def _generate_image_response(
  328. self,
  329. provider_handler,
  330. provider_name,
  331. model: str,
  332. prompt: str,
  333. prompt_prefix: str = "Generate a image: ",
  334. **kwargs
  335. ) -> ImageResponse:
  336. messages = [{"role": "user", "content": f"{prompt_prefix}{prompt}"}]
  337. response = None
  338. if hasattr(provider_handler, "create_async_generator"):
  339. async for item in provider_handler.create_async_generator(
  340. model,
  341. messages,
  342. stream=True,
  343. prompt=prompt,
  344. **kwargs
  345. ):
  346. if isinstance(item, ImageResponse):
  347. response = item
  348. break
  349. elif hasattr(provider_handler, "create_completion"):
  350. for item in provider_handler.create_completion(
  351. model,
  352. messages,
  353. True,
  354. prompt=prompt,
  355. **kwargs
  356. ):
  357. if isinstance(item, ImageResponse):
  358. response = item
  359. break
  360. else:
  361. raise ValueError(f"Provider {provider_name} does not support image generation")
  362. return response
  363. def create_variation(
  364. self,
  365. image: ImageType,
  366. model: str = None,
  367. provider: Optional[ProviderType] = None,
  368. response_format: Optional[str] = None,
  369. **kwargs
  370. ) -> ImagesResponse:
  371. return asyncio.run(self.async_create_variation(
  372. image, model, provider, response_format, **kwargs
  373. ))
  374. async def async_create_variation(
  375. self,
  376. image: ImageType,
  377. model: Optional[str] = None,
  378. provider: Optional[ProviderType] = None,
  379. response_format: Optional[str] = None,
  380. proxy: Optional[str] = None,
  381. **kwargs
  382. ) -> ImagesResponse:
  383. provider_handler = await self.get_provider_handler(model, provider, OpenaiAccount)
  384. provider_name = provider_handler.__name__ if hasattr(provider_handler, "__name__") else type(provider_handler).__name__
  385. if proxy is None:
  386. proxy = self.client.proxy
  387. prompt = "create a variation of this image"
  388. if image is not None:
  389. kwargs["images"] = [(image, None)]
  390. error = None
  391. response = None
  392. if isinstance(provider_handler, IterListProvider):
  393. for provider in provider_handler.providers:
  394. try:
  395. response = await self._generate_image_response(provider, provider.__name__, model, prompt, **kwargs)
  396. if response is not None:
  397. provider_name = provider.__name__
  398. break
  399. except Exception as e:
  400. error = e
  401. debug.log(f"Image provider {provider.__name__}: {e}")
  402. else:
  403. response = await self._generate_image_response(provider_handler, provider_name, model, prompt, **kwargs)
  404. if isinstance(response, ImageResponse):
  405. return await self._process_image_response(response, model, provider_name, response_format, proxy)
  406. if response is None:
  407. if error is not None:
  408. raise error
  409. raise NoImageResponseError(f"No image response from {provider_name}")
  410. raise NoImageResponseError(f"Unexpected response type: {type(response)}")
  411. async def _process_image_response(
  412. self,
  413. response: ImageResponse,
  414. model: str,
  415. provider: str,
  416. response_format: Optional[str] = None,
  417. proxy: str = None
  418. ) -> ImagesResponse:
  419. if response_format == "url":
  420. # Return original URLs without saving locally
  421. images = [Image.model_construct(url=image, revised_prompt=response.alt) for image in response.get_list()]
  422. elif response_format == "b64_json":
  423. # Convert URLs directly to base64 without saving
  424. async def get_b64_from_url(url: str) -> Image:
  425. async with aiohttp.ClientSession() as session:
  426. async with session.get(url, proxy=proxy) as resp:
  427. if resp.status == 200:
  428. image_data = await resp.read()
  429. b64_data = base64.b64encode(image_data).decode()
  430. return Image.model_construct(b64_json=b64_data, revised_prompt=response.alt)
  431. images = await asyncio.gather(*[get_b64_from_url(image) for image in response.get_list()])
  432. else:
  433. # Save locally for None (default) case
  434. images = await copy_images(response.get_list(), response.get("cookies"), proxy)
  435. images = [Image.model_construct(url=f"/images/{os.path.basename(image)}", revised_prompt=response.alt) for image in images]
  436. return ImagesResponse.model_construct(
  437. created=int(time.time()),
  438. data=images,
  439. model=model,
  440. provider=provider
  441. )
  442. class AsyncClient(BaseClient):
  443. def __init__(
  444. self,
  445. provider: Optional[ProviderType] = None,
  446. image_provider: Optional[ImageProvider] = None,
  447. **kwargs
  448. ) -> None:
  449. super().__init__(**kwargs)
  450. self.chat: AsyncChat = AsyncChat(self, provider)
  451. self.images: AsyncImages = AsyncImages(self, image_provider)
  452. class AsyncChat:
  453. completions: AsyncCompletions
  454. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  455. self.completions = AsyncCompletions(client, provider)
  456. class AsyncCompletions:
  457. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  458. self.client: AsyncClient = client
  459. self.provider: ProviderType = provider
  460. def create(
  461. self,
  462. messages: Messages,
  463. model: str,
  464. provider: Optional[ProviderType] = None,
  465. stream: Optional[bool] = False,
  466. proxy: Optional[str] = None,
  467. image: Optional[ImageType] = None,
  468. image_name: Optional[str] = None,
  469. response_format: Optional[dict] = None,
  470. max_tokens: Optional[int] = None,
  471. stop: Optional[Union[list[str], str]] = None,
  472. api_key: Optional[str] = None,
  473. ignored: Optional[list[str]] = None,
  474. ignore_working: Optional[bool] = False,
  475. ignore_stream: Optional[bool] = False,
  476. **kwargs
  477. ) -> Awaitable[ChatCompletion]:
  478. model, provider = get_model_and_provider(
  479. model,
  480. self.provider if provider is None else provider,
  481. stream,
  482. ignore_working,
  483. ignore_stream,
  484. )
  485. stop = [stop] if isinstance(stop, str) else stop
  486. if image is not None:
  487. kwargs["images"] = [(image, image_name)]
  488. if ignore_stream:
  489. kwargs["ignore_stream"] = True
  490. response = async_iter_run_tools(
  491. provider.get_async_create_function(),
  492. model,
  493. messages,
  494. stream=stream,
  495. **filter_none(
  496. proxy=self.client.proxy if proxy is None else proxy,
  497. max_tokens=max_tokens,
  498. stop=stop,
  499. api_key=self.client.api_key if api_key is None else api_key
  500. ),
  501. **kwargs
  502. )
  503. response = async_iter_response(response, stream, response_format, max_tokens, stop)
  504. response = async_iter_append_model_and_provider(response, model, provider)
  505. return response if stream else anext(response)
  506. def stream(
  507. self,
  508. messages: Messages,
  509. model: str,
  510. **kwargs
  511. ) -> AsyncIterator[ChatCompletionChunk, BaseConversation]:
  512. return self.create(messages, model, stream=True, **kwargs)
  513. class AsyncImages(Images):
  514. def __init__(self, client: AsyncClient, provider: Optional[ProviderType] = None):
  515. self.client: AsyncClient = client
  516. self.provider: Optional[ProviderType] = provider
  517. self.models: ImageModels = ImageModels(client)
  518. async def generate(
  519. self,
  520. prompt: str,
  521. model: Optional[str] = None,
  522. provider: Optional[ProviderType] = None,
  523. response_format: Optional[str] = None,
  524. **kwargs
  525. ) -> ImagesResponse:
  526. return await self.async_generate(prompt, model, provider, response_format, **kwargs)
  527. async def create_variation(
  528. self,
  529. image: ImageType,
  530. model: str = None,
  531. provider: ProviderType = None,
  532. response_format: Optional[str] = None,
  533. **kwargs
  534. ) -> ImagesResponse:
  535. return await self.async_create_variation(
  536. image, model, provider, response_format, **kwargs
  537. )