Cerebras.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. from .OpenaiAPI import OpenaiAPI
  4. from ...typing import AsyncResult, Messages, Cookies
  5. from ...requests.raise_for_status import raise_for_status
  6. from ...cookies import get_cookies
  7. class Cerebras(OpenaiAPI):
  8. label = "Cerebras Inference"
  9. url = "https://inference.cerebras.ai/"
  10. login_url = "https://cloud.cerebras.ai"
  11. api_base = "https://api.cerebras.ai/v1"
  12. working = True
  13. default_model = "llama3.1-70b"
  14. models = [
  15. default_model,
  16. "llama3.1-8b",
  17. "llama-3.3-70b",
  18. "deepseek-r1-distill-llama-70b"
  19. ]
  20. model_aliases = {"llama-3.1-70b": default_model, "llama-3.1-8b": "llama3.1-8b", "deepseek-r1": "deepseek-r1-distill-llama-70b"}
  21. @classmethod
  22. async def create_async_generator(
  23. cls,
  24. model: str,
  25. messages: Messages,
  26. api_key: str = None,
  27. cookies: Cookies = None,
  28. **kwargs
  29. ) -> AsyncResult:
  30. if api_key is None:
  31. if cookies is None:
  32. cookies = get_cookies(".cerebras.ai")
  33. async with ClientSession(cookies=cookies) as session:
  34. async with session.get("https://inference.cerebras.ai/api/auth/session") as response:
  35. await raise_for_status(response)
  36. data = await response.json()
  37. if data:
  38. api_key = data.get("user", {}).get("demoApiKey")
  39. async for chunk in super().create_async_generator(
  40. model, messages,
  41. impersonate="chrome",
  42. api_key=api_key,
  43. headers={
  44. "User-Agent": "ex/JS 1.5.0",
  45. },
  46. **kwargs
  47. ):
  48. yield chunk