run_tools.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. from __future__ import annotations
  2. import re
  3. import json
  4. import asyncio
  5. import time
  6. from pathlib import Path
  7. from typing import Optional, Callable, AsyncIterator, Iterator, Dict, Any, Tuple, List, Union
  8. from ..typing import Messages
  9. from ..providers.helper import filter_none
  10. from ..providers.asyncio import to_async_iterator
  11. from ..providers.response import Reasoning, FinishReason, Sources
  12. from ..providers.types import ProviderType
  13. from ..cookies import get_cookies_dir
  14. from .web_search import do_search, get_search_message
  15. from .files import read_bucket, get_bucket_dir
  16. from .. import debug
  17. # Constants
  18. BUCKET_INSTRUCTIONS = """
  19. Instruction: Make sure to add the sources of cites using [[domain]](Url) notation after the reference. Example: [[a-z0-9.]](http://example.com)
  20. """
  21. TOOL_NAMES = {
  22. "SEARCH": "search_tool",
  23. "CONTINUE": "continue_tool",
  24. "BUCKET": "bucket_tool"
  25. }
  26. class ToolHandler:
  27. """Handles processing of different tool types"""
  28. @staticmethod
  29. def validate_arguments(data: dict) -> dict:
  30. """Validate and parse tool arguments"""
  31. if "arguments" in data:
  32. if isinstance(data["arguments"], str):
  33. data["arguments"] = json.loads(data["arguments"])
  34. if not isinstance(data["arguments"], dict):
  35. raise ValueError("Tool function arguments must be a dictionary or a json string")
  36. else:
  37. return filter_none(**data["arguments"])
  38. else:
  39. return {}
  40. @staticmethod
  41. async def process_search_tool(messages: Messages, tool: dict) -> Messages:
  42. """Process search tool requests"""
  43. messages = messages.copy()
  44. args = ToolHandler.validate_arguments(tool["function"])
  45. messages[-1]["content"] = await do_search(
  46. messages[-1]["content"],
  47. **args
  48. )
  49. return messages
  50. @staticmethod
  51. def process_continue_tool(messages: Messages, tool: dict, provider: Any) -> Tuple[Messages, Dict[str, Any]]:
  52. """Process continue tool requests"""
  53. kwargs = {}
  54. if provider not in ("OpenaiAccount", "HuggingFaceAPI"):
  55. messages = messages.copy()
  56. last_line = messages[-1]["content"].strip().splitlines()[-1]
  57. content = f"Carry on from this point:\n{last_line}"
  58. messages.append({"role": "user", "content": content})
  59. else:
  60. # Enable provider native continue
  61. kwargs["action"] = "continue"
  62. return messages, kwargs
  63. @staticmethod
  64. def process_bucket_tool(messages: Messages, tool: dict) -> Messages:
  65. """Process bucket tool requests"""
  66. messages = messages.copy()
  67. def on_bucket(match):
  68. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  69. has_bucket = False
  70. for message in messages:
  71. if "content" in message and isinstance(message["content"], str):
  72. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  73. if new_message_content != message["content"]:
  74. has_bucket = True
  75. message["content"] = new_message_content
  76. last_message_content = messages[-1]["content"]
  77. if has_bucket and isinstance(last_message_content, str):
  78. if "\nSource: " in last_message_content:
  79. messages[-1]["content"] = last_message_content + BUCKET_INSTRUCTIONS
  80. return messages
  81. @staticmethod
  82. async def process_tools(messages: Messages, tool_calls: List[dict], provider: Any) -> Tuple[Messages, Dict[str, Any]]:
  83. """Process all tool calls and return updated messages and kwargs"""
  84. if not tool_calls:
  85. return messages, {}
  86. extra_kwargs = {}
  87. messages = messages.copy()
  88. for tool in tool_calls:
  89. if tool.get("type") != "function":
  90. continue
  91. function_name = tool.get("function", {}).get("name")
  92. if function_name == TOOL_NAMES["SEARCH"]:
  93. messages = await ToolHandler.process_search_tool(messages, tool)
  94. elif function_name == TOOL_NAMES["CONTINUE"]:
  95. messages, kwargs = ToolHandler.process_continue_tool(messages, tool, provider)
  96. extra_kwargs.update(kwargs)
  97. elif function_name == TOOL_NAMES["BUCKET"]:
  98. messages = ToolHandler.process_bucket_tool(messages, tool)
  99. return messages, extra_kwargs
  100. class AuthManager:
  101. """Handles API key management"""
  102. @staticmethod
  103. def get_api_key_file(cls) -> Path:
  104. """Get the path to the API key file for a provider"""
  105. return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  106. @staticmethod
  107. def load_api_key(provider: Any) -> Optional[str]:
  108. """Load API key from config file if needed"""
  109. if not getattr(provider, "needs_auth", False):
  110. return None
  111. auth_file = AuthManager.get_api_key_file(provider)
  112. try:
  113. if auth_file.exists():
  114. with auth_file.open("r") as f:
  115. auth_result = json.load(f)
  116. return auth_result.get("api_key")
  117. except (json.JSONDecodeError, PermissionError, FileNotFoundError) as e:
  118. debug.error(f"Failed to load API key: {e.__class__.__name__}: {e}")
  119. return None
  120. class ThinkingProcessor:
  121. """Processes thinking chunks"""
  122. @staticmethod
  123. def process_thinking_chunk(chunk: str, start_time: float = 0) -> Tuple[float, List[Union[str, Reasoning]]]:
  124. """Process a thinking chunk and return timing and results."""
  125. results = []
  126. # Handle non-thinking chunk
  127. if not start_time and "<think>" not in chunk and "</think>" not in chunk:
  128. return 0, [chunk]
  129. # Handle thinking start
  130. if "<think>" in chunk and "`<think>`" not in chunk:
  131. before_think, *after = chunk.split("<think>", 1)
  132. if before_think:
  133. results.append(before_think)
  134. results.append(Reasoning(status="🤔 Is thinking...", is_thinking="<think>"))
  135. if after:
  136. if "</think>" in after[0]:
  137. after, *after_end = after[0].split("</think>", 1)
  138. results.append(Reasoning(after))
  139. results.append(Reasoning(status="Finished", is_thinking="</think>"))
  140. if after_end:
  141. results.append(after_end[0])
  142. return 0, results
  143. else:
  144. results.append(Reasoning(after[0]))
  145. return time.time(), results
  146. # Handle thinking end
  147. if "</think>" in chunk:
  148. before_end, *after = chunk.split("</think>", 1)
  149. if before_end:
  150. results.append(Reasoning(before_end))
  151. thinking_duration = time.time() - start_time if start_time > 0 else 0
  152. status = f"Thought for {thinking_duration:.2f}s" if thinking_duration > 1 else "Finished"
  153. results.append(Reasoning(status=status, is_thinking="</think>"))
  154. # Make sure to handle text after the closing tag
  155. if after and after[0].strip():
  156. results.append(after[0])
  157. return 0, results
  158. # Handle ongoing thinking
  159. if start_time:
  160. return start_time, [Reasoning(chunk)]
  161. return start_time, [chunk]
  162. async def perform_web_search(messages: Messages, web_search_param: Any) -> Tuple[Messages, Optional[Sources]]:
  163. """Perform web search and return updated messages and sources"""
  164. messages = messages.copy()
  165. sources = None
  166. if not web_search_param:
  167. return messages, sources
  168. try:
  169. search_query = web_search_param if isinstance(web_search_param, str) and web_search_param != "true" else None
  170. messages[-1]["content"], sources = await do_search(messages[-1]["content"], search_query)
  171. except Exception as e:
  172. debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  173. return messages, sources
  174. async def async_iter_run_tools(
  175. provider: ProviderType,
  176. model: str,
  177. messages: Messages,
  178. tool_calls: Optional[List[dict]] = None,
  179. **kwargs
  180. ) -> AsyncIterator:
  181. """Asynchronously run tools and yield results"""
  182. # Process web search
  183. sources = None
  184. web_search = kwargs.get('web_search')
  185. if web_search:
  186. messages, sources = await perform_web_search(messages, web_search)
  187. # Get API key if needed
  188. api_key = AuthManager.load_api_key(provider)
  189. if api_key and "api_key" not in kwargs:
  190. kwargs["api_key"] = api_key
  191. # Process tool calls
  192. if tool_calls:
  193. messages, extra_kwargs = await ToolHandler.process_tools(messages, tool_calls, provider)
  194. kwargs.update(extra_kwargs)
  195. # Generate response
  196. create_function = provider.get_async_create_function()
  197. response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
  198. async for chunk in response:
  199. yield chunk
  200. # Yield sources if available
  201. if sources:
  202. yield sources
  203. def iter_run_tools(
  204. iter_callback: Callable,
  205. model: str,
  206. messages: Messages,
  207. provider: Optional[str] = None,
  208. tool_calls: Optional[List[dict]] = None,
  209. **kwargs
  210. ) -> Iterator:
  211. """Run tools synchronously and yield results"""
  212. # Process web search
  213. web_search = kwargs.get('web_search')
  214. sources = None
  215. if web_search:
  216. try:
  217. messages = messages.copy()
  218. search_query = web_search if isinstance(web_search, str) and web_search != "true" else None
  219. # Note: Using asyncio.run inside sync function is not ideal, but maintaining original pattern
  220. messages[-1]["content"], sources = asyncio.run(do_search(messages[-1]["content"], search_query))
  221. except Exception as e:
  222. debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  223. # Get API key if needed
  224. if provider is not None and getattr(provider, "needs_auth", False) and "api_key" not in kwargs:
  225. api_key = AuthManager.load_api_key(provider)
  226. if api_key:
  227. kwargs["api_key"] = api_key
  228. # Process tool calls
  229. if tool_calls:
  230. for tool in tool_calls:
  231. if tool.get("type") == "function":
  232. function_name = tool.get("function", {}).get("name")
  233. if function_name == TOOL_NAMES["SEARCH"]:
  234. tool["function"]["arguments"] = ToolHandler.validate_arguments(tool["function"])
  235. messages[-1]["content"] = get_search_message(
  236. messages[-1]["content"],
  237. raise_search_exceptions=True,
  238. **tool["function"]["arguments"]
  239. )
  240. elif function_name == TOOL_NAMES["CONTINUE"]:
  241. if provider not in ("OpenaiAccount", "HuggingFace"):
  242. last_line = messages[-1]["content"].strip().splitlines()[-1]
  243. content = f"Carry on from this point:\n{last_line}"
  244. messages.append({"role": "user", "content": content})
  245. else:
  246. # Enable provider native continue
  247. kwargs["action"] = "continue"
  248. elif function_name == TOOL_NAMES["BUCKET"]:
  249. def on_bucket(match):
  250. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  251. has_bucket = False
  252. for message in messages:
  253. if "content" in message and isinstance(message["content"], str):
  254. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  255. if new_message_content != message["content"]:
  256. has_bucket = True
  257. message["content"] = new_message_content
  258. last_message = messages[-1]["content"]
  259. if has_bucket and isinstance(last_message, str):
  260. if "\nSource: " in last_message:
  261. messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
  262. # Process response chunks
  263. thinking_start_time = 0
  264. processor = ThinkingProcessor()
  265. for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
  266. if isinstance(chunk, FinishReason):
  267. if sources is not None:
  268. yield sources
  269. sources = None
  270. yield chunk
  271. continue
  272. elif isinstance(chunk, Sources):
  273. sources = None
  274. if not isinstance(chunk, str):
  275. yield chunk
  276. continue
  277. thinking_start_time, results = processor.process_thinking_chunk(chunk, thinking_start_time)
  278. for result in results:
  279. yield result
  280. if sources is not None:
  281. yield sources