Qwen_Qwen_2_72B_Instruct.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from __future__ import annotations
  2. import aiohttp
  3. import json
  4. import uuid
  5. import re
  6. from ...typing import AsyncResult, Messages
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. from ..helper import format_prompt
  9. from ... import debug
  10. class Qwen_Qwen_2_72B_Instruct(AsyncGeneratorProvider, ProviderModelMixin):
  11. url = "https://qwen-qwen2-72b-instruct.hf.space"
  12. api_endpoint = "https://qwen-qwen2-72b-instruct.hf.space/queue/join?"
  13. working = True
  14. supports_stream = True
  15. supports_system_message = True
  16. supports_message_history = False
  17. default_model = "qwen-qwen2-72b-instruct"
  18. models = [default_model]
  19. model_aliases = {"qwen-2-72b": default_model}
  20. @classmethod
  21. async def create_async_generator(
  22. cls,
  23. model: str,
  24. messages: Messages,
  25. proxy: str = None,
  26. **kwargs
  27. ) -> AsyncResult:
  28. def generate_session_hash():
  29. """Generate a unique session hash."""
  30. return str(uuid.uuid4()).replace('-', '')[:12]
  31. # Generate a unique session hash
  32. session_hash = generate_session_hash()
  33. headers_join = {
  34. 'accept': '*/*',
  35. 'accept-language': 'en-US,en;q=0.9',
  36. 'content-type': 'application/json',
  37. 'origin': f'{cls.url}',
  38. 'referer': f'{cls.url}/',
  39. 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
  40. }
  41. # Prepare the prompt
  42. system_prompt = "\n".join([message["content"] for message in messages if message["role"] == "system"])
  43. messages = [message for message in messages if message["role"] != "system"]
  44. prompt = format_prompt(messages)
  45. payload_join = {
  46. "data": [prompt, [], system_prompt],
  47. "event_data": None,
  48. "fn_index": 0,
  49. "trigger_id": 11,
  50. "session_hash": session_hash
  51. }
  52. async with aiohttp.ClientSession() as session:
  53. # Send join request
  54. async with session.post(cls.api_endpoint, headers=headers_join, json=payload_join) as response:
  55. event_id = (await response.json())['event_id']
  56. # Prepare data stream request
  57. url_data = f'{cls.url}/queue/data'
  58. headers_data = {
  59. 'accept': 'text/event-stream',
  60. 'accept-language': 'en-US,en;q=0.9',
  61. 'referer': f'{cls.url}/',
  62. 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
  63. }
  64. params_data = {
  65. 'session_hash': session_hash
  66. }
  67. # Send data stream request
  68. async with session.get(url_data, headers=headers_data, params=params_data) as response:
  69. full_response = ""
  70. final_full_response = ""
  71. async for line in response.content:
  72. decoded_line = line.decode('utf-8')
  73. if decoded_line.startswith('data: '):
  74. try:
  75. json_data = json.loads(decoded_line[6:])
  76. # Look for generation stages
  77. if json_data.get('msg') == 'process_generating':
  78. if 'output' in json_data and 'data' in json_data['output']:
  79. output_data = json_data['output']['data']
  80. if len(output_data) > 1 and len(output_data[1]) > 0:
  81. for item in output_data[1]:
  82. if isinstance(item, list) and len(item) > 1:
  83. fragment = str(item[1])
  84. # Ignore [0, 1] type fragments and duplicates
  85. if not re.match(r'^\[.*\]$', fragment) and not full_response.endswith(fragment):
  86. full_response += fragment
  87. yield fragment
  88. # Check for completion
  89. if json_data.get('msg') == 'process_completed':
  90. # Final check to ensure we get the complete response
  91. if 'output' in json_data and 'data' in json_data['output']:
  92. output_data = json_data['output']['data']
  93. if len(output_data) > 1 and len(output_data[1]) > 0:
  94. final_full_response = output_data[1][0][1]
  95. # Clean up the final response
  96. if final_full_response.startswith(full_response):
  97. final_full_response = final_full_response[len(full_response):]
  98. # Yield the remaining part of the final response
  99. if final_full_response:
  100. yield final_full_response
  101. break
  102. except json.JSONDecodeError:
  103. debug.log("Could not parse JSON:", decoded_line)