Cloudflare.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import asyncio
  4. import json
  5. import uuid
  6. import cloudscraper
  7. from typing import AsyncGenerator
  8. from ..typing import AsyncResult, Messages
  9. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  10. from .helper import format_prompt
  11. class Cloudflare(AsyncGeneratorProvider, ProviderModelMixin):
  12. label = "Cloudflare AI"
  13. url = "https://playground.ai.cloudflare.com"
  14. api_endpoint = "https://playground.ai.cloudflare.com/api/inference"
  15. working = True
  16. supports_stream = True
  17. supports_system_message = True
  18. supports_message_history = True
  19. default_model = '@cf/meta/llama-3.1-8b-instruct-awq'
  20. models = [
  21. '@cf/meta/llama-2-7b-chat-fp16',
  22. '@cf/meta/llama-2-7b-chat-int8',
  23. '@cf/meta/llama-3-8b-instruct',
  24. '@cf/meta/llama-3-8b-instruct-awq',
  25. '@hf/meta-llama/meta-llama-3-8b-instruct',
  26. default_model,
  27. '@cf/meta/llama-3.1-8b-instruct-fp8',
  28. '@cf/meta/llama-3.2-1b-instruct',
  29. '@hf/mistral/mistral-7b-instruct-v0.2',
  30. '@cf/qwen/qwen1.5-7b-chat-awq',
  31. '@cf/defog/sqlcoder-7b-2',
  32. ]
  33. model_aliases = {
  34. "llama-2-7b": "@cf/meta/llama-2-7b-chat-fp16",
  35. "llama-2-7b": "@cf/meta/llama-2-7b-chat-int8",
  36. "llama-3-8b": "@cf/meta/llama-3-8b-instruct",
  37. "llama-3-8b": "@cf/meta/llama-3-8b-instruct-awq",
  38. "llama-3-8b": "@hf/meta-llama/meta-llama-3-8b-instruct",
  39. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-awq",
  40. "llama-3.1-8b": "@cf/meta/llama-3.1-8b-instruct-fp8",
  41. "llama-3.2-1b": "@cf/meta/llama-3.2-1b-instruct",
  42. "qwen-1.5-7b": "@cf/qwen/qwen1.5-7b-chat-awq",
  43. #"sqlcoder-7b": "@cf/defog/sqlcoder-7b-2",
  44. }
  45. @classmethod
  46. def get_model(cls, model: str) -> str:
  47. if model in cls.models:
  48. return model
  49. elif model in cls.model_aliases:
  50. return cls.model_aliases[model]
  51. else:
  52. return cls.default_model
  53. @classmethod
  54. async def create_async_generator(
  55. cls,
  56. model: str,
  57. messages: Messages,
  58. proxy: str = None,
  59. max_tokens: int = 2048,
  60. **kwargs
  61. ) -> AsyncResult:
  62. model = cls.get_model(model)
  63. headers = {
  64. 'Accept': 'text/event-stream',
  65. 'Accept-Language': 'en-US,en;q=0.9',
  66. 'Cache-Control': 'no-cache',
  67. 'Content-Type': 'application/json',
  68. 'Origin': cls.url,
  69. 'Pragma': 'no-cache',
  70. 'Referer': f'{cls.url}/',
  71. 'Sec-Ch-Ua': '"Chromium";v="129", "Not=A?Brand";v="8"',
  72. 'Sec-Ch-Ua-Mobile': '?0',
  73. 'Sec-Ch-Ua-Platform': '"Linux"',
  74. 'Sec-Fetch-Dest': 'empty',
  75. 'Sec-Fetch-Mode': 'cors',
  76. 'Sec-Fetch-Site': 'same-origin',
  77. 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36',
  78. }
  79. cookies = {
  80. '__cf_bm': uuid.uuid4().hex,
  81. }
  82. scraper = cloudscraper.create_scraper()
  83. data = {
  84. "messages": [
  85. {"role": "user", "content": format_prompt(messages)}
  86. ],
  87. "lora": None,
  88. "model": model,
  89. "max_tokens": max_tokens,
  90. "stream": True
  91. }
  92. max_retries = 3
  93. full_response = ""
  94. for attempt in range(max_retries):
  95. try:
  96. response = scraper.post(
  97. cls.api_endpoint,
  98. headers=headers,
  99. cookies=cookies,
  100. json=data,
  101. stream=True,
  102. proxies={'http': proxy, 'https': proxy} if proxy else None
  103. )
  104. if response.status_code == 403:
  105. await asyncio.sleep(2 ** attempt)
  106. continue
  107. response.raise_for_status()
  108. for line in response.iter_lines():
  109. if line.startswith(b'data: '):
  110. if line == b'data: [DONE]':
  111. if full_response:
  112. yield full_response
  113. break
  114. try:
  115. content = json.loads(line[6:].decode('utf-8'))
  116. if 'response' in content and content['response'] != '</s>':
  117. yield content['response']
  118. except Exception:
  119. continue
  120. break
  121. except Exception as e:
  122. if attempt == max_retries - 1:
  123. raise