利用 langchian 搭建通用的表格数据分析工具

基于 RAG 搭建表格问答系统

场景描述及解决思路

与 RAG 类似,但是 RAG 无法解决少量提示词的问题,比如:

  1. 2019 年 7 月一共出现多少次异常?
  2. 张三一共解决多少个问题?
  3. 2019 年 10 月 1 日的异常是谁解决的?

诸如此类,涉及时间、人物的精确查询,RAG 过程很难精确检索到,鉴于此,本文基于 langchain 实现针对表格的通用数据分析应用,其架构如下:

Drawing-2024-11-14-13.55.10.excalidraw
首先使用 LLM 对提问进行分类,一共分两类:普通查询、精确查询,普通查询是 RAG 过程,精确查询首先通过 LLM 生成查询命令,如 pandas 命令,然后执行命令拿到结果

关键代码分析

分类 LLM

1
2
3
4
5
6
7
8
9
class_prompt=ChatPromptTemplate.from_template("""你是一名问题归类员,你的任务识别以下由三个引号包含的问题,然后将其分类为:“普通查询”、“精准查询”,当提问涉及准确的时间、地点、人物时,归类为精准查询,否则归类为普通查询。

直接输出“普通查询”、“精准查询”之一,不要输出其他任何信息

问题:```{question}```  

""")

class_chain=class_prompt|llm|StrOutputParser()

RAG 查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
rag_prompt=PromptTemplate.from_template(
"""
# 角色
你是一个严格的根据上下文生成回答的回答者,请根据以下由三个单引号包含的上下文回答问题,注意回答内容必须来自上下文,不能输出上下文不包含的内容,最后使用表格输出

# 技能
1. 分析用户问题
2. 根据上下文回答用户问题

## 技能1
1. 根据提问检索到关联的上下文
2. 将关联上下文直接以表格的形式输出,不作任何改变

## 技能2
1. 分析提问及关联的上下文,给出最终答案
2. 答案以列表输出,并使用加粗优化输出

# 要求
1. 回答内容必须来自上下文,不得输出上下文不存在的内容
2. 如果提问没有涉及的上下文,回答"知识库无法检索到该内容"

上下文:'''{context}'''
问题:{question}

""")

def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = {"context": itemgetter("question")|vectordb.as_retriever(search_kwargs={'k': 4, 'lambda_mult': 0.25})|format_docs, "question": RunnablePassthrough()} \
    |rag_prompt \
    | llm \
    | StrOutputParser()

精确查询

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# 首先,构造一个提示模版字符串:`template_string`
system_string="你正在使用pandas处理DataFrame,请根据三个引号分隔的问题输出pandas命令,其中`print(df.head())`的结果如下:\n"

system_string+=df.head().to_markdown()+"\n"
system_string+= "如果不清楚使用哪一列过滤结果,请使用“现象描述、异常分析,解决方案”三列检索内容是否出现\n"
system_string+= "不要转换提问的文字,比如提问:相机标定不成功如何解决的?直接用'相机标定不成功'检索,而不是用'视觉'\n"
system_string+= "请使用1行代码完成需求\n"
system_string+= "只输出代码,不要输出其他任何信息,也不能写任何注释\n"
system_string+= "确保代码可运行\n"

sql_prompt = ChatPromptTemplate.from_messages([("system", system_string), ("human", "'''{question}'''")])

# 修改pandas命令
def _sanitize_output(pandas_cmd: str):
    if pandas_cmd[-1]==']':
        ss,match_end=find_matching_bracket(pandas_cmd)
        pandas_cmd=pandas_cmd[ss-2:match_end+1]
        pandas_cmd+='.to_markdown()'
    pandas_cmd="print("+pandas_cmd+")"
    return pandas_cmd

# 统计sql找到记录的数量
def postprocess(sql_result):
    row=sql_result.count('\n')-2 # 除去表头及分隔
    append_text="根据以上找到的信息,统计出现{}条记录。".format(row)
    return sql_result+'\n'+append_text

sql_chain = sql_prompt | llm | StrOutputParser() |_sanitize_output
| PythonREPL(_globals=globals(), _locals=None).run
| postprocess

合并所有过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 定义路由

def route(info):
    if "精准查询" in info["class"]:
        return sql_chain
    elif "普通查询" in info["class"]:
        return rag_chain
    else:
        return rag_chain

from langchain_core.runnables import RunnableLambda
full_chain={"class":class_chain,"question":RunnablePassthrough()}\
|RunnableLambda(route)

response = full_chain.invoke(message)