run_tools.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. from __future__ import annotations
  2. import re
  3. import json
  4. import asyncio
  5. import time
  6. from pathlib import Path
  7. from typing import Optional, Callable, AsyncIterator
  8. from ..typing import Messages
  9. from ..providers.helper import filter_none
  10. from ..providers.asyncio import to_async_iterator
  11. from ..providers.response import Reasoning
  12. from ..providers.types import ProviderType
  13. from ..cookies import get_cookies_dir
  14. from .web_search import do_search, get_search_message
  15. from .files import read_bucket, get_bucket_dir
  16. from .. import debug
  17. BUCKET_INSTRUCTIONS = """
  18. Instruction: Make sure to add the sources of cites using [[domain]](Url) notation after the reference. Example: [[a-z0-9.]](http://example.com)
  19. """
  20. def validate_arguments(data: dict) -> dict:
  21. if "arguments" in data:
  22. if isinstance(data["arguments"], str):
  23. data["arguments"] = json.loads(data["arguments"])
  24. if not isinstance(data["arguments"], dict):
  25. raise ValueError("Tool function arguments must be a dictionary or a json string")
  26. else:
  27. return filter_none(**data["arguments"])
  28. else:
  29. return {}
  30. def get_api_key_file(cls) -> Path:
  31. return Path(get_cookies_dir()) / f"api_key_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  32. async def async_iter_run_tools(provider: ProviderType, model: str, messages, tool_calls: Optional[list] = None, **kwargs):
  33. # Handle web_search from kwargs
  34. web_search = kwargs.get('web_search')
  35. if web_search:
  36. try:
  37. messages = messages.copy()
  38. web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
  39. messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
  40. except Exception as e:
  41. debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  42. # Keep web_search in kwargs for provider native support
  43. pass
  44. # Read api_key from config file
  45. if getattr(provider, "needs_auth", False) and "api_key" not in kwargs:
  46. auth_file = get_api_key_file(provider)
  47. if auth_file.exists():
  48. with auth_file.open("r") as f:
  49. auth_result = json.load(f)
  50. if "api_key" in auth_result:
  51. kwargs["api_key"] = auth_result["api_key"]
  52. if tool_calls is not None:
  53. for tool in tool_calls:
  54. if tool.get("type") == "function":
  55. if tool.get("function", {}).get("name") == "search_tool":
  56. tool["function"]["arguments"] = validate_arguments(tool["function"])
  57. messages = messages.copy()
  58. messages[-1]["content"] = await do_search(
  59. messages[-1]["content"],
  60. **tool["function"]["arguments"]
  61. )
  62. elif tool.get("function", {}).get("name") == "continue":
  63. last_line = messages[-1]["content"].strip().splitlines()[-1]
  64. content = f"Carry on from this point:\n{last_line}"
  65. messages.append({"role": "user", "content": content})
  66. elif tool.get("function", {}).get("name") == "bucket_tool":
  67. def on_bucket(match):
  68. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  69. has_bucket = False
  70. for message in messages:
  71. if "content" in message and isinstance(message["content"], str):
  72. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  73. if new_message_content != message["content"]:
  74. has_bucket = True
  75. message["content"] = new_message_content
  76. if has_bucket and isinstance(messages[-1]["content"], str):
  77. messages[-1]["content"] += BUCKET_INSTRUCTIONS
  78. create_function = provider.get_async_create_function()
  79. response = to_async_iterator(create_function(model=model, messages=messages, **kwargs))
  80. async for chunk in response:
  81. yield chunk
  82. def iter_run_tools(
  83. iter_callback: Callable,
  84. model: str,
  85. messages: Messages,
  86. provider: Optional[str] = None,
  87. tool_calls: Optional[list] = None,
  88. **kwargs
  89. ) -> AsyncIterator:
  90. # Handle web_search from kwargs
  91. web_search = kwargs.get('web_search')
  92. if web_search:
  93. try:
  94. messages = messages.copy()
  95. web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
  96. messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
  97. except Exception as e:
  98. debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  99. # Keep web_search in kwargs for provider native support
  100. pass
  101. # Read api_key from config file
  102. if provider is not None and provider.needs_auth and "api_key" not in kwargs:
  103. auth_file = get_api_key_file(provider)
  104. if auth_file.exists():
  105. with auth_file.open("r") as f:
  106. auth_result = json.load(f)
  107. if "api_key" in auth_result:
  108. kwargs["api_key"] = auth_result["api_key"]
  109. if tool_calls is not None:
  110. for tool in tool_calls:
  111. if tool.get("type") == "function":
  112. if tool.get("function", {}).get("name") == "search_tool":
  113. tool["function"]["arguments"] = validate_arguments(tool["function"])
  114. messages[-1]["content"] = get_search_message(
  115. messages[-1]["content"],
  116. raise_search_exceptions=True,
  117. **tool["function"]["arguments"]
  118. )
  119. elif tool.get("function", {}).get("name") == "continue_tool":
  120. if provider not in ("OpenaiAccount", "HuggingFace"):
  121. last_line = messages[-1]["content"].strip().splitlines()[-1]
  122. content = f"Carry on from this point:\n{last_line}"
  123. messages.append({"role": "user", "content": content})
  124. else:
  125. # Enable provider native continue
  126. if "action" not in kwargs:
  127. kwargs["action"] = "continue"
  128. elif tool.get("function", {}).get("name") == "bucket_tool":
  129. def on_bucket(match):
  130. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  131. has_bucket = False
  132. for message in messages:
  133. if "content" in message and isinstance(message["content"], str):
  134. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  135. if new_message_content != message["content"]:
  136. has_bucket = True
  137. message["content"] = new_message_content
  138. if has_bucket and isinstance(messages[-1]["content"], str):
  139. messages[-1]["content"] += BUCKET_INSTRUCTIONS
  140. is_thinking = 0
  141. for chunk in iter_callback(model=model, messages=messages, provider=provider, **kwargs):
  142. if not isinstance(chunk, str):
  143. yield chunk
  144. continue
  145. if "<think>" in chunk:
  146. chunk = chunk.split("<think>", 1)
  147. yield chunk[0]
  148. yield Reasoning(is_thinking="<think>")
  149. yield Reasoning(chunk[1])
  150. yield Reasoning(None, "Is thinking...")
  151. is_thinking = time.time()
  152. if "</think>" in chunk:
  153. chunk = chunk.split("</think>", 1)
  154. yield Reasoning(chunk[0])
  155. yield Reasoning(is_thinking="</think>")
  156. yield Reasoning(None, f"Finished in {round(time.time()-is_thinking, 2)} seconds")
  157. yield chunk[1]
  158. is_thinking = 0
  159. elif is_thinking:
  160. yield Reasoning(chunk)
  161. else:
  162. yield chunk