stubs.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. from __future__ import annotations
  2. import os
  3. from typing import Optional, List
  4. from time import time
  5. from ..image import extract_data_uri
  6. from ..image.copy_images import get_media_dir
  7. from ..client.helper import filter_markdown
  8. from .helper import filter_none
  9. try:
  10. from pydantic import BaseModel, field_serializer
  11. except ImportError:
  12. class BaseModel():
  13. @classmethod
  14. def model_construct(cls, **data):
  15. new = cls()
  16. for key, value in data.items():
  17. setattr(new, key, value)
  18. return new
  19. class field_serializer():
  20. def __init__(self, field_name):
  21. self.field_name = field_name
  22. def __call__(self, *args, **kwargs):
  23. return args[0]
  24. class BaseModel(BaseModel):
  25. @classmethod
  26. def model_construct(cls, **data):
  27. if hasattr(super(), "model_construct"):
  28. return super().model_construct(**data)
  29. return cls.construct(**data)
  30. class TokenDetails(BaseModel):
  31. cached_tokens: int
  32. class UsageModel(BaseModel):
  33. prompt_tokens: int
  34. completion_tokens: int
  35. total_tokens: int
  36. prompt_tokens_details: TokenDetails
  37. completion_tokens_details: TokenDetails
  38. @classmethod
  39. def model_construct(cls, prompt_tokens=0, completion_tokens=0, total_tokens=0, prompt_tokens_details=None, completion_tokens_details=None, **kwargs):
  40. return super().model_construct(
  41. prompt_tokens=prompt_tokens,
  42. completion_tokens=completion_tokens,
  43. total_tokens=total_tokens,
  44. prompt_tokens_details=TokenDetails.model_construct(**prompt_tokens_details if prompt_tokens_details else {"cached_tokens": 0}),
  45. completion_tokens_details=TokenDetails.model_construct(**completion_tokens_details if completion_tokens_details else {}),
  46. **kwargs
  47. )
  48. class ToolFunctionModel(BaseModel):
  49. name: str
  50. arguments: str
  51. class ToolCallModel(BaseModel):
  52. id: str
  53. type: str
  54. function: ToolFunctionModel
  55. @classmethod
  56. def model_construct(cls, function=None, **kwargs):
  57. return super().model_construct(
  58. **kwargs,
  59. function=ToolFunctionModel.model_construct(**function),
  60. )
  61. class ChatCompletionChunk(BaseModel):
  62. id: str
  63. object: str
  64. created: int
  65. model: str
  66. provider: Optional[str]
  67. choices: List[ChatCompletionDeltaChoice]
  68. usage: UsageModel
  69. conversation: dict
  70. @classmethod
  71. def model_construct(
  72. cls,
  73. content: str,
  74. finish_reason: str,
  75. completion_id: str = None,
  76. created: int = None,
  77. usage: UsageModel = None,
  78. conversation: dict = None
  79. ):
  80. return super().model_construct(
  81. id=f"chatcmpl-{completion_id}" if completion_id else None,
  82. object="chat.completion.chunk",
  83. created=created,
  84. model=None,
  85. provider=None,
  86. choices=[ChatCompletionDeltaChoice.model_construct(
  87. ChatCompletionDelta.model_construct(content),
  88. finish_reason
  89. )],
  90. **filter_none(usage=usage, conversation=conversation)
  91. )
  92. @field_serializer('conversation')
  93. def serialize_conversation(self, conversation: dict):
  94. if hasattr(conversation, "get_dict"):
  95. return conversation.get_dict()
  96. return conversation
  97. class ChatCompletionMessage(BaseModel):
  98. role: str
  99. content: str
  100. tool_calls: list[ToolCallModel] = None
  101. @classmethod
  102. def model_construct(cls, content: str, tool_calls: list = None):
  103. return super().model_construct(role="assistant", content=content, **filter_none(tool_calls=tool_calls))
  104. @field_serializer('content')
  105. def serialize_content(self, content: str):
  106. return str(content)
  107. def save(self, filepath: str, allowed_types = None):
  108. if hasattr(self.content, "data"):
  109. os.rename(self.content.data.replace("/media", get_media_dir()), filepath)
  110. return
  111. if self.content.startswith("data:"):
  112. with open(filepath, "wb") as f:
  113. f.write(extract_data_uri(self.content))
  114. return
  115. content = filter_markdown(self.content, allowed_types)
  116. if content is not None:
  117. with open(filepath, "w") as f:
  118. f.write(content)
  119. class ChatCompletionChoice(BaseModel):
  120. index: int
  121. message: ChatCompletionMessage
  122. finish_reason: str
  123. @classmethod
  124. def model_construct(cls, message: ChatCompletionMessage, finish_reason: str):
  125. return super().model_construct(index=0, message=message, finish_reason=finish_reason)
  126. class ChatCompletion(BaseModel):
  127. id: str
  128. object: str
  129. created: int
  130. model: str
  131. provider: Optional[str]
  132. choices: list[ChatCompletionChoice]
  133. usage: UsageModel
  134. conversation: dict
  135. @classmethod
  136. def model_construct(
  137. cls,
  138. content: str,
  139. finish_reason: str,
  140. completion_id: str = None,
  141. created: int = None,
  142. tool_calls: list[ToolCallModel] = None,
  143. usage: UsageModel = None,
  144. conversation: dict = None
  145. ):
  146. return super().model_construct(
  147. id=f"chatcmpl-{completion_id}" if completion_id else None,
  148. object="chat.completion",
  149. created=created,
  150. model=None,
  151. provider=None,
  152. choices=[ChatCompletionChoice.model_construct(
  153. ChatCompletionMessage.model_construct(content, tool_calls),
  154. finish_reason,
  155. )],
  156. **filter_none(usage=usage, conversation=conversation)
  157. )
  158. @field_serializer('conversation')
  159. def serialize_conversation(self, conversation: dict):
  160. if hasattr(conversation, "get_dict"):
  161. return conversation.get_dict()
  162. return conversation
  163. class ChatCompletionDelta(BaseModel):
  164. role: str
  165. content: Optional[str]
  166. @classmethod
  167. def model_construct(cls, content: Optional[str]):
  168. return super().model_construct(role="assistant", content=content)
  169. @field_serializer('content')
  170. def serialize_content(self, content: Optional[str]):
  171. if content is None:
  172. return ""
  173. return str(content)
  174. class ChatCompletionDeltaChoice(BaseModel):
  175. index: int
  176. delta: ChatCompletionDelta
  177. finish_reason: Optional[str]
  178. @classmethod
  179. def model_construct(cls, delta: ChatCompletionDelta, finish_reason: Optional[str]):
  180. return super().model_construct(index=0, delta=delta, finish_reason=finish_reason)
  181. class Image(BaseModel):
  182. url: Optional[str]
  183. b64_json: Optional[str]
  184. revised_prompt: Optional[str]
  185. @classmethod
  186. def model_construct(cls, url: str = None, b64_json: str = None, revised_prompt: str = None):
  187. return super().model_construct(**filter_none(
  188. url=url,
  189. b64_json=b64_json,
  190. revised_prompt=revised_prompt
  191. ))
  192. def save(self, path: str):
  193. if self.url is not None and self.url.startswith("/media/"):
  194. os.rename(self.url.replace("/media", get_media_dir()), path)
  195. class ImagesResponse(BaseModel):
  196. data: List[Image]
  197. model: str
  198. provider: str
  199. created: int
  200. @classmethod
  201. def model_construct(cls, data: List[Image], created: int = None, model: str = None, provider: str = None):
  202. if created is None:
  203. created = int(time())
  204. return super().model_construct(
  205. data=data,
  206. model=model,
  207. provider=provider,
  208. created=created
  209. )