Matt Zheng

RAG教學

RAG是什麼

檢索增強生成(Retrieval-Augmented Generation, RAG)是一種結合了搜尋檢索和生成能力的自然語言處理架構。透過這個架構,模型可以從外部知識庫搜尋相關信息,然後使用這些信息來生成回應或完成特定的NLP任務。

更通俗一點的說,RAG就像考試時教授允許大家帶的A4大抄,你可以在考試的時候邊看邊回答問題。

如何實作

RAG的運作流程如下圖

RAG flow

當中的 「相似度匹配」,近乎是RAG技術的核心

我們在大抄裡面優先問題的答案,最快的方法是:“找到與問題最相關的詞語”。而想要將自然語言進行“相關度匹配”的時候,就要使用到Sentence Transformer這個工具了。

Sentence Transformer

目前我們最常處理自然語言的方法是:利用專用的模型,將自然語言嵌入成向量,也就是Vector Embedding。

這些特別的模型(Sentence Transformer),是專門針對相近語意的資訊進行訓練。最後模型就可以做到:評比兩句話之間語意有多相似,最後再給一個相似程度的分數。

RAG

我們要的就是將使用者的問題,利用Sentence Transformer與大抄中內容進行比對。

如果找到了分數夠高的內容,那說明我們在大抄中找到了答案。我們就將這段大抄送給LLM一併進行生成,由此就可以達成擴增外部知識庫的功能了。
RAG flow

取得額外資料後,最後生成的流程可以依照自己的需求重新決定(如 Prompt Parameter…)。

Code

以下用一段簡短的程式碼進行示例

1
2
3
4
5
6
7
8
9
10
# requirements.txt
accelerate
chromadb
jq
langchain
langchain_community
langchain-chroma
sentence-transformers
torch
transformers
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# main.py
# Author: Matt Zheng

# Kernel generate tool
from langchain_community.llms import Ollama

# For streaming output
from langchain.callbacks.streaming_stdout import (
StreamingStdOutCallbackHandler
)
from langchain.callbacks.manager import CallbackManager
callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])

# langchain chainer
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

# Embedding
from langchain_community.embeddings import HuggingFaceEmbeddings

# RAG
from langchain_chroma import Chroma
from langchain.chains import RetrievalQA
from langchain_community.document_loaders import CSVLoader
from langchain_community.document_loaders import JSONLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Embedding
from langchain_community.embeddings import HuggingFaceEmbeddings

# utils
import shutil, os
import json

if os.path.isdir('./db'): shutil.rmtree('./db')

# load json data
loader = JSONLoader(file_path="data/RAG_data.json", jq_schema=".", text_content=False)
data = loader.load()

# load csv data
# loader = CSVLoader(file_path="data/test.csv", encoding='utf8')
# data = loader.load()

# split data for batch embedding
text_splitter = RecursiveCharacterTextSplitter(
separators=["}"],
chunk_size=100,
chunk_overlap=0
)
data = text_splitter.split_documents(data)

# load embedding model
model_name = "intfloat/multilingual-e5-small"
model_kwargs = {'device': 'cpu'}
embedding = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs
)

# store embedding vectors into database
persist_directory = 'db'
vectordb = Chroma.from_documents(
documents=data,
embedding=embedding,
persist_directory=persist_directory,
collection_metadata={"hnsw:space": "cosine"}
)

# Load model
llm = Ollama(
model=your_model_name,
keep_alive=30,
temperature=0,
top_k=40,
top_p=0.95,
verbose=True,
callbacks=callback_manager
)

# RAG parameter
top_k = 10

# load RAG database
vectordb = Chroma(persist_directory='db', embedding_function=embedding)
retriever = vectordb.as_retriever(search_kwargs={"k": top_k})

# Creating LangChain
print('Creating LangChain')

rag_template = """請根據 `{input}` 幫我在檔案中找出所有相關的內容,並將結果統整給我,輸出必須為中文
"""

# Create RAG search engine
qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
verbose=True
)

while 1:
user_input = input("User:")

# prepare rag search template
rag_input = rag_template.format_map({'input': user_input})

# print similarity search result (debug field)
print()
print('========== debug field ==========')
docs = vectordb.similarity_search_with_score(rag_input, k=top_k)
for doc in docs:
print(doc)
print(len(docs))

# get RAG result
qa.invoke(rag_input)