forked from Data-drone/ANZ_LLM_Bootcamp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
3_Advanced_RAG.py
341 lines (238 loc) · 10.1 KB
/
3_Advanced_RAG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
# Databricks notebook source
# MAGIC %md
# MAGIC # Building a document store
# MAGIC We will now build out a larger document store persist and use that
# COMMAND ----------
# DBTITLE 1,Extra Libs to install
# MAGIC %pip install pypdf sentence_transformers pymupdf ctransformers
# COMMAND ----------
dbutils.library.restartPython()
# COMMAND ----------
import glob
import re
import os
import chromadb
from chromadb.config import Settings
from langchain.text_splitter import CharacterTextSplitter
from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
# COMMAND ----------
# DBTITLE 1,Setup dbfs folder paths
%run ./utils
# COMMAND ----------
# MAGIC %md
# MAGIC # Create Document Store
# MAGIC The document store has to be created first.
# MAGIC We need to have some sort of index and we will need to manage this ourselves.
# COMMAND ----------
# for class
#source_docs = glob.glob('/dbfs/bootcamp_data/pdf_data/*.pdf')
source_docs = glob.glob(dbfs_source_docs+'/*.pdf')
collection_name = 'arxiv_articles'
# We will use default HuggingFaceEmbeddings for now
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-mpnet-base-v2',
model_kwargs={'device': 'cpu'})
def embed_fn(text):
hfe = HuggingFaceEmbeddings()
return hfe.embed_documents(text)
# setup Chroma client with persistence
client = chromadb.chromadb.Client(Settings(chroma_db_impl="duckdb+parquet",
persist_directory=linux_vector_store_directory),
)
rebuild = True
# COMMAND ----------
# MAGIC %md
# MAGIC # Build ChromaDB
# MAGIC See chroma docs for more information
# COMMAND ----------
if rebuild:
dbutils.fs.rm(f'dbfs:{linux_vector_store_directory}', True)
# COMMAND ----------
# Initiate the ChromaDB
# Create collection. get_collection, get_or_create_collection, delete_collection also available!
## Colection is where we set embeddings? # embedding_function=embed_fn
collection = client.get_or_create_collection(name=collection_name)
print(f"we have {collection.count()} in the collection.")
# COMMAND ----------
# DBTITLE 1,Collection Building Function
# we can look at other splitters later.
# Probably Paragraph? And / Or Sentence?
def collection_builder(source_docs:list,
collection:chromadb.api.models.Collection.Collection):
assert collection.count() == 0, "WARNING This function will append to collection regardless of whether it already exists or not"
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0)
# we will process page by page
for doc in source_docs:
# This regex will only work for arxiv
match = re.search(r'/([\d.]+)\.pdf$', doc)
article_number = match.group(1)
loader = PyMuPDFLoader(doc)
pages = loader.load_and_split()
# for page in pages:
#print(type(page))
texts = text_splitter.split_documents(pages)
#print(texts)
# print(len(texts))
doc_list = [x.page_content for x in texts]
embed_list = embeddings.embed_documents(doc_list)
collection.add(
documents=doc_list,
embeddings=embed_list,
metadatas=[x.metadata for x in texts],
ids=[article_number+str(texts.index(x)) for x in texts]
)
# See: https://github.com/chroma-core/chroma/issues/275
client.persist()
# COMMAND ----------
try:
collection_builder(source_docs, collection)
print(f"we now have {collection.count()} in the collection.")
except AssertionError:
print("Doing nothing, we will not rebuild the collection")
# COMMAND ----------
# MAGIC %md
# MAGIC # Setup LLM to interface with chroma DB
# MAGIC NOTE that reloading with langchain seems glitchy hence why we need to do it manually
# COMMAND ----------
# Load the collection
# we reuse the previous client and embeddings
docsource = Chroma(collection_name=collection_name,
persist_directory=linux_vector_store_directory,
embedding_function=embeddings)
# we can verify that our docsearch index has objects in it with this
print('The index includes: {} documents'.format(docsource._collection.count()))
# COMMAND ----------
# MAGIC %md
# MAGIC Note that the llm_model funciton doesn't clean up after itself. so if you call it repeatedly it will fill up the VRAM
# MAGIC
# MAGIC We will add some code to quickly stop reinitiating
# MAGIC In order to understand the HuggingFace Pipeline we need to look at:
# MAGIC - https://huggingface.co/docs/transformers/main_classes/pipelines
# MAGIC The task set for this pipe is text-generation the def of this is:
# MAGIC - https://huggingface.co/docs/transformers/main_classes/pipelines#transformers.TextGenerationPipeline
# MAGIC Device needs to be set in order to utilise GPU
# MAGIC - See: https://huggingface.co/transformers/v3.0.2/main_classes/pipelines.html#transformers.Pipeline
# COMMAND ----------
## One problem with the library at the moment is that GPU ram doesn't get relinquished when the object is overridden
# The only way to clear GPU ram is to detach and reattach
# This snippet will make sure we don't keep reloading the model and running out of GPU ram
try:
llm_model
except NameError:
if run_mode == 'cpu':
# the cTransformers class interfaces with langchain differently
from ctransformers.langchain import CTransformers
llm_model = CTransformers(model='TheBloke/Llama-2-7B-Chat-GGML', model_type='llama')
elif run_mode == 'gpu':
pipe = load_model(run_mode, dbfs_tmp_cache)
llm_model = HuggingFacePipeline(pipeline=pipe)
else:
pass
# COMMAND ----------
# MAGIC %md
# MAGIC Before we used `RetrievalQA` that doesn't have a concept of memory
# MAGIC We can add in memory and use the `ConversationalRetrievalChain` Chain instead
# COMMAND ----------
# DBTITLE 1,Setting up prompt template
from langchain import PromptTemplate
system_template = """<s>[INST] <<SYS>>
As a helpful assistant, answer questions from users but be polite and concise. If you don't know say I don't know.
<</SYS>>
Based on the following context:
{context}
Answer the following question:
{question}[/INST]
"""
# prompt templates in langchain need the input variables specified it can then be loaded in the string
# Note that the names of the input_variables are particular to the chain type.
friendly_template = PromptTemplate(
input_variables=["question", "context"], template=system_template
)
# COMMAND ----------
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain import LLMChain
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer')
# Broken at the moment
memory_chain = ConversationalRetrievalChain.from_llm(llm=llm_model,
retriever=docsource.as_retriever(search_kwargs={"k": 2}),
chain_type='stuff',
return_source_documents=True,
output_key='answer',
verbose=True,
combine_docs_chain_kwargs={"prompt": friendly_template},
memory=memory,
get_chat_history=lambda h : h)
# COMMAND ----------
# DBTITLE 1,Verify docsource is valid
# Basic Vector Similarity Search
query = "What is this is a token limit?"
query_embed = embeddings.embed_query(query)
docsource._collection.query(query_embeddings=query_embed, n_results=2)
# COMMAND ----------
# Lets test out querying
# Something is wrong with the similarity search? Are my embeddings not saving?
# Also the docsource has a different embedding structure (vectors don't line up)
query_embed = embeddings.embed_query(query)
query_embed
docs = docsource.similarity_search_by_vector(query_embed)
# COMMAND ----------
memory_chain({"question": query}, return_only_outputs=True)
# COMMAND ----------
query = 'tell me more!'
memory_chain({"question": query}, return_only_outputs=True)
# COMMAND ----------
# MAGIC %md
# MAGIC # Adding Human Feedback
# MAGIC **EXPERIMENTAL**
# MAGIC Now if only we could add in human in the loop reasoning and make the chain more intelligent that way
# MAGIC
# MAGIC We can try agents
# COMMAND ----------
# The conversation and memory doesn't occur at the retreival stage so lets use the old RetrievalQA
from langchain import PromptTemplate
system_template = """<s>[INST] <<SYS>>
As a helpful assistant, answer questions from users but be polite and concise. If you don't know say I don't know.
<</SYS>>
Based on the following context:
{context}
Answer the following question:
{question}[/INST]
"""
# prompt templates in langchain need the input variables specified it can then be loaded in the string
# Note that the names of the input_variables are particular to the chain type.
prompt_template = PromptTemplate(
input_variables=["question", "context"], template=system_template
)
qa = RetrievalQA.from_chain_type(llm=llm_model, chain_type="stuff",
retriever=docsource.as_retriever(search_kwargs={"k": 3}),
chain_type_kwargs={"prompt": prompt_template})
# COMMAND ----------
from langchain.agents import Tool, load_tools
# turn our qa chain into a tool
retrieval_tool = Tool(
name = 'Document Search',
func = qa,
description ='this is a chain that has access to a cache of arxiv papers on deep learning and large language models'
)
tools = load_tools(
["human"],
llm=llm_model,
)
tools.append(retrieval_tool)
# COMMAND ----------
# Setup agent
from langchain.agents import initialize_agent, AgentType
agent_chain = initialize_agent(
tools,
llm_model,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)
# Mileage may vary!!
agent_chain.run("What should I ask you about llms?")