여기에서는 Self RAG와 Corrective RAG를 결합하여 RAG의 성능을 향상시키는 방법에 대해 설명합니다. Self-Corrective RAG는 Corrective RAG처럼 Vector Store로 부터 얻어진 문서의 관련성을 확인하여 관련성이 없는 문서를 제외하고 웹 검색을 통해 결과를 보강합니다. 또한, Self RAG처럼 RAG의 결과가 환각(Hallucination)인지, 적절한 답변인지 검증하는 절차를 가지고 있습니다. 아래는 Self-Corrective RAG에 대한 acitivity diagram입니다.
- "retrieve"는 질문(question)과 관련된 문서를 Vector Store를 통해 조회합니다. 이때, "grade_generation" 동작을 위해 "web_fallback"을 True로 초기화합니다.
- "generator"는 Vector Store에서 얻어진 관련된 문서(documents)를 이용하여 답변(generation)을 생성합니다. 이때, retries count를 증가시킵니다.
- "grade_generation"은 "web_fallback"이 True이라면, "hallucination"과 "answer_question"에서 환각 및 답변의 적절성을 확인합니다. 환각일 경우에, 반복 횟수(retries)가 "max_retries"에 도달하지 않았다면 "generate"보내서 답변을 다시 생성하고, "max_retires"에 도달했다면 "websearch"로 보내서 웹 검색을 수행합니다. 또한 답변이 적절하지 않다면, 반복 횟수가 "max_reties"에 도달하기 않았다면, "rewrite"로 보내서 향상된 질문(better question)을 생성하고, 도달하였다면 "websearch"로 보내서 웹 검색을 수행합니다.
- "websearch"는 웹 검색을 통해 문서를 보강하고, "generate"에 보내서 답변을 생성합니다. 이때, "web_fallback"을 False로 설정하여 "grade_generation"에서 "finalized_response"로 보내도록 합니다.
- "rewrite"는 새로운 질문(better question)을 생성하여, "retrieve"에 전달합니다. 새로운 질문으로 전체 RAG 동작을 재수행합니다. 전체 RAG 동작은 무한 루프를 방지하기 위하여, "max_retries"만큼 수행할 수 있습니다.
- "finalize_response"는 최종 답변을 전달합니다.
상세 코드는 lambda_function.py을 참조합니다. 동작 결과는 cself-corrective-rag.ipynb에서 확인할 수 있습니다.
Self Corrective RAG를 위한 class와 환경 설정을 위한 config를 아래와 같이 정의합니다.
class SelfCorrectiveRagState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
question: str
documents: list[Document]
candidate_answer: str
retries: int
web_fallback: bool
class GraphConfig(TypedDict):
max_retries: int
환각(Hallucination)을 평가하기 위한 get_hallucination_grader()을 정의합니다.
def get_hallucination_grader():
class GradeHallucinations(BaseModel):
"""Binary score for hallucination present in generation answer."""
binary_score: str = Field(
description="Answer is grounded in the facts, 'yes' or 'no'"
)
system = """You are a grader assessing whether an LLM generation is grounded in / supported by a set of retrieved facts. \n
Give a binary score 'yes' or 'no'. 'Yes' means that the answer is grounded in / supported by the set of facts."""
hallucination_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "Set of facts: \n\n {documents} \n\n LLM generation: {generation}"),
]
)
chat = get_chat()
structured_llm_grade_hallucination = chat.with_structured_output(GradeHallucinations)
hallucination_grader = hallucination_prompt | structured_llm_grade_hallucination
return hallucination_grader
답변의 유용성을 평가하기 위한 get_answer()를 정의합니다.
def get_answer_grader():
class GradeAnswer(BaseModel):
"""Binary score to assess answer addresses question."""
binary_score: str = Field(
description="Answer addresses the question, 'yes' or 'no'"
)
chat = get_chat()
structured_llm_grade_answer = chat.with_structured_output(GradeAnswer)
system = """You are a grader assessing whether an answer addresses / resolves a question \n
Give a binary score 'yes' or 'no'. Yes' means that the answer resolves the question."""
answer_prompt = ChatPromptTemplate.from_messages(
[
("system", system),
("human", "User question: \n\n {question} \n\n LLM generation: {generation}"),
]
)
answer_grader = answer_prompt | structured_llm_grade_answer
return answer_grader
답변을 생성하기 위한 generate()를 정의합니다.
def generate_with_retires(state: CragState):
print("###### generate ######")
question = state["question"]
documents = state["documents"]
retries = state["retries"] if state.get("retries") is not None else -1
# RAG generation
rag_chain = get_reg_chain()
generation = rag_chain.invoke({"context": documents, "question": question})
print('generation: ', generation.content)
return {"documents": documents, "question": question, "generation": generation, "retries": retries + 1}
관련된 문서들에서 각 문서별로 관련도를 LLM으로 평가합니다.
def grade_documents_with_count(state: SelfRagState):
print("###### grade_documents ######")
question = state["question"]
documents = state["documents"]
count = state["count"] if state.get("count") is not None else -1
# Score each doc
filtered_docs = []
retrieval_grader = get_retrieval_grader()
for doc in documents:
score = retrieval_grader.invoke({"question": question, "document": doc.page_content})
grade = score.binary_score
# Document relevant
if grade.lower() == "yes":
print("---GRADE: DOCUMENT RELEVANT---")
filtered_docs.append(doc)
# Document not relevant
else:
print("---GRADE: DOCUMENT NOT RELEVANT---")
# We do not include the document in filtered_docs
# We set a flag to indicate that we want to run web search
continue
print('len(docments): ', len(filtered_docs))
return {"question": question, "documents": filtered_docs, "count": count + 1}
최종 답변에서 응답을 추출하기 위한 Node입니다.
def finalize_response(state: SelfCorrectiveRagState):
return {"messages": [AIMessage(content=state["candidate_answer"])]}
생성된 문서의 관련도 평가를 기준으로 적절한 동작을 수행할 수 있도록 conditinal edge를 정의합니다.
def decide_to_generate_with_retires(state: SelfRagState, config):
print("###### decide_to_generate ######")
filtered_documents = state["documents"]
count = state["count"] if state.get("count") is not None else -1
max_count = config.get("configurable", {}).get("max_counts", MAX_RETRIES)
print("count: ", count)
if not filtered_documents:
# All documents have been filtered check_relevance
# We will re-generate a new query
print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
return "no document" if count < max_count else "not available"
else:
# We have relevant documents, so generate answer
print("---DECISION: GENERATE---")
return "document"
답변이 환각인지, 유용한 답변인지 확인해서 적절한 동작을 수행하기 위한 conditional edge를 정의합니다.
def grade_generation(state: SelfRagState, config):
print("###### grade_generation ######")
question = state["question"]
documents = state["documents"]
generation = state["generation"]
retries = state["retries"] if state.get("retries") is not None else -1
max_retries = config.get("configurable", {}).get("max_retries", MAX_RETRIES)
hallucination_grader = get_hallucination_grader()
score = hallucination_grader.invoke(
{"documents": documents, "generation": generation}
)
hallucination_grade = score.binary_score
print("hallucination_grade: ", hallucination_grade)
print("retries: ", retries)
# Check hallucination
answer_grader = get_answer_grader()
if hallucination_grade == "yes":
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
# Check question-answering
print("---GRADE GENERATION vs QUESTION---")
score = answer_grader.invoke({"question": question, "generation": generation})
answer_grade = score.binary_score
print("answer_grade: ", answer_grade)
if answer_grade == "yes":
print("---DECISION: GENERATION ADDRESSES QUESTION---")
return "useful"
else:
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
return "not useful" if retries < max_retries else "not available"
else:
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
return "not supported" if retries < max_retries else "not available"
이제 Workflow를 정의하기 위한 Graph를 선언합니다.
def buildSelCorrectivefRAG():
workflow = StateGraph(SelfCorrectiveRagState)
# Define the nodes
workflow.add_node("retrieve", retrieve_for_scrag)
workflow.add_node("generate", generate_for_scrag)
workflow.add_node("rewrite", rewrite)
workflow.add_node("websearch", web_search)
workflow.add_node("finalize_response", finalize_response)
# Build graph
workflow.set_entry_point("retrieve")
workflow.add_edge("retrieve", "generate")
workflow.add_edge("rewrite", "retrieve")
workflow.add_edge("websearch", "generate")
workflow.add_edge("finalize_response", END)
workflow.add_conditional_edges(
"generate",
grade_generation,
{
"generate": "generate",
"websearch": "websearch",
"rewrite": "rewrite",
"finalize_response": "finalize_response",
},
)
# Compile
return workflow.compile()
scrag_app = buildSelfRAG()
이때 생성되는 Graph는 아래와 같습니다.
아래와 같이 Self Corrective RAG를 실행합니다.
def run_self_corrective_rag(connectionId, requestId, app, query):
global langMode
langMode = isKorean(query)
isTyping(connectionId, requestId)
inputs = {"question": query}
config = {"recursion_limit": 50}
for output in app.stream(inputs, config):
for key, value in output.items():
print(f"Finished running: {key}:")
print("value: ", value)
print('value: ', value)
readStreamMsg(connectionId, requestId, value["generation"].content)
return value["generation"].content
Self-Corrective RAG in LangGraph을 참조합니다.
아래와 같이 Hallucination인지 관련된 문서인지를 LLM을 통해 판별합니다. 설정된 루프보다 더 많은 task를 수행하면, 인터넷 검색을 통해 결과를 얻을 수 있습니다.