run_tools.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. from __future__ import annotations
  2. import re
  3. import json
  4. import asyncio
  5. import time
  6. from pathlib import Path
  7. from typing import Optional, Callable, AsyncIterator, Iterator, Dict, Any, Tuple, List, Union
  8. from ..typing import Messages
  9. from ..providers.helper import filter_none
  10. from ..providers.asyncio import to_async_iterator
  11. from ..providers.response import Reasoning, FinishReason, Sources
  12. from ..providers.types import ProviderType
  13. from ..cookies import get_cookies_dir
  14. from .web_search import do_search, get_search_message
  15. from .files import read_bucket, get_bucket_dir
  16. from .. import debug
  17. # Constants
  18. BUCKET_INSTRUCTIONS = """
  19. Instruction: Make sure to add the sources of cites using [[domain]](Url) notation after the reference. Example: [[a-z0-9.]](http://example.com)
  20. """
  21. TOOL_NAMES = {
  22. "SEARCH": "search_tool",
  23. "CONTINUE": "continue_tool",
  24. "BUCKET": "bucket_tool"
  25. }
  26. class ToolHandler:
  27. """Handles processing of different tool types"""
  28. @staticmethod
  29. def validate_arguments(data: dict) -> dict:
  30. """Validate and parse tool arguments"""
  31. if "arguments" in data:
  32. if isinstance(data["arguments"], str):
  33. data["arguments"] = json.loads(data["arguments"])
  34. if not isinstance(data["arguments"], dict):
  35. raise ValueError("Tool function arguments must be a dictionary or a json string")
  36. else:
  37. return filter_none(**data["arguments"])
  38. else:
  39. return {}
  40. @staticmethod
  41. async def process_search_tool(messages: Messages, tool: dict) -> Messages:
  42. """Process search tool requests"""
  43. messages = messages.copy()
  44. args = ToolHandler.validate_arguments(tool["function"])
  45. messages[-1]["content"], sources = await do_search(
  46. messages[-1]["content"],
  47. **args
  48. )
  49. return messages, sources
  50. @staticmethod
  51. def process_continue_tool(messages: Messages, tool: dict, provider: Any) -> Tuple[Messages, Dict[str, Any]]:
  52. """Process continue tool requests"""
  53. kwargs = {}
  54. if provider not in ("OpenaiAccount", "HuggingFaceAPI"):
  55. messages = messages.copy()
  56. last_line = messages[-1]["content"].strip().splitlines()[-1]
  57. content = f"Carry on from this point:\n{last_line}"
  58. messages.append({"role": "user", "content": content})
  59. else:
  60. # Enable provider native continue
  61. kwargs["action"] = "continue"
  62. return messages, kwargs
  63. @staticmethod
  64. def process_bucket_tool(messages: Messages, tool: dict) -> Messages:
  65. """Process bucket tool requests"""
  66. messages = messages.copy()
  67. def on_bucket(match):
  68. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  69. has_bucket = False
  70. for message in messages:
  71. if "content" in message and isinstance(message["content"], str):
  72. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  73. if new_message_content != message["content"]:
  74. has_bucket = True
  75. message["content"] = new_message_content
  76. last_message_content = messages[-1]["content"]
  77. if has_bucket and isinstance(last_message_content, str):
  78. if "\nSource: " in last_message_content:
  79. messages[-1]["content"] = last_message_content + BUCKET_INSTRUCTIONS
  80. return messages
  81. @staticmethod
  82. async def process_tools(messages: Messages, tool_calls: List[dict], provider: Any) -> Tuple[Messages, Dict[str, Any]]:
  83. """Process all tool calls and return updated messages and kwargs"""
  84. if not tool_calls:
  85. return messages, {}
  86. extra_kwargs = {}
  87. messages = messages.copy()
  88. sources = None
  89. for tool in tool_calls:
  90. if tool.get("type") != "function":
  91. continue
  92. function_name = tool.get("function", {}).get("name")
  93. if function_name == TOOL_NAMES["SEARCH"]:
  94. messages, sources = await ToolHandler.process_search_tool(messages, tool)
  95. elif function_name == TOOL_NAMES["CONTINUE"]:
  96. messages, kwargs = ToolHandler.process_continue_tool(messages, tool, provider)
  97. extra_kwargs.update(kwargs)
  98. elif function_name == TOOL_NAMES["BUCKET"]:
  99. messages = ToolHandler.process_bucket_tool(messages, tool)
  100. return messages, sources, extra_kwargs
  101. class AuthManager:
  102. """Handles API key management"""
  103. @staticmethod
  104. def get_api_key_file(cls) -> Path:
  105. """Get the path to the API key file for a provider"""
  106. return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  107. @staticmethod
  108. def load_api_key(provider: Any) -> Optional[str]:
  109. """Load API key from config file if needed"""
  110. if not getattr(provider, "needs_auth", False):
  111. return None
  112. auth_file = AuthManager.get_api_key_file(provider)
  113. try:
  114. if auth_file.exists():
  115. with auth_file.open("r") as f:
  116. auth_result = json.load(f)
  117. return auth_result.get("api_key")
  118. except (json.JSONDecodeError, PermissionError, FileNotFoundError) as e:
  119. debug.error(f"Failed to load API key: {e.__class__.__name__}: {e}")
  120. return None
  121. class ThinkingProcessor:
  122. """Processes thinking chunks"""
  123. @staticmethod
  124. def process_thinking_chunk(chunk: str, start_time: float = 0) -> Tuple[float, List[Union[str, Reasoning]]]:
  125. """Process a thinking chunk and return timing and results."""
  126. results = []
  127. # Handle non-thinking chunk
  128. if not start_time and "<think>" not in chunk and "</think>" not in chunk:
  129. return 0, [chunk]
  130. # Handle thinking start
  131. if "<think>" in chunk and "`<think>`" not in chunk:
  132. before_think, *after = chunk.split("<think>", 1)
  133. if before_think:
  134. results.append(before_think)
  135. results.append(Reasoning(status="🤔 Is thinking...", is_thinking="<think>"))
  136. if after:
  137. if "</think>" in after[0]:
  138. after, *after_end = after[0].split("</think>", 1)
  139. results.append(Reasoning(after))
  140. results.append(Reasoning(status="Finished", is_thinking="</think>"))
  141. if after_end:
  142. results.append(after_end[0])
  143. return 0, results
  144. else:
  145. results.append(Reasoning(after[0]))
  146. return time.time(), results
  147. # Handle thinking end
  148. if "</think>" in chunk:
  149. before_end, *after = chunk.split("</think>", 1)
  150. if before_end:
  151. results.append(Reasoning(before_end))
  152. thinking_duration = time.time() - start_time if start_time > 0 else 0
  153. status = f"Thought for {thinking_duration:.2f}s" if thinking_duration > 1 else "Finished"
  154. results.append(Reasoning(status=status, is_thinking="</think>"))
  155. # Make sure to handle text after the closing tag
  156. if after and after[0].strip():
  157. results.append(after[0])
  158. return 0, results
  159. # Handle ongoing thinking
  160. if start_time:
  161. return start_time, [Reasoning(chunk)]
  162. return start_time, [chunk]
  163. async def perform_web_search(messages: Messages, web_search_param: Any) -> Tuple[Messages, Optional[Sources]]:
  164. """Perform web search and return updated messages and sources"""
  165. messages = messages.copy()
  166. sources = None
  167. if not web_search_param:
  168. return messages, sources
  169. try:
  170. search_query = web_search_param if isinstance(web_search_param, str) and web_search_param != "true" else None
  171. messages[-1]["content"], sources = await do_search(messages[-1]["content"], search_query)
  172. except Exception as e:
  173. debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  174. return messages, sources
  175. async def async_iter_run_tools(
  176. provider: ProviderType,
  177. model: str,
  178. messages: Messages,
  179. tool_calls: Optional[List[dict]] = None,
  180. **kwargs
  181. ) -> AsyncIterator:
  182. """Asynchronously run tools and yield results"""
  183. # Process web search
  184. sources = None
  185. web_search = kwargs.get('web_search')
  186. if web_search:
  187. messages, sources = await perform_web_search(messages, web_search)
  188. # Get API key if needed
  189. api_key = AuthManager.load_api_key(provider)
  190. if api_key and "api_key" not in kwargs:
  191. kwargs["api_key"] = api_key
  192. # Process tool calls
  193. if tool_calls:
  194. messages, sources, extra_kwargs = await ToolHandler.process_tools(messages, tool_calls, provider)
  195. kwargs.update(extra_kwargs)
  196. # Generate response
  197. create_function = provider.get_async_create_function()
  198. response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
  199. async for chunk in response:
  200. yield chunk
  201. # Yield sources if available
  202. if sources:
  203. yield sources
  204. def iter_run_tools(
  205. iter_callback: Callable,
  206. model: str,
  207. messages: Messages,
  208. provider: Optional[str] = None,
  209. tool_calls: Optional[List[dict]] = None,
  210. **kwargs
  211. ) -> Iterator:
  212. """Run tools synchronously and yield results"""
  213. # Process web search
  214. web_search = kwargs.get('web_search')
  215. sources = None
  216. if web_search:
  217. try:
  218. messages = messages.copy()
  219. search_query = web_search if isinstance(web_search, str) and web_search != "true" else None
  220. # Note: Using asyncio.run inside sync function is not ideal, but maintaining original pattern
  221. messages[-1]["content"], sources = asyncio.run(do_search(messages[-1]["content"], search_query))
  222. except Exception as e:
  223. debug.error(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  224. # Get API key if needed
  225. if provider is not None and getattr(provider, "needs_auth", False) and "api_key" not in kwargs:
  226. api_key = AuthManager.load_api_key(provider)
  227. if api_key:
  228. kwargs["api_key"] = api_key
  229. # Process tool calls
  230. if tool_calls:
  231. for tool in tool_calls:
  232. if tool.get("type") == "function":
  233. function_name = tool.get("function", {}).get("name")
  234. if function_name == TOOL_NAMES["SEARCH"]:
  235. tool["function"]["arguments"] = ToolHandler.validate_arguments(tool["function"])
  236. messages[-1]["content"] = get_search_message(
  237. messages[-1]["content"],
  238. raise_search_exceptions=True,
  239. **tool["function"]["arguments"]
  240. )
  241. elif function_name == TOOL_NAMES["CONTINUE"]:
  242. if provider not in ("OpenaiAccount", "HuggingFace"):
  243. last_line = messages[-1]["content"].strip().splitlines()[-1]
  244. content = f"Carry on from this point:\n{last_line}"
  245. messages.append({"role": "user", "content": content})
  246. else:
  247. # Enable provider native continue
  248. kwargs["action"] = "continue"
  249. elif function_name == TOOL_NAMES["BUCKET"]:
  250. def on_bucket(match):
  251. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  252. has_bucket = False
  253. for message in messages:
  254. if "content" in message and isinstance(message["content"], str):
  255. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  256. if new_message_content != message["content"]:
  257. has_bucket = True
  258. message["content"] = new_message_content
  259. last_message = messages[-1]["content"]
  260. if has_bucket and isinstance(last_message, str):
  261. if "\nSource: " in last_message:
  262. messages[-1]["content"] = last_message + BUCKET_INSTRUCTIONS
  263. # Process response chunks
  264. thinking_start_time = 0
  265. processor = ThinkingProcessor()
  266. for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
  267. if isinstance(chunk, FinishReason):
  268. if sources is not None:
  269. yield sources
  270. sources = None
  271. yield chunk
  272. continue
  273. elif isinstance(chunk, Sources):
  274. sources = None
  275. if not isinstance(chunk, str):
  276. yield chunk
  277. continue
  278. thinking_start_time, results = processor.process_thinking_chunk(chunk, thinking_start_time)
  279. for result in results:
  280. yield result
  281. if sources is not None:
  282. yield sources