基于 RAG 搭建表格问答系统
场景描述及解决思路
与 RAG 类似,但是 RAG 无法解决少量提示词的问题,比如:
- 2019 年 7 月一共出现多少次异常?
- 张三一共解决多少个问题?
- 2019 年 10 月 1 日的异常是谁解决的?
诸如此类,涉及时间、人物的精确查询,RAG 过程很难精确检索到,鉴于此,本文基于 langchain 实现针对表格的通用数据分析应用,其架构如下:
首先使用 LLM 对提问进行分类,一共分两类:普通查询、精确查询,普通查询是 RAG 过程,精确查询首先通过 LLM 生成查询命令,如 pandas 命令,然后执行命令拿到结果
关键代码分析
分类 LLM
1 2 3 4 5 6 7 8 9
| class_prompt=ChatPromptTemplate.from_template("""你是一名问题归类员,你的任务识别以下由三个引号包含的问题,然后将其分类为:“普通查询”、“精准查询”,当提问涉及准确的时间、地点、人物时,归类为精准查询,否则归类为普通查询。
直接输出“普通查询”、“精准查询”之一,不要输出其他任何信息
问题:```{question}```
""")
class_chain=class_prompt|llm|StrOutputParser()
|
RAG 查询
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
| rag_prompt=PromptTemplate.from_template( """ # 角色 你是一个严格的根据上下文生成回答的回答者,请根据以下由三个单引号包含的上下文回答问题,注意回答内容必须来自上下文,不能输出上下文不包含的内容,最后使用表格输出
# 技能 1. 分析用户问题 2. 根据上下文回答用户问题
## 技能1 1. 根据提问检索到关联的上下文 2. 将关联上下文直接以表格的形式输出,不作任何改变
## 技能2 1. 分析提问及关联的上下文,给出最终答案 2. 答案以列表输出,并使用加粗优化输出
# 要求 1. 回答内容必须来自上下文,不得输出上下文不存在的内容 2. 如果提问没有涉及的上下文,回答"知识库无法检索到该内容"
上下文:'''{context}''' 问题:{question}
""")
def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs)
rag_chain = {"context": itemgetter("question")|vectordb.as_retriever(search_kwargs={'k': 4, 'lambda_mult': 0.25})|format_docs, "question": RunnablePassthrough()} \ |rag_prompt \ | llm \ | StrOutputParser()
|
精确查询
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
| system_string="你正在使用pandas处理DataFrame,请根据三个引号分隔的问题输出pandas命令,其中`print(df.head())`的结果如下:\n"
system_string+=df.head().to_markdown()+"\n" system_string+= "如果不清楚使用哪一列过滤结果,请使用“现象描述、异常分析,解决方案”三列检索内容是否出现\n" system_string+= "不要转换提问的文字,比如提问:相机标定不成功如何解决的?直接用'相机标定不成功'检索,而不是用'视觉'\n" system_string+= "请使用1行代码完成需求\n" system_string+= "只输出代码,不要输出其他任何信息,也不能写任何注释\n" system_string+= "确保代码可运行\n"
sql_prompt = ChatPromptTemplate.from_messages([("system", system_string), ("human", "'''{question}'''")])
def _sanitize_output(pandas_cmd: str): if pandas_cmd[-1]==']': ss,match_end=find_matching_bracket(pandas_cmd) pandas_cmd=pandas_cmd[ss-2:match_end+1] pandas_cmd+='.to_markdown()' pandas_cmd="print("+pandas_cmd+")" return pandas_cmd
def postprocess(sql_result): row=sql_result.count('\n')-2 append_text="根据以上找到的信息,统计出现{}条记录。".format(row) return sql_result+'\n'+append_text
sql_chain = sql_prompt | llm | StrOutputParser() |_sanitize_output | PythonREPL(_globals=globals(), _locals=None).run | postprocess
|
合并所有过程
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
|
def route(info): if "精准查询" in info["class"]: return sql_chain elif "普通查询" in info["class"]: return rag_chain else: return rag_chain
from langchain_core.runnables import RunnableLambda full_chain={"class":class_chain,"question":RunnablePassthrough()}\ |RunnableLambda(route)
response = full_chain.invoke(message)
|