run_tools.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from __future__ import annotations
  2. import re
  3. import json
  4. import asyncio
  5. from typing import Optional, Callable, AsyncIterator
  6. from ..typing import Messages
  7. from ..providers.helper import filter_none
  8. from ..providers.asyncio import to_async_iterator
  9. from .web_search import do_search, get_search_message
  10. from .files import read_bucket, get_bucket_dir
  11. from .. import debug
  12. BUCKET_INSTRUCTIONS = """
  13. Instruction: Make sure to add the sources of cites using [[domain]](Url) notation after the reference. Example: [[a-z0-9.]](http://example.com)
  14. """
  15. def validate_arguments(data: dict) -> dict:
  16. if "arguments" in data:
  17. if isinstance(data["arguments"], str):
  18. data["arguments"] = json.loads(data["arguments"])
  19. if not isinstance(data["arguments"], dict):
  20. raise ValueError("Tool function arguments must be a dictionary or a json string")
  21. else:
  22. return filter_none(**data["arguments"])
  23. else:
  24. return {}
  25. async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
  26. # Handle web_search from kwargs
  27. web_search = kwargs.get('web_search')
  28. if web_search:
  29. try:
  30. messages = messages.copy()
  31. web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
  32. messages[-1]["content"] = await do_search(messages[-1]["content"], web_search)
  33. except Exception as e:
  34. debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  35. # Keep web_search in kwargs for provider native support
  36. pass
  37. if tool_calls is not None:
  38. for tool in tool_calls:
  39. if tool.get("type") == "function":
  40. if tool.get("function", {}).get("name") == "search_tool":
  41. tool["function"]["arguments"] = validate_arguments(tool["function"])
  42. messages = messages.copy()
  43. messages[-1]["content"] = await do_search(
  44. messages[-1]["content"],
  45. **tool["function"]["arguments"]
  46. )
  47. elif tool.get("function", {}).get("name") == "continue":
  48. last_line = messages[-1]["content"].strip().splitlines()[-1]
  49. content = f"Carry on from this point:\n{last_line}"
  50. messages.append({"role": "user", "content": content})
  51. elif tool.get("function", {}).get("name") == "bucket_tool":
  52. def on_bucket(match):
  53. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  54. has_bucket = False
  55. for message in messages:
  56. if "content" in message and isinstance(message["content"], str):
  57. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  58. if new_message_content != message["content"]:
  59. has_bucket = True
  60. message["content"] = new_message_content
  61. if has_bucket and isinstance(messages[-1]["content"], str):
  62. messages[-1]["content"] += BUCKET_INSTRUCTIONS
  63. response = to_async_iterator(async_iter_callback(model=model, messages=messages, **kwargs))
  64. async for chunk in response:
  65. yield chunk
  66. def iter_run_tools(
  67. iter_callback: Callable,
  68. model: str,
  69. messages: Messages,
  70. provider: Optional[str] = None,
  71. tool_calls: Optional[list] = None,
  72. **kwargs
  73. ) -> AsyncIterator:
  74. # Handle web_search from kwargs
  75. web_search = kwargs.get('web_search')
  76. if web_search:
  77. try:
  78. messages = messages.copy()
  79. web_search = web_search if isinstance(web_search, str) and web_search != "true" else None
  80. messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], web_search))
  81. except Exception as e:
  82. debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  83. # Keep web_search in kwargs for provider native support
  84. pass
  85. if tool_calls is not None:
  86. for tool in tool_calls:
  87. if tool.get("type") == "function":
  88. if tool.get("function", {}).get("name") == "search_tool":
  89. tool["function"]["arguments"] = validate_arguments(tool["function"])
  90. messages[-1]["content"] = get_search_message(
  91. messages[-1]["content"],
  92. raise_search_exceptions=True,
  93. **tool["function"]["arguments"]
  94. )
  95. elif tool.get("function", {}).get("name") == "continue_tool":
  96. if provider not in ("OpenaiAccount", "HuggingFace"):
  97. last_line = messages[-1]["content"].strip().splitlines()[-1]
  98. content = f"Carry on from this point:\n{last_line}"
  99. messages.append({"role": "user", "content": content})
  100. else:
  101. # Enable provider native continue
  102. if "action" not in kwargs:
  103. kwargs["action"] = "continue"
  104. elif tool.get("function", {}).get("name") == "bucket_tool":
  105. def on_bucket(match):
  106. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  107. has_bucket = False
  108. for message in messages:
  109. if "content" in message and isinstance(message["content"], str):
  110. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  111. if new_message_content != message["content"]:
  112. has_bucket = True
  113. message["content"] = new_message_content
  114. if has_bucket and isinstance(messages[-1]["content"], str):
  115. messages[-1]["content"] += BUCKET_INSTRUCTIONS
  116. return iter_callback(model=model, messages=messages, provider=provider, **kwargs)