tool_support.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from __future__ import annotations
  2. import json
  3. from ..typing import AsyncResult, Messages, MediaListType
  4. from ..client.service import get_model_and_provider
  5. from ..client.helper import filter_json
  6. from .base_provider import AsyncGeneratorProvider
  7. from .response import ToolCalls, FinishReason, Usage
  8. class ToolSupportProvider(AsyncGeneratorProvider):
  9. working = True
  10. @classmethod
  11. async def create_async_generator(
  12. cls,
  13. model: str,
  14. messages: Messages,
  15. stream: bool = True,
  16. media: MediaListType = None,
  17. tools: list[str] = None,
  18. response_format: dict = None,
  19. **kwargs
  20. ) -> AsyncResult:
  21. provider = None
  22. if ":" in model:
  23. provider, model = model.split(":", 1)
  24. model, provider = get_model_and_provider(
  25. model, provider,
  26. stream, logging=False,
  27. has_images=media is not None
  28. )
  29. if tools is not None:
  30. if len(tools) > 1:
  31. raise ValueError("Only one tool is supported.")
  32. if response_format is None:
  33. response_format = {"type": "json"}
  34. tools = tools.pop()
  35. lines = ["Respone in JSON format."]
  36. properties = tools["function"]["parameters"]["properties"]
  37. properties = {key: value["type"] for key, value in properties.items()}
  38. lines.append(f"Response format: {json.dumps(properties, indent=2)}")
  39. messages = [{"role": "user", "content": "\n".join(lines)}] + messages
  40. finish = None
  41. chunks = []
  42. has_usage = False
  43. async for chunk in provider.get_async_create_function()(
  44. model,
  45. messages,
  46. stream=stream,
  47. media=media,
  48. response_format=response_format,
  49. **kwargs
  50. ):
  51. if isinstance(chunk, str):
  52. chunks.append(chunk)
  53. elif isinstance(chunk, Usage):
  54. yield chunk
  55. has_usage = True
  56. elif isinstance(chunk, FinishReason):
  57. finish = chunk
  58. break
  59. else:
  60. yield chunk
  61. if not has_usage:
  62. yield Usage(completion_tokens=len(chunks), total_tokens=len(chunks))
  63. chunks = "".join(chunks)
  64. if tools is not None:
  65. yield ToolCalls([{
  66. "id": "",
  67. "type": "function",
  68. "function": {
  69. "name": tools["function"]["name"],
  70. "arguments": filter_json(chunks)
  71. }
  72. }])
  73. yield chunks
  74. if finish is not None:
  75. yield finish