1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162 |
- from langchain.document_loaders import OnlinePDFLoader
- from langchain.vectorstores import Chroma
- from langchain.embeddings import GPT4AllEmbeddings
- from langchain import PromptTemplate
- from langchain.llms import Ollama
- from langchain.callbacks.manager import CallbackManager
- from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
- from langchain.chains import RetrievalQA
- import sys
- import os
- class SuppressStdout:
- def __enter__(self):
- self._original_stdout = sys.stdout
- self._original_stderr = sys.stderr
- sys.stdout = open(os.devnull, 'w')
- sys.stderr = open(os.devnull, 'w')
- def __exit__(self, exc_type, exc_val, exc_tb):
- sys.stdout.close()
- sys.stdout = self._original_stdout
- sys.stderr = self._original_stderr
- # load the pdf and split it into chunks
- loader = OnlinePDFLoader("https://d18rn0p25nwr6d.cloudfront.net/CIK-0001813756/975b3e9b-268e-4798-a9e4-2a9a7c92dc10.pdf")
- data = loader.load()
- from langchain.text_splitter import RecursiveCharacterTextSplitter
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
- all_splits = text_splitter.split_documents(data)
- with SuppressStdout():
- vectorstore = Chroma.from_documents(documents=all_splits, embedding=GPT4AllEmbeddings())
- while True:
- query = input("\nQuery: ")
- if query == "exit":
- break
- if query.strip() == "":
- continue
- # Prompt
- template = """Use the following pieces of context to answer the question at the end.
- If you don't know the answer, just say that you don't know, don't try to make up an answer.
- Use three sentences maximum and keep the answer as concise as possible.
- {context}
- Question: {question}
- Helpful Answer:"""
- QA_CHAIN_PROMPT = PromptTemplate(
- input_variables=["context", "question"],
- template=template,
- )
- llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
- qa_chain = RetrievalQA.from_chain_type(
- llm,
- retriever=vectorstore.as_retriever(),
- chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
- )
- result = qa_chain({"query": query})
|