-
Notifications
You must be signed in to change notification settings - Fork 3
/
07-4-mpt-7b-complete.py
74 lines (60 loc) · 3.29 KB
/
07-4-mpt-7b-complete.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import typer
# 0xVs
# SOURCE: https://huggingface.co/TheBloke/MPT-7B-Instruct-GGML/discussions/2
from ctransformers.langchain import CTransformers
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceInstructEmbeddings
from langchain.document_loaders import DirectoryLoader
from langchain.text_splitter import SentenceTransformersTokenTextSplitter
from langchain.document_loaders import PDFPlumberLoader
# from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import CharacterTextSplitter
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from rich import print
from rich.prompt import Prompt
from langchain import HuggingFaceHub
from langchain.embeddings import HuggingFaceEmbeddings
from dotenv import load_dotenv
load_dotenv()
app = typer.Typer()
device = "cpu"
@app.command()
def import_pdfs(dir: str, embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
# loader = DirectoryLoader(dir, glob="./*.pdf", loader_cls=PDFPlumberLoader, show_progress=True)
loader = DirectoryLoader(dir, glob="./*.txt", show_progress=True)
documents = loader.load()
# text_splitter = RecursiveCharacterTextSplitter( separator = "\n", chunk_size=512, chunk_overlap=0)
text_splitter = CharacterTextSplitter( separator = "\n", chunk_size=1512, chunk_overlap=0)
docs = text_splitter.split_documents(documents)
# embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model,
# model_kwargs={"device": device})
embeddings = HuggingFaceEmbeddings()
db = FAISS.from_documents(docs, embeddings)
db.save_local("faiss_index")
@app.command()
def question(model_path: str = "./models/mpt-7b-instruct.ggmlv3.fp16.bin",
model_type='mpt',
embedding_model="sentence-transformers/all-MiniLM-L6-v2",
search_breadth : int = 5, threads : int = 6, temperature : float = 0.4):
embeddings = HuggingFaceEmbeddings()
# embeddings = HuggingFaceInstructEmbeddings(model_name=embedding_model,
# model_kwargs={"device": device})
# config = {'temperature': temperature, 'threads' : threads}
# llm = CTransformers(model=model_path, model_type=model_type, config=config)
llm=HuggingFaceHub(repo_id="tiiuae/falcon-40b-instruct", model_kwargs={"temperature":0.1, "max_length":64})
db = FAISS.load_local("faiss_index", embeddings)
retriever = db.as_retriever(search_kwargs={"k": search_breadth})
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
qa = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever,
memory=memory, return_source_documents=True, verbose=True)
while True:
query = Prompt.ask('[bright_yellow]\nQuestion[/bright_yellow] ')
res = qa({"question": query})
print("[spring_green4]"+res['answer']+"[/spring_green4]")
if "source_documents" in res:
print("\n[italic grey46]References[/italic grey46]:")
for ref in res["source_documents"]:
print("> [grey19]" + ref.metadata['source'] + "[/grey19]")
if __name__ == "__main__":
app()