1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677 |
- from __future__ import annotations
- import os
- from gpt4all import GPT4All
- from .models import get_models
- from ..typing import Messages
- MODEL_LIST: dict[str, dict] = None
- def find_model_dir(model_file: str) -> str:
- local_dir = os.path.dirname(os.path.abspath(__file__))
- project_dir = os.path.dirname(os.path.dirname(local_dir))
- new_model_dir = os.path.join(project_dir, "models")
- new_model_file = os.path.join(new_model_dir, model_file)
- if os.path.isfile(new_model_file):
- return new_model_dir
- old_model_dir = os.path.join(local_dir, "models")
- old_model_file = os.path.join(old_model_dir, model_file)
- if os.path.isfile(old_model_file):
- return old_model_dir
- working_dir = "./"
- for root, dirs, files in os.walk(working_dir):
- if model_file in files:
- return root
- return new_model_dir
- class LocalProvider:
- @staticmethod
- def create_completion(model: str, messages: Messages, stream: bool = False, **kwargs):
- global MODEL_LIST
- if MODEL_LIST is None:
- MODEL_LIST = get_models()
- if model not in MODEL_LIST:
- raise ValueError(f'Model "{model}" not found / not yet implemented')
- model = MODEL_LIST[model]
- model_file = model["path"]
- model_dir = find_model_dir(model_file)
- if not os.path.isfile(os.path.join(model_dir, model_file)):
- print(f'Model file "models/{model_file}" not found.')
- download = input(f"Do you want to download {model_file}? [y/n]: ")
- if download in ["y", "Y"]:
- GPT4All.download_model(model_file, model_dir)
- else:
- raise ValueError(f'Model "{model_file}" not found.')
- model = GPT4All(model_name=model_file,
- #n_threads=8,
- verbose=False,
- allow_download=False,
- model_path=model_dir)
- system_message = "\n".join(message["content"] for message in messages if message["role"] == "system")
- if system_message:
- system_message = "A chat between a curious user and an artificial intelligence assistant."
- prompt_template = "USER: {0}\nASSISTANT: "
- conversation = "\n" . join(
- f"{message['role'].upper()}: {message['content']}"
- for message in messages
- if message["role"] != "system"
- ) + "\nASSISTANT: "
- def should_not_stop(token_id: int, token: str):
- return "USER" not in token
- with model.chat_session(system_message, prompt_template):
- if stream:
- for token in model.generate(conversation, streaming=True, callback=should_not_stop):
- yield token
- else:
- yield model.generate(conversation, callback=should_not_stop)
|