main.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from langchain.document_loaders import OnlinePDFLoader
  2. from langchain.vectorstores import Chroma
  3. from langchain.embeddings import GPT4AllEmbeddings
  4. from langchain import PromptTemplate
  5. from langchain.llms import Ollama
  6. from langchain.callbacks.manager import CallbackManager
  7. from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
  8. from langchain.chains import RetrievalQA
  9. import sys
  10. import os
  11. class SuppressStdout:
  12. def __enter__(self):
  13. self._original_stdout = sys.stdout
  14. self._original_stderr = sys.stderr
  15. sys.stdout = open(os.devnull, 'w')
  16. sys.stderr = open(os.devnull, 'w')
  17. def __exit__(self, exc_type, exc_val, exc_tb):
  18. sys.stdout.close()
  19. sys.stdout = self._original_stdout
  20. sys.stderr = self._original_stderr
  21. # load the pdf and split it into chunks
  22. loader = OnlinePDFLoader("https://d18rn0p25nwr6d.cloudfront.net/CIK-0001813756/975b3e9b-268e-4798-a9e4-2a9a7c92dc10.pdf")
  23. data = loader.load()
  24. from langchain.text_splitter import RecursiveCharacterTextSplitter
  25. text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=0)
  26. all_splits = text_splitter.split_documents(data)
  27. with SuppressStdout():
  28. vectorstore = Chroma.from_documents(documents=all_splits, embedding=GPT4AllEmbeddings())
  29. while True:
  30. query = input("\nQuery: ")
  31. if query == "exit":
  32. break
  33. if query.strip() == "":
  34. continue
  35. # Prompt
  36. template = """Use the following pieces of context to answer the question at the end.
  37. If you don't know the answer, just say that you don't know, don't try to make up an answer.
  38. Use three sentences maximum and keep the answer as concise as possible.
  39. {context}
  40. Question: {question}
  41. Helpful Answer:"""
  42. QA_CHAIN_PROMPT = PromptTemplate(
  43. input_variables=["context", "question"],
  44. template=template,
  45. )
  46. llm = Ollama(model="llama3:8b", callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))
  47. qa_chain = RetrievalQA.from_chain_type(
  48. llm,
  49. retriever=vectorstore.as_retriever(),
  50. chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
  51. )
  52. result = qa_chain({"query": query})