123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- from __future__ import annotations
- import asyncio
- import json
- from ..typing import AsyncResult, Messages, Cookies
- from .base_provider import AsyncGeneratorProvider, ProviderModelMixin, get_running_loop
- from ..requests import Session, StreamSession, get_args_from_nodriver, raise_for_status, merge_cookies
- from ..requests import DEFAULT_HEADERS, has_nodriver, has_curl_cffi
- from ..providers.response import FinishReason
- from ..errors import ResponseStatusError, ModelNotFoundError
- class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
- label = "Cloudflare AI"
- url = "https://playground.ai.cloudflare.com"
- api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
- models_url = "https://playground.ai.cloudflare.com/api/models"
- working = True
- supports_stream = True
- supports_system_message = True
- supports_message_history = True
- default_model = "@cf/meta/llama-3.1-8b-instruct"
- model_aliases = {
- "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
- "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
- "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
- "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
- "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
- "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
- "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
- "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
- "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
- }
- _args: dict = None
- @classmethod
- def get_models(cls) -> str:
- if not cls.models:
- if cls._args is None:
- if has_nodriver:
- get_running_loop(check_nested=True)
- args = get_args_from_nodriver(cls.url)
- cls._args = asyncio.run(args)
- elif not has_curl_cffi:
- return cls.models
- else:
- cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
- with Session(**cls._args) as session:
- response = session.get(cls.models_url)
- cls._args["cookies"] = merge_cookies(cls._args["cookies"], response)
- try:
- raise_for_status(response)
- except ResponseStatusError:
- return cls.models
- json_data = response.json()
- cls.models = [model.get("name") for model in json_data.get("models")]
- return cls.models
- @classmethod
- async def create_async_generator(
- cls,
- model: str,
- messages: Messages,
- proxy: str = None,
- max_tokens: int = 2048,
- cookies: Cookies = None,
- timeout: int = 300,
- **kwargs
- ) -> AsyncResult:
- if cls._args is None:
- if has_nodriver:
- cls._args = await get_args_from_nodriver(cls.url, proxy, timeout, cookies)
- else:
- cls._args = {"headers": DEFAULT_HEADERS, "cookies": {}}
- try:
- model = cls.get_model(model)
- except ModelNotFoundError:
- pass
- data = {
- "messages": messages,
- "lora": None,
- "model": model,
- "max_tokens": max_tokens,
- "stream": True
- }
- async with StreamSession(**cls._args) as session:
- async with session.post(
- cls.api_endpoint,
- json=data,
- ) as response:
- cls._args["cookies"] = merge_cookies(cls._args["cookies"] , response)
- try:
- await raise_for_status(response)
- except ResponseStatusError:
- cls._args = None
- raise
- reason = None
- async for line in response.iter_lines():
- if line.startswith(b'data: '):
- if line == b'data: [DONE]':
- break
- try:
- content = json.loads(line[6:].decode())
- if content.get("response") and content.get("response") != '</s>':
- yield content['response']
- reason = "max_tokens"
- elif content.get("response") == '':
- reason = "stop"
- except Exception:
- continue
- if reason is not None:
- yield FinishReason(reason)
|