run_tools.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from __future__ import annotations
  2. import re
  3. import json
  4. import asyncio
  5. from typing import Optional, Callable, AsyncIterator
  6. from ..typing import Messages
  7. from ..providers.helper import filter_none
  8. from ..providers.asyncio import to_async_iterator
  9. from .web_search import do_search, get_search_message
  10. from .files import read_bucket, get_bucket_dir
  11. from .. import debug
  12. BUCKET_INSTRUCTIONS = """
  13. Instruction: Make sure to add the sources of cites using [[domain]](Url) notation after the reference. Example: [[a-z0-9.]](http://example.com)
  14. """
  15. def validate_arguments(data: dict) -> dict:
  16. if "arguments" in data:
  17. if isinstance(data["arguments"], str):
  18. data["arguments"] = json.loads(data["arguments"])
  19. if not isinstance(data["arguments"], dict):
  20. raise ValueError("Tool function arguments must be a dictionary or a json string")
  21. else:
  22. return filter_none(**data["arguments"])
  23. else:
  24. return {}
  25. async def async_iter_run_tools(async_iter_callback, model, messages, tool_calls: Optional[list] = None, **kwargs):
  26. if tool_calls is not None:
  27. for tool in tool_calls:
  28. if tool.get("type") == "function":
  29. if tool.get("function", {}).get("name") == "search_tool":
  30. tool["function"]["arguments"] = validate_arguments(tool["function"])
  31. messages = messages.copy()
  32. messages[-1]["content"] = await do_search(
  33. messages[-1]["content"],
  34. **tool["function"]["arguments"]
  35. )
  36. elif tool.get("function", {}).get("name") == "continue":
  37. last_line = messages[-1]["content"].strip().splitlines()[-1]
  38. content = f"Carry on from this point:\n{last_line}"
  39. messages.append({"role": "user", "content": content})
  40. elif tool.get("function", {}).get("name") == "bucket_tool":
  41. def on_bucket(match):
  42. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  43. has_bucket = False
  44. for message in messages:
  45. if "content" in message and isinstance(message["content"], str):
  46. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  47. if new_message_content != message["content"]:
  48. has_bucket = True
  49. message["content"] = new_message_content
  50. if has_bucket and isinstance(messages[-1]["content"], str):
  51. messages[-1]["content"] += BUCKET_INSTRUCTIONS
  52. response = to_async_iterator(async_iter_callback(model=model, messages=messages, **kwargs))
  53. async for chunk in response:
  54. yield chunk
  55. def iter_run_tools(
  56. iter_callback: Callable,
  57. model: str,
  58. messages: Messages,
  59. provider: Optional[str] = None,
  60. tool_calls: Optional[list] = None,
  61. **kwargs
  62. ) -> AsyncIterator:
  63. if tool_calls is not None:
  64. for tool in tool_calls:
  65. if tool.get("type") == "function":
  66. if tool.get("function", {}).get("name") == "search_tool":
  67. tool["function"]["arguments"] = validate_arguments(tool["function"])
  68. messages[-1]["content"] = get_search_message(
  69. messages[-1]["content"],
  70. raise_search_exceptions=True,
  71. **tool["function"]["arguments"]
  72. )
  73. elif tool.get("function", {}).get("name") == "safe_search_tool":
  74. tool["function"]["arguments"] = validate_arguments(tool["function"])
  75. try:
  76. messages[-1]["content"] = asyncio.run(do_search(messages[-1]["content"], **tool["function"]["arguments"]))
  77. except Exception as e:
  78. debug.log(f"Couldn't do web search: {e.__class__.__name__}: {e}")
  79. # Enable provider native web search
  80. kwargs["web_search"] = True
  81. elif tool.get("function", {}).get("name") == "continue_tool":
  82. if provider not in ("OpenaiAccount", "HuggingFace"):
  83. last_line = messages[-1]["content"].strip().splitlines()[-1]
  84. content = f"Carry on from this point:\n{last_line}"
  85. messages.append({"role": "user", "content": content})
  86. else:
  87. # Enable provider native continue
  88. if "action" not in kwargs:
  89. kwargs["action"] = "continue"
  90. elif tool.get("function", {}).get("name") == "bucket_tool":
  91. def on_bucket(match):
  92. return "".join(read_bucket(get_bucket_dir(match.group(1))))
  93. has_bucket = False
  94. for message in messages:
  95. if "content" in message and isinstance(message["content"], str):
  96. new_message_content = re.sub(r'{"bucket_id":"([^"]*)"}', on_bucket, message["content"])
  97. if new_message_content != message["content"]:
  98. has_bucket = True
  99. message["content"] = new_message_content
  100. if has_bucket and isinstance(messages[-1]["content"], str):
  101. messages[-1]["content"] += BUCKET_INSTRUCTIONS
  102. return iter_callback(model=model, messages=messages, provider=provider, **kwargs)