Janus_Pro_7B.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. from __future__ import annotations
  2. import json
  3. import uuid
  4. import re
  5. import random
  6. from datetime import datetime, timezone, timedelta
  7. import urllib.parse
  8. from ...typing import AsyncResult, Messages, Cookies, ImagesType
  9. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  10. from ..helper import format_prompt, format_image_prompt
  11. from ...providers.response import JsonConversation, ImageResponse, Reasoning
  12. from ...requests.aiohttp import StreamSession, StreamResponse, FormData
  13. from ...requests.raise_for_status import raise_for_status
  14. from ...image import to_bytes, is_accepted_format
  15. from ...cookies import get_cookies
  16. from ...errors import ResponseError
  17. from ... import debug
  18. from .raise_for_status import raise_for_status
  19. class Janus_Pro_7B(AsyncGeneratorProvider, ProviderModelMixin):
  20. space = "deepseek-ai/Janus-Pro-7B"
  21. url = f"https://huggingface.co/spaces/{space}"
  22. api_url = "https://deepseek-ai-janus-pro-7b.hf.space"
  23. referer = f"{api_url}?__theme=light"
  24. working = True
  25. supports_stream = True
  26. supports_system_message = True
  27. supports_message_history = True
  28. default_model = "janus-pro-7b"
  29. default_image_model = "janus-pro-7b-image"
  30. default_vision_model = default_model
  31. models = [default_model, default_image_model]
  32. image_models = [default_image_model]
  33. @classmethod
  34. def run(cls, method: str, session: StreamSession, prompt: str, conversation: JsonConversation, image: dict = None, seed: int = 0):
  35. headers = {
  36. "content-type": "application/json",
  37. "x-zerogpu-token": conversation.zerogpu_token,
  38. "x-zerogpu-uuid": conversation.zerogpu_uuid,
  39. "referer": cls.referer,
  40. }
  41. if method == "post":
  42. return session.post(f"{cls.api_url}/gradio_api/queue/join?__theme=light", **{
  43. "headers": {k: v for k, v in headers.items() if v is not None},
  44. "json": {"data":[image,prompt,seed,0.95,0.1],"event_data":None,"fn_index":2,"trigger_id":10,"session_hash":conversation.session_hash},
  45. })
  46. elif method == "image":
  47. return session.post(f"{cls.api_url}/gradio_api/queue/join?__theme=light", **{
  48. "headers": {k: v for k, v in headers.items() if v is not None},
  49. "json": {"data":[prompt,seed,5,1],"event_data":None,"fn_index":3,"trigger_id":20,"session_hash":conversation.session_hash},
  50. })
  51. return session.get(f"{cls.api_url}/gradio_api/queue/data?session_hash={conversation.session_hash}", **{
  52. "headers": {
  53. "accept": "text/event-stream",
  54. "content-type": "application/json",
  55. "referer": cls.referer,
  56. }
  57. })
  58. @classmethod
  59. async def create_async_generator(
  60. cls,
  61. model: str,
  62. messages: Messages,
  63. images: ImagesType = None,
  64. prompt: str = None,
  65. proxy: str = None,
  66. cookies: Cookies = None,
  67. api_key: str = None,
  68. zerogpu_uuid: str = "[object Object]",
  69. return_conversation: bool = False,
  70. conversation: JsonConversation = None,
  71. seed: int = None,
  72. **kwargs
  73. ) -> AsyncResult:
  74. method = "post"
  75. if model == cls.default_image_model or prompt is not None:
  76. method = "image"
  77. prompt = format_prompt(messages) if prompt is None and conversation is None else prompt
  78. prompt = format_image_prompt(messages, prompt)
  79. if seed is None:
  80. seed = random.randint(1000, 999999)
  81. session_hash = uuid.uuid4().hex if conversation is None else getattr(conversation, "session_hash", uuid.uuid4().hex)
  82. async with StreamSession(proxy=proxy, impersonate="chrome") as session:
  83. if api_key is None:
  84. zerogpu_uuid, api_key = await get_zerogpu_token(cls.space, session, conversation, cookies)
  85. if conversation is None or not hasattr(conversation, "session_hash"):
  86. conversation = JsonConversation(session_hash=session_hash, zerogpu_token=api_key, zerogpu_uuid=zerogpu_uuid)
  87. else:
  88. conversation.zerogpu_token = api_key
  89. if return_conversation:
  90. yield conversation
  91. if images is not None:
  92. data = FormData()
  93. for i in range(len(images)):
  94. images[i] = (to_bytes(images[i][0]), images[i][1])
  95. for image, image_name in images:
  96. data.add_field(f"files", image, filename=image_name)
  97. async with session.post(f"{cls.api_url}/gradio_api/upload", params={"upload_id": session_hash}, data=data) as response:
  98. await raise_for_status(response)
  99. image_files = await response.json()
  100. images = [{
  101. "path": image_file,
  102. "url": f"{cls.api_url}/gradio_api/file={image_file}",
  103. "orig_name": images[i][1],
  104. "size": len(images[i][0]),
  105. "mime_type": is_accepted_format(images[i][0]),
  106. "meta": {
  107. "_type": "gradio.FileData"
  108. }
  109. } for i, image_file in enumerate(image_files)]
  110. async with cls.run(method, session, prompt, conversation, None if images is None else images.pop(), seed) as response:
  111. await raise_for_status(response)
  112. async with cls.run("get", session, prompt, conversation, None, seed) as response:
  113. response: StreamResponse = response
  114. counter = 3
  115. async for line in response.iter_lines():
  116. decoded_line = line.decode(errors="replace")
  117. if decoded_line.startswith('data: '):
  118. try:
  119. json_data = json.loads(decoded_line[6:])
  120. if json_data.get('msg') == 'log':
  121. yield Reasoning(status=json_data["log"])
  122. if json_data.get('msg') == 'progress':
  123. if 'progress_data' in json_data:
  124. if json_data['progress_data']:
  125. progress = json_data['progress_data'][0]
  126. yield Reasoning(status=f"{progress['desc']} {progress['index']}/{progress['length']}")
  127. else:
  128. yield Reasoning(status=f"Generating")
  129. elif json_data.get('msg') == 'heartbeat':
  130. yield Reasoning(status=f"Generating{''.join(['.' for i in range(counter)])}")
  131. counter += 1
  132. elif json_data.get('msg') == 'process_completed':
  133. if 'output' in json_data and 'error' in json_data['output']:
  134. json_data['output']['error'] = json_data['output']['error'].split(" <a ")[0]
  135. raise ResponseError("Missing images input" if json_data['output']['error'] and "AttributeError" in json_data['output']['error'] else json_data['output']['error'])
  136. if 'output' in json_data and 'data' in json_data['output']:
  137. yield Reasoning(status="Finished")
  138. if "image" in json_data['output']['data'][0][0]:
  139. yield ImageResponse([image["image"]["url"] for image in json_data['output']['data'][0]], prompt)
  140. else:
  141. yield json_data['output']['data'][0]
  142. break
  143. except json.JSONDecodeError:
  144. debug.log("Could not parse JSON:", decoded_line)
  145. async def get_zerogpu_token(space: str, session: StreamSession, conversation: JsonConversation, cookies: Cookies = None):
  146. zerogpu_uuid = None if conversation is None else getattr(conversation, "zerogpu_uuid", None)
  147. zerogpu_token = "[object Object]"
  148. cookies = get_cookies("huggingface.co", raise_requirements_error=False) if cookies is None else cookies
  149. if zerogpu_uuid is None:
  150. async with session.get(f"https://huggingface.co/spaces/{space}", cookies=cookies) as response:
  151. match = re.search(r"&quot;token&quot;:&quot;([^&]+?)&quot;", await response.text())
  152. if match:
  153. zerogpu_token = match.group(1)
  154. match = re.search(r"&quot;sessionUuid&quot;:&quot;([^&]+?)&quot;", await response.text())
  155. if match:
  156. zerogpu_uuid = match.group(1)
  157. if cookies:
  158. # Get current UTC time + 10 minutes
  159. dt = (datetime.now(timezone.utc) + timedelta(minutes=10)).isoformat(timespec='milliseconds')
  160. encoded_dt = urllib.parse.quote(dt)
  161. async with session.get(f"https://huggingface.co/api/spaces/{space}/jwt?expiration={encoded_dt}&include_pro_status=true", cookies=cookies) as response:
  162. response_data = (await response.json())
  163. if "token" in response_data:
  164. zerogpu_token = response_data["token"]
  165. return zerogpu_uuid, zerogpu_token