<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
  <title>年轻人起来冲</title>
  
  
  <link href="https://www.shaogui.life/atom.xml" rel="self"/>
  
  <link href="https://www.shaogui.life/"/>
  <updated>2025-03-01T08:01:48.305Z</updated>
  <id>https://www.shaogui.life/</id>
  
  <author>
    <name>Shaogui</name>
    
  </author>
  
  <generator uri="https://hexo.io/">Hexo</generator>
  
  <entry>
    <title>利用 ollama 搭建知识卡片制作系统</title>
    <link href="https://www.shaogui.life/posts/3905846024.html"/>
    <id>https://www.shaogui.life/posts/3905846024.html</id>
    <published>2025-03-01T08:31:02.000Z</published>
    <updated>2025-03-01T08:01:48.305Z</updated>
    
    <content type="html"><![CDATA[<p>本文基于 ollama 推理框架，搭建了一个知识问答工具，可以利用大模型学习 obsidian 上的知识，并结合其中的闪卡插件，完成传统复习软件 Anki 的两个关键过程，即 "制卡与复习"，使用者只需要复习知识，不必进行复杂的制作卡片工作，基于大模型产出的问题，问题角度更发散，避免个人 “死背书”，还是倾向理解问题，但是同样也存在以下问题和文档内容不符合的问题，本文将一并说明</p><span id="more"></span><p>我的很多知识都记忆在 obsidian 上，长期积累下来后，文本变得越来越多，那天复习时，又得全部查看，比较不方便。我本意是利用类似 anki 的软件搭建一个知识回顾系统，但是存在两个问题：（1）制作卡片非常耗时；（2）知识和复习分别位于 obsidian 和 anki 两个软件，更新维护麻烦。所以想利用 ollama 调用大模型，帮自动生成卡片，解决问题 1，利用 obsidian 的间隔复习插件解决问题 2</p><h3 id="认识-obsidian-的-2-个插件"><a class="markdownIt-Anchor" href="#认识-obsidian-的-2-个插件"></a> 认识 obsidian 的 2 个插件</h3><p>这个系统主要是基于 obsidian 2 个插件的再开发，他们是 <code>Quiz Generator</code> 和 <code>Spaced Repetition</code>，第一个插件是基于文档生成 <strong>Q/A 问答</strong>的插件，第二个工具是 <strong>间隔复习插件</strong></p><h4 id="quiz-generator"><a class="markdownIt-Anchor" href="#quiz-generator"></a> Quiz Generator</h4><p>在 <code>Quiz Generator</code> 中，安装后，配置大模型路径，这里是个人电脑，所以选择 ollama 框架进行大模型推理，</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8ollama%E6%90%AD%E5%BB%BA%E7%9F%A5%E8%AF%86%E5%8D%A1%E7%89%87%E5%88%B6%E4%BD%9C%E7%B3%BB%E7%BB%9F-20250301104809.png" alt="利用ollama搭建知识卡片制作系统-20250301104809"></p><p>然后选择需要生成问答的文档或者文件夹，即可生成以下效果的问答</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8ollama%E6%90%AD%E5%BB%BA%E7%9F%A5%E8%AF%86%E5%8D%A1%E7%89%87%E5%88%B6%E4%BD%9C%E7%B3%BB%E7%BB%9F-20250301110420.png" alt="利用ollama搭建知识卡片制作系统-20250301110420"></p><p>这个工具还可以生成不同类型的问答，包括：单选、多选、填空、问答题等，其中填空和问答等类似的答案不唯一的问题，支持使用大模型评估问答效果</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8ollama%E6%90%AD%E5%BB%BA%E7%9F%A5%E8%AF%86%E5%8D%A1%E7%89%87%E5%88%B6%E4%BD%9C%E7%B3%BB%E7%BB%9F-20250301104833.png" alt="利用ollama搭建知识卡片制作系统-20250301104833"></p><p>这个插件还有一个功能是保存生成的问题，包括 <code>callout</code> 和 <code>Spaced Repetition</code> 两种方式，这里选择第二种方式，配合下面插件 <code>Spaced Repetition</code> 去复习之前生成过的插件</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8ollama%E6%90%AD%E5%BB%BA%E7%9F%A5%E8%AF%86%E5%8D%A1%E7%89%87%E5%88%B6%E4%BD%9C%E7%B3%BB%E7%BB%9F-20250301104933.png" alt="利用ollama搭建知识卡片制作系统-20250301104933"></p><h4 id="spaced-repetition"><a class="markdownIt-Anchor" href="#spaced-repetition"></a> Spaced Repetition</h4><p>插件 <code>Quiz Generator</code> 已经生成可以被 <code>Spaced Repetition</code> 解析的卡片文件，直接打开 <code>Spaced Repetition</code> 会看到以下界面</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8ollama%E6%90%AD%E5%BB%BA%E7%9F%A5%E8%AF%86%E5%8D%A1%E7%89%87%E5%88%B6%E4%BD%9C%E7%B3%BB%E7%BB%9F-20250301110544.png" alt="利用ollama搭建知识卡片制作系统-20250301110544"></p><h3 id="修改-quiz-generator"><a class="markdownIt-Anchor" href="#修改-quiz-generator"></a> 修改 Quiz Generator</h3><p>以上 2 个插件已经具备这个复习系统的功能，但是在实际使用中，发现 <code>Quiz Generator</code> 存在以下问题：</p><ol><li>显示答案时，没有显示是基于那些资料产生的提问，导致使用者无法确定自己的选择是否正确（有时候大模型给的答案也不准确）</li><li>生成卡片过程，没有任何提示，不知道生成进度，只能一直等</li></ol><p>针对以上两个问题，基于 <code>Quiz Generator</code> ，做以下修改：</p><ol><li>该插件无法给出问题的引用源，是因为基于整篇文章生成提问，所以改进是：获取文档后，先对文档进行分片，对每个分片生成提问，在两个插件复习时，显示答案的同时显示引用源</li><li>根据文档分片数量，实时输出生成提问的进度，使用更加友好</li></ol><p>修改源码时，我们以标题进行分片，也就是针对一个标题内容生成提问，比如下图是文档内容及生成的提问</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8ollama%E6%90%AD%E5%BB%BA%E7%9F%A5%E8%AF%86%E5%8D%A1%E7%89%87%E5%88%B6%E4%BD%9C%E7%B3%BB%E7%BB%9F-20250301150347.png" alt="利用ollama搭建知识卡片制作系统-20250301150347"></p><p>通过修改源码，让在 <code>Quiz Generator</code> 和 <code>Spaced Repetition</code> 两个插件显示问题引用源，效果类似以下：</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8ollama%E6%90%AD%E5%BB%BA%E7%9F%A5%E8%AF%86%E5%8D%A1%E7%89%87%E5%88%B6%E4%BD%9C%E7%B3%BB%E7%BB%9F-20250301150436.png" alt="利用ollama搭建知识卡片制作系统-20250301150436"></p><h3 id="安装方式"><a class="markdownIt-Anchor" href="#安装方式"></a> 安装方式</h3><p>本系统主要是在插件 <code>Quiz Generator</code> 的基础上修改得到，并没有直接上传到 obsidian 插件平台，所以需要下载源码安装，步骤为：</p><ol><li>在 <a href="https://github.com/WuShaogui/obsidian-quiz-generator.git">GitHub - WuShaogui/obsidian-quiz-generator: 针对自己文档特点进行优化</a>下载源码</li><li>在 Obsidian 安装插件 <code>Quiz Generator</code></li><li>在源码位置，找到以下 4 个文件，拷贝到 obsidian 文件的 <code>.obsidian\plugins\quiz-generator</code> 路径下</li><li>重新打开 obsidian</li></ol><p>四个文件分别是：<code>main.js、data.json、styles.css、manifest.json</code></p><p>如何再次基础上进行开发，请先在目录执行 <code>npm install</code> 安装依赖，然后使用 <code>npm run dev</code> 调试代码，使用 <code>npm run build</code> 编译代码</p><p>总结：</p><ol><li>通过修改 <code>Quiz Generator</code> 实现按标题内容生成问答对，生成的提问更准确</li><li>通过引入源，解决复习时需要翻找资料的问题</li><li>实时打印大模型生成问答对进度，更友好</li></ol>]]></content>
    
    
    <summary type="html">&lt;p&gt;本文基于 ollama 推理框架，搭建了一个知识问答工具，可以利用大模型学习 obsidian 上的知识，并结合其中的闪卡插件，完成传统复习软件 Anki 的两个关键过程，即 &quot;制卡与复习&quot;，使用者只需要复习知识，不必进行复杂的制作卡片工作，基于大模型产出的问题，问题角度更发散，避免个人 “死背书”，还是倾向理解问题，但是同样也存在以下问题和文档内容不符合的问题，本文将一并说明&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    
  </entry>
  
  <entry>
    <title>基于 llamaindex 探究 graphrag 原理</title>
    <link href="https://www.shaogui.life/posts/2804851745.html"/>
    <id>https://www.shaogui.life/posts/2804851745.html</id>
    <published>2025-02-19T07:08:30.000Z</published>
    <updated>2025-02-19T07:26:55.326Z</updated>
    
    <content type="html"><![CDATA[<p>本文利用 llamaindex 的 <code>KnowledgeGraphIndex</code> 非结构化文档的知识图谱，并使用该知识图谱回答问题</p><span id="more"></span><h2 id="环境设置"><a class="markdownIt-Anchor" href="#环境设置"></a> 环境设置</h2><p>设置推理模型及 embedding 模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> llama_index.core <span class="keyword">import</span> Settings</span><br><span class="line"><span class="keyword">from</span> llama_index.llms.ollama <span class="keyword">import</span> Ollama</span><br><span class="line"><span class="keyword">from</span> llama_index.embeddings.ollama <span class="keyword">import</span> OllamaEmbedding</span><br><span class="line">base_url=<span class="string">'http://192.168.3.165:11434'</span></span><br><span class="line">Settings.llm = Ollama(model=<span class="string">"qwen2.5:latest"</span>, request_timeout=<span class="number">360.0</span>,base_url=base_url)</span><br><span class="line">Settings.embed_model = OllamaEmbedding(model_name=<span class="string">"quentinz/bge-large-zh-v1.5:latest"</span>,base_url=base_url)</span><br><span class="line">Settings.chunk_size = <span class="number">512</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 可视化知识库函数</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">show_knowledge_graph</span>(<span class="params">index,name=<span class="string">'example.html'</span></span>):</span><br><span class="line">    <span class="keyword">from</span> pyvis.network <span class="keyword">import</span> Network</span><br><span class="line"></span><br><span class="line">    g = index.get_networkx_graph()</span><br><span class="line">    net = Network(notebook=<span class="literal">True</span>, cdn_resources=<span class="string">"in_line"</span>, directed=<span class="literal">True</span>)</span><br><span class="line">    net.from_nx(g)</span><br><span class="line">    net.show(name)</span><br></pre></td></tr></tbody></table></figure><h2 id="构建知识图谱"><a class="markdownIt-Anchor" href="#构建知识图谱"></a> 构建知识图谱</h2><p>首先，使用 SimpleDirectoryReaderd 读取非结构化文档</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 1.自定义documents</span></span><br><span class="line"><span class="keyword">from</span> llama_index.core.schema <span class="keyword">import</span> TextNode,Document</span><br><span class="line">documents=[</span><br><span class="line">    Document(<span class="built_in">id</span>=<span class="string">'123'</span>,text=<span class="string">'阴雨绵绵的清晨，李华见徐天行老师在校门口焦急张望，立刻推出自己的新自行车：“老师，我送您去教室！”放学后，伍娟摸着儿子湿透的校服眼眶发红，李山拍拍儿子肩膀：“懂得承担责任，是男子汉了。'</span>),</span><br><span class="line">    Document(<span class="built_in">id</span>=<span class="string">'456'</span>,text=<span class="string">'徐天行举着李华的作业本：“横平竖直，一夜进步神速！”原来李华熬夜练习书法，只因徐老师曾说“字如心正”。'</span>),</span><br><span class="line">    Document(<span class="built_in">id</span>=<span class="string">'789'</span>,tex=<span class="string">'张志抱着一摞昆虫图鉴冲进教室：“生物课展示靠你了！”李华苦笑——他最怕甲虫，但想起徐老师“直面恐惧”的鼓励，咬牙点头。两人熬夜制作模型，李山帮忙3D打印鞘翅，伍娟端来牛奶。次日展示时，徐老师按下录音键：“这份协作，值得全班倾听。” 。'</span>)</span><br><span class="line">]</span><br></pre></td></tr></tbody></table></figure><p>其次利用 <code>KnowledgeGraphIndex</code> 在全局知识层面构建知识图谱 (也就是索引)</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> llama_index.core <span class="keyword">import</span> SimpleDirectoryReader,StorageContext, KnowledgeGraphIndex</span><br><span class="line"><span class="keyword">from</span> llama_index.core.graph_stores <span class="keyword">import</span> SimpleGraphStore</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line">save_kg_path=<span class="string">r'db\example\graph.db'</span></span><br><span class="line"><span class="keyword">if</span> os.path.exists(save_kg_path):</span><br><span class="line">    graph_store = SimpleGraphStore().from_persist_dir(save_kg_path)</span><br><span class="line"></span><br><span class="line">    index = KnowledgeGraphIndex.from_documents(</span><br><span class="line">        documents,</span><br><span class="line">        max_triplets_per_chunk=<span class="number">50</span>,</span><br><span class="line">        graph_store=graph_store,</span><br><span class="line">        show_progress=<span class="literal">True</span></span><br><span class="line">    )</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">    graph_store = SimpleGraphStore()</span><br><span class="line">    storage_context = StorageContext.from_defaults(graph_store=graph_store)</span><br><span class="line">    </span><br><span class="line">    index = KnowledgeGraphIndex.from_documents(</span><br><span class="line">        documents,</span><br><span class="line">        max_triplets_per_chunk=<span class="number">50</span>,</span><br><span class="line">        storage_context=storage_context,</span><br><span class="line">        show_progress=<span class="literal">True</span></span><br><span class="line">    )</span><br><span class="line">    index.storage_context.persist(persist_dir=<span class="string">r'db\example'</span>,graph_store_fname=<span class="string">'graph.db'</span>)</span><br></pre></td></tr></tbody></table></figure><pre><code>Parsing nodes:   0%|          | 0/3 [00:00&lt;?, ?it/s]Processing nodes:   0%|          | 0/3 [00:00&lt;?, ?it/s]</code></pre><p>最后，使用 <code>pyvis</code> 工具查看知识图谱的内容</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">show_knowledge_graph(index,name=<span class="string">'example.html'</span>)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Ellamaindex%E6%8E%A2%E7%A9%B6graphrag%E5%8E%9F%E7%90%86-20250219151832.png" alt="alt text"></p><p>由可视化结果看出，使用 llm 构建的知识图谱，将中文名改为了英文，而且感觉图谱质量不高。</p><p>实际上，知识图谱的构建本质是 N 个三元组 (实体 1，实体 2，关系) 组成，所以也可以直接向索引对象添加三元组，实现 “手动” 知识图谱的构建</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 2.文档chunk化</span></span><br><span class="line"><span class="keyword">from</span> llama_index.core.node_parser <span class="keyword">import</span> SentenceSplitter</span><br><span class="line">node_parser=SentenceSplitter(chunk_size=<span class="number">512</span>,chunk_overlap=<span class="number">32</span>)</span><br><span class="line">nodes=node_parser.get_nodes_from_documents(documents)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 3.构建索引</span></span><br><span class="line"><span class="keyword">from</span> llama_index.core <span class="keyword">import</span> KnowledgeGraphIndex</span><br><span class="line">index = KnowledgeGraphIndex(</span><br><span class="line">    [],</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 针对第1个文档的关系</span></span><br><span class="line">node_tups=[(<span class="string">'李华'</span>,<span class="string">'母亲是'</span>,<span class="string">'伍娟'</span>),(<span class="string">'李华'</span>,<span class="string">'父亲是'</span>,<span class="string">'李山'</span>)]</span><br><span class="line"><span class="keyword">for</span> node_tup <span class="keyword">in</span> node_tups:</span><br><span class="line">    index.upsert_triplet_and_node(node_tup,nodes[<span class="number">0</span>],include_embeddings=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 针对第1个文档的关系</span></span><br><span class="line">node_tups=[(<span class="string">'李华'</span>,<span class="string">'同学是'</span>,<span class="string">'张志'</span>)]</span><br><span class="line"><span class="keyword">for</span> node_tup <span class="keyword">in</span> node_tups:</span><br><span class="line">    index.upsert_triplet_and_node(node_tup,nodes[<span class="number">1</span>],include_embeddings=<span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 针对第3个文档的关系</span></span><br><span class="line">node_tups=[(<span class="string">'张志'</span>,<span class="string">'老师是'</span>,<span class="string">'徐天行'</span>)]</span><br><span class="line"><span class="keyword">for</span> node_tup <span class="keyword">in</span> node_tups:</span><br><span class="line">    index.upsert_triplet_and_node(node_tup,nodes[<span class="number">2</span>],include_embeddings=<span class="literal">False</span>)</span><br><span class="line">    </span><br><span class="line">show_knowledge_graph(index,name=<span class="string">'custom_example.html'</span>)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Ellamaindex%E6%8E%A2%E7%A9%B6graphrag%E5%8E%9F%E7%90%86-20250219151833.png" alt="alt text"></p><h2 id="使用知识图谱"><a class="markdownIt-Anchor" href="#使用知识图谱"></a> 使用知识图谱</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 直接使用索引，分析文本中的三元组</span></span><br><span class="line">index._extract_triplets(<span class="string">"李华和王五去徐天行老师家补课"</span>)</span><br></pre></td></tr></tbody></table></figure><pre><code>[('李华', '去', '徐天行老师家'), ('王五', '去', '徐天行老师家'), ('补课', '发生在', '徐天行老师家')]</code></pre><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 索引-&gt;检索器</span></span><br><span class="line"><span class="keyword">from</span> llama_index.core.indices.knowledge_graph.retrievers <span class="keyword">import</span> KGRetrieverMode</span><br><span class="line">retriever=index.as_retriever(retriever_mode=KGRetrieverMode.KEYWORD)</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> llama_index.core.schema <span class="keyword">import</span> QueryBundle</span><br><span class="line">query=QueryBundle(query_str=<span class="string">"李华和同学张志一起完成了什么任务？"</span>)</span><br><span class="line">retriever._retrieve(query)</span><br></pre></td></tr></tbody></table></figure><figure class="highlight markdown"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">[NodeWithScore(node=TextNode(id<span class="emphasis">_='d38689d8-a7c2-41f8-a8de-d79c84bc4034', embedding=None, metadata={}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=[], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=[], relationships={<span class="language-xml"><span class="tag">&lt;<span class="name">NodeRelationship.SOURCE:</span> '<span class="attr">1</span>'&gt;</span></span>: RelatedNodeInfo(node_</span>id='97ee7e69-7ebf-4343-a587-8ee10e281cae', node<span class="emphasis">_type='4', metadata={}, hash='96f4e9c8f81325976a25213572d9255024347d55bce55e980bc6d1ec03bd33f7')}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text='徐天行举着李华的作业本：“横平竖直，一夜进步神速！”原来李华熬夜练习书法，只因徐老师曾说“字如心正”。', mimetype='text/plain', start_</span>char<span class="emphasis">_idx=0, end_</span>char<span class="emphasis">_idx=51, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0),</span><br><span class="line"> NodeWithScore(node=TextNode(id<span class="emphasis">_='3a354c0b-3051-41ae-80f0-089627928e8c', embedding=None, metadata={}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=[], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=[], relationships={<span class="language-xml"><span class="tag">&lt;<span class="name">NodeRelationship.SOURCE:</span> '<span class="attr">1</span>'&gt;</span></span>: RelatedNodeInfo(node_</span>id='a3aa4ae9-40e0-4303-b800-4a88d4f90b54', node<span class="emphasis">_type='4', metadata={}, hash='44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a')}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text='', mimetype='text/plain', start_</span>char<span class="emphasis">_idx=0, end_</span>char<span class="emphasis">_idx=0, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0),</span><br><span class="line"> NodeWithScore(node=TextNode(id<span class="emphasis">_='132f96dc-4d68-4313-ad64-d3a19f33581b', embedding=None, metadata={}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=[], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=[], relationships={<span class="language-xml"><span class="tag">&lt;<span class="name">NodeRelationship.SOURCE:</span> '<span class="attr">1</span>'&gt;</span></span>: RelatedNodeInfo(node_</span>id='822127f3-1511-4a31-921b-c527fd0f3a15', node<span class="emphasis">_type='4', metadata={}, hash='904fbb3a3e01cf71a67cd33a5dafe1966fade6f3c538516a59823b2e78041b62')}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text='阴雨绵绵的清晨，李华见徐天行老师在校门口焦急张望，立刻推出自己的新自行车：“老师，我送您去教室！”放学后，伍娟摸着儿子湿透的校服眼眶发红，李山拍拍儿子肩膀：“懂得承担责任，是男子汉了。', mimetype='text/plain', start_</span>char<span class="emphasis">_idx=0, end_</span>char<span class="emphasis">_idx=92, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0),</span><br><span class="line"> NodeWithScore(node=TextNode(id<span class="emphasis">_='64566ef0-e951-4512-86d3-b3c5dcf17f8a', embedding=None, metadata={'kg_</span>rel<span class="emphasis">_texts': ["['张志', '老师是', '徐天行']", "['李华', '母亲是', '伍娟']", "['李华', '父亲是', '李山']", "['李华', '同学是', '张志']", "['张志', '老师是', '徐天行']"], 'kg_</span>rel<span class="emphasis">_map': {'任务': [], '完成': [], '同学': [], '张志': [['张志', '老师是', '徐天行']], '李华': [['李华', '母亲是', '伍娟'], ['李华', '父亲是', '李山'], ['李华', '同学是', '张志'], ['张志', '老师是', '徐天行']]}}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=['kg<span class="emphasis">_rel_</span>map', 'kg<span class="emphasis">_rel_</span>texts'], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=['kg_</span>rel<span class="emphasis">_map', 'kg_</span>rel<span class="emphasis">_texts'], relationships={}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text="The following are knowledge sequence in max depth 2 in the form of directed graph like:\n`subject -[predicate]-&gt;, object, &lt;-[predicate_</span>next<span class="emphasis">_hop]-, object_</span>next<span class="emphasis">_hop ...`\n['张志', '老师是', '徐天行']\n['李华', '母亲是', '伍娟']\n['李华', '父亲是', '李山']\n['李华', '同学是', '张志']\n['张志', '老师是', '徐天行']", mimetype='text/plain', start_</span>char<span class="emphasis">_idx=None, end_</span>char<span class="emphasis">_idx=None, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0)]</span><br></pre></td></tr></tbody></table></figure><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 索引-&gt;RAG引擎</span></span><br><span class="line">query_engine = index.as_query_engine(</span><br><span class="line">    include_text=<span class="literal">True</span>, response_mode=<span class="string">"tree_summarize"</span>,similarity_top_k=<span class="number">3</span></span><br><span class="line">)</span><br><span class="line">response = query_engine.query(</span><br><span class="line">    <span class="string">"老师如何评价李华和张志的课程作业"</span>,</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> IPython.display <span class="keyword">import</span> Markdown</span><br><span class="line">display(Markdown(<span class="string">f"&lt;b&gt;<span class="subst">{response}</span>&lt;/b&gt;"</span>))</span><br></pre></td></tr></tbody></table></figure><blockquote><p><b>徐天行对李华的书法进步给予了积极评价，说他的字 “横平竖直，一夜进步神速”。至于张志的情况，文中并未提及具体课程作业的评价信息。</b></p></blockquote><p>后面两个应用中，可以发现知识图谱并没有准确检索到文档，或者准确回答问题，这是因为只使用了 “实体（关键字）” 去知识图谱检索文档，下面我们在知识图谱的基础上添加文档的整体语义信息，检索时，同时考虑实体及语义，看看其检索与回答效果</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line">index = KnowledgeGraphIndex(</span><br><span class="line">    [],</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 针对第1个文档的关系</span></span><br><span class="line">node_tups=[(<span class="string">'李华'</span>,<span class="string">'母亲是'</span>,<span class="string">'伍娟'</span>),(<span class="string">'李华'</span>,<span class="string">'父亲是'</span>,<span class="string">'李山'</span>)]</span><br><span class="line"><span class="keyword">for</span> node_tup <span class="keyword">in</span> node_tups:</span><br><span class="line">    index.upsert_triplet_and_node(node_tup,nodes[<span class="number">0</span>],include_embeddings=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 针对第1个文档的关系</span></span><br><span class="line">node_tups=[(<span class="string">'李华'</span>,<span class="string">'同学是'</span>,<span class="string">'张志'</span>)]</span><br><span class="line"><span class="keyword">for</span> node_tup <span class="keyword">in</span> node_tups:</span><br><span class="line">    index.upsert_triplet_and_node(node_tup,nodes[<span class="number">1</span>],include_embeddings=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 针对第3个文档的关系</span></span><br><span class="line">node_tups=[(<span class="string">'张志'</span>,<span class="string">'老师是'</span>,<span class="string">'徐天行'</span>)]</span><br><span class="line"><span class="keyword">for</span> node_tup <span class="keyword">in</span> node_tups:</span><br><span class="line">    index.upsert_triplet_and_node(node_tup,nodes[<span class="number">2</span>],include_embeddings=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 索引-&gt;检索器</span></span><br><span class="line"><span class="keyword">from</span> llama_index.core.indices.knowledge_graph.retrievers <span class="keyword">import</span> KGRetrieverMode</span><br><span class="line"></span><br><span class="line">retriever=index.as_retriever(retriever_mode=KGRetrieverMode.HYBRID)</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> llama_index.core.schema <span class="keyword">import</span> QueryBundle</span><br><span class="line">query=QueryBundle(query_str=<span class="string">"老师如何评价李华和张志的课程作业"</span>)</span><br><span class="line">retriever._retrieve(query)</span><br></pre></td></tr></tbody></table></figure><figure class="highlight markdown"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">[NodeWithScore(node=TextNode(id<span class="emphasis">_='d38689d8-a7c2-41f8-a8de-d79c84bc4034', embedding=None, metadata={}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=[], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=[], relationships={<span class="language-xml"><span class="tag">&lt;<span class="name">NodeRelationship.SOURCE:</span> '<span class="attr">1</span>'&gt;</span></span>: RelatedNodeInfo(node_</span>id='97ee7e69-7ebf-4343-a587-8ee10e281cae', node<span class="emphasis">_type='4', metadata={}, hash='96f4e9c8f81325976a25213572d9255024347d55bce55e980bc6d1ec03bd33f7')}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text='徐天行举着李华的作业本：“横平竖直，一夜进步神速！”原来李华熬夜练习书法，只因徐老师曾说“字如心正”。', mimetype='text/plain', start_</span>char<span class="emphasis">_idx=0, end_</span>char<span class="emphasis">_idx=51, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0),</span><br><span class="line"> NodeWithScore(node=TextNode(id<span class="emphasis">_='3a354c0b-3051-41ae-80f0-089627928e8c', embedding=None, metadata={}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=[], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=[], relationships={<span class="language-xml"><span class="tag">&lt;<span class="name">NodeRelationship.SOURCE:</span> '<span class="attr">1</span>'&gt;</span></span>: RelatedNodeInfo(node_</span>id='a3aa4ae9-40e0-4303-b800-4a88d4f90b54', node<span class="emphasis">_type='4', metadata={}, hash='44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a')}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text='', mimetype='text/plain', start_</span>char<span class="emphasis">_idx=0, end_</span>char<span class="emphasis">_idx=0, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0),</span><br><span class="line"> NodeWithScore(node=TextNode(id<span class="emphasis">_='132f96dc-4d68-4313-ad64-d3a19f33581b', embedding=None, metadata={}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=[], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=[], relationships={<span class="language-xml"><span class="tag">&lt;<span class="name">NodeRelationship.SOURCE:</span> '<span class="attr">1</span>'&gt;</span></span>: RelatedNodeInfo(node_</span>id='822127f3-1511-4a31-921b-c527fd0f3a15', node<span class="emphasis">_type='4', metadata={}, hash='904fbb3a3e01cf71a67cd33a5dafe1966fade6f3c538516a59823b2e78041b62')}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text='阴雨绵绵的清晨，李华见徐天行老师在校门口焦急张望，立刻推出自己的新自行车：“老师，我送您去教室！”放学后，伍娟摸着儿子湿透的校服眼眶发红，李山拍拍儿子肩膀：“懂得承担责任，是男子汉了。', mimetype='text/plain', start_</span>char<span class="emphasis">_idx=0, end_</span>char<span class="emphasis">_idx=92, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0),</span><br><span class="line"> NodeWithScore(node=TextNode(id<span class="emphasis">_='f730db7b-3fe4-4133-b749-2872dee27412', embedding=None, metadata={'kg_</span>rel<span class="emphasis">_texts': ["('张志', '老师是', '徐天行')", "['张志', '老师是', '徐天行']", "('李华', '同学是', '张志')", "['李华', '母亲是', '伍娟']", "['李华', '同学是', '张志']", "['李华', '父亲是', '李山']"], 'kg_</span>rel<span class="emphasis">_map': {'课程作业': [], '老师': [], '李华': [['李华', '母亲是', '伍娟'], ['李华', '父亲是', '李山'], ['李华', '同学是', '张志'], ['张志', '老师是', '徐天行']], '张志': [['张志', '老师是', '徐天行']], '评价': []}}, excluded_</span>embed<span class="emphasis">_metadata_</span>keys=['kg<span class="emphasis">_rel_</span>map', 'kg<span class="emphasis">_rel_</span>texts'], excluded<span class="emphasis">_llm_</span>metadata<span class="emphasis">_keys=['kg_</span>rel<span class="emphasis">_map', 'kg_</span>rel<span class="emphasis">_texts'], relationships={}, metadata_</span>template='{key}: {value}', metadata<span class="emphasis">_separator='\n', text="The following are knowledge sequence in max depth 2 in the form of directed graph like:\n`subject -[predicate]-&gt;, object, &lt;-[predicate_</span>next<span class="emphasis">_hop]-, object_</span>next<span class="emphasis">_hop ...`\n('张志', '老师是', '徐天行')\n['张志', '老师是', '徐天行']\n('李华', '同学是', '张志')\n['李华', '母亲是', '伍娟']\n['李华', '同学是', '张志']\n['李华', '父亲是', '李山']", mimetype='text/plain', start_</span>char<span class="emphasis">_idx=None, end_</span>char<span class="emphasis">_idx=None, metadata_</span>seperator='\n', text<span class="emphasis">_template='{metadata_</span>str}\n\n{content}'), score=1000.0)]</span><br></pre></td></tr></tbody></table></figure><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 索引-&gt;RAG引擎</span></span><br><span class="line">query_engine = index.as_query_engine(</span><br><span class="line">    include_text=<span class="literal">True</span>, response_mode=<span class="string">"refine"</span></span><br><span class="line">)</span><br><span class="line">response = query_engine.query(</span><br><span class="line">    <span class="string">"老师如何评价李华和张志的生物课程作业"</span>,</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> IPython.display <span class="keyword">import</span> Markdown</span><br><span class="line">display(Markdown(<span class="string">f"&lt;b&gt;<span class="subst">{response}</span>&lt;/b&gt;"</span>))</span><br></pre></td></tr></tbody></table></figure><blockquote><p><b>根据新提供的信息，我们仍然没有关于老师对李华和张志生物课程作业的具体评价。文中仅描述了徐天行是张志的老师以及与李华相关的一些家庭成员关系，并未提及任何有关生物课程作业的内容。因此，无法回答关于他们的生物课程作业情况。</b></p></blockquote><p>可以发现，检测、回答效果并没有改善，看来 GraphRAG 并不是在所有场景有效，或者是其效果依赖于知识图谱的质量，在本脚本中，知识图谱太小，使用受限</p><h2 id="构建知识图谱的原理"><a class="markdownIt-Anchor" href="#构建知识图谱的原理"></a> 构建知识图谱的原理</h2><p>知识图谱的目的是使用 llm 提取文本的三元组，我们首先来看看 llm 的相关 prompt</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> llama_index.core <span class="keyword">import</span> KnowledgeGraphIndex</span><br><span class="line"></span><br><span class="line">index=KnowledgeGraphIndex([])</span><br><span class="line">index.kg_triplet_extract_template</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><blockquote><p>PromptTemplate(metadata={‘prompt_type’: &lt;PromptType.KNOWLEDGE_TRIPLET_EXTRACT: ‘knowledge_triplet_extract’&gt;}, template_vars=[‘max_knowledge_triplets’, ‘text’], kwargs={‘max_knowledge_triplets’: 10}, output_parser=None, template_var_mappings=None, function_mappings=None, template=“Some text is provided below. Given the text, extract up to {max_knowledge_triplets} knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n---------------------\nExample:Text: Alice is Bob’s mother.Triplets:\n(Alice, is mother of, Bob)\nText: Philz is a coffee shop founded in Berkeley in 1982.\nTriplets:\n(Philz, is, coffee shop)\n(Philz, founded in, Berkeley)\n(Philz, founded in, 1982)\n---------------------\nText: {text}\nTriplets:\n”)</p></blockquote><p>其中 prompt 内容如下:</p><figure class="highlight markdown"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">Some text is provided below. Given the text, extract up to {max<span class="emphasis">_knowledge_</span>triplets} knowledge triplets in the form of (subject, predicate, object). Avoid stopwords.\n</span><br><span class="line">---------------------\n</span><br><span class="line">Example:Text: Alice is Bob's mother.Triplets:\n(Alice, is mother of, Bob)\n</span><br><span class="line">Text: Philz is a coffee shop founded in Berkeley in 1982.\n</span><br><span class="line">Triplets:\n(Philz, is, coffee shop)\n</span><br><span class="line">(Philz, founded in, Berkeley)\n</span><br><span class="line">(Philz, founded in, 1982)\n</span><br><span class="line">---------------------\n</span><br><span class="line">Text: {text}\n</span><br><span class="line">Triplets:\n</span><br></pre></td></tr></tbody></table></figure><p>内容显示，要求 llm 提供数量为 max_knowledge_triplets 个的 (subject, predicate, object) s 三元组，并且给出提示例子</p><h2 id="应用小例子"><a class="markdownIt-Anchor" href="#应用小例子"></a> 应用小例子</h2><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> llama_index.core <span class="keyword">import</span> SimpleDirectoryReader</span><br><span class="line">documents = SimpleDirectoryReader(<span class="string">"./三国演义白话文/"</span>,recursive=<span class="literal">True</span>).load_data(show_progress=<span class="literal">True</span>)</span><br></pre></td></tr></tbody></table></figure><pre><code>Loading files: 100%|██████████| 8/8 [00:00&lt;00:00, 74.28file/s]</code></pre><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> llama_index.core <span class="keyword">import</span> StorageContext, KnowledgeGraphIndex</span><br><span class="line"><span class="keyword">from</span> llama_index.core.graph_stores <span class="keyword">import</span> SimpleGraphStore</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> os</span><br><span class="line">save_kg_path=<span class="string">r'db\sanguo\graph.db'</span></span><br><span class="line"><span class="keyword">if</span> os.path.exists(save_kg_path):</span><br><span class="line">    graph_store = SimpleGraphStore().from_persist_dir(save_kg_path)</span><br><span class="line"></span><br><span class="line">    index = KnowledgeGraphIndex.from_documents(</span><br><span class="line">        documents,</span><br><span class="line">        max_triplets_per_chunk=<span class="number">50</span>,</span><br><span class="line">        graph_store=graph_store,</span><br><span class="line">        show_progress=<span class="literal">True</span></span><br><span class="line">    )</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line">    graph_store = SimpleGraphStore()</span><br><span class="line">    storage_context = StorageContext.from_defaults(graph_store=graph_store)</span><br><span class="line">    </span><br><span class="line">    index = KnowledgeGraphIndex.from_documents(</span><br><span class="line">        documents,</span><br><span class="line">        max_triplets_per_chunk=<span class="number">50</span>,</span><br><span class="line">        storage_context=storage_context,</span><br><span class="line">        show_progress=<span class="literal">True</span></span><br><span class="line">    )</span><br><span class="line">    index.storage_context.persist(persist_dir=<span class="string">r'db\sanguo'</span>,graph_store_fname=<span class="string">'graph.db'</span>)</span><br></pre></td></tr></tbody></table></figure><pre><code>Parsing nodes:   0%|          | 0/8 [00:00&lt;?, ?it/s]Processing nodes:   0%|          | 0/80 [00:00&lt;?, ?it/s]</code></pre><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">query_engine = index.as_query_engine(</span><br><span class="line">    include_text=<span class="literal">False</span>, response_mode=<span class="string">"tree_summarize"</span></span><br><span class="line">)</span><br><span class="line">response = query_engine.query(</span><br><span class="line">    <span class="string">"董卓如何将吕布收为义子的？"</span>,</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> IPython.display <span class="keyword">import</span> Markdown</span><br><span class="line">display(Markdown(<span class="string">f"&lt;b&gt;<span class="subst">{response}</span>&lt;/b&gt;"</span>))</span><br></pre></td></tr></tbody></table></figure><blockquote><p><b>董卓通过给予礼物的方式收买了吕布，将他收为了义子。根据信息，董卓用一匹名为 “赤兔马” 的千里马和数不尽的黄金珠宝作为礼物送给了吕布，成功地让他拜自己为义父。</b></p></blockquote><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">query_engine = index.as_query_engine(</span><br><span class="line">    include_text=<span class="literal">True</span>, response_mode=<span class="string">"tree_summarize"</span></span><br><span class="line">)</span><br><span class="line">response = query_engine.query(</span><br><span class="line">    <span class="string">"刘备如何为朝廷效力？"</span>,</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line">display(Markdown(<span class="string">f"&lt;b&gt;<span class="subst">{response}</span>&lt;/b&gt;"</span>))</span><br></pre></td></tr></tbody></table></figure><blockquote><p><b>根据提供的信息，刘备最初是通过军事才能为朝廷效力的。在故事中提到，关羽和张飞在战场上英勇无敌，而刘备则既有勇又有谋，三人接连取胜，战功赫赫。但是，随着时间推移，朝廷变得更加腐败，实行卖官鬻爵的行为。尽管如此，刘备并未放弃为国家贡献力量的机会，在担任安喜县县尉期间，他与百姓秋毫无犯，深得民心，展现了其忠于职守的一面。</b></p></blockquote><p><strong>总结</strong>：本文使用 llamaindex 学习了知识图谱的构建，并基于知识图谱创建 RAG 应用，发现 GraphRAG 的回答质量不是在所有场景适用，而是依赖于高质量的知识图谱</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;本文利用 llamaindex 的 &lt;code&gt;KnowledgeGraphIndex&lt;/code&gt; 非结构化文档的知识图谱，并使用该知识图谱回答问题&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
  <entry>
    <title>RAG 演进 05-AgenticRAG</title>
    <link href="https://www.shaogui.life/posts/3434692338.html"/>
    <id>https://www.shaogui.life/posts/3434692338.html</id>
    <published>2025-02-14T06:14:51.000Z</published>
    <updated>2025-02-14T08:04:49.300Z</updated>
    
    <content type="html"><![CDATA[<p>本文用于记录学习 AgenticRAG 的过程</p><span id="more"></span><p>RAG 的标准过程是 “索引 - 检索 - 生成”，他们的影响关系如下：</p><ul><li>索引的目的是为了更好的检索</li><li>检索的目的是为了提供更好的上下文</li><li>基于更好的上下文，产生更高质量的回答</li></ul><p>在 RAG 发展的不同阶段，对以上 3 个过程进行改进或扩展，到 Agentic RAG 为止，其技术演进思路如下：</p><table><thead><tr><th>范式</th><th>关键特征</th><th>优势</th><th>优化</th></tr></thead><tbody><tr><td> Naive RAG</td><td> 基于关键词的检索 (如 TFIDF、BM25)</td><td> 适合事实性查询</td><td> -</td></tr><tr><td>Advanced RAG</td><td> 密集检索模型 (如 DPR)<br>神经排序和重排序模型<br>多跳检索</td><td>提高上下文相关性</td><td>索引、检索</td></tr><tr><td> Modular RAG</td><td> 混合检索 (稀疏 + 密集)<br>工具或 API 集成<br>可组合流水线</td><td>高度灵活和定制性<br>可扩展</td><td>索引、检索、生成</td></tr><tr><td> Graph RAG</td><td> 图结构数据集成<br>多跳响应<br>通过节点丰富上下文</td><td>关系推理能力<br>减少幻觉<br>适合结构化数据任务</td><td>索引、检索、生成</td></tr><tr><td> Agentic RAG</td><td> 自主智能体<br>动态决策<br>迭代优化与工作流调整</td><td>适应实时变化<br>高可扩展性<br>适合多模态任务</td><td>生成</td></tr></tbody></table><h2 id="rag演进历史"><a class="markdownIt-Anchor" href="#rag演进历史"></a> RAG 演进历史</h2><h3 id="naive-rag"><a class="markdownIt-Anchor" href="#naive-rag"></a> Naive RAG</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155654.png" alt="RAG演进05-AgenticRAG-20250214155654"></p><p>Naive RAG 是 RAG 的基础实现，侧重于关键词检索技术，如 TFIDF,BM25 等传统检索技术，从静态数据集中获取上下文，然后生成输出</p><p><strong>缺点</strong></p><ul><li><strong>缺乏上下文感知</strong>&nbsp;：由于依赖关键词匹配，导致无法捕捉查询的语义细微差别</li><li><strong>输出碎片化</strong>&nbsp;：缺乏高级预处理或上下文集成，导致生成的响应可能不连贯或过于通用。</li><li><strong>可扩展性问题</strong>&nbsp;：基于关键词的检索技术，无法处理大规模数据集</li></ul><h3 id="advanced-rag"><a class="markdownIt-Anchor" href="#advanced-rag"></a> Advanced RAG</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155654-1.png" alt="RAG演进05-AgenticRAG-20250214155654-1"></p><p>针对 Naive RAG 的缺点，Advanced RAG 着重提升检索阶段的质量，包括以下改进：</p><ul><li><strong>密集向量搜索</strong>：查询和文档在高维空间表示，实现了用户查询和检索文档之间的更好 语义对齐</li><li><strong>上下文重排</strong>：神经模型重新排序检索到的文档，优先考虑最上下文相关的信</li><li><strong>迭代检索</strong>：Advanced RAG 引入了多跳检索机制，能够对复杂查询进行跨多个文档的推理</li></ul><h3 id="modular-rag"><a class="markdownIt-Anchor" href="#modular-rag"></a> Modular RAG</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155655.png" alt="RAG演进05-AgenticRAG-20250214155655"></p><p>Modular RAG 在灵活性和定制化方向，扩展 RAG 的应用，将 RAG 的关键过程（如：检索、生成、路由) 整合成独立的组件，方便复用，它有以下优势：</p><ul><li><strong>混合检索策略</strong>&nbsp;：结合稀疏和密集检索方法，以最大化不同查询类型的准确性。</li><li><strong>工具集成</strong>&nbsp;：整合外部 API、数据库或计算工具来处理专业任务，例如实时数据分析或特定领域的计算。</li><li><strong>可组合管道</strong>&nbsp;：允许独立替换、增强或重新配置检索器、生成器和其他组件。</li></ul><h3 id="graph-rag"><a class="markdownIt-Anchor" href="#graph-rag"></a> Graph RAG</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155655-1.png" alt="RAG演进05-AgenticRAG-20250214155655-1"></p><p>通过对 “全局文档” 构建知识图谱，实现对多跳、复杂推理的进一步提升，本质上是重构数据的索引方式，有原来片段式的索引，变为全局实体及关系的索引</p><p>GraphRAG 的特点：</p><ul><li><strong>节点连接</strong>：捕获和响应不同推理实体之间的关系</li><li><strong>分层知识管理</strong>：通过图形化的结构处理结构化和非结构化数据</li><li><strong>上下文扩充</strong>：基于图形化的路径增加关系的理解</li></ul><p>GraphRAG 存在以下局限性：</p><ul><li>扩展限制：图结构化数据限制其扩展性，尤其是使用大量数据源的情况下，知识图谱更新将变得困难</li><li>数据依赖：需要构建高质量的知识图谱，在非结构化数据或者缺少注释的数据上难以做到</li><li>集成的复杂性：检索图形化数据、非结构化数据，其集成将使得系统变得复杂</li></ul><h3 id="agentic-rag"><a class="markdownIt-Anchor" href="#agentic-rag"></a> Agentic RAG</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155655-2.png" alt="RAG演进05-AgenticRAG-20250214155655-2"></p><p>Agentic RAG 通过引入动态决策和流程化的自主代理，实现 “静态系统” 到 “动态系统” 的转换，可以进行迭代优化和自适应检索策略处理复杂问题</p><p>Agentic RAG 的主要特征包括：</p><ul><li>自主决策：根据查询的复杂性，自主评估和管理检索策略</li><li>迭代优化：整合反馈循环，提高检索的准确性和响应相关性</li><li>工作流优化：动态编排任务，提高实时应用程序的效率</li></ul><p>Agentic RAG 存在以下调挑战：</p><ul><li><strong>协调复杂性</strong>：管理代理之间的交互需要复杂的编排机制</li><li><strong>计算开销</strong>：使用多个代理会增加复杂工作所需的资源</li><li><strong>可扩展性限制</strong>：虽然可扩展，但系统的动态特性可能会使高查询量的计算资源紧张</li></ul><h2 id="agentic-rag-详解"><a class="markdownIt-Anchor" href="#agentic-rag-详解"></a> Agentic RAG 详解</h2><h3 id="agent的组成及能力"><a class="markdownIt-Anchor" href="#agent的组成及能力"></a> Agent 的组成及能力</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155656.png" alt="RAG演进05-AgenticRAG-20250214155656"></p><p><strong>Agent 的组成</strong></p><ul><li>LLM (定义角色和任务)：作为 Agent 的主要推理引擎，它解释用户查询、生成响应</li><li> Memory (短期和长期)：捕获上下文相关的数据，短期记忆追踪对话状态，长期记忆积累知识和代理体验</li><li> Planning (反思和自我批评)：通过反思和自我评判，指导代理迭代过程，确保复杂任务被分解</li><li> Tool (工具)：将代理功能扩展到文本生成之外，支持访问外部资源，使用外部工具</li></ul><p>支撑 Agent 动态地构建工作流的模式有四种：</p><p><strong>1. 反思 (Reflection)</strong></p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155656-1.png" alt="RAG演进05-AgenticRAG-20250214155656-1"></p><p>代理工作流中的基本模式，通过评估去迭代优化输出，评估指标包括输出的正确性、风格等</p><p>在多智能体系统中，反射可能涉及不同的角色，例如一个智能体生成输出，而另一个智能体批评输出，从而促进协作改进。</p><p>2. 规划 (Planning)</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155657.png" alt="RAG演进05-AgenticRAG-20250214155657"></p><p>规划是另一种代理设计模式，它使代理能够自主将复杂任务分解为更小、可管理的子任务，但是与 Reflection 相比， Planning 产生的结果更难预测，因为子任务相互依赖，后续子任务的质量依赖前序子任务的完成质量</p><p>3. 工具使用 (Tool use)</p><p>工具使用是代理的另一种工作流，它允许代理调用外部工具、API，提供更准确的上下文</p><p>4. 多代理 (Multi-Agent)</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155657-1.png" alt="RAG演进05-AgenticRAG-20250214155657-1"></p><p>多代理是代理工作流的一种关键模式，它支持任务专业化和并行处理。代理之间可以共享数据、相互指派任务，每个代理内部有自己的记忆、工具、反思和规划，从而实现动态和协作地解决问题</p><h3 id="agentic-rag-分类"><a class="markdownIt-Anchor" href="#agentic-rag-分类"></a> Agentic RAG  分类</h3><h4 id="single-agent-agentic-rag-router"><a class="markdownIt-Anchor" href="#single-agent-agentic-rag-router"></a> Single-Agent Agentic RAG: Router</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155657-2.png" alt="RAG演进05-AgenticRAG-20250214155657-2"></p><p>有单一的 Agent 管理检索、路由和整合的过程，简化了系统设计，但是工具或者数据源的数量有限，因为单一的 Agent 无法使用大量的工具</p><p>工作流程：</p><ol><li>查询提交与评估：代理分析查询，分析确定合适的数据源</li><li>数据源选择：<ol><li>结构化数据：对于需要表格访问的查询，需要将查询转为 SQL 处理</li><li>语义搜索：处理非结构化数据时，基于向量检索即可</li><li>网络搜索：对于实时或者广泛上下文，系统利用互联网工具访问最新在线数据</li><li>推荐系统：对于个性化或上下文查询，利用推荐引擎提供定制化建议</li></ol></li><li>数据集成及生成：将检索到的数据集成到 prompt 中，然后交给 llm 生成</li></ol><h4 id="multi-agent-agentic-rag-systems"><a class="markdownIt-Anchor" href="#multi-agent-agentic-rag-systems"></a> Multi-Agent Agentic RAG Systems</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155658.png" alt="RAG演进05-AgenticRAG-20250214155658"></p><p>由于一个代理绑定的工具有限，那么在 Single-Agent Agentic RAG 的基础上增加多个代理，用于负责不同工具的调用，降低每个 agent 的 prompt 复杂程度</p><p>工作流程：</p><ol><li>查询提交与评估：代理分析查询，分析确定合适的代理</li><li>代理源：<ol><li>代理 X：负责向量化检索</li><li>代理 Y：负责网络搜索</li><li>代理 Z：负责检索邮件或者聊天</li></ol></li><li>将检测的数据整合进 prompt，</li></ol><h4 id="hierarchical-agentic-rag-systems"><a class="markdownIt-Anchor" href="#hierarchical-agentic-rag-systems"></a> Hierarchical Agentic RAG Systems</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155658-1.png" alt="RAG演进05-AgenticRAG-20250214155658-1"></p><p>相比较 Multi-Agent Agentic RAG，Hierarchical Agentic RAG 采取多层次的方法构建 Agent，高级代理监督指导低级代理</p><p>工作流程：</p><ol><li>接受查询，委派代理：用户输入查询，通过主代理分析，并将任务委派到子代理，可以委派多个代理</li><li>子代理执行：子代理执行，并将结果返回给主代理，可以有多个子代理同时执行</li><li>生成：主代理接收汇总所有子代理的结果，然后生成回答</li></ol><h4 id="agentic-corrective-rag"><a class="markdownIt-Anchor" href="#agentic-corrective-rag"></a> Agentic Corrective RAG</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155659.png" alt="RAG演进05-AgenticRAG-20250214155659"></p><p>通过引入自我纠正机制，迭代细化上下文文档和回答，最小化错误并最大化相关性。这个方法是通过以下 3 个代理完成：</p><ol><li>上下文检索代理：负责从向量库中检索上下文</li><li>相关性评估代理：评估检索到文档的相关性，过滤或纠正对于标记不相关的文档</li><li>查询细化代理：利用语义理解来优化检索，以获得更好上下文</li><li>外部知识代理：当上下文不足时，通过网络搜索补充知识</li><li>回答合成代理：将经过验证的代理整合生成连贯且准确的回答</li></ol><h4 id="adaptive-agentic-rag"><a class="markdownIt-Anchor" href="#adaptive-agentic-rag"></a> Adaptive Agentic RAG</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214155659-1.png" alt="RAG演进05-AgenticRAG-20250214155659-1"></p><p>根据查询的复杂性调整处理策略，然后选择合适的方法回答问题，范围从单步推理到多步推理，甚至对于简单回答直接回答，而不需检索</p><p>Adaptive Agentic RAG 的核心在于根据查询复杂性动态调整检索策略，比如以下检索：</p><ul><li><strong>直接查询</strong>：对于不需要额外检索的基于事实的问题（例如，“水的沸点是多少？”），系统使用预先存在的知识直接生成答案。</li><li><strong>简单查询</strong>：对于需要最少上下文的中等复杂任务（例如，“我的最新电费单的状态如何？”），系统会执行单步检索以获取相关详细信息。</li><li><strong>复杂查询</strong>：对于需要迭代推理的多层查询（例如，“过去十年 X 市的人口变化如何，影响因素是什么？”），系统采用多步骤检索，逐步提炼中间结果以提供全面的答案</li></ul><p>Adaptive Agentic RAG 有以下 3 个核心组件组成</p><ol><li>分类角色：一个小的 llm 模型，用于分析用户输入复杂性，分类器使用从过去模型结果和查询模式派生的自动标注数据训练</li><li>动态策略选择：直接查询，使用 llm 回答即可；简单查询，单步检索即可；复杂查询，使用多步检索，迭代细化结果</li><li> LLM 整合：整合检索到的信息，生成回答</li></ol><h4 id="graph-based-agentic-rag"><a class="markdownIt-Anchor" href="#graph-based-agentic-rag"></a> Graph-Based Agentic RAG</h4><h5 id="agent-g-agentic-framework-for-graph-rag"><a class="markdownIt-Anchor" href="#agent-g-agentic-framework-for-graph-rag"></a> Agent-G: Agentic Framework for Graph RAG</h5><p>![[Pasted image 20250214135237.png]]</p><p>Agent-G 将图知识库和非结构化的检索相结合，通过结构化的知识图谱和非结构化化数据增强 RAG 的上下文能力，在设计上还使用评估模块，确保可以不断迭代细化，产生高质量输出</p><p>工作流程：</p><ol><li>知识图谱：使用结构化的数据在全文层次提取实体、关系</li><li>非结构化数据：使用向量检索检索相关文档</li><li>评估模块：评估检索到信息的相关性和质量，确保与查询一致</li><li>反馈循环：通过迭代验证和重新查询细化检索和合成</li></ol><h5 id="geargraph-enhancedagentforretrieval-augmentedgeneration"><a class="markdownIt-Anchor" href="#geargraph-enhancedagentforretrieval-augmentedgeneration"></a> GeAR:Graph-EnhancedAgentforRetrieval-AugmentedGeneration</h5><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214135311.png" alt="RAG演进05-AgenticRAG-20250214135311"></p><p>通过整合基于图的检索机制来增强传统的 RAG 系统。通过利用图扩展技术和基于代理的架构，GeAR 解决了多跳检索场景中的挑战，提高了系统处理复杂查询的能力</p><ul><li>增强的多跳检索：GeAR 的图形扩展允许系统处理需要对多个互连信息进行推理的复杂查询。</li><li>代理决策：代理框架支持动态和自主选择检索策略，从而提高效率和相关性。</li><li>提高准确性：通过整合结构化图形数据，GeAR 提高了检索信息的精度，从而产生更准确和上下文合适的响应。</li><li>可扩展性：代理框架的模块化特性允许根据需要集成其他检索策略和数据源。</li></ul><h4 id="agentic-document-workflows-in-agentic-rag"><a class="markdownIt-Anchor" href="#agentic-document-workflows-in-agentic-rag"></a> Agentic Document Workflows in Agentic RAG</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B05-AgenticRAG-20250214135333.png" alt="RAG演进05-AgenticRAG-20250214135333"></p><p>智能体文档工作流（Agentic Document Workflows, ADW）整合文档解析、检索、推理和结构化输出与智能代理。通过维护状态、协调多步工作流，并将领域特定逻辑应用于文档，解决了智能文档处理（IDP）和 RAG 的限制</p><p>工作流程：</p><ol><li>文档解析和信息结构化：使用企业级工具（例如 LlamaParse）解析文档，提取相关数据字段，如发票号码、日期、供应商信息、明细项和付款条款</li><li>跨流程的状态维护： 系统维护文档上下文的状态，确保多步工作流之间的一致性和相关性</li><li>知识检索：从外部知识库或向量索引中检索相关参考文献</li><li>代理协调：智能代理应用业务规则，执行多跳推理，并生成可操作的建议</li><li>可操作输出生成：以结构化格式呈现输出，针对特定用例进行定制</li></ol><h3 id="agentic-rag框架比较"><a class="markdownIt-Anchor" href="#agentic-rag框架比较"></a> Agentic RAG 框架比较</h3><table><thead><tr><th>特征</th><th>传统 RAG</th><th>Agentic RAG</th><th> 代理文档工作流（ADW）</th></tr></thead><tbody><tr><td>聚焦</td><td>孤立的检索和生成任务</td><td>多代理协作和推理</td><td>以文档为中心的端到端工作流</td></tr><tr><td>上下文维护</td><td>有限的</td><td>通过记忆模块实现</td><td>在多步骤工作流中保持状态</td></tr><tr><td>动态适应性</td><td>最小的</td><td>高</td><td>针对文档工作流定制</td></tr><tr><td>工作流编排</td><td>无</td><td>协调多代理任务</td><td>集成多步骤文档处理</td></tr><tr><td>外部工具 / API 的使用</td><td>基本集成（例如，检索工具）</td><td>通过工具如 API 和知识库扩展</td><td>深度集成业务规则和特定领域的工具</td></tr><tr><td>可扩展性</td><td>限制在小型数据集或查询</td><td>多代理系统可扩展</td><td>适用于多领域企业工作流的扩展</td></tr><tr><td>复杂推理</td><td>基本的（例如，简单的问答）</td><td>与代理进行多步骤推理</td><td>在文档间进行结构化推理</td></tr><tr><td>主要应用</td><td>问答系统，知识检索</td><td>多领域知识和推理</td><td>合同审查、发票处理、索赔分析</td></tr><tr><td>优势</td><td>简单，快速设置</td><td>高准确性，协作推理</td><td>端到端自动化，领域特定智能</td></tr><tr><td>挑战</td><td>较差的上下文理解</td><td>协调复杂性</td><td>资源开销，领域标准化</td></tr></tbody></table><p>Agentic RAG 的出现，意味着 RAG 和 Agent 结合迈向下一步，扩展了 RAG 解决复杂查询的能力，如果 llm 没有进一步发展，可以试想这个架构将统治 RAG 很长一段时间。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;本文用于记录学习 AgenticRAG 的过程&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
  <entry>
    <title>RAG 演进 04-GraphRAG</title>
    <link href="https://www.shaogui.life/posts/1558141483.html"/>
    <id>https://www.shaogui.life/posts/1558141483.html</id>
    <published>2025-02-12T07:19:19.000Z</published>
    <updated>2025-02-12T07:20:27.005Z</updated>
    
    <content type="html"><![CDATA[<p>本文总结 GraphRAG 的定义、和传统 RAG 的区别，以及如何实现 GraphRAG</p><span id="more"></span><h3 id="什么是graphrag"><a class="markdownIt-Anchor" href="#什么是graphrag"></a> 什么是 GraphRAG？</h3><p>GraphRAG（Graph-based Retrieval-Augmented Generation）是微软提出的一种结合知识图谱的增强检索生成框架，旨在改进传统 RAG 的局限性。其核心思想是通过构建<strong>结构化知识图谱</strong>，对文档内容进行全局语义建模，从而提升大语言模型（LLM）在复杂问题中的推理能力和回答准确性。</p><p>与传统 RAG 依赖非结构化文本的向量检索不同，GraphRAG 通过图谱的拓扑结构（如节点、边、子图）挖掘深层语义关联，解决传统 RAG 在长尾问题、多跳推理和全局一致性上的不足，总的来说，他们之间有以下区别：</p><table><thead><tr><th><strong>维度</strong></th><th><strong>传统 RAG</strong></th><th><strong>GraphRAG</strong></th></tr></thead><tbody><tr><td><strong> 数据组织方式</strong></td><td>非结构化文本的向量化表示</td><td>结构化知识图谱（实体、关系、子图）</td></tr><tr><td><strong>检索逻辑</strong></td><td>基于向量相似度的片段匹配</td><td>基于图谱拓扑的关联推理（多跳查询）</td></tr><tr><td><strong>上下文理解</strong></td><td>局部语义匹配，缺乏全局关联性</td><td>全局图谱支持复杂逻辑推理</td></tr><tr><td><strong>可解释性</strong></td><td>黑盒检索，结果依赖向量空间分布</td><td>白盒化路径检索，答案支持图谱溯源</td></tr><tr><td><strong>适用场景</strong></td><td>简单问答、单跳检索任务</td><td>复杂推理、多跳问答、事件关联分析</td></tr><tr><td><strong>数据动态性</strong></td><td>更新需全量重新嵌入</td><td>支持增量更新（增删节点 / 边）</td></tr><tr><td><strong>计算开销</strong></td><td>低（仅需向量索引）</td><td>高（需构建和维护知识图谱）</td></tr></tbody></table><p>RAG 每次检索出的 chunk，只是某个文档的内容，而通过知识图谱，GraphRAG 在全局预料层面检索知识，弥补了传统 RAG 在复杂关联推理中的不足，同时提升了回答的可解释性。</p><h3 id="graphrag的实现过程"><a class="markdownIt-Anchor" href="#graphrag的实现过程"></a> GraphRAG 的实现过程</h3><pre><code class="highlight mermaid">flowchart LRA[文档]B[实体/关系抽取]C[构建知识图谱]D[社区划分]E[检索子图]F[llm生成答案]A--&gt;B--&gt;C--&gt;D--&gt;E--&gt;F</code></pre><p>GraphRAG 的实现分为以下关键步骤：</p><ol><li><p><strong>索引（数据预处理与知识图谱构建</strong>）</p><ul><li><strong>实体与关系抽取</strong>：使用 LLM 解析文档，提取实体（如人物、地点、事件）及其关系（如 “属于”“导致”）。</li><li><strong>图谱生成</strong>：将实体和关系存储为图结构（节点 = 实体，边 = 关系），支持属性扩展（如时间、上下文描述）。</li><li><strong>社区检测</strong>：通过图聚类算法（如 Louvain）将图谱划分为主题社区（如 “技术”“市场”），每个社区代表一个语义子图。</li><li><strong>生成社区摘要</strong>：使用自下而上的方法为每个 community 及其中的重要部分生成摘要。这些摘要包括 Community 内的主要 Entity、Entity 的关系和关键 Claim。这一步为整个数据集提供了概览，并为后续查询提供了有用的上下文信息</li></ul></li><li><p><strong>查询生成 (检索)</strong><br>利用用户提问到知识图谱检索上下文，根据检索的范围，可以分为 2 种检索方式：</p></li></ol><ul><li>全局查询：在整个知识图谱检索上下文，可以跨多文档检索</li><li>本地查询：在特定子图（如社区、领域节点），聚焦在某个实体上</li></ul><p>举个例子说明两者之间的关系：</p><table><thead><tr><th></th><th>全局查询</th><th>本地查询</th></tr></thead><tbody><tr><td><strong>目标</strong></td><td>从全局图谱中提取跨社区、跨实体的关联信息，解决复杂问题</td><td>快速检索与特定实体直接相关的信息，解决简单问题</td></tr><tr><td><strong>步骤</strong></td><td> 1. 问题解析：使用 LLM 解析问题，提取关键实体（如 “公司 A”“产品策略”“市场地位”）。<br><br>2. 图谱范围检索：<br>       社区检测：通过聚类算法（如 Louvain）识别与实体相关的多个社区（如 “产品社区”“市场社区”）。<br>       跨社区路径发现：使用图遍历算法（如随机游走、广度优先搜索）查找跨社区的关系路径。<br><br>3. 子图融合与增强：<br>      合并相关社区的局部子图，生成全局语义视图。<br>      通过图嵌入（如 Node2Vec）生成向量表示，捕捉全局语义关联。<br><br>4. 生成答案：<br>       将融合后的子图信息转化为自然语言提示，输入 LLM 生成回答。</td><td>1. 实体定位：<br>   使用命名实体识别（NER）或关键词匹配定位目标实体（如 “产品 A”）。<br><br>2. 邻域检索：<br>   提取实体的直接邻居节点（如 “功能”“用户群体”“技术参数”）。<br>   限定跳数（如 1~2 跳）遍历子图，避免过度扩展。<br><br>3. 上下文生成：<br>   将邻域关系转化为文本描述（如 “产品 A 的功能包括 X、Y，主要用户是 Z”）。<br><br>4. 生成答案：<br>   直接基于局部上下文生成简洁回答，无需复杂推理。</td></tr><tr><td><strong>问题</strong></td><td>产品 A 的销量下降导致公司 B 股价下跌？</td><td>产品 A 的发布日期是什么时候？</td></tr><tr><td><strong>检索路径</strong></td><td>产品 A（社区 1）→ 用户投诉（社区 1）→ 市场份额下降（社区 2）→ 财报数据（社区 3）→ 股价下跌（社区 3）</td><td>产品 A → 发布日期 → “2023 年 1 月”</td></tr><tr><td> 使用算法</td><td>社区检测、图嵌入</td><td>邻域遍历、关键词匹配</td></tr><tr><td>使用工具</td><td> Neo4j（图遍历）、Gephi（可视化）</td><td>Elasticsearch（关键词检索）、NetworkX</td></tr></tbody></table><p>所以全局查询强调整体关联，适合复杂推理，但计算成本高，本地查询注重效率，适合简单问答，但依赖图谱构建质量，实际使用建议混合使用，当用户提出关于特定 Entity（如人名、地点、组织等）的问题时，建议使用本地搜索工作流程，若结果不足再触发全局检索</p><ol start="3"><li><strong>生成</strong><ul><li><strong>子图检索</strong>：根据用户问题<strong>定位相关社区或子图</strong>，提取结构化信息（如实体链、关系路径）。</li><li><strong>上下文增强</strong>：将子图信息转化为文本提示（Prompt），输入 LLM 生成最终回答。</li></ul></li></ol><h3 id="graphrag的例子"><a class="markdownIt-Anchor" href="#graphrag的例子"></a> GraphRAG 的例子</h3><p><strong>背景</strong>：微软基于 GraphRAG 构建新闻事件分析系统，用于回答涉及多实体关联的复杂问题，例如：“2023 年某科技公司并购案对行业竞争格局的影响？”</p><p><strong>实施流程</strong>：</p><ol><li><strong>数据源</strong>：爬取 100 万篇科技新闻，清洗后提取实体（公司、人物、产品）及关系（并购、合作、竞争）。</li><li><strong>图谱构建</strong>：在 Neo4j 中存储节点（公司 A、公司 B）、边（“并购”）、属性（时间、金额）。</li><li><strong>混合检索</strong>：<ul><li>向量检索：通过 Milvus 检索与 “并购” 相关的文本片段。</li><li>图谱检索：查询 “公司 A→并购→公司 B→影响→行业竞争” 的多跳路径。</li></ul></li><li><strong>答案生成</strong>：将检索结果（文本片段 + 子图结构）输入 GPT-4，生成包含因果链的答案。</li></ol><p><strong>效果对比</strong>：</p><ul><li><strong>传统 RAG</strong>：仅能返回单篇新闻中的并购金额等片段化信息。</li><li><strong>GraphRAG</strong>：可生成跨时间、跨实体的综合分析，例如：“并购导致市场份额向头部集中，触发反垄断调查”。</li></ul><h3 id="总结"><a class="markdownIt-Anchor" href="#总结"></a> 总结</h3><p>GraphRAG 通过知识图谱的引入，解决了传统 RAG 在复杂推理任务中的局限性，但其实现成本较高，需权衡业务需求与资源投入。未来方向包括：</p><ol><li><strong>轻量化图谱构建</strong>：利用 LLM 自动生成三元组。</li><li><strong>动态图谱更新</strong>：结合流数据处理技术实现实时更新。</li><li><strong>多模态扩展</strong>：融合文本、图像、表格等多模态数据构建异构图谱。</li></ol><p><strong>适用场景建议</strong>：</p><ul><li><strong>推荐采用 GraphRAG</strong>：金融风控（关联交易分析）、医疗诊断（病因推理）、法律合规（案件关联性审查）。</li><li><strong>仍适用传统 RAG</strong>：客服问答（单轮对话）、文档摘要（无复杂逻辑）。</li></ul><p>通过灵活选择 RAG 与 GraphRAG 的组合，可最大化大模型在垂直领域落地的价值。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;本文总结 GraphRAG 的定义、和传统 RAG 的区别，以及如何实现 GraphRAG&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
  <entry>
    <title>RAG 演进 02-AdvancedRAG</title>
    <link href="https://www.shaogui.life/posts/3705592184.html"/>
    <id>https://www.shaogui.life/posts/3705592184.html</id>
    <published>2025-02-07T08:02:02.000Z</published>
    <updated>2025-02-07T09:39:08.613Z</updated>
    
    <content type="html"><![CDATA[<p>AdvancedRAG 侧重于增强检索到的文档的相关性和范围，这些技术（包括密集检索、混合搜索、重新排名和查询扩展）解决了 NaiveRAG 基于关键字的检索的限制</p><span id="more"></span><p>文章<a href="https://www.willowtreeapps.com/guides/advanced-rag-techniques">《15 Advanced RAG Techniques from Pre-Retrieval to Generation》</a> 总结了提升 RAG 系统性能的 4 类 15 种技术，以优化 RAG 的输出质量、成本及鲁棒性</p><p>这些技术分别是：</p><ul><li>预检索与数据索引技术<ul><li>技术 1：使用 LLM 提高信息密度（如 GPT-4 提取网页关键信息，减少冗余和噪声）</li><li>技术 2：分层索引检索（通过摘要实现多层检索，提升效率）</li><li>技术 3：假设性问题索引（生成 QA 对嵌入，解决查询 - 文档不对称问题）</li><li>技术 4：LLM 去重（聚类嵌入空间，合并重复信息）</li><li>技术 5：分块策略优化（A/B 测试分块大小、重叠率等参数）</li></ul></li><li>检索技术<ul><li>技术 6：LLM 优化搜索查询（适配 Google 语法或对话上下文）</li><li>技术 7：HyDE（生成假设性文档嵌入，提升语义相似性）</li><li>技术 8：RAG 决策器模式（判断是否需要检索，降低成本）</li></ul></li><li>后检索技术<ul><li>技术 9：重排序（优先展示最相关文档）</li><li>技术 10：上下文提示压缩（如 LLMLingua 框架，压缩无关信息）</li><li>技术 11：纠正性 RAG（T5 模型过滤不相关结果）</li></ul></li><li>生成技术<ul><li>技术 12：链式思考（CoT）提示（通过推理减少噪声影响）</li><li>技术 13：Self-RAG（自省标记，动态调用检索并批判输出）</li><li>技术 14：微调模型（提升忽略无关上下文的能力）</li><li>技术 15：自然语言推理（NLI 模型过滤无关内容）</li></ul></li></ul><h2 id="预检索与数据索引技术"><a class="markdownIt-Anchor" href="#预检索与数据索引技术"></a> 预检索与数据索引技术</h2><p>预检索优化主要目的是提高检索的质量，减少在检索阶段的冗余、错误信息，有助于降低 RAG 系统的成本及幻觉。</p><h3 id="技术1使用llm提高信息密度"><a class="markdownIt-Anchor" href="#技术1使用llm提高信息密度"></a> 技术 1：使用 LLM 提高信息密度</h3><p>RAG 外挂的数据源有不同的信息密度，如 markdown 等文本文件的信息密度一般比 Web 数据信息密度高，因为 Web 数据除了文本数据，还包含一些 html 标记数据，总的来说，数据源可能出现以下问题：</p><ul><li>信息密度低 / 信息冗余：低信息内容要求输入更多 token 到 llm，这会导致使用 llm 成本上升</li><li>不相关信息或噪声：噪声可能误导 llm，甚至在一些常识问题上犯错，也就是外挂数据源引入了新幻觉</li></ul><p>所以我们在存储数据前，可以使用 llm 提高数据源的信息密度，减少冗余数据，比如使用 llm 剥离 Web 数据中的 html 标记等信息，然后再存储到知识库中，</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B02-AdvancedRAG-20250207152903.png" alt="RAG演进02-AdvancedRAG-20250207152903"></p><p>利用这个方法可以明显降低 token 数据</p><ul><li>原始 html：~55000 token</li><li> 精简 html：1500 token</li><li>llm 处理后的 html：330 token</li></ul><p>但是使用这个方法，可能会去除一些有用信息，需要提防</p><h3 id="技术2分层索引检索"><a class="markdownIt-Anchor" href="#技术2分层索引检索"></a> 技术 2：分层索引检索</h3><p>直接用用户提问去检索原始文本，可能因为原始文本的信息密度问题导致检索失败，假设知识库存在以下两个片段，当用户提问是：“llm 的定义？”，可能检索到片段 2，因为 llm 出现更加频繁，但是我们实际需要的是片段 1</p><ol><li>llm 是 xxxxxxx</li><li> 我要使用 llm、llm 真好用</li></ol><p>此时我们除了检索原始文本，还去检索由原始文本总结的摘要，形成摘要、原始文本双层的索引检索，提高检索的质量</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B02-AdvancedRAG-20250207155904.png" alt="RAG演进02-AdvancedRAG-20250207155904"></p><h3 id="技术3假设性问题索引"><a class="markdownIt-Anchor" href="#技术3假设性问题索引"></a> 技术 3：假设性问题索引</h3><p>直接用用户提问去检索原始文本，除了信息密度问题，还有一个问题是：通过问题无法检索到答案相关片段，可能是因为答案片段完成没有出现问题的任何关键词或相近词</p><p>这是使用 llm 针对所有片段生成片段可能的提问，然后将问题与片段一起存储，这样检索类似场景的难度就会下降，比如有以下片段：</p><blockquote><p>LLM 是大型语言模型（Large Language Model）的缩写，是一种基于深度学习的人工智能技术，通过大量文本数据训练，能够理解和生成自然语言。LLM 可用于文本生成、问答、翻译、摘要、情感分析等任务，还能进行代码生成和多模态内容创作。LLM 采用 Transformer 架构，通过自注意力机制捕捉上下文信息，利用无监督学习方法，如下一个单词预测和掩码语言模型，学习语言的模式和规律。</p></blockquote><p>可以将其拆分为 3 个 QA 组合：</p><blockquote><p>问题：llm 的定义？ 回答：[以上原始文本]<br>问题：llm 能做什么？ 回答：[以上原始文本]<br>问题：llm 的原理？ 回答：[以上原始文本]</p></blockquote><p>可以看出，使用这个方法会扩充知识库，提高检索成本，这时候可能使用 HyDE 技术，使用 llm 先回答用户提问，然后使用 [提问，llm 回答] 去检索文档，相当于没有扩充知识库，而是丰富了提问</p><h3 id="技术4llm去重"><a class="markdownIt-Anchor" href="#技术4llm去重"></a> 技术 4：LLM 去重</h3><p>不同数据源可能包含重复信息，并且会排挤到真正有用的上下文，比如以下 3 个片段</p><blockquote><ol><li>苹果清新甜美，苹果口感爽脆多汁，果香浓郁</li><li>苹果，一种常见的水果，富含维生素和膳食纤维，口感清脆多汁，营养丰富，可鲜食或加工成果汁、果酱等</li><li>苹果清新甜美，苹果口感爽脆多汁，果香浓郁</li></ol></blockquote><p>假设用户提问是 “苹果是什么？”，当检索系统（有可能）认为片段 1 比片段 2 评分高，由于片段 1、3 是重复的，所以 3 个片段的检索评分是 1=3&gt;2，在检索系统要求返回评分 topk2 的片段时，只会返回片段 1、3，但是实际需要的是片段 2</p><p>所以说数据源的去重，有助于提高上下文的准确性，提高模型回答质量</p><h3 id="技术5分块策略优化"><a class="markdownIt-Anchor" href="#技术5分块策略优化"></a> 技术 5：分块策略优化</h3><p>不同数据源，信息密度不同，在分块构建索引时，不同的分块策略，可能影响检索准确性，所以在通过调整：分块大小、重叠率、选择不同嵌入模型，测试检索的效果很有必要</p><p>在 llamaindex 中，已经提供流程化的工具去评估检索器，可能通过效果以上参数，测试检索器性能，选择最佳的参数</p><h2 id="检索技术"><a class="markdownIt-Anchor" href="#检索技术"></a> 检索技术</h2><h3 id="技术6llm优化搜索查询"><a class="markdownIt-Anchor" href="#技术6llm优化搜索查询"></a> 技术 6：LLM 优化搜索查询</h3><p>在检索内容时，当用户的提问以特定的形式出现时，检索的质量更加精确，比如用户提问：张三 2025 年 2 月 7 日做了什么？直接检索可能检索到其他日期的事情，因为单个数字在检索系统起到的作用很小，此时我们将用户检索转换为<br>[“张三”,“2025-02-07”]，然后使用这两个字段，类似数据库的方式去查找记录，得到的结果就是 100% 准确的</p><p>在多轮的对话系统中，直接使用某次对话去检索上下文可能失败，因为有些信息包含在对话历史中，比如以下对话</p><blockquote><p>用户：xxx 的价格是多少<br>机器: xxx<br>用户：它能解决 xxxx</p></blockquote><p>如果直接使用 "它能解决 xxxx" 去检索知识库，得到的肯定是错误答案，此时基于对话历史先用 llm 优化查询很重要</p><h3 id="技术7hyde"><a class="markdownIt-Anchor" href="#技术7hyde"></a> 技术 7：HyDE</h3><p>在查询 vs 文档不对称的 RAG 系统中，直接使用查询去检索文档的准确性不高，此时先使用 llm 生成回答，然后使用 [查询、llm 回答] 去检索文档，准确性得到提高</p><blockquote><p>所谓不对称是指目标文档未直接包含查询内容，或者目标文档在语义层面回答了问题，但是没在字面上回答问题</p></blockquote><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B02-AdvancedRAG-20250207165135.png" alt="RAG演进02-AdvancedRAG-20250207165135"></p><h3 id="技术8rag决策器模式"><a class="markdownIt-Anchor" href="#技术8rag决策器模式"></a> 技术 8：RAG 决策器模式</h3><p>在准备检索文档前，先使用 llm 判断是否需要检索，对于一些常识问题或者在对话历史出现的问题，直接使用 llm 回答即可，而不需要再次去检索文档，已降低 RAG 系统的成本</p><p>比如提问:</p><blockquote><p>请统计 "xxx" 这段话中的人物？<br>请将 "" 翻译为英文？</p></blockquote><h2 id="后检索技术"><a class="markdownIt-Anchor" href="#后检索技术"></a> 后检索技术</h2><h3 id="技术9重排序"><a class="markdownIt-Anchor" href="#技术9重排序"></a> 技术 9：重排序</h3><p>在检索系统找到的文档中，根据查询可以分为以下 4 类</p><ul><li>相关文档 (可以直接回答查询的)</li><li> 有联系但是不相关</li><li>无联系并且不相关</li><li>反事实文档（与相关文档相反的）</li></ul><p>理想的检索系统是只检索到 “相关文档”，但是实际情况往往会出现其他类型的文档，此时使用 llm 对检索到的文档进行重排序，可以有效提高 llm 的回答质量</p><h3 id="技术10上下文提示压缩"><a class="markdownIt-Anchor" href="#技术10上下文提示压缩"></a> 技术 10：上下文提示压缩</h3><p>和重排序不同的是，这里使用 llm 压缩检索到的文档，目的是提高相关文档的重要性，过滤掉其他文档的干扰</p><p>这里 llm 相当于一个过滤器，智能选择能回答提问的上下文</p><h3 id="技术11纠正性rag"><a class="markdownIt-Anchor" href="#技术11纠正性rag"></a> 技术 11：纠正性 RAG</h3><p>使用 llm 评估 RAG 系统的输出，将其分类为 [正确、不正确] 的，对于不正确的提问，直接丢起</p><p>比如提问 “苹果是什么？”，然后检索到的文档如下，使用每个检索到的文档去生成回答，一共得到 3 个回答，然后使用 llm 选择最佳答案</p><blockquote><ol><li>苹果清新甜美，苹果口感爽脆多汁，果香浓郁</li><li>苹果，一种常见的水果，富含维生素和膳食纤维，口感清脆多汁，营养丰富，可鲜食或加工成果汁、果酱等</li><li>苹果清新甜美，苹果口感爽脆多汁，果香浓郁</li></ol></blockquote><h2 id="生成技术"><a class="markdownIt-Anchor" href="#生成技术"></a> 生成技术</h2><h3 id="技术12链式思考cot提示"><a class="markdownIt-Anchor" href="#技术12链式思考cot提示"></a> 技术 12：链式思考（CoT）提示</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/RAG%E6%BC%94%E8%BF%9B02-AdvancedRAG-20250207171659.png" alt="RAG演进02-AdvancedRAG-20250207171659"></p><p>利用链式思考（CoT）提升模型在噪声或者不相关文档上回答准确性</p><h3 id="技术13self-rag"><a class="markdownIt-Anchor" href="#技术13self-rag"></a> 技术 13：Self-RAG</h3><p>给定一个输入提示和先前的生成，Self‑RAG&nbsp;&nbsp;首先确定使用检索到的段落来扩充后续生成 是否有帮助。如果是，它会输出一个检索标记，该标记会根据需要调用检索器模型。</p><p>随后，Self‑RAG&nbsp;&nbsp;会同时处理多个检索到的段落，评估它们的相关性，然后生成相应的任务输 出。然后，它会生成批评标记来批评自己的输出，并从事实性和整体质量方面选择最佳的输出</p><h3 id="技术14微调模型"><a class="markdownIt-Anchor" href="#技术14微调模型"></a> 技术 14：微调模型</h3><p>鉴于&nbsp;&nbsp;LLM&nbsp;&nbsp;通常不会针对&nbsp;&nbsp;RAG&nbsp;&nbsp;进行明确训练或调整，因此可以推断，针对此用例对模型进行微调可以提高 模型忽略不相关上下文的能力</p><h3 id="技术15自然语言推理"><a class="markdownIt-Anchor" href="#技术15自然语言推理"></a> 技术 15：自然语言推理</h3><p>使用自然语言推理&nbsp;&nbsp;(NLI)&nbsp;&nbsp;模型来识别不相关的上下文，并过滤</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;AdvancedRAG 侧重于增强检索到的文档的相关性和范围，这些技术（包括密集检索、混合搜索、重新排名和查询扩展）解决了 NaiveRAG 基于关键字的检索的限制&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
  <entry>
    <title>基于 OpenVINO 在 CPU 上部署模型</title>
    <link href="https://www.shaogui.life/posts/1484111444.html"/>
    <id>https://www.shaogui.life/posts/1484111444.html</id>
    <published>2025-02-06T09:03:59.000Z</published>
    <updated>2025-02-07T02:59:25.081Z</updated>
    
    <content type="html"><![CDATA[<p>在实际场景中，深度学习模型运行的机器不总是有 GPU，这时候基于 tensorrt 的部署方案就无法满足要求，本文基于 openvino 压缩量化模型，然后将其部署到 CPU 上，实验证明：推理效果相当、推理耗时接近</p><span id="more"></span><h2 id="1-模型优化与性能提升"><a class="markdownIt-Anchor" href="#1-模型优化与性能提升"></a> 1. 模型优化与性能提升</h2><p>在深度学习模型应用中，模型的大小和复杂度往往与其推理速度和内存占用密切相关。为了在资源有限的环境中部署大型网络，通常需要对模型进行优化。这一优化过程主要包括以下几方面：</p><ul><li><strong>量化（Quantization）</strong>：通过将浮点数向量的精度降低，将模型权重和激活值从 32 位转换为更高效的表示形式（如 8 位或 4 位），从而减少内存占用并加快推理速度。</li><li><strong>模型压缩（Model Compression）</strong>：通过去除冗余的网络结构、迁移学习等技术，削弱不必要的计算步骤，进一步降低模型复杂度。</li><li><strong>优化后端（Post-Tuning）</strong>：针对特定硬件架构（如 CPU）的性能特点，对模型进行微调，使其在资源利用上达到最佳状态。</li></ul><p>OpenVINO 框架通过集成了这些技术，为模型优化提供了通用的解决方案，支持多种深度学习模型的部署和推理任务。</p><h2 id="2-基于ncff量化模型"><a class="markdownIt-Anchor" href="#2-基于ncff量化模型"></a> 2. 基于 ncff 量化模型</h2><p>NCFF（NNAfter Calibration and Folding）是一种用于对 ONNX 模型进行量化的工具。其核心原理是通过对模型中的神经网络层参数和激活值进行量化，减少数据类型的精度，从而降低计算复杂度和内存占用</p><p>NCFF 量化模型的原理主要包括以下几个步骤：首先，选择适合的量化位数；然后，使用训练数据对量化后的模型进行校准，以找到最佳的量化参数；最后，将原始模型中的神经网络层、权重和激活值转换为量化形式，具体代码如下：</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">quantiz_by_onnx</span>(<span class="params">onnx_path,save_dir=<span class="string">''</span></span>):</span><br><span class="line">    <span class="comment"># 1. 准备量化需要的校准数据集</span></span><br><span class="line">    train_dataset = DatasetBuilder()  <span class="comment"># 针对自己数据编写加载逻辑，和训练时加载逻辑一样</span></span><br><span class="line">    train_loader = DataLoader(train_dataset, batch_size=<span class="number">1</span>, shuffle=<span class="literal">True</span>)</span><br><span class="line">    </span><br><span class="line">    <span class="comment"># 校准时只需要图片数据，不需要标签</span></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">transform_fn</span>(<span class="params">data_item</span>):</span><br><span class="line">        img, label, weight, is_ng = data_item</span><br><span class="line">        <span class="keyword">return</span> {<span class="string">'input'</span>: img.numpy()} </span><br><span class="line"></span><br><span class="line">    <span class="comment"># 构建校准数据集</span></span><br><span class="line">    calibration_dataset = nncf.Dataset(train_loader, transform_fn)</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 2. 使用nncf量化onnx</span></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">'calibration....'</span>)</span><br><span class="line">    onnx_model = onnx.load(onnx_path)</span><br><span class="line">    quantized_model = nncf.quantize(</span><br><span class="line">        model=onnx_model, </span><br><span class="line">        calibration_dataset=calibration_dataset</span><br><span class="line">        )</span><br><span class="line"></span><br><span class="line">    <span class="comment"># 3. 保存量化模型</span></span><br><span class="line">    quantized_onnx_path=os.path.join(save_dir,os.path.basename(onnx_path))</span><br><span class="line">    onnx.save(quantized_model, quantized_onnx_path)</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">'quantized_onnx_path:'</span>,quantized_onnx_path)</span><br><span class="line"></span><br><span class="line">    <span class="keyword">return</span> quantized_onnx_path</span><br></pre></td></tr></tbody></table></figure><h2 id="3-基于openvino部署量化模型"><a class="markdownIt-Anchor" href="#3-基于openvino部署量化模型"></a> 3. 基于 openvino 部署量化模型</h2><p>按照以下步骤基于 openvino 部署量化模型</p><ol><li>模型转换：把 ONNX 文件转换成 OpenVINO 的 XML 格式。这一步应该涉及到 OpenVINO 的 API，比如 ov.convert_model 和 ov.save_model</li><li> 编译模型：导入了一些性能优化的设置，比如启用 Hyper-threading 和 CPU 固定，这可能是在准备模型进行推理时的设置 </li></ol><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 1. onnx-&gt;openvino</span></span><br><span class="line">quantized_openvino_path=quantized_onnx_path.replace(<span class="string">'.onnx'</span>,<span class="string">'.xml'</span>)</span><br><span class="line">ov_quantized_model = ov.convert_model(quantized_onnx_path)</span><br><span class="line">ov.save_model(ov_quantized_model, quantized_openvino_path)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 2. 编译模型</span></span><br><span class="line">config = {</span><br><span class="line">   hints.performance_mode: hints.PerformanceMode.LATENCY,</span><br><span class="line">   hints.enable_hyper_threading(): <span class="literal">False</span>,</span><br><span class="line">   hints.enable_cpu_pinning(): <span class="literal">True</span>}</span><br><span class="line">core = ov.Core()</span><br><span class="line">core.set_property({props.cache_dir: <span class="string">r"E:\MyCode\python-openvino\cache"</span>})</span><br><span class="line">ov_model = core.read_model(model=quantized_openvino_path)</span><br><span class="line"></span><br><span class="line">compiled_model = core.compile_model(model=ov_quantized_model, device_name=<span class="string">"CPU"</span>,config=config)</span><br><span class="line"></span><br></pre></td></tr></tbody></table></figure><p>基于编译好的模型，我们使用代码去测试其推理效果</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 模型推理</span></span><br><span class="line">img_path=<span class="string">r'1.png'</span></span><br><span class="line">img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), <span class="number">1</span>)</span><br><span class="line">img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)</span><br><span class="line"><span class="comment"># 前处理</span></span><br><span class="line">img = cv2.resize(img, (<span class="number">512</span>,<span class="number">512</span>), interpolation=cv2.INTER_NEAREST)</span><br><span class="line">img_data=img.astype(np.float32)</span><br><span class="line">img_data/=<span class="number">255.0</span></span><br><span class="line"><span class="built_in">print</span>(img_data.shape)</span><br><span class="line">img_data-=<span class="number">0.5</span></span><br><span class="line">img_data/=<span class="number">0.5</span></span><br><span class="line">h,w=img_data.shape[:<span class="number">2</span>]</span><br><span class="line">img_data=img_data.reshape(<span class="number">1</span>,<span class="number">1</span>,h,w)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 推理</span></span><br><span class="line"><span class="built_in">print</span>(<span class="string">'infer....'</span>)</span><br><span class="line">output_layer = compiled_model.output(<span class="number">0</span>)</span><br><span class="line"></span><br><span class="line">iters=<span class="number">100</span></span><br><span class="line">start_time = time.time()  <span class="comment"># 获取当前时间</span></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(iters):</span><br><span class="line">   output=compiled_model([img_data])[output_layer]</span><br><span class="line">end_time = time.time()  <span class="comment"># 获取当前时间</span></span><br><span class="line"><span class="built_in">print</span>(<span class="string">f"运行时间：<span class="subst">{(end_time - start_time)*<span class="number">1.0</span>/iters}</span>秒"</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 后处理</span></span><br><span class="line">pred = np.zeros_like(output)</span><br><span class="line">pred[output&gt;<span class="number">0.5</span>]=<span class="number">255</span></span><br><span class="line">pred=pred[<span class="number">0</span>,<span class="number">0</span>].astype(np.uint8)</span><br><span class="line">result=cv2.addWeighted(img,<span class="number">1.0</span>,pred,<span class="number">0.3</span>,<span class="number">0</span>)</span><br><span class="line">cv2.imwrite(<span class="string">'ai.png'</span>,result)</span><br></pre></td></tr></tbody></table></figure><h2 id="4-在c部署openvino推理过程"><a class="markdownIt-Anchor" href="#4-在c部署openvino推理过程"></a> 4. 在 C++ 部署 openvino 推理过程</h2><p>基于 openvino 的 cpp 版，在 CPU 上部署模型，关键代码如下：</p><figure class="highlight c"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span> <span class="title function_">OpenvinoRuntime::sync_infer</span><span class="params">(<span class="built_in">std</span>::<span class="built_in">vector</span>&lt;cv::Mat&gt;&amp; frames)</span></span><br><span class="line">{</span><br><span class="line">    <span class="keyword">for</span> (<span class="type">int</span> si = <span class="number">0</span>; si &lt; frames.size(); si++)</span><br><span class="line">    {</span><br><span class="line">        <span class="type">int</span> imgWidth = frames[si].size().width;</span><br><span class="line">        <span class="type">int</span> imgHeight = frames[si].size().height;</span><br><span class="line"></span><br><span class="line">        <span class="comment">// 1.前处理</span></span><br><span class="line">        <span class="type">int</span> retn = preprocess(frames[si]);</span><br><span class="line">        <span class="keyword">if</span> (retn &lt; <span class="number">0</span>)</span><br><span class="line">        {</span><br><span class="line">            <span class="keyword">return</span> <span class="number">-1</span>;</span><br><span class="line">        }</span><br><span class="line"></span><br><span class="line">        <span class="comment">// 2.推理</span></span><br><span class="line">        try {</span><br><span class="line">            <span class="type">const</span> ov::Tensor input_tensor = ov::Tensor(compiled_model.input().get_element_type(), ov::Shape(modelInputShape), (<span class="type">float</span>*)(frames[si].data));</span><br><span class="line">            curr_request.set_input_tensor(input_tensor);</span><br><span class="line">            curr_request.infer();</span><br><span class="line">        }</span><br><span class="line">        catch (ov::Exception&amp; e)</span><br><span class="line">        {</span><br><span class="line">            <span class="built_in">std</span>::<span class="built_in">cout</span> &lt;&lt; e.what() &lt;&lt; <span class="built_in">endl</span>;</span><br><span class="line">            <span class="keyword">return</span> <span class="number">-2</span>;</span><br><span class="line">        }</span><br><span class="line"></span><br><span class="line">        <span class="comment">// 3.后处理</span></span><br><span class="line">        <span class="type">float</span>* detections = curr_request.get_output_tensor().data&lt;<span class="type">float</span>&gt;();</span><br><span class="line">        cv::Mat <span class="title function_">predMat</span><span class="params">(cv::Size(modelInputWidth, modelInputHeight), CV_32F, detections)</span>;</span><br><span class="line">        retn = postprocess(frames[si],predMat, imgWidth, imgHeight,si);</span><br><span class="line">        <span class="keyword">if</span> (retn &lt; <span class="number">0</span>)</span><br><span class="line">        {</span><br><span class="line">            <span class="keyword">return</span> <span class="number">-3</span>;</span><br><span class="line">        }</span><br><span class="line">    }</span><br><span class="line"></span><br><span class="line">    <span class="keyword">return</span> <span class="number">0</span>;</span><br><span class="line">}</span><br></pre></td></tr></tbody></table></figure><p>针对 openvino 提供的批处理及异步推理，相应还编写以下推理方法</p><figure class="highlight c"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="type">int</span> <span class="title function_">sync_infer</span><span class="params">(<span class="built_in">std</span>::<span class="built_in">vector</span>&lt;cv::Mat&gt;&amp; frames)</span>;</span><br><span class="line"><span class="type">int</span> <span class="title function_">sync_infer_batch</span><span class="params">(<span class="built_in">std</span>::<span class="built_in">vector</span>&lt;cv::Mat&gt;&amp; frames)</span>;</span><br><span class="line"><span class="type">int</span> <span class="title function_">async_infer</span><span class="params">(<span class="built_in">std</span>::<span class="built_in">vector</span>&lt;cv::Mat&gt;&amp; frames)</span>;</span><br><span class="line"><span class="type">int</span> <span class="title function_">sync_infer_preprocess</span><span class="params">(<span class="built_in">std</span>::<span class="built_in">vector</span>&lt;cv::Mat&gt;&amp; frames)</span>;</span><br><span class="line"><span class="type">int</span> <span class="title function_">async_infer_preprocess</span><span class="params">(<span class="built_in">std</span>::<span class="built_in">vector</span>&lt;cv::Mat&gt;&amp; frames)</span>; </span><br></pre></td></tr></tbody></table></figure><p>除此之外，将以上的 C<ins> 代码编译为 dll，并使用 CLR 封装其接口，以便供 C</ins> 或者 C# 使用</p><h2 id="5-效果分析"><a class="markdownIt-Anchor" href="#5-效果分析"></a> 5. 效果分析</h2><p>以下统计在不同软硬件平台部署模型的耗时，可以看出在 CPU 部署模型时，openvino 部署方式更快，已经接近在 GPU 部署</p><table><thead><tr><th>软件</th><th>硬件</th><th>耗时</th><th>备注</th></tr></thead><tbody><tr><td> onnxruntime</td><td>CPU</td><td>1000ms</td><td></td></tr><tr><td>openvino</td><td>CPU</td><td>30ms</td><td> 输入图片 (512x512)</td></tr><tr><td>openvino</td><td>CPU</td><td>12ms</td><td> 输入图片 (320x512)</td></tr><tr><td>tensorrt</td><td>GPU</td><td>8ms</td><td> 输入图片 (320x512)</td></tr></tbody></table><h2 id="总结"><a class="markdownIt-Anchor" href="#总结"></a> 总结</h2><p>本文基于 openvino 构建了一个 onnx 模型的部署流程，其耗时和 GPU 部署接近，并最终编译为 dll，提供到不同的软件开发平台使用。</p><p>使用 openvino 丰富了模型的应用场景，从传统的要求使用 GPU，推广到 CPU，降低了模型应用的成本。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;在实际场景中，深度学习模型运行的机器不总是有 GPU，这时候基于 tensorrt 的部署方案就无法满足要求，本文基于 openvino 压缩量化模型，然后将其部署到 CPU 上，实验证明：推理效果相当、推理耗时接近&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="D-深度学习部署" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/D-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E9%83%A8%E7%BD%B2/"/>
    
    <category term="openvino" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/D-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E9%83%A8%E7%BD%B2/openvino/"/>
    
    
  </entry>
  
  <entry>
    <title>利用分割模型分析场景变化</title>
    <link href="https://www.shaogui.life/posts/2821877742.html"/>
    <id>https://www.shaogui.life/posts/2821877742.html</id>
    <published>2025-02-03T00:11:33.000Z</published>
    <updated>2025-02-07T02:59:12.882Z</updated>
    
    <content type="html"><![CDATA[<p>在一些场景中，需要分析图片在不同时刻的变化情况，比如居家的早上、晚上变化，卫星图片不同年份的变化，通过分析变化，掌握空间内事物的变化趋势，那么应该如何分析不同时刻下图片的变化呢？</p><span id="more"></span><p>所谓变化，就是场景中新增那些目标、那些目标去除了，由于图片不是对齐的，所以无法直接通过图片加减法去找到图片变化位置，并且如果涉及变化的目标类别，该方法也无法做到。</p><p>本文使用语义分割模型实现图片变化分析，常规的语义分割模型输入一张图片，然后分割图片目标，这里将范式改为：<strong>输入前后时刻的图片，输出变化的二值图</strong>，最终实现下图效果</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8%E5%88%86%E5%89%B2%E6%A8%A1%E5%9E%8B%E5%88%86%E6%9E%90%E5%9C%BA%E6%99%AF%E5%8F%98%E5%8C%96-20250203082348.png" alt="利用分割模型分析场景变化-20250203082348"></p><p>上图分别是前后不同时刻同一位置的卫星图片、人工认定的变化区域、语义分割模型分析的变化区域。可以看出语义分割模型基本找到变化区域</p><h3 id="模型定义"><a class="markdownIt-Anchor" href="#模型定义"></a> 模型定义</h3><p>本文使用 deeplabv3 + 模型，输入通道由原来的 3 通道变为 6 通道（两张 3 通道的图片合并），输出为 1 个类别的概率图</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">class</span> <span class="title class_">DeepLabV3Plus</span>(nn.Module):</span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self, aspp_dilations=<span class="literal">None</span>, aspp_dropout=<span class="literal">True</span>, dropout_prob=<span class="number">0.2</span>, num_classes=<span class="number">1</span>, pretrained=<span class="literal">True</span></span>):</span><br><span class="line">        <span class="built_in">super</span>(DeepLabV3Plus, <span class="variable language_">self</span>).__init__()</span><br><span class="line"></span><br><span class="line">        <span class="comment"># 改动输入为6通道转为3通道</span></span><br><span class="line">        <span class="variable language_">self</span>.get_input=nn.Conv2d(<span class="number">6</span>, <span class="number">3</span>, kernel_size=<span class="number">7</span>, stride=<span class="number">2</span>, padding=<span class="number">3</span>, bias=<span class="literal">False</span>)</span><br><span class="line">        <span class="variable language_">self</span>.backbone = resnet18(pretrained=pretrained)</span><br><span class="line"></span><br><span class="line">        <span class="variable language_">self</span>.aspp = ASPP(aspp_dilations=aspp_dilations, aspp_dropout=aspp_dropout)</span><br><span class="line">        <span class="variable language_">self</span>.decoder = Decoder(num_classes=num_classes, dropout_prob=dropout_prob)</span><br><span class="line">        <span class="variable language_">self</span>.freeze_modules = <span class="variable language_">self</span>.backbone</span><br><span class="line"></span><br><span class="line">    <span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, <span class="built_in">input</span></span>):</span><br><span class="line">        <span class="built_in">input</span>=<span class="variable language_">self</span>.get_input(<span class="built_in">input</span>)</span><br><span class="line">        x, feat_2x, feat_4x = <span class="variable language_">self</span>.backbone(<span class="built_in">input</span>)</span><br><span class="line">        x = <span class="variable language_">self</span>.aspp(x)</span><br><span class="line">        x = <span class="variable language_">self</span>.decoder(x, feat_2x, feat_4x)</span><br><span class="line">        x = F.interpolate(x, size=<span class="built_in">input</span>.size()[<span class="number">2</span>:], mode=<span class="string">"bilinear"</span>, align_corners=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">        <span class="keyword">return</span> x</span><br></pre></td></tr></tbody></table></figure><h3 id="数据加载"><a class="markdownIt-Anchor" href="#数据加载"></a> 数据加载</h3><p>由于此处分析的图片是卫星图片，分辨率很大，这里首先使用 sahi 将图片重叠切分为小图，然后再进行训练，切割后图片如下</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> sahi.slicing <span class="keyword">import</span> slice_image</span><br><span class="line"></span><br><span class="line">img1_path=<span class="string">'/home/wushaogui/MyCodes/Pytorch_Change_detection/CD_Data_GZ/labels_change/P_GZ_test4_2010_2019.png'</span></span><br><span class="line">data_name=<span class="string">'test'</span></span><br><span class="line"></span><br><span class="line">SliceImageResult=slice_image(</span><br><span class="line">    image=img1_path,</span><br><span class="line">    output_file_name=data_name,</span><br><span class="line">    output_dir=<span class="literal">None</span>,</span><br><span class="line">    slice_height=<span class="number">1024</span>,</span><br><span class="line">    slice_width=<span class="number">1024</span>,</span><br><span class="line">    overlap_height_ratio=<span class="number">0.5</span>,</span><br><span class="line">    overlap_width_ratio=<span class="number">0.5</span></span><br><span class="line">)</span><br><span class="line"></span><br><span class="line">show_images(SliceImageResult.images)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8%E5%88%86%E5%89%B2%E6%A8%A1%E5%9E%8B%E5%88%86%E6%9E%90%E5%9C%BA%E6%99%AF%E5%8F%98%E5%8C%96-20250203084716.png" alt="利用分割模型分析场景变化-20250203084716"></p><h3 id="训练过程"><a class="markdownIt-Anchor" href="#训练过程"></a> 训练过程</h3><p>使用 tensorboard 查看训练过程，可以将看出 fscore 逐渐提升，最终训练集 fscore&gt;0.98，验证集 fscore~0.87</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8%E5%88%86%E5%89%B2%E6%A8%A1%E5%9E%8B%E5%88%86%E6%9E%90%E5%9C%BA%E6%99%AF%E5%8F%98%E5%8C%96-20250203084109.png" alt="利用分割模型分析场景变化-20250203084109"></p>]]></content>
    
    
    <summary type="html">&lt;p&gt;在一些场景中，需要分析图片在不同时刻的变化情况，比如居家的早上、晚上变化，卫星图片不同年份的变化，通过分析变化，掌握空间内事物的变化趋势，那么应该如何分析不同时刻下图片的变化呢？&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="B-视觉模型" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/B-%E8%A7%86%E8%A7%89%E6%A8%A1%E5%9E%8B/"/>
    
    <category term="1-基础视觉任务CNN" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/B-%E8%A7%86%E8%A7%89%E6%A8%A1%E5%9E%8B/1-%E5%9F%BA%E7%A1%80%E8%A7%86%E8%A7%89%E4%BB%BB%E5%8A%A1CNN/"/>
    
    <category term="语义分割" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/B-%E8%A7%86%E8%A7%89%E6%A8%A1%E5%9E%8B/1-%E5%9F%BA%E7%A1%80%E8%A7%86%E8%A7%89%E4%BB%BB%E5%8A%A1CNN/%E8%AF%AD%E4%B9%89%E5%88%86%E5%89%B2/"/>
    
    
  </entry>
  
  <entry>
    <title>利用 nlp 抽取 pdf 文件关键信息</title>
    <link href="https://www.shaogui.life/posts/1325313955.html"/>
    <id>https://www.shaogui.life/posts/1325313955.html</id>
    <published>2025-02-02T03:56:37.000Z</published>
    <updated>2025-02-02T04:53:48.406Z</updated>
    
    <content type="html"><![CDATA[<p>本文使用 nlp 快速检索出 pdf 文件中的关键信息，比如现有一堆合同文件，当你想搜索合同的甲乙方、金额时，不需要一个个文件打开看，只需通过该方法</p><span id="more"></span><p>本方法是利用 paddle 的两个工具，一个是 ocr，用于从 pdf 抽取文本，另一个是 nlp 模型，用于从文本中抽取关键信息，开始介绍前先安装环境</p><figure class="highlight bash"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">conda install paddlepaddle-gpu cudatoolkit</span><br><span class="line">python -m pip install opencv-contrib-python==4.4.0.46 paddleocr paddlenlp==2.4</span><br></pre></td></tr></tbody></table></figure><h3 id="提取图片上的文字"><a class="markdownIt-Anchor" href="#提取图片上的文字"></a> 提取图片上的文字</h3><p>本方法处理的文件是 pdf，这里利用 <code>PaddleOCR</code> 抽取其中的文本数据，该工具输入是图片，所以需要其他工具将 pdf 转为图片</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> paddleocr <span class="keyword">import</span> PaddleOCR, draw_ocr  <span class="comment"># type: ignore</span></span><br><span class="line"><span class="keyword">from</span> PIL <span class="keyword">import</span> Image</span><br><span class="line"><span class="keyword">import</span> cv2</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"></span><br><span class="line"><span class="comment"># paddleocr目前支持中英文、英文、法语、德语、韩语、日语等80个语种，可以通过修改lang参数进行切换</span></span><br><span class="line">ocr = PaddleOCR(use_angle_cls=<span class="literal">False</span>, lang=<span class="string">"ch"</span>, det_db_box_thresh=<span class="number">0.3</span>, use_dilation=<span class="literal">True</span>)</span><br><span class="line"><span class="comment"># 印章部分造成的文本遮盖，影响了文本识别结果，因此可以考虑通道提取，去除图片中的红色印章</span></span><br><span class="line"><span class="comment">#读入图像,三通道</span></span><br><span class="line">image=cv2.imread(<span class="string">"./test_img/hetong3.jpg"</span>,cv2.IMREAD_COLOR) <span class="comment">#timg.jpeg</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> time</span><br><span class="line"><span class="keyword">from</span> my_read_code_tools <span class="keyword">import</span> *</span><br><span class="line">start=time.time()</span><br><span class="line"><span class="comment"># 合同文本信息提取</span></span><br><span class="line"><span class="comment"># 合同照片的红色通道被分离，获得了一张相对更干净的图片，此时可以再次使用ppocr模型提取文本内容</span></span><br><span class="line">img_path = <span class="string">'./red_channel.jpg'</span> <span class="comment"># 使用红色通道图片排除印章影响</span></span><br><span class="line">result = ocr.ocr(img_path, cls=<span class="literal">False</span>)</span><br><span class="line"><span class="comment"># 可视化结果，不想可视化可以注释下面几行代码</span></span><br><span class="line">image = Image.<span class="built_in">open</span>(img_path).convert(<span class="string">'RGB'</span>)</span><br><span class="line">boxes = [line[<span class="number">0</span>] <span class="keyword">for</span> line <span class="keyword">in</span> result[<span class="number">0</span>]]</span><br><span class="line">txts = [line[<span class="number">1</span>][<span class="number">0</span>] <span class="keyword">for</span> line <span class="keyword">in</span> result[<span class="number">0</span>]]</span><br><span class="line">scores = [line[<span class="number">1</span>][<span class="number">1</span>] <span class="keyword">for</span> line <span class="keyword">in</span> result[<span class="number">0</span>]]</span><br><span class="line">im_show = draw_ocr(image, boxes, txts, scores, font_path=<span class="string">'./simfang.ttf'</span>)</span><br><span class="line">im_show = Image.fromarray(im_show)</span><br><span class="line">vis = np.array(im_show)</span><br><span class="line"><span class="comment"># im_show.show()</span></span><br><span class="line">show_images([vis])</span><br><span class="line"><span class="comment">#忽略检测框内容，提取完整的合同文本：</span></span><br><span class="line">txts = [line[<span class="number">1</span>][<span class="number">0</span>] <span class="keyword">for</span> line <span class="keyword">in</span> result[<span class="number">0</span>]]</span><br><span class="line">all_context = <span class="string">"\n"</span>.join(txts)</span><br><span class="line"><span class="built_in">print</span>(all_context)</span><br><span class="line"></span><br><span class="line">end=time.time()</span><br><span class="line"><span class="built_in">print</span>(<span class="string">'cost time:{}s'</span>.<span class="built_in">format</span>(end-start))</span><br></pre></td></tr></tbody></table></figure><pre><code>[2024/04/10 16:06:16] ppocr DEBUG: dt_boxes num : 24, elapsed : 1.3868777751922607[2024/04/10 16:06:16] ppocr DEBUG: rec_res num  : 24, elapsed : 0.5372314453125</code></pre><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8nlp%E6%8A%BD%E5%8F%96pdf%E6%96%87%E4%BB%B6%E5%85%B3%E9%94%AE%E4%BF%A1%E6%81%AF-20250202123801.png" alt="png"><br>甲方：佛山市禅城区住房城乡建设和水利局<br>乙方：交通银行股份有限公司佛山分行<br>丙方：佛山市禅城区盈恒置业有限公司<br>为加强商品房预售管理，规范商品房预售款使用行为，根据《广<br>东省商品房预售管理条例》（以下简称《条例》）和《佛山市商品房预<br>售款监督管理实施办法》（以下简称《办法》），经甲方、乙方和丙方三<br>方协商，就坐落于_佛山市禅城区佛罗路南侧、化纤路北侧，项目名<br>称为保利芳华苑 12 座，监控账号 446899991010003029262<br>的商品房屋预售款收存和划拨使用订立如下协议，共同遵守。<br>一、 权利<br>1、甲方负责贯彻实施《条例》和《办法》有关规定，行使商品房<br>预售款收存和使用的日常监督管理的权利。<br>2、乙方在为丙方办理预售款拨付时，应要求丙方出具经甲方审核<br>同意的《佛山市商品房预售款使用申请表》。<br>公有限<br>3、预售项目完成初始登记并达到购房人可单方办理转移登记条件<br>V<br>的，丙方可持有关证明文件向甲方申请办理专用账户解除监管手续。<br>经甲方核准同意的，在《佛山市商品房预售款监管专用账户取消监管<br>申请表》上加具同意的意见后，视同本协议取消。丙方凭《佛山市商<br>cost time:2.178391218185425s</p><h3 id="关键信息抽取"><a class="markdownIt-Anchor" href="#关键信息抽取"></a> 关键信息抽取</h3><p>使用 <code>paddlenlp</code> 从提取的文本中抽取关键信息，比如下面抽取 [“甲方”,“乙方”,“总价”,“大写”,“小写”,“项目”] 等信息</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> paddlenlp <span class="keyword">import</span> Taskflow <span class="comment"># type: ignore</span></span><br><span class="line"></span><br><span class="line">schema = [<span class="string">"甲方"</span>,<span class="string">"乙方"</span>,<span class="string">"总价"</span>,<span class="string">"大写"</span>,<span class="string">"小写"</span>,<span class="string">"项目"</span>]</span><br><span class="line">ie = Taskflow(<span class="string">'information_extraction'</span>, schema=schema)</span><br><span class="line">ie.set_schema(schema)</span><br><span class="line"></span><br><span class="line">start=time.time()</span><br><span class="line"></span><br><span class="line">result = ie(all_context)</span><br><span class="line"><span class="built_in">print</span>(result)</span><br><span class="line"></span><br><span class="line">end=time.time()</span><br><span class="line"><span class="built_in">print</span>(<span class="string">'cost time:{}s'</span>.<span class="built_in">format</span>(end-start))</span><br></pre></td></tr></tbody></table></figure><pre><code>[{'甲方': [{'text': '佛山市禅城区住房城乡建设和水利局', 'start': 3, 'end': 19, 'probability': 0.8409922679564374}], '乙方': [{'text': '交通银行股份有限公司佛山分行', 'start': 23, 'end': 37, 'probability': 0.8475471161905048}], '项目': [{'text': '保利芳华苑12座', 'start': 183, 'end': 191, 'probability': 0.5095629052180897}]}]cost time:0.20002102851867676s</code></pre><p>使用 nuitka 将以上过程打包为 exe，可以自定义处理的文档及关键信息</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8nlp%E6%8A%BD%E5%8F%96pdf%E6%96%87%E4%BB%B6%E5%85%B3%E9%94%AE%E4%BF%A1%E6%81%AF-20250202123803.png" alt=""></p><p>运行时，后台日志如下：</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8nlp%E6%8A%BD%E5%8F%96pdf%E6%96%87%E4%BB%B6%E5%85%B3%E9%94%AE%E4%BF%A1%E6%81%AF-20250202123803-1.png" alt="alt text"></p><p><strong>结果展示</strong></p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%88%A9%E7%94%A8nlp%E6%8A%BD%E5%8F%96pdf%E6%96%87%E4%BB%B6%E5%85%B3%E9%94%AE%E4%BF%A1%E6%81%AF-20250202123804.png" alt="alt text"></p>]]></content>
    
    
    <summary type="html">&lt;p&gt;本文使用 nlp 快速检索出 pdf 文件中的关键信息，比如现有一堆合同文件，当你想搜索合同的甲乙方、金额时，不需要一个个文件打开看，只需通过该方法&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="C-语言模型" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/C-%E8%AF%AD%E8%A8%80%E6%A8%A1%E5%9E%8B/"/>
    
    
  </entry>
  
  <entry>
    <title>基于 Torch-TensorRT 量化模型</title>
    <link href="https://www.shaogui.life/posts/743849622.html"/>
    <id>https://www.shaogui.life/posts/743849622.html</id>
    <published>2025-02-01T09:08:50.000Z</published>
    <updated>2025-02-02T02:52:38.869Z</updated>
    
    <content type="html"><![CDATA[<p>Torch TensorRT 是 PyTorch 与 NVIDIA TensorRT 的新集成，它用一行代码加速推理</p><span id="more"></span><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172119.png" alt="基于Torch-TensorRT量化模型-20250123172119"></p><p>Torch-TensorRT 是 PyTorch/TorchScript/FX 的编译器。与 PyTorch 的即时 （JIT） 编译器不同，Torch-TensorRT 是一个提前 （AOT） 编译器，这意味着在部署 TorchScript 代码之前，您需要执行一个明确的编译步骤，将标准 TorchScript 或 FX 程序转换为面向 TensorRT 引擎的模块</p><h3 id="使用-torch-tensorrt-编译-resnet50"><a class="markdownIt-Anchor" href="#使用-torch-tensorrt-编译-resnet50"></a> 使用 Torch-TensorRT 编译 ResNet50</h3><p><strong>1. 加载并测试模型</strong></p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"></span><br><span class="line">torch.hub._validate_not_a_forked_repo=<span class="keyword">lambda</span> a,b,c: <span class="literal">True</span></span><br><span class="line"></span><br><span class="line">resnet50_model = torch.hub.load(<span class="string">'pytorch/vision:v0.10.0'</span>, <span class="string">'resnet50'</span>, pretrained=<span class="literal">True</span>)</span><br><span class="line">resnet50_model.<span class="built_in">eval</span>()</span><br><span class="line"></span><br><span class="line"><span class="comment"># Model benchmark without Torch-TensorRT</span></span><br><span class="line">model = resnet50_model.<span class="built_in">eval</span>().to(<span class="string">"cuda"</span>)</span><br><span class="line">benchmark(model, input_shape=(<span class="number">128</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172120.png" alt="基于Torch-TensorRT量化模型-20250123172120"></p><p><strong>2. 编译模型 - fp32</strong></p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch_tensorrt</span><br><span class="line"></span><br><span class="line"><span class="comment"># The compiled module will have precision as specified by "op_precision".</span></span><br><span class="line"><span class="comment"># Here, it will have FP32 precision.</span></span><br><span class="line">trt_model_fp32 = torch_tensorrt.<span class="built_in">compile</span>(model, inputs = [torch_tensorrt.Input((<span class="number">128</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>), dtype=torch.float32)],</span><br><span class="line">    enabled_precisions = torch.float32, <span class="comment"># Run with FP32</span></span><br><span class="line">    workspace_size = <span class="number">1</span> &lt;&lt; <span class="number">22</span></span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Obtain the average time taken by a batch of input</span></span><br><span class="line">benchmark(trt_model_fp32, input_shape=(<span class="number">128</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172120-1.png" alt="基于Torch-TensorRT量化模型-20250123172120-1"></p><p><strong>3. 编译模型 - fp16</strong></p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch_tensorrt</span><br><span class="line"></span><br><span class="line"><span class="comment"># The compiled module will have precision as specified by "op_precision".</span></span><br><span class="line"><span class="comment"># Here, it will have FP32 precision.</span></span><br><span class="line">trt_model_fp32 = torch_tensorrt.<span class="built_in">compile</span>(model, inputs = [torch_tensorrt.Input((<span class="number">128</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>), dtype=torch.float32)],</span><br><span class="line">    enabled_precisions = , dtype=torch.half, <span class="comment"># Run with FP32</span></span><br><span class="line">    workspace_size = <span class="number">1</span> &lt;&lt; <span class="number">22</span></span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Obtain the average time taken by a batch of input</span></span><br><span class="line">benchmark(trt_model_fp32, input_shape=(<span class="number">128</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>), dtype=<span class="string">'fp16'</span>, nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172121.png" alt="基于Torch-TensorRT量化模型-20250123172121"></p><h3 id="使用-torch-tensorrt-编译-torchscript模型"><a class="markdownIt-Anchor" href="#使用-torch-tensorrt-编译-torchscript模型"></a> 使用 Torch-TensorRT 编译 TorchScript 模型</h3><p><strong>1. 加载并测试模型</strong></p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">model = LeNet()</span><br><span class="line">model.to(<span class="string">"cuda"</span>).<span class="built_in">eval</span>()</span><br><span class="line">benchmark(model)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172121-1.png" alt="基于Torch-TensorRT量化模型-20250123172121-1"></p><p><strong>2. 生成 trace 模型，并测试速度</strong></p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">traced_model = torch.jit.trace(model, torch.empty([<span class="number">1</span>,<span class="number">1</span>,<span class="number">32</span>,<span class="number">32</span>]).to(<span class="string">"cuda"</span>))</span><br><span class="line">benchmark(traced_model)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172121-2.png" alt="基于Torch-TensorRT量化模型-20250123172121-2"></p><p><strong>3. 生成 script 模型，并测试速度</strong></p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">model = LeNet().to(<span class="string">"cuda"</span>).<span class="built_in">eval</span>()</span><br><span class="line">script_model = torch.jit.script(model)</span><br><span class="line">benchmark(script_model)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172122.png" alt="基于Torch-TensorRT量化模型-20250123172122"></p><p>4. 使用 Troch-TensoRT 编译 trace 模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch_tensorrt</span><br><span class="line"></span><br><span class="line"><span class="comment"># We use a batch-size of 1024, and half precision</span></span><br><span class="line">trt_ts_module = torch_tensorrt.<span class="built_in">compile</span>(traced_model, inputs=[torch_tensorrt.Input(</span><br><span class="line">            min_shape=[<span class="number">1024</span>, <span class="number">1</span>, <span class="number">32</span>, <span class="number">32</span>],</span><br><span class="line">            opt_shape=[<span class="number">1024</span>, <span class="number">1</span>, <span class="number">33</span>, <span class="number">33</span>],</span><br><span class="line">            max_shape=[<span class="number">1024</span>, <span class="number">1</span>, <span class="number">34</span>, <span class="number">34</span>],</span><br><span class="line">            dtype=torch.half</span><br><span class="line">            )], </span><br><span class="line">            enabled_precisions = {torch.half})</span><br><span class="line"></span><br><span class="line">input_data = torch.randn((<span class="number">1024</span>, <span class="number">1</span>, <span class="number">32</span>, <span class="number">32</span>))</span><br><span class="line">input_data = input_data.half().to(<span class="string">"cuda"</span>)</span><br><span class="line"></span><br><span class="line">input_data = input_data.half()</span><br><span class="line">result = trt_ts_module(input_data)</span><br><span class="line">torch.jit.save(trt_ts_module, <span class="string">"trt_ts_module.ts"</span>)</span><br><span class="line"></span><br><span class="line">benchmark(trt_ts_module, input_shape=(<span class="number">1024</span>, <span class="number">1</span>, <span class="number">32</span>, <span class="number">32</span>), dtype=<span class="string">"fp16"</span>)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172122-1.png" alt="基于Torch-TensorRT量化模型-20250123172122-1"></p><p>5. 使用 Troch-TensoRT 编译 script 模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch_tensorrt</span><br><span class="line"></span><br><span class="line">trt_script_module = torch_tensorrt.<span class="built_in">compile</span>(script_model, inputs = [torch_tensorrt.Input(</span><br><span class="line">            min_shape=[<span class="number">1024</span>, <span class="number">1</span>, <span class="number">32</span>, <span class="number">32</span>],</span><br><span class="line">            opt_shape=[<span class="number">1024</span>, <span class="number">1</span>, <span class="number">33</span>, <span class="number">33</span>],</span><br><span class="line">            max_shape=[<span class="number">1024</span>, <span class="number">1</span>, <span class="number">34</span>, <span class="number">34</span>],</span><br><span class="line">            dtype=torch.half</span><br><span class="line">            )],</span><br><span class="line">            enabled_precisions={torch.half})</span><br><span class="line"></span><br><span class="line">input_data = torch.randn((<span class="number">1024</span>, <span class="number">1</span>, <span class="number">32</span>, <span class="number">32</span>))</span><br><span class="line">input_data = input_data.half().to(<span class="string">"cuda"</span>)</span><br><span class="line"></span><br><span class="line">input_data = input_data.half()</span><br><span class="line">result = trt_script_module(input_data)</span><br><span class="line">torch.jit.save(trt_script_module, <span class="string">"trt_script_module.ts"</span>)</span><br><span class="line"></span><br><span class="line">benchmark(trt_script_module, input_shape=(<span class="number">1024</span>, <span class="number">1</span>, <span class="number">32</span>, <span class="number">32</span>), dtype=<span class="string">"fp16"</span>)</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172123.png" alt="基于Torch-TensorRT量化模型-20250123172123"></p><h3 id="使用torch-tensorrt进行ptq量化"><a class="markdownIt-Anchor" href="#使用torch-tensorrt进行ptq量化"></a> 使用 Torch-TensorRT 进行 PTQ 量化</h3><p>1. 构建模型，并使用 Torch-TensorRT 转换，但是未进行量化</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Exporting to TorchScript</span></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    data = <span class="built_in">iter</span>(val_dataloader)</span><br><span class="line">    images, _ = data.<span class="built_in">next</span>()</span><br><span class="line">    jit_model = torch.jit.trace(model, images.to(<span class="string">"cuda"</span>))</span><br><span class="line">    torch.jit.save(jit_model, <span class="string">"mobilenetv2_base.jit.pt"</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment">#Loading the Torchscript model and compiling it into a TensorRT model</span></span><br><span class="line">baseline_model = torch.jit.load(<span class="string">"mobilenetv2_base.jit.pt"</span>).<span class="built_in">eval</span>()</span><br><span class="line">compile_spec = {<span class="string">"inputs"</span>: [torch_tensorrt.Input([<span class="number">64</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>])]</span><br><span class="line">               , <span class="string">"enabled_precisions"</span>: torch.<span class="built_in">float</span></span><br><span class="line">               }</span><br><span class="line">trt_base = torch_tensorrt.<span class="built_in">compile</span>(baseline_model, **compile_spec)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Evaluate and benchmark the performance of the baseline TRT model (TRT FP32 Model)</span></span><br><span class="line">test_loss, test_acc = evaluate(trt_base, val_dataloader, criterion, <span class="number">0</span>)</span><br><span class="line"><span class="built_in">print</span>(<span class="string">"Mobilenetv2 TRT Baseline accuracy: {:.2f}%"</span>.<span class="built_in">format</span>(<span class="number">100</span> * test_acc))</span><br><span class="line"></span><br><span class="line">benchmark(trt_base, input_shape=(<span class="number">64</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>))</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172123-1.png" alt="基于Torch-TensorRT量化模型-20250123172123-1"></p><p>2. 使用 torch-tensorRT 进行 PTQ 量化</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(calib_dataloader,</span><br><span class="line">                                              use_cache=<span class="literal">False</span>,</span><br><span class="line">                                              algo_type=torch_tensorrt.ptq.CalibrationAlgo.MINMAX_CALIBRATION,</span><br><span class="line">                                              device=torch.device(<span class="string">'cuda:0'</span>))</span><br><span class="line"></span><br><span class="line">compile_spec = {</span><br><span class="line">         <span class="string">"inputs"</span>: [torch_tensorrt.Input([<span class="number">64</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>])],</span><br><span class="line">         <span class="string">"enabled_precisions"</span>: torch.int8,</span><br><span class="line">         <span class="string">"calibrator"</span>: calibrator,</span><br><span class="line">        <span class="string">"truncate_long_and_double"</span>: <span class="literal">True</span></span><br><span class="line">         </span><br><span class="line">     }</span><br><span class="line">trt_ptq = torch_tensorrt.<span class="built_in">compile</span>(baseline_model, **compile_spec)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Evaluate the PTQ model</span></span><br><span class="line">test_loss, test_acc = evaluate(trt_ptq, val_dataloader, criterion, <span class="number">0</span>)</span><br><span class="line"><span class="built_in">print</span>(<span class="string">"Mobilenetv2 PTQ accuracy: {:.2f}%"</span>.<span class="built_in">format</span>(<span class="number">100</span> * test_acc))</span><br><span class="line"></span><br><span class="line">benchmark(trt_ptq, input_shape=(<span class="number">64</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>))</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172123-2.png" alt="基于Torch-TensorRT量化模型-20250123172123-2"></p><h3 id="使用torch-tensorrt进行qat量化"><a class="markdownIt-Anchor" href="#使用torch-tensorrt进行qat量化"></a> 使用 Torch-TensorRT 进行 QAT 量化</h3><p>1. 定义并加载模型权重</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line">quant_modules.initialize()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义并加载模型权重</span></span><br><span class="line"><span class="comment"># All the regular conv, FC layers will be converted to their quantized counterparts due to quant_modules.initialize()</span></span><br><span class="line">feature_extract = <span class="literal">False</span></span><br><span class="line">q_model = models.mobilenet_v2(pretrained=<span class="literal">True</span>)</span><br><span class="line">set_parameter_requires_grad(q_model, feature_extract)</span><br><span class="line">q_model.classifier[<span class="number">1</span>] = nn.Linear(<span class="number">1280</span>, <span class="number">10</span>)</span><br><span class="line">q_model = q_model.cuda()</span><br><span class="line"></span><br><span class="line"><span class="comment"># mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.</span></span><br><span class="line">ckpt = torch.load(<span class="string">"./mobilenetv2_base_ckpt"</span>)</span><br><span class="line">modified_state_dict={}</span><br><span class="line"><span class="keyword">for</span> key, val <span class="keyword">in</span> ckpt[<span class="string">"model_state_dict"</span>].items():</span><br><span class="line">    <span class="comment"># Remove 'module.' from the key names</span></span><br><span class="line">    <span class="keyword">if</span> key.startswith(<span class="string">'module'</span>):</span><br><span class="line">        modified_state_dict[key[<span class="number">7</span>:]] = val</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        modified_state_dict[key] = val</span><br><span class="line"></span><br><span class="line"><span class="comment"># Load the pre-trained checkpoint</span></span><br><span class="line">q_model.load_state_dict(modified_state_dict)</span><br><span class="line">optimizer.load_state_dict(ckpt[<span class="string">"opt_state_dict"</span>])</span><br></pre></td></tr></tbody></table></figure><p>2. 定义校准规则并校准，这里使用 max 校准</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 定义校准规则并校准，这里使用max校准</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">compute_amax</span>(<span class="params">model, **kwargs</span>):</span><br><span class="line">    <span class="comment"># Load calib result</span></span><br><span class="line">    <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">        <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">            <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                <span class="keyword">if</span> <span class="built_in">isinstance</span>(module._calibrator, calib.MaxCalibrator):</span><br><span class="line">                    module.load_calib_amax()</span><br><span class="line">                <span class="keyword">else</span>:</span><br><span class="line">                    module.load_calib_amax(**kwargs)</span><br><span class="line">    model.cuda()</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">collect_stats</span>(<span class="params">model, data_loader, num_batches</span>):</span><br><span class="line">    <span class="string">"""Feed data to the network and collect statistics"""</span></span><br><span class="line">    <span class="comment"># Enable calibrators</span></span><br><span class="line">    <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">        <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">            <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                module.disable_quant()</span><br><span class="line">                module.enable_calib()</span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                module.disable()</span><br><span class="line"></span><br><span class="line">    <span class="comment"># Feed data to the network for collecting stats</span></span><br><span class="line">    <span class="keyword">for</span> i, (image, _) <span class="keyword">in</span> tqdm(<span class="built_in">enumerate</span>(data_loader), total=num_batches):</span><br><span class="line">        model(image.cuda())</span><br><span class="line">        <span class="keyword">if</span> i &gt;= num_batches:</span><br><span class="line">            <span class="keyword">break</span></span><br><span class="line"></span><br><span class="line">    <span class="comment"># Disable calibrators</span></span><br><span class="line">    <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">        <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">            <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                module.enable_quant()</span><br><span class="line">                module.disable_calib()</span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                module.enable()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment">#Calibrate the model using percentile calibration technique.</span></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    collect_stats(q_model, train_dataloader, num_batches=<span class="number">32</span>)</span><br><span class="line">    compute_amax(q_model, method=<span class="string">"max"</span>)</span><br></pre></td></tr></tbody></table></figure><p>3. 微调 QAT 模型 2 个 epoch</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 微调QAT模型2个epoch</span></span><br><span class="line">num_epochs=<span class="number">2</span></span><br><span class="line">lr = <span class="number">0.001</span></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">'Epoch: [%5d / %5d] LR: %f'</span> % (epoch + <span class="number">1</span>, num_epochs, lr))</span><br><span class="line"></span><br><span class="line">    train(q_model, train_dataloader, criterion, optimizer, epoch)</span><br><span class="line">    test_loss, test_acc = evaluate(q_model, val_dataloader, criterion, epoch)</span><br><span class="line"></span><br><span class="line">    <span class="built_in">print</span>(<span class="string">"Test Loss: {:.5f} Test Acc: {:.2f}%"</span>.<span class="built_in">format</span>(test_loss, <span class="number">100</span> * test_acc))</span><br><span class="line">    </span><br><span class="line">save_checkpoint({<span class="string">'epoch'</span>: epoch + <span class="number">1</span>,</span><br><span class="line">                 <span class="string">'model_state_dict'</span>: q_model.state_dict(),</span><br><span class="line">                 <span class="string">'acc'</span>: test_acc,</span><br><span class="line">                 <span class="string">'opt_state_dict'</span>: optimizer.state_dict()</span><br><span class="line">                },</span><br><span class="line">                ckpt_path=<span class="string">"mobilenetv2_qat_ckpt"</span>)</span><br></pre></td></tr></tbody></table></figure><p>4. 导出 QAT 模型，得到 Torchscript 模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 导出QAT模型，得到Torchscript模型</span></span><br><span class="line">quant_nn.TensorQuantizer.use_fb_fake_quant = <span class="literal">True</span></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    data = <span class="built_in">iter</span>(val_dataloader)</span><br><span class="line">    images, _ = data.<span class="built_in">next</span>()</span><br><span class="line">    jit_model = torch.jit.trace(q_model, images.to(<span class="string">"cuda"</span>))</span><br><span class="line">    torch.jit.save(jit_model, <span class="string">"mobilenetv2_qat.jit.pt"</span>)</span><br></pre></td></tr></tbody></table></figure><p>5. 加载 Torchscript 模型并编译为 TensorRT 模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#加载Torchscript模型并编译为TensorRT模型</span></span><br><span class="line">qat_model = torch.jit.load(<span class="string">"mobilenetv2_qat.jit.pt"</span>).<span class="built_in">eval</span>()</span><br><span class="line">compile_spec = {<span class="string">"inputs"</span>: [torch_tensorrt.Input([<span class="number">64</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>])],</span><br><span class="line">                <span class="string">"enabled_precisions"</span>: torch.int8</span><br><span class="line">               }</span><br><span class="line">trt_mod = torch_tensorrt.<span class="built_in">compile</span>(qat_model, **compile_spec)</span><br><span class="line"></span><br><span class="line"><span class="comment">#Evaluate and benchmark the performance of the QAT-TRT model (TRT INT8)</span></span><br><span class="line">test_loss, test_acc = evaluate(trt_mod, val_dataloader, criterion, <span class="number">0</span>)</span><br><span class="line"><span class="built_in">print</span>(<span class="string">"Mobilenetv2 QAT accuracy using TensorRT: {:.2f}%"</span>.<span class="built_in">format</span>(<span class="number">100</span> * test_acc))</span><br><span class="line">benchmark(trt_mod, input_shape=(<span class="number">64</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>))</span><br></pre></td></tr></tbody></table></figure><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8ETorch-TensorRT%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123172124.png" alt="基于Torch-TensorRT量化模型-20250123172124"></p>]]></content>
    
    
    <summary type="html">&lt;p&gt;Torch TensorRT 是 PyTorch 与 NVIDIA TensorRT 的新集成，它用一行代码加速推理&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="D-深度学习部署" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/D-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E9%83%A8%E7%BD%B2/"/>
    
    
  </entry>
  
  <entry>
    <title>Hexo 系列 8 - 网站维护</title>
    <link href="https://www.shaogui.life/posts/3331678657.html"/>
    <id>https://www.shaogui.life/posts/3331678657.html</id>
    <published>2025-01-31T00:57:36.000Z</published>
    <updated>2025-02-02T02:52:46.503Z</updated>
    
    <content type="html"><![CDATA[<p>网站通过备案完成后，相当与已经被正式发布到互联网，后续维护与安全需要设置</p><span id="more"></span><h3 id="提升网站的seo"><a class="markdownIt-Anchor" href="#提升网站的seo"></a> 提升网站的 SEO</h3><p>SEO（Search Engine Optimization，搜索引擎优化）是一种通过优化网站内容和结构，提高网站在搜索引擎自然搜索结果中排名的技术和策略。SEO 的目的是增加网站的流量、提升用户体验</p><p>在 next 主题中，支持在文件<code>_config.next.yml.yml</code> 配置搜索引擎验证</p><figure class="highlight yaml"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Google Webmaster tools verification.</span></span><br><span class="line"><span class="comment"># See: https://developers.google.com/search</span></span><br><span class="line"><span class="attr">google_site_verification:</span> </span><br><span class="line"></span><br><span class="line"><span class="comment"># Bing Webmaster tools verification.</span></span><br><span class="line"><span class="comment"># See: https://www.bing.com/webmasters</span></span><br><span class="line"><span class="attr">bing_site_verification:</span> </span><br><span class="line"></span><br><span class="line"><span class="comment"># Yandex Webmaster tools verification.</span></span><br><span class="line"><span class="comment"># See: https://webmaster.yandex.ru</span></span><br><span class="line"><span class="attr">yandex_site_verification:</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># Baidu Webmaster tools verification.</span></span><br><span class="line"><span class="comment"># See: https://ziyuan.baidu.com/site</span></span><br><span class="line"><span class="attr">baidu_site_verification:</span> </span><br></pre></td></tr></tbody></table></figure><p>分别访问以上连接，然后注册帐号，获取 site_verification，填入以上位置，重写提交后，到对应网站测试是否成功，以下是 bing 的提交界面</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Hexo%E5%AD%A6%E4%B9%A08-%E7%BD%91%E7%AB%99%E7%BB%B4%E6%8A%A4-20250131104456.png" alt="Hexo学习8-网站维护-20250131104456"></p><p>提交完成后，到对应的搜索引擎下检索自己的网站，检索成功即设置成功</p><h3 id="三-安全设置"><a class="markdownIt-Anchor" href="#三-安全设置"></a> 三、安全设置 <sup class="footnote-ref"><a href="#fn1" id="fnref1">[1]</a></sup></h3><p>以前将 hexo 发布到 github page，相当于服务器在 github，带宽和流量都不用自己担心，现在这些都在自己掌控下，服务器和域名这种一次性付费的项目还不用担心，但是带宽和流量是动态，这就需要设置预警门槛，别把自己钱包弄空了</p><p>目前动态计费项目有：</p><ul><li>COS 部分：因为 COS 上主要存储着图片，主要计费点是访问下行流量及上传图片存储，主要风险是盗刷</li><li> CDN 部分：通过加速网站内容的分发，提高用户访问速度，</li></ul><h4 id="跨域访问-cors-设置"><a class="markdownIt-Anchor" href="#跨域访问-cors-设置"></a> 跨域访问 CORS 设置</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Hexo%E5%AD%A6%E4%B9%A08-%E7%BD%91%E7%AB%99%E7%BB%B4%E6%8A%A4-20250131105718.png" alt="Hexo学习8-网站维护-20250131105718"></p><p><strong>防盗链设置</strong><br><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Hexo%E5%AD%A6%E4%B9%A08-%E7%BD%91%E7%AB%99%E7%BB%B4%E6%8A%A4-20250131105747.png" alt="Hexo学习8-网站维护-20250131105747"></p><h4 id="自定义cdn加速域名付费项目"><a class="markdownIt-Anchor" href="#自定义cdn加速域名付费项目"></a> 自定义 CDN 加速域名 (付费项目)</h4><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Hexo%E5%AD%A6%E4%B9%A08-%E7%BD%91%E7%AB%99%E7%BB%B4%E6%8A%A4-20250131105928.png" alt="Hexo学习8-网站维护-20250131105928"></p><h3 id="cdn配置部分"><a class="markdownIt-Anchor" href="#cdn配置部分"></a> cdn 配置部分</h3><p><strong>设置防盗链和访问频率</strong><br><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Hexo%E5%AD%A6%E4%B9%A08-%E7%BD%91%E7%AB%99%E7%BB%B4%E6%8A%A4-20250131111149.png" alt="Hexo学习8-网站维护-20250131111149"></p><p><strong>用量封顶配置</strong><br><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Hexo%E5%AD%A6%E4%B9%A08-%E7%BD%91%E7%AB%99%E7%BB%B4%E6%8A%A4-20250131111328.png" alt="Hexo学习8-网站维护-20250131111328"></p><p><strong>购买 CDN 流浪包</strong><br><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Hexo%E5%AD%A6%E4%B9%A08-%E7%BD%91%E7%AB%99%E7%BB%B4%E6%8A%A4-20250131111618.png" alt="Hexo学习8-网站维护-20250131111618"></p><p>参考：</p><hr class="footnotes-sep"><section class="footnotes"><ol class="footnotes-list"><li id="fn1" class="footnote-item"><p><a href="https://www.ruletree.club/archives/3327/">腾讯云 CDN 和 COS 防止高额欠费策略，配置教程 - 规则之树</a> <a href="#fnref1" class="footnote-backref">↩︎</a></p></li></ol></section>]]></content>
    
    
    <summary type="html">&lt;p&gt;网站通过备案完成后，相当与已经被正式发布到互联网，后续维护与安全需要设置&lt;/p&gt;</summary>
    
    
    
    <category term="4-系统软件" scheme="https://www.shaogui.life/categories/4-%E7%B3%BB%E7%BB%9F%E8%BD%AF%E4%BB%B6/"/>
    
    <category term="B-Hexo" scheme="https://www.shaogui.life/categories/4-%E7%B3%BB%E7%BB%9F%E8%BD%AF%E4%BB%B6/B-Hexo/"/>
    
    
  </entry>
  
  <entry>
    <title>Hexo 系列 7 - 网站备案</title>
    <link href="https://www.shaogui.life/posts/2572137130.html"/>
    <id>https://www.shaogui.life/posts/2572137130.html</id>
    <published>2025-01-31T00:56:41.000Z</published>
    <updated>2025-02-02T02:52:46.485Z</updated>
    
    <content type="html"><![CDATA[<p>网站备案在已经申请到域名及服务器后，提交个人资料到相关部门备案，无需额外费用</p><span id="more"></span><p>备案包含两个部门的备案，即 ICP 备案、公安备案，两者区别如下</p><table><thead><tr><th>备案类型</th><th>目的</th><th>主管部门</th></tr></thead><tbody><tr><td> ICP 备案</td><td>规范互联网信息服务活动，促进互联网信息服务健康有序发展，保障互联网信息服务提供者和用户的合法权益</td><td>工信部</td></tr><tr><td>公安备案</td><td>加强互联网信息安全管理，保障国家安全和社会公共利益，便于公安机关网安部门对互联网信息服务活动进行监管，提高互联网信息服务提供者的安全意识和防范能力</td><td>各地公安局相关部门</td></tr></tbody></table><p>由于本人的域名、服务器都在腾讯云，所以全程参考腾讯云提供的文档完成备案，两个备案的操作手册：</p><ul><li><a href="https://cloud.tencent.com/document/product/243/97668">ICP 备案 首次备案 - ICP 备案操作指引（PC 端）- 文档中心 - 腾讯云</a></li><li><a href="https://cloud.tencent.com/document/product/243/19142">ICP 备案 公安备案流程 - 公安备案与经营性备案 - 文档中心 - 腾讯云</a></li></ul><h3 id="icp备案"><a class="markdownIt-Anchor" href="#icp备案"></a> ICP 备案</h3><p>ICP 备案分为两个阶段，第一阶段在腾讯云提交审核资料，腾讯云完成资料初步审核后，与本人确认后再帮提交到工信部，第二阶段是在工信部网站完成的</p><h3 id="公安备案"><a class="markdownIt-Anchor" href="#公安备案"></a> 公安备案</h3><p>ICP 审核完成后，可以申请公安备案，这个过程是直接在<a href="https://beian.mps.gov.cn/#/">全国互联网安全管理平台</a>完成，</p><p>首先注册个人帐号，然后申请主体，完成网站备案</p><h3 id="hexo配置备案信息"><a class="markdownIt-Anchor" href="#hexo配置备案信息"></a> hexo 配置备案信息</h3><p>完成以上两个备案后，在 hexo 下找到<code>_config.next.yml.yml</code>，设置以下字段：</p><figure class="highlight yaml"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Beian ICP and gongan information for Chinese users. See: https://beian.miit.gov.cn, https://beian.mps.gov.cn</span></span><br><span class="line"><span class="attr">beian:</span></span><br><span class="line">  <span class="attr">enable:</span> <span class="literal">true</span></span><br><span class="line">  <span class="attr">icp:</span> <span class="string">粤ICP备</span> <span class="number">2025368593</span><span class="string">号-1</span></span><br><span class="line">  <span class="comment"># The digit in the num of gongan beian.</span></span><br><span class="line">  <span class="attr">gongan_id:</span> <span class="number">44011302004761</span></span><br><span class="line">  <span class="comment"># The full num of gongan beian.</span></span><br><span class="line">  <span class="attr">gongan_num:</span> <span class="string">粤公网安备</span> <span class="number">44011302004761</span><span class="string">号</span></span><br><span class="line">  <span class="comment"># The icon for gongan beian. Login and See: https://beian.mps.gov.cn/web/business/businessHome/website</span></span><br><span class="line">  <span class="attr">gongan_icon_url:</span> <span class="string">/images/beian.png</span></span><br></pre></td></tr></tbody></table></figure><p>重写提交网站到服务器，即可看到网站备案信息，此时可通过域名访问网站</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;网站备案在已经申请到域名及服务器后，提交个人资料到相关部门备案，无需额外费用&lt;/p&gt;</summary>
    
    
    
    <category term="4-系统软件" scheme="https://www.shaogui.life/categories/4-%E7%B3%BB%E7%BB%9F%E8%BD%AF%E4%BB%B6/"/>
    
    <category term="B-Hexo" scheme="https://www.shaogui.life/categories/4-%E7%B3%BB%E7%BB%9F%E8%BD%AF%E4%BB%B6/B-Hexo/"/>
    
    
  </entry>
  
  <entry>
    <title>基于 tensorrt 量化模型</title>
    <link href="https://www.shaogui.life/posts/3790857351.html"/>
    <id>https://www.shaogui.life/posts/3790857351.html</id>
    <published>2025-01-28T08:10:32.000Z</published>
    <updated>2025-02-02T03:07:24.244Z</updated>
    
    <content type="html"><![CDATA[<p>本文讨论 tensorrt 的量化原理</p><span id="more"></span><h3 id="tensorrt的量化原理"><a class="markdownIt-Anchor" href="#tensorrt的量化原理"></a> TensorRT 的量化原理？</h3><ul><li><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Etensorrt%E5%AF%B9%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E9%87%8F%E5%8C%96-20250123161943.png" alt="基于tensorrt对模型进行量化-20250123161943"></li><li>TensorRT 支持量化浮点，可以显着提高了算术吞吐量，同时降低了存储要求和内存带宽。在量化浮点张量时，TensorRT 需要知道它的动态范围 —— 即表示什么范围的值。动态范围信息可由构建器根据代表性输入数据计算（这称为校准 <code>calibration</code>）。或者，您可以在框架中执行量化感知训练，并将模型与必要的动态范围信息一起导入到 TensorRT，分别对应以下的 PQT 量化、QAT 量化</li><li> TensorRT 支持 2 种量化，一种是 INT8 推理，也就是 PQT 量化，另一种是带有 Q/QD 算子以指定量化参数的量化（QAT 量化）</li><li>INT8 量化需要进行校准，以便确定激活值的量化范围；而带有 Q/DQ 操作的模型则将量化参数自带在 Q/DQ 算子中，可以直接量化</li></ul><h3 id="tensort如何进行ptq量化"><a class="markdownIt-Anchor" href="#tensort如何进行ptq量化"></a> <strong> TensoRT 如何进行 PTQ 量化？</strong></h3><ul><li><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Etensorrt%E5%AF%B9%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E9%87%8F%E5%8C%96-20250123161944.png" alt="基于tensorrt对模型进行量化-20250123161944"></li><li>PTQ 量化时，量化激活值阶段需要进行校准，可以使用有代表的数据集进行校准，也可以外部输入校准表进行校准</li><li>所需的输入数据量取决于应用，但实验表明，约 500 幅图像足以校准 ImageNet 分类网络。</li><li>官方提供配套工具 <a href="https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization">pytorch-quantization</a> 进行 PTQ 量化</li></ul><h3 id="tensort如何使用经过qat的onnx模型"><a class="markdownIt-Anchor" href="#tensort如何使用经过qat的onnx模型"></a> TensoRT 如何使用经过 QAT 的 ONNX 模型？</h3><ul><li><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Etensorrt%E5%AF%B9%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E9%87%8F%E5%8C%96-20250123161944-1.png" alt="基于tensorrt对模型进行量化-20250123161944-1"></li><li>TensoRT8 可以直接加载通过 QAT 量化后且导出为 ONNX 的模型，官方提供配套工具 <a href="https://github.com/NVIDIA/TensorRT/tree/main/tools/pytorch-quantization">pytorch-quantization</a> 进行 QAT 量化</li></ul><h3 id="tensort如何使用使用pytorch模型进行量化推理"><a class="markdownIt-Anchor" href="#tensort如何使用使用pytorch模型进行量化推理"></a> <strong> TensoRT 如何使用使用 Pytorch 模型进行量化推理</strong></h3><ul><li><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Etensorrt%E5%AF%B9%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E9%87%8F%E5%8C%96-20250123161944-2.png" alt="基于tensorrt对模型进行量化-20250123161944-2"></li><li>已经训练好的模型，经过 Quantization Toolkit 进行模型量化</li><li>使用 <code>torch.onnx.export</code> 导出模型为 onnx</li><li>tensorRT 对 onnx 根据 Q/DQ 的设置量化模型 + 通用优化操作，生成推理引擎，并执行推理</li></ul><h3 id="逐张量per-tensor量化与逐通道per-channel量化"><a class="markdownIt-Anchor" href="#逐张量per-tensor量化与逐通道per-channel量化"></a> <strong>逐张量 (Per-tensor) 量化与逐通道 (Per-channel) 量化？</strong></h3><ul><li><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Etensorrt%E5%AF%B9%E6%A8%A1%E5%9E%8B%E8%BF%9B%E8%A1%8C%E9%87%8F%E5%8C%96-20250123161945.png" alt="基于tensorrt对模型进行量化-20250123161945"></li><li>逐张量 (Per-tensor) 量化：其中使用单个尺度值（标量）来缩放整个张量</li><li>逐通道 (Per-channel) 量化：对于卷积神经网络，其中尺度张量沿给定轴广播</li></ul><h3 id="tensorrt的显式量化"><a class="markdownIt-Anchor" href="#tensorrt的显式量化"></a> <strong> TensorRT 的显式量化？</strong></h3><ul><li>在显式量化的网络中，<strong>量化值和非量化值之间转换的缩放操作由图中的 IQuantizeLayer 和 IDequentizeLayer 节点显式表示</strong>，这些节点此后将被称为 Q/DQ 节点。与隐式量化相比，显式形式精确地指定了向 INT8 和从 INT8 进行转换的位置，优化器将仅执行由模型语义决定的精度转换</li><li>当 PyTorch 或 TensorFlow 中的模型导出到 ONNX 时，ONNX 使用显式量化表示法 -，框架图中的每个伪量化操作都导出为 Q，后跟 DQ</li><li> 当 TensorRT 检测到模型中有 QDQ 算子的时候，就会触发<strong>显式量化</strong></li></ul><h3 id="tensorrt的隐式量化"><a class="markdownIt-Anchor" href="#tensorrt的隐式量化"></a> <strong> TensorRT 的隐式量化？</strong></h3><ul><li>在处理隐式量化网络时，TensorRT 在应用图形优化时将模型视为浮点模型，并利用 INT8 机会优化层执行时间。<strong>如果一个层在 INT8 中运行得更快，那么它在 INT8 执行。否则，使用 FP32 或 FP16。</strong> 在此模式下，TensorRT 仅针对性能进行优化，您几乎无法控制 INT8 的使用位置 - 即使您在 API 级别明确设置了一个层的精度，TensorRT 也可能在图形优化期间将该层与另一层融合，并丢失它必须在 INT8 中执行的信息</li><li> TensorRT 的 PTQ 能力生成隐式量化网络</li></ul><h3 id="tensorrt的network-level精度设置"><a class="markdownIt-Anchor" href="#tensorrt的network-level精度设置"></a> <strong> TensorRT 的 Network-Level 精度设置？</strong></h3><ul><li>FP32 是大多数框架的默认训练精度，因此我们将首先使用 FP32 进行推断。<strong>推理通常需要比训练更少的数值精度</strong>。较低的精度可以在不牺牲精度的情况下实现更快的计算和更低的内存消耗</li><li>降低的精度支持取决于您的硬件，参考：<a href="https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix">Hardware and&nbsp;Precision</a></li><li><strong> 注意：</strong> TensorRT 为层选择精度，但是如果设置精度导致速度变慢或者该层没有低精度实现时，这些层将使用高精度</li><li> C++ 检测硬件支持精度并设置精度 <figure class="highlight c"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> (builder-&gt;platformHasFastFp16()) { … };</span><br><span class="line">config-&gt; setFlag(BuilderFlag: :kFP16);</span><br><span class="line">  </span><br></pre></td></tr></tbody></table></figure></li></ul><h3 id="tensorrt的layer-level的精度设置"><a class="markdownIt-Anchor" href="#tensorrt的layer-level的精度设置"></a> <strong>TensorRT 的 Layer-Level 的精度设置？</strong></h3><ul><li><code>config-&gt; setFlag(BuilderFlag: :kFP16);</code> 提供了精度粗粒度控制。然而有时网络的一部分需要更高的动态范围或对数值精度敏感，<strong>可以约束每层的输入和输出类型</strong></li><li> C++ 上为某一层设置输入输出类型 <figure class="highlight c"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">layer-&gt;setPrecision(DataType::kFP16)</span><br><span class="line">layer-&gt;setOutputType(out_tensor_index, DataType::kFLOAT)</span><br></pre></td></tr></tbody></table></figure></li><li><strong>计算使用精度：</strong> 计算将使用与输入首选的浮点类型相同的浮点类型。大多数 TensorRT 实现具有相同的输入和输出浮点类型；然而，卷积、反卷积和 FullyConnected 可以支持量化 INT8 输入和非量化 FP16 或 FP32 输出，因为有时需要使用来自量化输入的更高精度输出来保持精度。</li><li><strong>上下层精度设置冲突：</strong> 设置精度约束提示 TensorRT 应该选择输入和输出与首选类型匹配的层实现，如果上一层的输出和下一层的输入与请求的类型不匹配，则插入重新格式化操作。</li></ul><h3 id="什么是pytorch-quantization"><a class="markdownIt-Anchor" href="#什么是pytorch-quantization"></a> <strong>什么是 PyTorch-Quantization？</strong></h3><ul><li><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Etensorrt%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123183219.png" alt="基于tensorrt量化模型-20250123183219"></li><li>TensorRT 开发的量化工具包，可以便捷将 Pytorch 的模型量化为 TensorRT 支持的量化模型，它支持将 Pytorch 模型按 PQT 或 QAT 量化。</li></ul><h3 id="使用pytorch-quantization导出pqt模型"><a class="markdownIt-Anchor" href="#使用pytorch-quantization导出pqt模型"></a> 使用 PyTorch-Quantization 导出 PQT 模型</h3><p>使用以下代码将量化每个模块，如果不希望所有模块都量化，则应手动替换量化模块</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> pytorch_quantization <span class="keyword">import</span> quant_modules</span><br><span class="line">quant_modules.initialize()</span><br></pre></td></tr></tbody></table></figure><p>为了有效推断，我们希望为每个量化器选择一个固定范围。从预先训练的模型开始，最简单的方法是校准</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br><span class="line">60</span><br><span class="line">61</span><br><span class="line">62</span><br><span class="line">63</span><br><span class="line">64</span><br><span class="line">65</span><br><span class="line">66</span><br></pre></td><td class="code"><pre><span class="line">quant_desc_input = QuantDescriptor(calib_method=<span class="string">'histogram'</span>)</span><br><span class="line">quant_nn.QuantConv2d.set_default_quant_desc_input(quant_desc_input)</span><br><span class="line">quant_nn.QuantLinear.set_default_quant_desc_input(quant_desc_input)</span><br><span class="line"></span><br><span class="line">model = models.resnet50(pretrained=<span class="literal">True</span>)</span><br><span class="line">model.cuda()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义校准数据集</span></span><br><span class="line">data_path = <span class="string">"PATH to imagenet"</span></span><br><span class="line">batch_size = <span class="number">512</span></span><br><span class="line"></span><br><span class="line">traindir = os.path.join(data_path, <span class="string">'train'</span>)</span><br><span class="line">valdir = os.path.join(data_path, <span class="string">'val'</span>)</span><br><span class="line">dataset, dataset_test, train_sampler, test_sampler = load_data(traindir, valdir, <span class="literal">False</span>, <span class="literal">False</span>)</span><br><span class="line"></span><br><span class="line">data_loader = torch.utils.data.DataLoader(</span><br><span class="line">    dataset, batch_size=batch_size,</span><br><span class="line">    sampler=train_sampler, num_workers=<span class="number">4</span>, pin_memory=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">data_loader_test = torch.utils.data.DataLoader(</span><br><span class="line">    dataset_test, batch_size=batch_size,</span><br><span class="line">    sampler=test_sampler, num_workers=<span class="number">4</span>, pin_memory=<span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 定义校准规则并校准</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">collect_stats</span>(<span class="params">model, data_loader, num_batches</span>):</span><br><span class="line">     <span class="string">"""Feed data to the network and collect statistic"""</span></span><br><span class="line"></span><br><span class="line">     <span class="comment"># Enable calibrators</span></span><br><span class="line">     <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">         <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">             <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                 module.disable_quant()</span><br><span class="line">                 module.enable_calib()</span><br><span class="line">             <span class="keyword">else</span>:</span><br><span class="line">                 module.disable()</span><br><span class="line"></span><br><span class="line">     <span class="keyword">for</span> i, (image, _) <span class="keyword">in</span> tqdm(<span class="built_in">enumerate</span>(data_loader), total=num_batches):</span><br><span class="line">         model(image.cuda())</span><br><span class="line">         <span class="keyword">if</span> i &gt;= num_batches:</span><br><span class="line">             <span class="keyword">break</span></span><br><span class="line"></span><br><span class="line">     <span class="comment"># Disable calibrators</span></span><br><span class="line">     <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">         <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">             <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                 module.enable_quant()</span><br><span class="line">                 module.disable_calib()</span><br><span class="line">             <span class="keyword">else</span>:</span><br><span class="line">                 module.enable()</span><br><span class="line"></span><br><span class="line"> <span class="keyword">def</span> <span class="title function_">compute_amax</span>(<span class="params">model, **kwargs</span>):</span><br><span class="line">     <span class="comment"># Load calib result</span></span><br><span class="line">     <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">         <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">             <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                 <span class="keyword">if</span> <span class="built_in">isinstance</span>(module._calibrator, calib.MaxCalibrator):</span><br><span class="line">                     module.load_calib_amax()</span><br><span class="line">                 <span class="keyword">else</span>:</span><br><span class="line">                     module.load_calib_amax(**kwargs)</span><br><span class="line">             <span class="built_in">print</span>(<span class="string">F"<span class="subst">{name:<span class="number">40</span>}</span>: <span class="subst">{module}</span>"</span>)</span><br><span class="line">     model.cuda()</span><br><span class="line"></span><br><span class="line"><span class="comment"># It is a bit slow since we collect histograms on CPU</span></span><br><span class="line"> <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">     collect_stats(model, data_loader, num_batches=<span class="number">2</span>)</span><br><span class="line">     compute_amax(model, method=<span class="string">"percentile"</span>, percentile=<span class="number">99.99</span>)</span><br></pre></td></tr></tbody></table></figure><p>评估量化 + 校准后的模型，并保存</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">criterion = nn.CrossEntropyLoss()</span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    evaluate(model, criterion, data_loader_test, device=<span class="string">"cuda"</span>, print_freq=<span class="number">20</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Save the model</span></span><br><span class="line">torch.save(model.state_dict(), <span class="string">"/tmp/quant_resnet50-calibrated.pth"</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="使用pytorch-quantization导出qat模型"><a class="markdownIt-Anchor" href="#使用pytorch-quantization导出qat模型"></a> 使用 PyTorch-Quantization 导出 QAT 模型</h3><p>对于 PQT 模型，对其进行微调，得到 QAT 模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">criterion = nn.CrossEntropyLoss()</span><br><span class="line">optimizer = torch.optim.SGD(model.parameters(), lr=<span class="number">0.0001</span>)</span><br><span class="line">lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=<span class="number">1</span>, gamma=<span class="number">0.1</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Training takes about one and half hour per epoch on a single V100</span></span><br><span class="line">train_one_epoch(model, criterion, optimizer, data_loader, <span class="string">"cuda"</span>, <span class="number">0</span>, <span class="number">100</span>)</span><br><span class="line"></span><br><span class="line"><span class="comment"># Save the model</span></span><br><span class="line">torch.save(model.state_dict(), <span class="string">"/tmp/quant_resnet50-finetuned.pth"</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="将经过pytorch-quantization量化的模型导出为onnx"><a class="markdownIt-Anchor" href="#将经过pytorch-quantization量化的模型导出为onnx"></a> 将经过 PyTorch-Quantization 量化的模型导出为 ONNX</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E5%9F%BA%E4%BA%8Etensorrt%E9%87%8F%E5%8C%96%E6%A8%A1%E5%9E%8B-20250123183219-1.png" alt="基于tensorrt量化模型-20250123183219-1"><br>导出到 ONNX 的目标是通过 TensorRT 部署推理，而不是 ONNX 运行时。因此，我们只将假量化模型导出为 TensorRT 将采用的形式。假量化将被分解成一对 QuantizeLinear/DequantizeLinear ONNX ops</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> pytorch_quantization <span class="keyword">import</span> nn <span class="keyword">as</span> quant_nn</span><br><span class="line">quant_nn.TensorQuantizer.use_fb_fake_quant = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> pytorch_quantization <span class="keyword">import</span> nn <span class="keyword">as</span> quant_nn</span><br><span class="line"><span class="keyword">from</span> pytorch_quantization <span class="keyword">import</span> quant_modules</span><br><span class="line">quant_nn.TensorQuantizer.use_fb_fake_quant = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line">quant_modules.initialize()</span><br><span class="line">model = torchvision.models.resnet50()</span><br><span class="line"><span class="comment"># load the calibrated model</span></span><br><span class="line">state_dict = torch.load(<span class="string">"quant_resnet50-entropy-1024.pth"</span>, map_location=<span class="string">"cpu"</span>)</span><br><span class="line">model.load_state_dict(state_dict)</span><br><span class="line">model.cuda()</span><br><span class="line"></span><br><span class="line">dummy_input = torch.randn(<span class="number">128</span>, <span class="number">3</span>, <span class="number">224</span>, <span class="number">224</span>, device=<span class="string">'cuda'</span>)</span><br><span class="line"></span><br><span class="line">input_names = [ <span class="string">"actual_input_1"</span> ]</span><br><span class="line">output_names = [ <span class="string">"output1"</span> ]</span><br><span class="line"></span><br><span class="line"><span class="comment"># enable_onnx_checker needs to be disabled. See notes below.</span></span><br><span class="line">torch.onnx.export(</span><br><span class="line">    model, dummy_input, <span class="string">"quant_resnet50.onnx"</span>, verbose=<span class="literal">True</span>, opset_version=<span class="number">10</span>, enable_onnx_checker=<span class="literal">False</span>)</span><br></pre></td></tr></tbody></table></figure><p>参考：</p><ol><li><a href="https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#working-with-int8">Developer Guide :: NVIDIA Deep Learning TensorRT Documentation</a></li><li><a href="https://developer.nvidia.com/zh-cn/blog/accelerating-quantized-networks-with-qat-toolkit-and-tensorrt/">使用 NVIDIA QAT 工具包为 TensorFlow 和 NVIDIA TensorRT 加速量化网络 - NVIDIA 技术博客</a></li><li><a href="https://developer.nvidia.com/zh-cn/blog/chieving-fp32-accuracy-for-int8-inference-using-quantization-aware-training-with-tensorrt/">利用 NVIDIA TensorRT 量化感知训练实现 INT8 推理的 FP32 精度 - NVIDIA 技术博客</a></li><li><a href="https://oldpan.me/archives/quantize-in-action-tensorrt-8">一起实践量化番外篇 ——TensorRT-8 的量化细节 - Oldpan 的个人博客</a></li></ol>]]></content>
    
    
    <summary type="html">&lt;p&gt;本文讨论 tensorrt 的量化原理&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="D-深度学习部署" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/D-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E9%83%A8%E7%BD%B2/"/>
    
    
  </entry>
  
  <entry>
    <title>基于 Pytorch 量化模型</title>
    <link href="https://www.shaogui.life/posts/3346997870.html"/>
    <id>https://www.shaogui.life/posts/3346997870.html</id>
    <published>2025-01-27T08:04:32.000Z</published>
    <updated>2025-02-02T03:07:13.294Z</updated>
    
    <content type="html"><![CDATA[<ul><li></li></ul><span id="more"></span><h3 id="pytorch原生量化之fx-graph-mode-quantization"><a class="markdownIt-Anchor" href="#pytorch原生量化之fx-graph-mode-quantization"></a> Pytorch 原生量化之 FX Graph Mode Quantization?</h3><ul><li>FX Graph Mode Quantization 是 PyTorch 中一个新的自动量化框架，目前它是一个原型功能。它通过添加对函数的支持和<strong>自动</strong>化量化过程来改进 Eager Mode Quantization</li><li></li></ul><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> torch.quantization <span class="keyword">import</span> get_default_qconfig</span><br><span class="line"><span class="keyword">from</span> torch.quantization.quantize_fx <span class="keyword">import</span> prepare_fx, convert_fx</span><br><span class="line">float_model.<span class="built_in">eval</span>()  <span class="comment"># 因为是PTQ，所以就推理模式就够了</span></span><br><span class="line">qconfig = get_default_qconfig(<span class="string">"fbgemm"</span>)  <span class="comment"># 指定量化细节配置</span></span><br><span class="line">qconfig_dict = {<span class="string">""</span>: qconfig}             <span class="comment"># 指定量化选项</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">calibrate</span>(<span class="params">model, data_loader</span>):       <span class="comment"># 校准功能函数</span></span><br><span class="line">model.<span class="built_in">eval</span>()</span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line"><span class="keyword">for</span> image, target <span class="keyword">in</span> data_loader:</span><br><span class="line">model(image)</span><br><span class="line">prepared_model = prepare_fx(float_model, qconfig_dict)  <span class="comment"># 准备量化模型，比如融合CONV+BN+RELU，然后插入量化观察节点</span></span><br><span class="line">calibrate(prepared_model, data_loader_test)  <span class="comment"># 校准数据集进行标准</span></span><br><span class="line">quantized_model = convert_fx(prepared_model)  <span class="comment"># 把校准后的模型转化为量化版本模型</span></span><br></pre></td></tr></tbody></table></figure><h3 id="pytorch原生量化之eager-mode-quantization"><a class="markdownIt-Anchor" href="#pytorch原生量化之eager-mode-quantization"></a> Pytorch 原生量化之 Eager Mode Quantization?</h3><ul><li>Eager Mode Quantization 是一项测试功能。用户需要进行融合并指定<strong>手动</strong>进行量化和去量化的位置，而且它只支持模块而不支持功能 </li><li><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br></pre></td><td class="code"><pre><span class="line"> <span class="keyword">import</span> torch</span><br><span class="line"></span><br><span class="line"><span class="comment"># define a floating point model where some layers could be statically quantized</span></span><br><span class="line"><span class="keyword">class</span> <span class="title class_">M</span>(torch.nn.Module):</span><br><span class="line"><span class="keyword">def</span> <span class="title function_">__init__</span>(<span class="params">self</span>):</span><br><span class="line"><span class="built_in">super</span>(M, <span class="variable language_">self</span>).__init__()</span><br><span class="line"><span class="comment"># QuantStub converts tensors from floating point to quantized</span></span><br><span class="line"><span class="variable language_">self</span>.quant = torch.quantization.QuantStub()</span><br><span class="line"><span class="variable language_">self</span>.conv = torch.nn.Conv2d(<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>)</span><br><span class="line"><span class="variable language_">self</span>.relu = torch.nn.ReLU()</span><br><span class="line"><span class="comment"># DeQuantStub converts tensors from quantized to floating point</span></span><br><span class="line"><span class="variable language_">self</span>.dequant = torch.quantization.DeQuantStub()</span><br><span class="line"><span class="keyword">def</span> <span class="title function_">forward</span>(<span class="params">self, x</span>):</span><br><span class="line"><span class="comment"># 自己指定开始量化的层</span></span><br><span class="line">x = <span class="variable language_">self</span>.quant(x)</span><br><span class="line">x = <span class="variable language_">self</span>.conv(x)</span><br><span class="line">x = <span class="variable language_">self</span>.relu(x)</span><br><span class="line"><span class="comment"># 指定结束量化的层</span></span><br><span class="line">x = <span class="variable language_">self</span>.dequant(x)</span><br><span class="line"><span class="keyword">return</span> x</span><br><span class="line"><span class="comment"># create a model instance</span></span><br><span class="line">model_fp32 = M()</span><br><span class="line"><span class="comment"># model must be set to eval mode for static quantization logic to work</span></span><br><span class="line">model_fp32.<span class="built_in">eval</span>()</span><br><span class="line">model_fp32.qconfig = torch.quantization.get_default_qconfig(<span class="string">'fbgemm'</span>)</span><br><span class="line"><span class="comment"># 指定融合的层</span></span><br><span class="line">model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [[<span class="string">'conv'</span>, <span class="string">'relu'</span>]])</span><br><span class="line">model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)</span><br><span class="line">input_fp32 = torch.randn(<span class="number">4</span>, <span class="number">1</span>, <span class="number">4</span>, <span class="number">4</span>)</span><br><span class="line">model_fp32_prepared(input_fp32)</span><br><span class="line">model_int8 = torch.quantization.convert(model_fp32_prepared)</span><br><span class="line">res = model_int8(input_fp32)</span><br></pre></td></tr></tbody></table></figure></li></ul><p>参考：</p><ol><li><a href="https://oldpan.me/archives/torch-fx-second-quantize-with-fx">TORCH.FX 第二篇 ——PTQ 量化实操 - Oldpan 的个人博客</a></li><li><a href="https://www.zhihu.com/question/431572414">Site Unreachable</a></li></ol>]]></content>
    
    
    <summary type="html">&lt;ul&gt;
&lt;li&gt;&lt;/li&gt;
&lt;/ul&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="D-深度学习部署" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/D-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E9%83%A8%E7%BD%B2/"/>
    
    
  </entry>
  
  <entry>
    <title>深度学习模型量化技术</title>
    <link href="https://www.shaogui.life/posts/4174027832.html"/>
    <id>https://www.shaogui.life/posts/4174027832.html</id>
    <published>2025-01-26T06:37:27.000Z</published>
    <updated>2025-02-02T03:06:57.241Z</updated>
    
    <content type="html"><![CDATA[<p>总结模型量化的技术</p><span id="more"></span><h3 id="量化原理"><a class="markdownIt-Anchor" href="#量化原理"></a> 量化原理</h3><p>量化是模型压缩的一种方式，通过<strong>将模型参数从宽范围调整宰范围，使得计算量降低</strong></p><p>未量化前权重和激活的表示范围是 FP32，量化可以是 FP16、TF8，甚至是 4bit、2bit</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Pasted-image-20230107090001.png" alt="Pasted-image-20230107090001"></p><ul><li>FP32 表示的范围：浮点数范围在 [-3.4e38,3.4e38] 之间，这个区间被称为动态范围 (dynamic-range)</li><li>int8 量化：使用 scale-factor 将浮点张量的动态范围映射到 [-128,127]，又称对称量子化，因为范围关于原点对称，TensoRT 使用对称量化来表示激活值与权重</li><li>量化计算公式：对任意浮点张量分布，拿到绝对值最大的元素 max，量化的目的是将浮点数从 [-max,max] 缩小到 [-128,127]，则先计算 [-max,max] 被 256 等分，然后判断浮点数处于第几等分上</li></ul><p class="katex-block"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mtable rowspacing="0.15999999999999992em" columnalign="center" columnspacing="1em"><mtr><mtd><mstyle scriptlevel="0" displaystyle="false"><mrow><msub><mi>a</mi><mrow><mi>m</mi><mi>a</mi><mi>x</mi></mrow></msub><mo>=</mo><mi>max</mi><mo>⁡</mo><mrow><mo fence="true">(</mo><mi>a</mi><mi>b</mi><mi>s</mi><mrow><mo fence="true">(</mo><msub><mi>x</mi><mi>f</mi></msub><mo fence="true">)</mo></mrow><mo fence="true">)</mo></mrow></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="false"><mrow><mtext>&nbsp;scale&nbsp;</mtext><mo>=</mo><mo stretchy="false">(</mo><mn>2</mn><mo>∗</mo><msub><mi>a</mi><mrow><mi>m</mi><mi>a</mi><mi>x</mi></mrow></msub><mo stretchy="false">)</mo><mi mathvariant="normal">/</mi><mn>256</mn></mrow></mstyle></mtd></mtr><mtr><mtd><mstyle scriptlevel="0" displaystyle="false"><mrow><msub><mi>x</mi><mi>q</mi></msub><mo>=</mo><mi mathvariant="normal">Clip</mi><mo>⁡</mo><mrow><mo fence="true">(</mo><mi mathvariant="normal">Round</mi><mo>⁡</mo><mrow><mo fence="true">(</mo><msub><mi>x</mi><mi>f</mi></msub><mi mathvariant="normal">/</mi><mtext>&nbsp;scale&nbsp;</mtext><mo fence="true">)</mo></mrow><mo fence="true">)</mo></mrow></mrow></mstyle></mtd></mtr></mtable><annotation encoding="application/x-tex">\begin{array}{c}a_{max} =\max \left(a b s\left(x_{f}\right)\right) \\\text { scale }=(2 * a_{max} ) / 256 \\x_{q}=\operatorname{Clip}\left(\operatorname{Round}\left(x_{f} / \text { scale }\right)\right)\end{array} </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:3.6000000000000005em;vertical-align:-1.5500000000000007em;"></span><span class="mord"><span class="mtable"><span class="arraycolsep" style="width:0.5em;"></span><span class="col-align-c"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:2.05em;"><span style="top:-4.21em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">m</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop">max</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;">(</span><span class="mord mathnormal">a</span><span class="mord mathnormal">b</span><span class="mord mathnormal">s</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10764em;">f</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mclose delimcenter" style="top:0em;">)</span></span><span class="mclose delimcenter" style="top:0em;">)</span></span></span></span><span style="top:-3.0099999999999993em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord text"><span class="mord">&nbsp;scale&nbsp;</span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mopen">(</span><span class="mord">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal">a</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.151392em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">m</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight">x</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mord">/</span><span class="mord">2</span><span class="mord">5</span><span class="mord">6</span></span></span><span style="top:-1.8099999999999994em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.15139200000000003em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.03588em;">q</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mop"><span class="mord mathrm">C</span><span class="mord mathrm">l</span><span class="mord mathrm">i</span><span class="mord mathrm">p</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;">(</span><span class="mop"><span class="mord mathrm">R</span><span class="mord mathrm">o</span><span class="mord mathrm">u</span><span class="mord mathrm">n</span><span class="mord mathrm">d</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.3361079999999999em;"><span style="top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.10764em;">f</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.286108em;"><span></span></span></span></span></span></span><span class="mord">/</span><span class="mord text"><span class="mord">&nbsp;scale&nbsp;</span></span><span class="mclose delimcenter" style="top:0em;">)</span></span><span class="mclose delimcenter" style="top:0em;">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.5500000000000007em;"><span></span></span></span></span></span><span class="arraycolsep" style="width:0.5em;"></span></span></span></span></span></span></span></p><p>为什么量化后模型推理变快了呢？可以从以下角度分析：</p><ul><li>吞吐量变大：使用更加快速的 INT8 内核进行计算，其速度更快</li><li>数据交换时间更少：尤其是数据交换从原来的 32 位，变为 8 位，交换的数据量减少到 1/8</li><li> 减少带宽占用：有些层有带宽限制（内存有限）。这意味着它们的实现将大部分时间用于读写数据，因此减少它们的计算时间并不会减少它们的总体运行时间。带宽限制层从减少的带宽需求中获益最大</li><li>减少内存占用：模型需要更少的存储空间，参数更新更小，缓存利用率更高</li></ul><h3 id="量化按技术路线"><a class="markdownIt-Anchor" href="#量化按技术路线"></a> 量化按技术路线</h3><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/202501121710846.png" alt="模型量化-20230704212354"></p><p><strong>训练后量化 (PTQ)</strong>：</p><ul><li>在训练得到的高精度模型后，<strong>统计权重和激活值的动态范围和量化参数（q-parms）后</strong>，再进行量化操作</li><li>分为 2 个量化阶段：1) 量化权重，这个很简单，因为权重可以直接被访问，所以可以很轻松计算得到其分布；2) 量化激活，这个比较麻烦，因为必须使用实际输入数据才能测出其分布</li><li><strong>校准 (Calibration)：</strong> 在量化激活时，输入代表性数据集，获得层间激活分布的过程</li><li>训练得到模型后再进行量化，按照是否提前确定激活值量化参数，又分为静态量化：提前通过校准确定量化参数；动态量化：前向推理过程中动态计算量化参数</li></ul><p><strong>量化感知 (QAT)：</strong></p><ul><li>训练后量化 (PTQ) 有时候出现无法接受的精度损失，这时候就需要使用量化感知训练 Quantization Aware Training (QAT)，它的主要思想是：在训练阶段包含加入量化算子，训练时学习量化参数，使得网络可以适应量化后的权值与激活</li><li>通过在训练图中插入量化操作 (Q) 和反量化操作 (DQ)，实现将量化误差包含在网络中，以使得 <strong>量化参数</strong> 更加符合网络，减少精度损失</li></ul><p>训练后量化核心过程的是校准，比较简单，量化感知训练涉及训练过程，比较复杂，我们来看看是怎么回事</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Pasted-image-20230103145704.png" alt="Pasted-image-20230103145704"></p><ul><li><strong>QAT</strong> 量化中最重要的就是 <strong>fake 量化算子</strong>，fake 算子负责将输入该算子的参数和输入<strong>先量化后反量化</strong>，然后记录这个 scale</li><li> 上图原始网络精度是 FP32，输入和权重因此也是 FP32</li><li>FQ (fake-quan) 算子会将 FP32 精度的输入和权重转化为 INT8 再转回 FP32，<strong>记住转换过程中的尺度信息</strong></li><li> FQ (fake-quan) 算子在 ONNX 中可以表示为 QDQ 算子</li></ul><p>在带量化算子的 pytorch 模型转为 onnx 后，模型中会带有 Q/DQ 的算子，那么在 tensorrt 部署框架如何使用这个模型呢？</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Pasted-image-20230103150631.png" alt="Pasted-image-20230103150631"></p><ul><li>输入 X 是 FP32 类型的 op，输出是 FP32，在输入 A 这个 op 时会经过 Q（即量化）操作，这个时候操作 A 我们会默认是 INT8 类型的操作，A 操作之后会经过 DQ（即反量化）操作将 A 输出的 INT8 类型的结果转化为 FP32 类型的结果并传给下一个 FP32 类型的 op</li><li> 有了 QDQ 信息，TensorRT 在解析模型的时候会根据 QDQ 的位置找到可量化的 op，然后与 QDQ 融合（<strong>吸收尺度信息到 OP 中</strong>）<br><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Pasted-image-20230103150719.png" alt="Pasted-image-20230103150719"></li></ul><p>在 tensorrt 优化这个 Q/DQ 操作的方式很多，比如 QDQ-ONNX 网络在输入到 TensorRT 中的时候，TensorRT 的算法会 propagate 整个网络，根据一些规则适当移动 Q/DQ 算子的位置，（<strong>需要尽可以拼凑出 QDQ 结构，使整个网络尽可能多的 op 变为量化算子</strong>）然后再执行 QDQ 融合策略<br><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Pasted-image-20230103151041.png" alt="Pasted-image-20230103151041"></p><p>然后尽可能将 DQ 算子推迟，推迟反量化操作 ±   尽可能将 Q 算子提前，提前量化操作<br><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Pasted-image-20230103150920.png" alt="Pasted-image-20230103150920"></p><p>参考：</p><ol><li><a href="https://zhuanlan.zhihu.com/p/405571578?utm_medium=social&amp;utm_oi=627160086720811008">Site Unreachable</a></li></ol>]]></content>
    
    
    <summary type="html">&lt;p&gt;总结模型量化的技术&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="D-深度学习部署" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/D-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E9%83%A8%E7%BD%B2/"/>
    
    
  </entry>
  
  <entry>
    <title>大分辨率图片下的多类别多目标识别</title>
    <link href="https://www.shaogui.life/posts/2793251604.html"/>
    <id>https://www.shaogui.life/posts/2793251604.html</id>
    <published>2025-01-25T05:53:32.000Z</published>
    <updated>2025-02-02T03:06:22.131Z</updated>
    
    <content type="html"><![CDATA[<p>在大分辨率图片的多类别多目标识别领域，技术的发展经历了从传统模板匹配到现代人工智能技术的转变。传统模板匹配方法，尽管在目标特征固定时表现出色，却难以适应目标的旋转、放大或部分遮挡，这限制了它在复杂场景中的应用。随着技术的进步，直接应用 AI 技术成为可能</p><span id="more"></span><p>本文该方法首先对大分辨率图片进行重叠切割，这一步骤允许我们更细致地处理图像的每个部分，避免了因缩放而丢失小目标的问题。随后，利用多线程技术并行处理各个分块，显著提升了处理速度。在此基础上，应用 AI 技术对每个分块进行快速而准确的识别，确保了识别的精度。最终，通过智能的结果融合策略，将各个分块的识别结果综合起来，不仅提高了识别的准确性，也增强了系统的鲁棒性。</p><h3 id="思路"><a class="markdownIt-Anchor" href="#思路"></a> 思路</h3><p>本发明提供一种大分辨率图片下的多类别多目标识别方法，包括以下步骤：</p><p>1.&nbsp;<strong>重叠切割：</strong> 在处理大分辨率图片时，我们首先采用重叠切割的方法。这种策略的引入是为了解决目标在图像边缘被切割导致 AI 无法学习的问题。以 14400x10800 像素的图片为例，我们将图片划分为 1024x1024 像素的块，同时保证每个块在上下左右方向上具有 20% 的重叠区域。这样的设计确保了即使目标对象位于图像的边缘，也能在至少一个分块中完整呈现。通过这种方法，我们最终能够从原始图像中切出 235 张具有重叠区域的图片，为后续的 AI 识别提供了更加精确的输入数据。<br>2.&nbsp;<strong>训练 AI 识别器</strong>：在完成图片的重叠切割之后，接下来的任务是对这些小图进行 AI 识别器的训练。我们选择 YOLOv10 模型作为识别器，该模型以其在目标检测任务中的速度和精度而著称。通过在切割后的小图上进行训练，YOLOv10 能够学习到多类别多目标的特征，从而在小图上实现高效的目标识别。这一步骤是实现快速且准确识别的关键，为后续的多线程并行处理打下了坚实的基础。<br>3.&nbsp;<strong>多线程识别大图</strong>：利用多线程技术，我们可以显著提高大图的识别效率。在这一步骤中，我们将利用多线程快速处理切割后的图片块，同时并行地应用训练好的 AI 识别器对每个分块进行分析。这种方法不仅能够充分利用现代多核处理器的计算能力，还能够大幅度缩短整体的处理时间，实现对大分辨率图片的快速响应。<br>4.&nbsp;<strong>多目标识别结果融合</strong>：最后，为了解决由于重叠切割带来的重复识别问题，我们采用了非极大值抑制（Non-Maximum Suppression, NMS）算法。在这一步骤中，我们根据目标识别的置信度对结果进行筛选和融合。通过比较不同分块中识别到的同一目标，我们保留置信度最高的结果，并抑制其他低置信度的识别，从而有效减少重复识别，提高识别的准确性。这一融合过程是确保最终结果既精确又一致的关键环节。</p><p><strong>创新点</strong><br>原始的 nms 直接按照预测打分排序预测框，然后使用 iou 过滤重叠框，这就要求对 “预测打分” 有足够自信，所谓足够自信，就是相信模型输出 “好预测框的分数” 比 “次一些的预测框分数” 高，这在单图优化确实可以做到。比如下图，假设图片有目标 A、B、C，其中红色框是标注的真实框，目标 A 的预测框有 2 个，分别标为 1、2，可以看出预测框 2 比 1 号预测框好。神经网络学习的目的是计算预测框和真实框的距离，然后再更新网络，所以预测框 2 比预测框 1 的距离更小，网络学习使得预测框 2 与预测框 1 的预测分数有竞争排斥关系，即预测框 2 越来越接近真实框，所以其分数越来越高，预测框 1 越来越远离真实框，其分数越来越低，这样最后才能通过高分数选择高质量的预测结果，因此可以得出结论：<strong>网络学习目的是输出越来越准确的预测框，并给出越来越高的预测分数。</strong></p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Drawing-20240909112754.excalidraw.png" alt="Drawing-20240909112754.excalidraw"></p><p>有了这个结论，可以认为当模型对同一目标进行预测时，预测分数越高其预测框越准确，那么就可以使用 nms 对选择高预测分数的预测框，过滤掉预测分数低的，但是这个传统的目标检测处理后处理方式在大分辨率重叠预测上不适用。</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Drawing-20240909112835.excalidraw.png" alt="Drawing-20240909112835.excalidraw"></p><p>假设图片进行重叠切图时，切图边缘刚好经过目标，图片 R1 只包含部分目标 A，R2 包含整体目标，网络训练时 R1、R2 图片分别进行，神经网络并不知道这两个目标是来自大图同一目标的不同部分，所以他们分数没有竞争排斥关系，所以网络学习的目的是各自推高框 1、框 2 的预测分数。这样会带来一个问题，预测框 1 只有一小部分，更容易学习，其预测分数比预测框 2 的预测分数更高，直接使用 nms 会过滤掉预测框 2，但是明显我们需要的是预测框 2。</p><p>关于使用框融合的问题，还是以上图为例，实际上预测框 2 已经接近真实框，如果预测框 1 和预测框 2 进行加权平均，其预测框效果更差，如果有多个类似预测框 1 时，效果就是靠近预测框 1，这还是无法实现 “保留预测框 2、过滤预测框 1” 的目的。</p><p>再回头思考一下，造成 nms 无法使用的原因是重叠切图带来的目标在不同图像块上多次预测，由于不同图像块的预测分数没有竞争排斥关系，所以分数最高的预测框不一定是框得最准的。既然是重叠切图造成的，那修改 nms，确保预测框面积较大的结果被保留，过滤面积较小，实践效果如下：</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E4%B8%80%E7%A7%8D%E5%AE%9E%E7%8E%B0%E5%A4%A7%E5%88%86%E8%BE%A8%E7%8E%87%E5%9B%BE%E7%89%87%E4%B8%8B%E7%9A%84%E5%A4%9A%E7%B1%BB%E5%88%AB%E5%A4%9A%E7%9B%AE%E6%A0%87%E8%AF%86%E5%88%AB%E6%96%B9%E6%B3%95-20250123135039.png" alt="一种实现大分辨率图片下的多类别多目标识别方法-20250123135039"></p><p>图片 1、2、3 分别是未进行 nms、根据预测分数进行 nms、根据面积进行 nms，从 1 可知着色区域的目标有两个预测结果，根据预测分数 nms 得到的框，预测分数更高，但是并没有框住目标，而根据面积进行 nms 的预测框虽然预测分数低一点，但是更贴合目标。</p><p>更进一步，使用面积替代预测分数选择预测框，能解决问题的原因是：假定任何目标在重叠切图下，必定有一个切块存在完整目标（预测面积最大的），如果存在一个大目标，横跨多个切图时，比如下图，切图 R1、R2 都无法预测到完整目标，只使用面积 nms 会从预测框 1 或预测框 2 选择其中一个或者认为有两个预测框，这都是不准确的。实际上，这里应该使用框融合，融合 2 个小框为 1 个大框。那么什么情况下只使用面积 nms，什么情况下使用框融合呢？</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Drawing-20240909112943.excalidraw.png" alt="Drawing-20240909112943.excalidraw"></p><p>本文通过计算 “重叠面积对待分析预测框面积占比 “的后处理方式，下面分小目标和大目标情况进行讨论，小目标的预测框 1、2 的重叠区域占各自面积百分比为 100%，40%，预测框 1 占比为 100%，这时按照面积 nms 过滤即可；大目标的预测框 1、2 的重叠区域占各自面积百分比为 30%，40%，说明两个框很大部分没重叠，此时使用 WBF 对这两个框进行融合。</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Drawing-20240909113014.excalidraw.png" alt="Drawing-20240909113014.excalidraw"></p><p>实践证明，加入 “重叠面积对待分析预测框面积占比” 的阈值后，可以解决单纯面积阈值过滤方式的两个问题，一是重叠时 nms 的 IOU 阈值大于实际 IOU 导致无法过滤的情况，如下图，标号 1、2、3 分别是无 nms、面积 nms、面积 nms+“重叠面积对待分析预测框面积占比” 阈值，从标号 1 看出，AI 输出 3 个检测框，从标号 2 可以看出，由于计算的 IOU 小于 nms 的过滤 IOU 阈值，所以有一个重叠框没被过滤，在使用 “重叠面积对待分析预测框面积占比” 阈值后，内部的预测框被过滤。</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E4%B8%80%E7%A7%8D%E5%AE%9E%E7%8E%B0%E5%A4%A7%E5%88%86%E8%BE%A8%E7%8E%87%E5%9B%BE%E7%89%87%E4%B8%8B%E7%9A%84%E5%A4%9A%E7%B1%BB%E5%88%AB%E5%A4%9A%E7%9B%AE%E6%A0%87%E8%AF%86%E5%88%AB%E6%96%B9%E6%B3%95-20240909113041.png" alt=""></p><p>还有一个优势是，重新组合预测框提升框的准确度，如下图标号 1、2、3 分别是无 nms、面积 nms、面积 nms+“重叠面积对待分析预测框面积占比” 阈值，从标号 1 可以看出，AI 输出 2 个预测框，仅使用面积过滤，只能过滤一个框，使用 “重叠面积对待分析预测框面积占比” 阈值后，对这两个框进行重新组合，得到一个新的框，新的框更加准确框住目标。</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E4%B8%80%E7%A7%8D%E5%AE%9E%E7%8E%B0%E5%A4%A7%E5%88%86%E8%BE%A8%E7%8E%87%E5%9B%BE%E7%89%87%E4%B8%8B%E7%9A%84%E5%A4%9A%E7%B1%BB%E5%88%AB%E5%A4%9A%E7%9B%AE%E6%A0%87%E8%AF%86%E5%88%AB%E6%96%B9%E6%B3%95-20240909113103.png" alt=""></p><p>所以，总结基于大分辨率下的 nms 思路如下：待分析框被保留认为 AI 预测了 2 个目标的检测框，而不是 1 个目标的重叠框，框融合使用取确认保留框与待分析框的横纵座标最小最大值即可，假设确认保留框的左上角和右下角座标为<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi> x</mi><mn>1</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>y</mi><mn>1</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>x</mi><mn>1</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>y</mi><mn>1</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup></mrow><annotation encoding="application/x-tex">x_1^{LT},y_1^{LT},x_1^{BR},y_1^{BR}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0894389999999998em;vertical-align:-0.24810799999999997em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span></span></span></span>，待分析框的左上角及右下角座标为<span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msubsup><mi> x</mi><mn>2</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>y</mi><mn>2</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>x</mi><mn>2</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>y</mi><mn>2</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup></mrow><annotation encoding="application/x-tex">x_2^{LT},y_2^{LT},x_2^{BR},y_2^{BR}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.0894389999999998em;vertical-align:-0.24810799999999997em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8413309999999999em;"><span style="top:-2.4518920000000004em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.063em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.24810799999999997em;"><span></span></span></span></span></span></span></span></span></span>，那么框融合后座标变为</p><p class="katex-block"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi>m</mi><mi>i</mi><mi>n</mi><mo stretchy="false">(</mo><msubsup><mi>x</mi><mn>1</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>x</mi><mn>2</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo stretchy="false">)</mo><mo separator="true">,</mo><msubsup><mrow><mi>min</mi><mo>⁡</mo><mo stretchy="false">(</mo><mi>y</mi></mrow><mn>1</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>y</mi><mn>2</mn><mrow><mi>L</mi><mi>T</mi></mrow></msubsup><mo stretchy="false">)</mo><mo separator="true">,</mo><msubsup><mrow><mi>m</mi><mi>a</mi><mi>x</mi><mo stretchy="false">(</mo><mi>x</mi></mrow><mn>1</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>x</mi><mn>2</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup><mo stretchy="false">)</mo><mo separator="true">,</mo><msubsup><mrow><mi>m</mi><mi>a</mi><mi>x</mi><mo stretchy="false">(</mo><mi>y</mi></mrow><mn>1</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup><mo separator="true">,</mo><msubsup><mi>y</mi><mn>2</mn><mrow><mi>B</mi><mi>R</mi></mrow></msubsup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">min(x_1^{LT},x_2^{LT}),{\min(y}_1^{LT},y_2^{LT}),{max(x}_1^{BR},x_2^{BR}),{max(y}_1^{BR},y_2^{BR})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.2809309999999998em;vertical-align:-0.29969999999999997em;"></span><span class="mord mathnormal">m</span><span class="mord mathnormal">i</span><span class="mord mathnormal">n</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913309999999999em;"><span style="top:-2.4530000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913309999999999em;"><span style="top:-2.4530000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord"><span class="mop">min</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.981231em;"><span style="top:-2.4003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.29969999999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913309999999999em;"><span style="top:-2.4530000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">L</span><span class="mord mathnormal mtight" style="margin-right:0.13889em;">T</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">m</span><span class="mord mathnormal">a</span><span class="mord mathnormal">x</span><span class="mopen">(</span><span class="mord mathnormal">x</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.981231em;"><span style="top:-2.4003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.29969999999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913309999999999em;"><span style="top:-2.4530000000000003em;margin-left:0em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal">m</span><span class="mord mathnormal">a</span><span class="mord mathnormal">x</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right:0.03588em;">y</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.981231em;"><span style="top:-2.4003em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span><span style="top:-3.2029em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.29969999999999997em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">y</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8913309999999999em;"><span style="top:-2.4530000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span><span style="top:-3.113em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right:0.05017em;">B</span><span class="mord mathnormal mtight" style="margin-right:0.00773em;">R</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.247em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Drawing-20240831133137.excalidraw.png" alt="Drawing-20240831133137.excalidraw"></p><h3 id="效果"><a class="markdownIt-Anchor" href="#效果"></a> 效果</h3><p>本文提出的大分辨率图片下的多类别多目标识别方法大大提高识别效率及识别精度，在对某一张包含 100 以上目标的 14400x10800 分辨率的图片分析实践中，使用模板匹配的技术耗时 5 分钟，识别准确率只有 60%，而直接在大图上使用 AI 的技术耗时 1 分钟，识别准确率达到 85%，而采用本文的技术，可以在 5 秒内实现 98% 以上的识别效果，在速度和精度上超越现有方法。</p>]]></content>
    
    
    <summary type="html">&lt;p&gt;在大分辨率图片的多类别多目标识别领域，技术的发展经历了从传统模板匹配到现代人工智能技术的转变。传统模板匹配方法，尽管在目标特征固定时表现出色，却难以适应目标的旋转、放大或部分遮挡，这限制了它在复杂场景中的应用。随着技术的进步，直接应用 AI 技术成为可能&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="B-视觉模型" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/B-%E8%A7%86%E8%A7%89%E6%A8%A1%E5%9E%8B/"/>
    
    <category term="1-基础视觉任务CNN" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/B-%E8%A7%86%E8%A7%89%E6%A8%A1%E5%9E%8B/1-%E5%9F%BA%E7%A1%80%E8%A7%86%E8%A7%89%E4%BB%BB%E5%8A%A1CNN/"/>
    
    <category term="目标检测" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/B-%E8%A7%86%E8%A7%89%E6%A8%A1%E5%9E%8B/1-%E5%9F%BA%E7%A1%80%E8%A7%86%E8%A7%89%E4%BB%BB%E5%8A%A1CNN/%E7%9B%AE%E6%A0%87%E6%A3%80%E6%B5%8B/"/>
    
    
  </entry>
  
  <entry>
    <title>使用 Torch_TensorRT 量化分割模型</title>
    <link href="https://www.shaogui.life/posts/2058951322.html"/>
    <id>https://www.shaogui.life/posts/2058951322.html</id>
    <published>2025-01-24T00:45:12.000Z</published>
    <updated>2025-02-03T00:09:12.161Z</updated>
    
    <content type="html"><![CDATA[<p>本文使用 Torch_TensorRT 量化 deeplapv3 + 模型，</p><span id="more"></span><p>Torch-TensorRT 是 PyTorch 的推理编译器，通过 NVIDIA 的 TensorRT 深度学习优化器，运行时以 NVIDIA GPU 为目标。 它通过界面支持即时 （JIT） 编译工作流以及预先 （AOT） 工作流。 Torch-TensorRT 无缝集成到 PyTorch 生态系统中，支持将优化的 TensorRT 代码与标准 PyTorch 代码混合执行。</p><p>由于 Torch-TensorRT 接受 torchScript 输入，优化后输出 ts 模型，所以下文将从 pytorch\ptq\qat 三个方向测试 Torch-TensorRT</p><p>为了评估量化水平，我们定义一个评估函数</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Helper function to benchmark the model</span></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">benchmark</span>(<span class="params">model, input_shape=(<span class="params"><span class="number">1024</span>, <span class="number">1</span>, <span class="number">32</span>, <span class="number">32</span></span>), dtype=<span class="string">'fp32'</span>, nwarmup=<span class="number">50</span>, nruns=<span class="number">800</span></span>):</span><br><span class="line">    input_data = torch.randn(input_shape)</span><br><span class="line">    input_data = input_data.to(<span class="string">"cuda"</span>)</span><br><span class="line">    <span class="keyword">if</span> dtype==<span class="string">'fp16'</span>:</span><br><span class="line">        input_data = input_data.half()</span><br><span class="line">        </span><br><span class="line">    <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">        <span class="keyword">for</span> _ <span class="keyword">in</span> <span class="built_in">range</span>(nwarmup):</span><br><span class="line">            features = model(input_data)</span><br><span class="line">    torch.cuda.synchronize()</span><br><span class="line">    timings = []</span><br><span class="line">    <span class="keyword">with</span> torch.no_grad():</span><br><span class="line">        <span class="keyword">for</span> i <span class="keyword">in</span> <span class="built_in">range</span>(<span class="number">1</span>, nruns+<span class="number">1</span>):</span><br><span class="line">            start_time = time.time()</span><br><span class="line">            output = model(input_data)</span><br><span class="line">            torch.cuda.synchronize()</span><br><span class="line">            end_time = time.time()</span><br><span class="line">            timings.append(end_time - start_time)</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">'Average batch time: %.2f ms,median:%2f-%2f:'</span>%(np.mean(timings)*<span class="number">1000</span>,</span><br></pre></td></tr></tbody></table></figure><h3 id="base-torchscript-torch_tensorrt"><a class="markdownIt-Anchor" href="#base-torchscript-torch_tensorrt"></a> base-&gt;TorchScript-&gt;Torch_tensorRT</h3><p>基于原始的 pth，编译为 torchscript，然后再使用 Torch_tensorRT 优化</p><p>首先导入 pth 模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">from</span> modelingv2.deeplab <span class="keyword">import</span> DeepLab</span><br><span class="line">model=DeepLab(in_channels=<span class="number">3</span>,num_classes=<span class="number">2</span>,pretrained=<span class="literal">False</span>)</span><br><span class="line">model = model.cuda()</span><br><span class="line"><span class="comment"># mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.</span></span><br><span class="line">ckpt = torch.load(<span class="string">"./models/deeplabv3plus_base.pth"</span>)</span><br><span class="line">modified_state_dict={}</span><br><span class="line"><span class="keyword">for</span> key, val <span class="keyword">in</span> ckpt.items():</span><br><span class="line">    <span class="comment"># Remove 'module.' from the key names</span></span><br><span class="line">    <span class="keyword">if</span> key.startswith(<span class="string">'module'</span>):</span><br><span class="line">        modified_state_dict[key[<span class="number">7</span>:]] = val</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        modified_state_dict[key] = val</span><br><span class="line"><span class="comment"># Load the pre-trained checkpoint</span></span><br><span class="line">model.load_state_dict(modified_state_dict)</span><br><span class="line">model = model.cuda()</span><br></pre></td></tr></tbody></table></figure><p>其次，导出为 torchscript</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Exporting to TorchScript</span></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    data = <span class="built_in">iter</span>(train_dataloader)</span><br><span class="line">    images, _ = data.<span class="built_in">next</span>()</span><br><span class="line">    jit_model = torch.jit.trace(model, images.to(<span class="string">"cuda"</span>))</span><br><span class="line">    torch.jit.save(jit_model, <span class="string">"models/deeplabv3plus_base.jit.pt"</span>)</span><br><span class="line">benchmark(jit_model, input_shape=(<span class="number">16</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><p>最后使用 Torch_tensorRT 优化</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#Loading the Torchscript model and compiling it into a TensorRT model</span></span><br><span class="line">baseline_model = torch.jit.load(<span class="string">"models/deeplabv3plus_base.jit.pt"</span>).<span class="built_in">eval</span>()</span><br><span class="line"></span><br><span class="line">compile_spec = {<span class="string">"inputs"</span>: [torch_tensorrt.Input([<span class="number">4</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>])],</span><br><span class="line">                <span class="string">"enabled_precisions"</span>: torch.<span class="built_in">float</span>,</span><br><span class="line">                <span class="string">"truncate_long_and_double"</span>: <span class="literal">True</span></span><br><span class="line">               }</span><br><span class="line">trt_base = torch_tensorrt.<span class="built_in">compile</span>(baseline_model, **compile_spec)</span><br><span class="line">torch.jit.save(trt_base, <span class="string">"models/deeplabv3plus_base_trt.ts"</span>)</span><br><span class="line">benchmark(trt_base, input_shape=(<span class="number">16</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="base-ptq-torchscript-torch_tensorrt"><a class="markdownIt-Anchor" href="#base-ptq-torchscript-torch_tensorrt"></a> base-&gt;ptq-&gt;TorchScript-&gt;Torch_tensorRT</h3><p>首先导入 pth 模型，并使用 pytorch_quantization 进行量化</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br><span class="line">53</span><br><span class="line">54</span><br><span class="line">55</span><br><span class="line">56</span><br><span class="line">57</span><br><span class="line">58</span><br><span class="line">59</span><br></pre></td><td class="code"><pre><span class="line">quant_modules.initialize()</span><br><span class="line">q_model=DeepLab(in_channels=<span class="number">3</span>,num_classes=<span class="number">2</span>,pretrained=<span class="literal">False</span>)</span><br><span class="line">q_model = q_model.cuda()</span><br><span class="line">ckpt = torch.load(<span class="string">"./models/deeplabv3plus_base.pth"</span>)</span><br><span class="line">modified_state_dict={}</span><br><span class="line"><span class="keyword">for</span> key, val <span class="keyword">in</span> ckpt.items():</span><br><span class="line">    <span class="comment"># Remove 'module.' from the key names</span></span><br><span class="line">    <span class="keyword">if</span> key.startswith(<span class="string">'module'</span>):</span><br><span class="line">        modified_state_dict[key[<span class="number">7</span>:]] = val</span><br><span class="line">    <span class="keyword">else</span>:</span><br><span class="line">        modified_state_dict[key] = val</span><br><span class="line"><span class="comment"># Load the pre-trained checkpoint</span></span><br><span class="line">q_model.load_state_dict(modified_state_dict)</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">compute_amax</span>(<span class="params">model, **kwargs</span>):</span><br><span class="line">    <span class="comment"># Load calib result</span></span><br><span class="line">    <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">        <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">            <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                <span class="keyword">if</span> <span class="built_in">isinstance</span>(module._calibrator, calib.MaxCalibrator):</span><br><span class="line">                    module.load_calib_amax()</span><br><span class="line">                <span class="keyword">else</span>:</span><br><span class="line">                    module.load_calib_amax(**kwargs)</span><br><span class="line">    model.cuda()</span><br><span class="line"></span><br><span class="line"><span class="keyword">def</span> <span class="title function_">collect_stats</span>(<span class="params">model, data_loader, num_batches</span>):</span><br><span class="line">    <span class="string">"""Feed data to the network and collect statistics"""</span></span><br><span class="line">    <span class="comment"># Enable calibrators</span></span><br><span class="line">    <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">        <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">            <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                module.disable_quant()</span><br><span class="line">                module.enable_calib()</span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                module.disable()</span><br><span class="line">    <span class="comment"># Feed data to the network for collecting stats</span></span><br><span class="line">    <span class="keyword">for</span> i, (image, _) <span class="keyword">in</span> tqdm(<span class="built_in">enumerate</span>(data_loader), total=num_batches):</span><br><span class="line">        model(image.cuda())</span><br><span class="line">        <span class="keyword">if</span> i &gt;= num_batches:</span><br><span class="line">            <span class="keyword">break</span></span><br><span class="line">    <span class="comment"># Disable calibrators</span></span><br><span class="line">    <span class="keyword">for</span> name, module <span class="keyword">in</span> model.named_modules():</span><br><span class="line">        <span class="keyword">if</span> <span class="built_in">isinstance</span>(module, quant_nn.TensorQuantizer):</span><br><span class="line">            <span class="keyword">if</span> module._calibrator <span class="keyword">is</span> <span class="keyword">not</span> <span class="literal">None</span>:</span><br><span class="line">                module.enable_quant()</span><br><span class="line">                module.disable_calib()</span><br><span class="line">            <span class="keyword">else</span>:</span><br><span class="line">                module.enable()</span><br><span class="line"><span class="comment"># Calibrate the model using max calibration technique.</span></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    collect_stats(q_model, train_dataloader, num_batches=<span class="number">20</span>)</span><br><span class="line">    compute_amax(q_model, method=<span class="string">"max"</span>)</span><br><span class="line">    <span class="comment"># compute_amax(q_model, method="entropy")</span></span><br><span class="line">    <span class="comment"># compute_amax(q_model, method="percentile")</span></span><br><span class="line">    <span class="comment"># compute_amax(q_model, method="mse")</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    collect_stats(q_model, val_dataloader, num_batches=<span class="number">20</span>)</span><br><span class="line">    compute_amax(q_model, method=<span class="string">"max"</span>)</span><br></pre></td></tr></tbody></table></figure><p>其次，导出为 torchscript</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">quant_nn.TensorQuantizer.use_fb_fake_quant = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># Exporting to TorchScript</span></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    data = <span class="built_in">iter</span>(train_dataloader)</span><br><span class="line">    images, _ = data.<span class="built_in">next</span>()</span><br><span class="line">    jit_model = torch.jit.trace(q_model, images.to(<span class="string">"cuda"</span>))</span><br><span class="line">    torch.jit.save(jit_model, <span class="string">"models/deeplabv3plus_ptq.jit.pt"</span>)</span><br><span class="line">benchmark(jit_model, input_shape=(<span class="number">16</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><p>最后使用 Torch_tensorRT 优化</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#Loading the Torchscript model and compiling it into a TensorRT model</span></span><br><span class="line">ptq_model = torch.jit.load(<span class="string">"models/deeplabv3plus_ptq.jit.pt"</span>).<span class="built_in">eval</span>()</span><br><span class="line"></span><br><span class="line">compile_spec = {<span class="string">"inputs"</span>: [torch_tensorrt.Input([<span class="number">4</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>])],</span><br><span class="line">                <span class="string">"enabled_precisions"</span>: torch.int8,</span><br><span class="line">               }</span><br><span class="line">trt_ptq = torch_tensorrt.<span class="built_in">compile</span>(ptq_model, **compile_spec)</span><br><span class="line">torch.jit.save(trt_base, <span class="string">"models/deeplabv3plus_ptq_trt.ts"</span>)</span><br><span class="line">benchmark(trt_ptq, input_shape=(<span class="number">16</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="ptq-qat-torchscript-torch_tensorrt"><a class="markdownIt-Anchor" href="#ptq-qat-torchscript-torch_tensorrt"></a> ptq-&gt;qat-&gt;TorchScript-&gt;Torch_tensorRT</h3><p>首先使用 qat 优化模型</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">def</span> <span class="title function_">train</span>(<span class="params">model, dataloader, crit, opt</span>):</span><br><span class="line">    model.train()</span><br><span class="line">    <span class="keyword">for</span> batch, (data, labels) <span class="keyword">in</span> <span class="built_in">enumerate</span>(dataloader):</span><br><span class="line">        data, labels = data.cuda(), labels.cuda()</span><br><span class="line">        opt.zero_grad()</span><br><span class="line">        outputs = model(data)</span><br><span class="line">        loss = Focal_Loss(outputs, labels)+Dice_loss(outputs, labels)</span><br><span class="line">        loss.backward()</span><br><span class="line">        opt.step()</span><br><span class="line"></span><br><span class="line">crit=torch.nn.CrossEntropyLoss()</span><br><span class="line">optimizer = torch.optim.Adam(model.parameters(),lr=<span class="number">1e-4</span>,weight_decay=<span class="number">0.9</span>)</span><br><span class="line">q_model=q_model.train()</span><br><span class="line"><span class="comment"># Finetune the QAT model for 2 epochs</span></span><br><span class="line">num_epochs=<span class="number">10</span></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> <span class="built_in">range</span>(num_epochs):</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">'Epoch: [%5d / %5d]'</span> % (epoch + <span class="number">1</span>, num_epochs))</span><br><span class="line">    train(q_model, train_dataloader, crit, optimizer)</span><br><span class="line">    test_loss,acc = evaluate(q_model, val_dataloader, crit)</span><br><span class="line">    <span class="built_in">print</span>(<span class="string">"Test Loss: {:.5f} Test acc {:.2f}%"</span>.<span class="built_in">format</span>(test_loss,acc*<span class="number">100</span>))</span><br></pre></td></tr></tbody></table></figure><p>其次，导出为 torchscript</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">quant_nn.TensorQuantizer.use_fb_fake_quant = <span class="literal">True</span></span><br><span class="line">q_model=q_model.<span class="built_in">eval</span>()</span><br><span class="line"></span><br><span class="line"><span class="comment"># Exporting to TorchScript</span></span><br><span class="line"><span class="keyword">with</span> torch.no_grad():</span><br><span class="line">    data = <span class="built_in">iter</span>(train_dataloader)</span><br><span class="line">    images, _ = data.<span class="built_in">next</span>()</span><br><span class="line">    jit_model = torch.jit.trace(q_model, images.to(<span class="string">"cuda"</span>))</span><br><span class="line">    torch.jit.save(jit_model, <span class="string">"models/deeplabv3plus_qat.jit.pt"</span>)</span><br><span class="line">benchmark(jit_model, input_shape=(<span class="number">16</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><p>最后使用 Torch_tensorRT 优化</p><figure class="highlight python"><table><tbody><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment">#Loading the Torchscript model and compiling it into a TensorRT model</span></span><br><span class="line">qat_model = torch.jit.load(<span class="string">"models/deeplabv3plus_qat.jit.pt"</span>).<span class="built_in">eval</span>()</span><br><span class="line"></span><br><span class="line">compile_spec = {<span class="string">"inputs"</span>: [torch_tensorrt.Input([<span class="number">4</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>])],</span><br><span class="line">                <span class="string">"enabled_precisions"</span>: torch.int8</span><br><span class="line">               }</span><br><span class="line">trt_qat = torch_tensorrt.<span class="built_in">compile</span>(qat_model, **compile_spec)</span><br><span class="line">torch.jit.save(trt_base, <span class="string">"models/deeplabv3plus_qat_trt.ts"</span>)</span><br><span class="line">benchmark(trt_base, input_shape=(<span class="number">16</span>, <span class="number">3</span>, <span class="number">512</span>, <span class="number">512</span>), nruns=<span class="number">100</span>)</span><br></pre></td></tr></tbody></table></figure><h3 id="耗时汇总"><a class="markdownIt-Anchor" href="#耗时汇总"></a> 耗时汇总</h3><p>总结以上 3 个流程的耗时如下</p><table><thead><tr><th>实验</th><th> jit 结果情况</th><th> jit 耗时</th><th> Torch_TensorRT 结果</th><th> Torch_TensorRT 耗时</th></tr></thead><tbody><tr><td> base</td><td> 正确</td><td> 81.43</td><td> 正确</td><td></td></tr><tr><td> ptq</td><td> 正确</td><td> 94.3</td><td> 正确</td><td></td></tr><tr><td> qat</td><td> 正确</td><td> 94.5</td><td></td><td></td></tr><tr><td></td><td></td><td></td><td></td><td></td></tr></tbody></table>]]></content>
    
    
    <summary type="html">&lt;p&gt;本文使用 Torch_TensorRT 量化 deeplapv3 + 模型，&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="D-深度学习部署" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/D-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E9%83%A8%E7%BD%B2/"/>
    
    
  </entry>
  
  <entry>
    <title>RAG 演进 03-ModularRAG</title>
    <link href="https://www.shaogui.life/posts/2448275523.html"/>
    <id>https://www.shaogui.life/posts/2448275523.html</id>
    <published>2025-01-23T03:41:37.000Z</published>
    <updated>2025-02-02T02:52:41.511Z</updated>
    
    <content type="html"><![CDATA[<p>随着 RAG 领域发展和更多工程上的应用，RAG 框架需要更多加多样化和灵活，因此需要抽象为 ModularRAG</p><span id="more"></span><p>Modular RAG 是具有<strong>⾼度扩展的范式</strong>，它将 RAG 系统拆分为 Molule Type (Indexing) - Module (Chunk/Structrural)-Operator (…) 的三层结构</p><p>![[认识 ModularRAG-20250123114804.png]]</p><p>有了以上这些 Operator，既可以组合出不同的 RAG 应用，常用的有以下：</p><h4 id="1-sequential"><a class="markdownIt-Anchor" href="#1-sequential"></a> 1. Sequential</h4><p>线性的结构的 RAG Flow，模块线性的组织成流⽔线，如果拥有 Pre-Retrieval 和 Post-Retrieval 两个 Module Type，则是典型的 Advanced RAG 范式，如果去掉则是典型的 Naive RAG 范式。</p><p>![[NextRAG-20250115222238.png]]</p><p>Sequential 是⽬前使⽤最多的 RAG Pipeline，其中在最常使⽤的搭配如下，在检索前增加 Query Rewrite，在检索后增加 Rerank 的算⼦。例如 QAnything。</p><p>![[NextRAG-20250115222238-1.png]]</p><p>Rewrite-Retrieve-Read (RRR) 也是典型的序列结构。其中 jQuery&nbsp;Rewrite 模块是⼀个⼩型的可训练的语⾔模型，并通过最终 LLM 的输出结果作为奖励。在强化学习的背景下，重写器优化被形式化为⼀个⻢尔科夫决策过程。检索器选⽤了稀疏编码器 BM25。</p><p>![[NextRAG-20250115222239.png]]</p><h4 id="2-conditional"><a class="markdownIt-Anchor" href="#2-conditional"></a> 2. Conditional</h4><p>条件结构的 RAG Flow，根据不同的条件选择不同的 RAG 路线。通常由⼀个 Routing 模块进⾏路由，判断依据包括通常包括 Query 的关键词或语义。路由到不同的路线，通常根据问题的类型，适⽤的场景路由到不同的 Flow 中。例如当⽤户提问到严肃的问题，政治问题或是娱乐问题，对⼤模型幻觉的容忍度是不同的。不同路由分⽀通常在检索源、检索流程、配置信息、模型选择和 Prompt 上进⾏差异化。</p><p>![[NextRAG-20250115222240.png]]</p><p>⼀个 Conditional RAG 的经典 Implementation 是 semantic&nbsp;Router。</p><h4 id="3-branching"><a class="markdownIt-Anchor" href="#3-branching"></a> 3. Branching</h4><p>分⽀结构的 RAG Flow。不同于 Conditional 中是要在多条分⽀中选择⼀条，Branching 则是有多个分⽀并⾏。从结构上可以分成两类：</p><ul><li><strong>检索前分⽀</strong>&nbsp;(Multi-Query, Parallel Retrieval)。对原始 Query 进⾏扩展，得到多个⼦ Query，然后对每⼀个⼦ Query 分别进⾏检索，检索后就可以选择⽴即根据⼦问题和对应检索来的内容⽣成答案，也可以只使⽤拓展检索出来的内容最后合并到统⼀上下⽂中进⾏⽣成。</li><li><strong>检索后分⽀</strong>&nbsp;<strong>(Single Query, Parallel Generation)</strong>。保持原来的 Query，检索到多个⽂档块后，并⾏使⽤原始 Query 和每⼀个⽂档块进⾏⽣成，最后将⽣成的结果合并到⼀起。</li></ul><p>![[NextRAG-20250115222240-1.png]]</p><p>REPLUG 就是⼀个典型的检索后分⽀的分结构，根据每⼀个分⽀预测 token 的概率，通过 Weighted possibilityEnsemble 将不同的分⽀聚合，并通过最后⽣成结果作作为反馈微调检索器 Contriever。</p><p>![[NextRAG-20250115222241.png]]</p><h4 id="4-loop"><a class="markdownIt-Anchor" href="#4-loop"></a> 4. Loop</h4><p>具有环状结构的 RAG Flow，这也是的 Modular RAG 的⼀个重要特点，检索和推理步骤相互影响的。通常包括⼀个 Judge 模块，⽤于控制流程。具体⼜可以分成迭代、递归和主动检索三种。</p><p>![[NextRAG-20250115222242.png]]</p><h4 id="5-iterative-retrieval"><a class="markdownIt-Anchor" href="#5-iterative-retrieval"></a> 5. &nbsp;Iterative Retrieval</h4><p>有时候单次检索和⽣成的并不能很好的解决⼀些需要⼤量知识的复杂的问题。因此可以使⽤迭代的⽅式进⾏ RAG, 通常来说迭代检索都有⼀个固定的迭代次数。迭代检索⼀个典型的案例是是 ITER-RETGEN。</p><p>在每次迭代中，ITER-RETGEN 利⽤前⼀次迭代的模型输出作为特 定上下⽂，帮助检索更相关的知识，这可能有助于改进模型⽣成。循序的终⽌通过预设的迭代次数来判断。</p><p>![[NextRAG-20250115222243.png]]</p><h4 id="6-recursive-retrieval"><a class="markdownIt-Anchor" href="#6-recursive-retrieval"></a> 6. Recursive Retrieval</h4><p>不同于迭代检索，递归检索的特点是有明显依赖上⼀步并不断深⼊的检索。通常有判断机制作为递归检索的出口。在 RAG 系统中，递归检索的通常要搭配 Query Transformation，每次检索时依赖于新改写后的 Query。</p><p>⼀个典型的递归检索实现例如 ToC。从初始问题 (Ambiguous&nbsp;Question,AQ) , 通过递归执⾏ RAC（递归澄清⽅法，<strong>Retrieval-Augmented Clarification</strong>）逐步插⼊⼦节点到澄清树中，在每个扩展步骤中，根据当前查询重新对段落进⾏重新排名并⽣成⼀个 (Disambiguous Question,DQ)。树的探索在达到了最⼤数量的有效节点或最⼤深度时结束。构建了澄清树后，TOC 收集所有有效节点并⽣成⼀个全⾯的⻓⽂本答案来回答 AQ。</p><p>![[NextRAG-20250115222244.png]]</p><h4 id="7-adaptive-active-retrieval"><a class="markdownIt-Anchor" href="#7-adaptive-active-retrieval"></a> 7. Adaptive (Active) Retrieval</h4><p>随着 RAG 的发展，逐步超越被动的检索的⽅式，出现了⾃适应的检索（也被称作主动检索），这⼀⽅⾯也是受益于 LLM 的强⼤能⼒。在核⼼思想上与 LLM Agent 相似。</p><p>RAG 系统可以主动判断的检索时机，以及判断时候结束整个流程，输出最终的结果。根据判断的依据，⼜可以分成和 Prompt-base 和 Tuning-base。</p><ul><li>**Prompt-base.** 通过 Prompt Engineering 的⽅式让 LLM 对流程进⾏控制。⼀个典型的实现案例是 FLARE。它的核⼼思想是 LM 应该仅在缺乏所需知识时进⾏检索，以避免被动检索增强的 LM 中出现不必要或不适当的检索。FLARE 迭代地⽣成下⼀个临时句⼦，并检查是否包含低概率标记。如果是这样，系统将检索相关⽂档并重新⽣成句⼦。</li></ul><p>![[NextRAG-20250115222244-1.png]]</p><ul><li><strong>Tuning-base</strong>. 对 LLM 进⾏微调使其⽣成特殊的 token，以此来触发检索或⽣成。这种思想可以追溯到 Toolformer 中，通过⽣成特俗的内容，来辅助调⽤⼯具。在 RAG 系统中则是⽤于控制检索和⽣成两个步骤。⼀个典型的案例是 Self-RAG。具体⽽⾔，</li></ul><p>（1）给定⼀个输⼊提示，和前⾯的⽣成结果，⾸先预测特殊 token “Retrieve" 判断是否通过检索段落对继续的⽣成进⾏增强是有帮助。<br>（2）如果有帮助，调⽤检索模型。模型会⽣成⼀个 critique token 来评估检索段的相关 性，下⼀个响应⽚段，和⼀个批判令牌来评估响应⽚段中的信息是否得到了检索段的⽀持。<br>（3）最后，⼀个新的批判令牌评估响应的整体效⽤。模型会并⾏处理这些内容，并选择最佳结果作为最终的输出。</p><p>![[NextRAG-20250115222245.png]]</p><h3 id="最佳行业案例"><a class="markdownIt-Anchor" href="#最佳行业案例"></a> 最佳行业案例</h3><h4 id="openai"><a class="markdownIt-Anchor" href="#openai"></a> OpenAI</h4><p>从 OpenAI Demo day 的演讲整理所得，并不能完全代表 OpenAI 的实际操作。在提升 RAG 的成功案例中，OpenAI 团队从 45% 的准确率开始，尝试了多种⽅法并标记哪些⽅法最终被采⽤到⽣产中。他们尝试了假设性⽂档嵌⼊（HyDE）和精调嵌⼊等⽅法，但效果并不理想。通过尝试不同⼤⼩块的信息和嵌⼊不同的内容部分，他们将准确率提升到 65%。通过 Reranking 和对不同类别问题特别处理的⽅法，他们进⼀步提升到 85% 的准确率。最终，通过提示⼯程、查询扩展和其他⽅法的结合，他们达到了 98% 的准确率。团队强调了模型精调和 RAG 结合使⽤时的强⼤潜⼒，尤其是在没有使⽤复杂技术的情况下，仅通过简单的模型精调和提示⼯程就接近了⾏业领先⽔平。</p><p>![[NextRAG-20250115222248.png]]</p><h4 id="baichuan"><a class="markdownIt-Anchor" href="#baichuan"></a> Baichuan</h4><p>基于百川的宣传资料整理（查看原⽂)。针对⽤户⽇益复杂的问题，百川借鉴了 Meta 的 <strong>CoVe</strong> 技术，将复杂 Prompt 拆分为多个独⽴且可并⾏检索的搜索友好型查询。利⽤⾃研的 <strong>TS（Think-Step Further)</strong> 技术来推断和挖掘⽤户输⼊背后更深层的问题，以更精准、全⾯地理解⽤户意图。在检索步骤中，百川智能⾃研了 Baichuan-TextEmbedding 向量模型。同时引⼊<strong>稀疏检索</strong>和&nbsp;<strong>rerank</strong>&nbsp;模型（未披露），形成向量检索与稀疏检索并⾏的混合检索⽅式，⼤幅提升了⽬标⽂档的召回率。此外还引⼊了 <strong>self-Critique</strong> 让⼤模型基于 Prompt、从相关性和可⽤性等⻆度对检索回来的内容⾃省，进⾏⼆次查看，从中筛选出与 Prompt 最匹配、最优质的候选内容。</p><p>![[NextRAG-20250115222249.png]]</p><h4 id="datdatabricks"><a class="markdownIt-Anchor" href="#datdatabricks"></a> DatDatabricks</h4><p>作为⼤数据领域中领先的服务商，在 RAG 设计上依然保持了⾃⼰特点和优势（查看原⽂）。⽤户输⼊问题，通过从事先处理好的⽂本向量索引⾥⾯获取问题相关信息，加上提示词⼯程，⽣成回答。上半部分 Unstructured Data pipeline 就输主流的 RAG ⽅法，并没有特殊之处。</p><p>![[NextRAG-20250115222249-1.png]]</p><p>下半部分为 Structured Data Pipeline，是 Databricks 特征⼯程处理流程，也是 Databricks RAG 最⼤的特点。Databricks 从⾃身专业的⼤数据⻆度出发，从原来的准确度较⾼的数据存储中进⾏额外的检索，充分发挥⾃身在 Real Time Data Serving 上的优势。可以看到 Databricks 在 GenAI 时代的策略是助具有⼴泛市场需求的 RAG 应⽤，将⾃身强⼤的 Lakehouse 数据处理能与⽣成式 AI 技术深度融合，构建出⼀体化解决⽅案。abricks</p><p>参考：</p><ol><li><a href="https://www.falkordb.com/blog/advanced-rag/">Advanced RAG Techniques: What They Are &amp; How to Use Them</a></li><li><a href="https://blog.csdn.net/TgqDT3gGaMdkHasLZv/article/details/135985279">技术动态 | 模块化（Modular）RAG 和 RAG Flow-CSDN 博客</a></li><li><a href="https://zhuanlan.zhihu.com/p/722159912">Site Unreachable</a></li></ol>]]></content>
    
    
    <summary type="html">&lt;p&gt;随着 RAG 领域发展和更多工程上的应用，RAG 框架需要更多加多样化和灵活，因此需要抽象为 ModularRAG&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
  <entry>
    <title>RAG 技术 07 - 生成</title>
    <link href="https://www.shaogui.life/posts/1737521514.html"/>
    <id>https://www.shaogui.life/posts/1737521514.html</id>
    <published>2025-01-23T03:14:32.000Z</published>
    <updated>2025-02-02T02:52:41.372Z</updated>
    
    <content type="html"><![CDATA[<p>这个过程是以检索到的结果作为上下文，回答用户提问</p><span id="more"></span><h3 id="回答方式"><a class="markdownIt-Anchor" href="#回答方式"></a> 回答方式</h3><p>RAG 针对多个检索文本，有不同的利用方式，可归纳为以下 3 种：</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/202501142325523.png" alt=""></p><table><thead><tr><th>文档组合方式</th><th>原理</th><th>优缺点</th></tr></thead><tbody><tr><td> concate</td><td> 直接将所有检索联合起来，和问题一起输入 llm 回答</td><td>利用了所有文档，但是内容太多，可能被截断，无法适应任意长度上下文</td></tr><tr><td> Map_reduce</td><td>llm 针对每个检索文档回答问题，最终再用 llm 选择答案</td><td>每个检索内容相互独立，如果问题需要多个检索内容的组合，则无法回答</td></tr><tr><td> Refine</td><td>llm 迭代式地使用检索文档回答问题，每次保留较好结果</td><td>递进式输出答案，结果通常较好，而且可以适应任意长度上下文</td></tr><tr><td> Map_Rerank</td><td>llm 针对每个文档回答问题并输出评分，选择高评分输出</td><td>每个检索内容相互独立，如果问题需要多个检索内容的组合，则无法回答</td></tr></tbody></table><h3 id="self-rag"><a class="markdownIt-Anchor" href="#self-rag"></a> self-RAG</h3><p>因为使用 llm，这个过程设计 prompt，所以很多 prompt 的技巧可以使用到这里，包括思维链等技术</p><p>甚至将 Agent 引入 llm，自动判断答案与问题的相关性，然后调整自己后续的行动，Self-RAG 就是这种技术，该模型不仅依赖于初始检索，而且通过生成后续查询和响应来主动重新评估和调整其方法。此迭代过程允许模型纠正自己的错误、填补空白并提高最终输出的质量</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E4%BD%BF%E7%94%A8Self-RAG%E8%BF%AD%E4%BB%A3%E7%AD%94%E6%A1%88-20241216140253.png" alt="使用Self-RAG迭代答案-20241216140253"></p><h3 id="评估"><a class="markdownIt-Anchor" href="#评估"></a> 评估</h3><p>针对 llm 生成的答案，评估答案的准确性，一般分为两种：</p><ul><li>人工评估：评估答案是否响应了问题，以及和利用上下文的水平</li><li>模型评估：使用 llm 模型来评估答案，甚至可以使用 llm 生成 (问题、上下文、回答) 的三元组，然后通过比较问题与答案，自动评估模型性能</li></ul>]]></content>
    
    
    <summary type="html">&lt;p&gt;这个过程是以检索到的结果作为上下文，回答用户提问&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
  <entry>
    <title>RAG 技术 06 - 后检索</title>
    <link href="https://www.shaogui.life/posts/3971780201.html"/>
    <id>https://www.shaogui.life/posts/3971780201.html</id>
    <published>2025-01-23T02:52:51.000Z</published>
    <updated>2025-02-02T02:52:41.346Z</updated>
    
    <content type="html"><![CDATA[<p>检索出问题相关的上下文后，如果将所有检索到的块直接送入 LLM，可能不是最佳选择，因为检索出来的文本可能包括冗余信息，或者文档长度太长需要压缩</p><span id="more"></span><h3 id="rerank"><a class="markdownIt-Anchor" href="#rerank"></a> Rerank</h3><p>根据文档与查询的相关性对文档进行排序，使用 reranker 模型重新排序检索结果</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E6%A3%80%E7%B4%A2%E7%BB%93%E6%9E%9C%E9%87%8D%E6%8E%92%E5%BA%8FRerank-20241216135816.png" alt="检索结果重排序Rerank-20241216135816"></p><p>rerank 有两种规则实现文档排序</p><ul><li>ReRank-Rulebase：计算指标以重新排列块根据某些规则。常见的指标包括：多样性，相关性和 MRR（最大边际相关性）</li><li>ReRank-Modelbase：专门的 AI 模型 (Rerank 模型)，用于评估这些检索到的文档与用户查询相关的相关性并确定其优先级</li></ul><h3 id="compression"><a class="markdownIt-Anchor" href="#compression"></a> Compression</h3><p>RAG 中的一个常见误解该过程是指尽可能多地检索相关文档并将其连接起来，形成一个冗长的文档检索提示是有益的。然而，过度的上下文可能会引入更多噪音，降低 LLM 对关键信息。解决这一问题的一种常见方法是压缩并选择检索到的内容</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Compression-20250116093106.png" alt="Compression-20250116093106"></p><h3 id="selection"><a class="markdownIt-Anchor" href="#selection"></a> Selection</h3><p>对于 self-Query retrieval 来说，在利用文档时，还可以根据用户查询与文档元数据过滤检索结果</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/%E8%87%AA%E6%9F%A5%E8%AF%A2selfQuery-20241216135424.png" alt="自查询selfQuery-20241216135424"></p><p>CRAG 则引入了一个轻量级检索评估器，用于评估检索到的文档的整体质量，提供触发不同知识检索操作（如 “正确”、“不正确” 或 “模糊”）的置信度</p><p>CRAG 还可以通过合并 Web 搜索来确定检索到的结果是否相关，从而解决静态语料库中的限制</p><p><img data-src="https://picgo-1304919305.cos.ap-guangzhou.myqcloud.com/picGo/Selection-20250116093147.png" alt="Selection-20250116093147"></p>]]></content>
    
    
    <summary type="html">&lt;p&gt;检索出问题相关的上下文后，如果将所有检索到的块直接送入 LLM，可能不是最佳选择，因为检索出来的文本可能包括冗余信息，或者文档长度太长需要压缩&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
  <entry>
    <title>BM25 算法原理探索</title>
    <link href="https://www.shaogui.life/posts/12153461.html"/>
    <id>https://www.shaogui.life/posts/12153461.html</id>
    <published>2025-01-23T01:50:47.000Z</published>
    <updated>2025-02-02T02:52:41.148Z</updated>
    
    <content type="html"><![CDATA[<p>在 RAG 技术中，有使用 BM25 算法去检索相关文档，那么他的原理是怎样的？</p><span id="more"></span><p>BM25（Best Matching 25）是一种用于信息检索的算法，广泛应用于搜索引擎和文本挖掘领域。它是对 TF-IDF（词频 - 逆文档频率）模型的改进，主要通过引入文档长度归一化和词频饱和机制，更准确地评估文档与查询之间的相关性</p><h3 id="原理"><a class="markdownIt-Anchor" href="#原理"></a> 原理</h3><ol><li><strong>逆文档频率（IDF）</strong><br>用于衡量一个词的 “稀有性”。如果一个词在较少的文档中出现，其 IDF 值越高，表明这个词具有更好的区分能力。</li></ol><p class="katex-block"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi>I</mi><mi>D</mi><mi>F</mi><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mo>=</mo><mi>log</mi><mo>⁡</mo><mrow><mo fence="true">(</mo><mfrac><mrow><mi>N</mi><mo>−</mo><mi>n</mi><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mo>+</mo><mn>0.5</mn></mrow><mrow><mi>n</mi><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mo>+</mo><mn>0.5</mn></mrow></mfrac><mo fence="true">)</mo></mrow></mrow><annotation encoding="application/x-tex">IDF(q_i) = \log \left( \frac{N - n(q_i) + 0.5}{n(q_i) + 0.5} \right)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord mathnormal" style="margin-right:0.13889em;">F</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size3">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.427em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal">n</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">5</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.10903em;">N</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord mathnormal">n</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">5</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.936em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size3">)</span></span></span></span></span></span></span></p><p>其中，N 是文档总数，n (qi​) 是包含词 qi​ 的文档数</p><ol start="2"><li><strong>词频（TF）调整</strong><br>BM25 引入了词频饱和机制，避免长文档因词频过高而获得不合理的高分</li></ol><p class="katex-block"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi>B</mi><mi>M</mi><mn>25</mn><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo separator="true">,</mo><mi>d</mi><mo stretchy="false">)</mo><mo>=</mo><mtext>IDF</mtext><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo stretchy="false">)</mo><mo>⋅</mo><mfrac><mrow><mo stretchy="false">(</mo><msub><mi>k</mi><mn>1</mn></msub><mo>+</mo><mn>1</mn><mo stretchy="false">)</mo><mo>⋅</mo><mtext>TF</mtext><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo separator="true">,</mo><mi>d</mi><mo stretchy="false">)</mo></mrow><mrow><mtext>TF</mtext><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo separator="true">,</mo><mi>d</mi><mo stretchy="false">)</mo><mo>+</mo><msub><mi>k</mi><mn>1</mn></msub><mo>⋅</mo><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mi>b</mi><mo>+</mo><mi>b</mi><mo>⋅</mo><mfrac><mi>d</mi><mtext>avgdi</mtext></mfrac><mo stretchy="false">)</mo></mrow></mfrac></mrow><annotation encoding="application/x-tex">BM25(q_i,d) = \text{IDF}(q_i) \cdot \frac{(k_1 + 1) \cdot \text{TF}(q_i,d)}{\text{TF}(q_i,d) + k_1 \cdot (1 - b + b \cdot \frac{d}{\text{avgdi}})}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.05017em;">B</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord">2</span><span class="mord">5</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord text"><span class="mord">IDF</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span></span><span class="base"><span class="strut" style="height:2.678216em;vertical-align:-1.2512159999999999em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.427em;"><span style="top:-2.229892em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord text"><span class="mord">TF</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord mathnormal">b</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">avgdi</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.481108em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03148em;">k</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.30110799999999993em;"><span style="top:-2.5500000000000003em;margin-left:-0.03148em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">1</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">⋅</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord text"><span class="mord">TF</span></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.2512159999999999em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p><p>其中 <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>T</mi><mi>F</mi><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo separator="true">,</mo><mi>d</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">TF(q_i,d)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.13889em;">T</span><span class="mord mathnormal" style="margin-right:0.13889em;">F</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mclose">)</span></span></span></span> 表示词汇 qi 在文档 d 中出现的次数，k1 和 b 是可调参数，|d | 是文档长度，avgdl 是文档平均长度</p><ol start="3"><li><p><strong>文档长度归一化</strong><br>BM25 通过文档长度归一化处理，使得文档长度对权重的影响是非线性的。这避免了长文档仅因词数多而得分更高的问题，即以上公式的 <span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mfrac><mi>d</mi><mtext>avgdi</mtext></mfrac></mrow><annotation encoding="application/x-tex">\frac{d}{\text{avgdi}}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1.3612159999999998em;vertical-align:-0.481108em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.8801079999999999em;"><span style="top:-2.6550000000000002em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord text mtight"><span class="mord mtight">avgdi</span></span></span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.394em;"><span class="pstrut" style="height:3em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.481108em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></p></li><li><p><strong>查询与文档的相关性</strong><br>对于一个查询 Q 和文档 d，BM25 的最终得分是查询中每个词与文档相关性的累加</p></li></ol><p class="katex-block"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi>B</mi><mi>M</mi><mn>25</mn><mo stretchy="false">(</mo><mi>Q</mi><mo separator="true">,</mo><mi>d</mi><mo stretchy="false">)</mo><mo>=</mo><munderover><mo>∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mi>n</mi></munderover><mi>B</mi><mi>M</mi><mn>25</mn><mo stretchy="false">(</mo><msub><mi>q</mi><mi>i</mi></msub><mo separator="true">,</mo><mi>d</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">BM25(Q,d) = \sum_{i=1}^{n} BM25(q_i,d) </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.05017em;">B</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord">2</span><span class="mord">5</span><span class="mopen">(</span><span class="mord mathnormal">Q</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.929066em;vertical-align:-1.277669em;"></span><span class="mop op-limits"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.6513970000000002em;"><span style="top:-1.872331em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">i</span><span class="mrel mtight">=</span><span class="mord mtight">1</span></span></span></span><span style="top:-3.050005em;"><span class="pstrut" style="height:3.05em;"></span><span><span class="mop op-symbol large-op">∑</span></span></span><span style="top:-4.3000050000000005em;margin-left:0em;"><span class="pstrut" style="height:3.05em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">n</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:1.277669em;"><span></span></span></span></span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.05017em;">B</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord">2</span><span class="mord">5</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right:0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:0.31166399999999994em;"><span style="top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em;"><span class="pstrut" style="height:2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">i</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.15em;"><span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal">d</span><span class="mclose">)</span></span></span></span></span></p><p>n 表示 Q 中的词汇数量</p><h3 id="例子"><a class="markdownIt-Anchor" href="#例子"></a> 例子</h3><p>假设我们有以下文档集合（3 篇文档）和一个查询：</p><p><strong>文档集合：</strong></p><ol><li><strong>文档 D1</strong>：<code>"苹果 苹果 苹果"</code>（3 个词，文档长度为 3）</li><li><strong>文档 D2</strong>：<code>"苹果"</code>（1 个词，文档长度为 1）</li><li><strong>文档 D3</strong>：<code>"苹果 香蕉"</code>（2 个词，文档长度为 2）</li></ol><p><strong>查询 Q</strong>：<code>"苹果"</code></p><ol><li><strong>计算 IDF</strong><br>首先，我们需要计算查询词 “苹果” 的逆文档频率（IDF）</li></ol><p class="katex-block"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi>I</mi><mi>D</mi><mi>F</mi><mo stretchy="false">(</mo><mi mathvariant="normal">"</mi><mtext> 苹果</mtext><mi mathvariant="normal"> "</mi><mo stretchy="false">)</mo><mtext> 的</mtext><mi> log</mi><mo>⁡</mo><mrow><mo fence="true">(</mo><mfrac><mrow><mn>3</mn><mo>−</mo><mn>3</mn><mo>+</mo><mn>0.5</mn></mrow><mrow><mn>3</mn><mo>+</mo><mn>0.5</mn></mrow></mfrac><mo fence="true">)</mo></mrow><mo>=</mo><mi>log</mi><mo>⁡</mo><mrow><mo fence="true">(</mo><mfrac><mn>0.5</mn><mn>3.5</mn></mfrac><mo fence="true">)</mo></mrow><mo>=</mo><mi>log</mi><mo>⁡</mo><mo stretchy="false">(</mo><mn>0.1429</mn><mo stretchy="false">)</mo><mo>≈</mo><mo>−</mo><mn>1.90</mn></mrow><annotation encoding="application/x-tex">IDF ("苹果") 的 \log \left ( \frac {3 - 3 + 0.5}{3 + 0.5} \right) = \log \left ( \frac {0.5}{3.5} \right) = \log (0.1429) \approx -1.90 </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mord mathnormal" style="margin-right:0.07847em;">I</span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord mathnormal" style="margin-right:0.13889em;">F</span><span class="mopen">(</span><span class="mord">"</span><span class="mord cjk_fallback"> 苹</span><span class="mord cjk_fallback">果</span><span class="mord"> "</span><span class="mclose">)</span><span class="mord cjk_fallback"> 的</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mop"> lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size3">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">5</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">0</span><span class="mord">.</span><span class="mord">5</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.7693300000000001em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size3">)</span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.40003em;vertical-align:-0.95003em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="minner"><span class="mopen delimcenter" style="top:0em;"><span class="delimsizing size3">(</span></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">3</span><span class="mord">.</span><span class="mord">5</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">0</span><span class="mord">.</span><span class="mord">5</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose delimcenter" style="top:0em;"><span class="delimsizing size3">)</span></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mop">lo<span style="margin-right:0.01389em;">g</span></span><span class="mopen">(</span><span class="mord">0</span><span class="mord">.</span><span class="mord">1</span><span class="mord">4</span><span class="mord">2</span><span class="mord">9</span><span class="mclose">)</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">≈</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.72777em;vertical-align:-0.08333em;"></span><span class="mord">−</span><span class="mord">1</span><span class="mord">.</span><span class="mord">9</span><span class="mord">0</span></span></span></span></span></p><p>注意：这里的 IDF 值是负数，因为 “苹果” 在所有文档中都出现了。在实际应用中，IDF 通常取正值，可以通过调整公式或对结果取绝对值来处理。</p><ol start="2"><li><strong>计算文档长度相关参数</strong></li></ol><p>接下来，计算文档长度相关的参数</p><p class="katex-block"><span class="katex-display"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML" display="block"><semantics><mrow><mi>a</mi><mi>v</mi><mi>g</mi><mi>d</mi><mi>l</mi><mo>=</mo><mfrac><mrow><mi>D</mi><mn>1</mn><mo>+</mo><mi>D</mi><mn>2</mn><mo>+</mo><mi>D</mi><mn>3</mn></mrow><mn>3</mn></mfrac><mo>=</mo><mfrac><mrow><mn>3</mn><mo>+</mo><mn>1</mn><mo>+</mo><mn>2</mn></mrow><mn>3</mn></mfrac><mo>=</mo><mn>2</mn></mrow><annotation encoding="application/x-tex">avgdl = \frac{D1 + D2 + D3} {3} = \frac{3 + 1 + 2}{3} = 2 </annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.8888799999999999em;vertical-align:-0.19444em;"></span><span class="mord mathnormal">a</span><span class="mord mathnormal" style="margin-right:0.03588em;">v</span><span class="mord mathnormal" style="margin-right:0.03588em;">g</span><span class="mord mathnormal">d</span><span class="mord mathnormal" style="margin-right:0.01968em;">l</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.04633em;vertical-align:-0.686em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.36033em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">3</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord">2</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord">3</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:2.00744em;vertical-align:-0.686em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height:1.32144em;"><span style="top:-2.314em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">3</span></span></span><span style="top:-3.23em;"><span class="pstrut" style="height:3em;"></span><span class="frac-line" style="border-bottom-width:0.04em;"></span></span><span style="top:-3.677em;"><span class="pstrut" style="height:3em;"></span><span class="mord"><span class="mord">3</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">1</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right:0.2222222222222222em;"></span><span class="mord">2</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height:0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mspace" style="margin-right:0.2777777777777778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right:0.2777777777777778em;"></span></span><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord">2</span></span></span></span></span></p><ol start="3"><li>计算 BM25 分数<br>假设参数 k1​=1.2 和 b=0.75，我们分别计算每个文档的 BM25 分数</li></ol><table><thead><tr><th>计算</th><th>值</th></tr></thead><tbody><tr><td><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi> B</mi><mi>M</mi><mn>25</mn><mo stretchy="false">(</mo><mi mathvariant="normal">"</mi><mtext> 苹果</mtext><mi mathvariant="normal"> "</mi><mo separator="true">,</mo><mi>D</mi><mn>1</mn><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">BM25 ("苹果",D1)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.05017em;">B</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord">2</span><span class="mord">5</span><span class="mopen">(</span><span class="mord">"</span><span class="mord cjk_fallback"> 苹</span><span class="mord cjk_fallback">果</span><span class="mord"> "</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord">1</span><span class="mclose">)</span></span></span></span></td><td>2.62</td></tr><tr><td><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>B</mi><mi>M</mi><mn>25</mn><mo stretchy="false">(</mo><mi mathvariant="normal">"</mi><mtext> 苹果</mtext><mi mathvariant="normal"> "</mi><mo separator="true">,</mo><mi>D</mi><mn>2</mn><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">BM25 ("苹果",D2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.05017em;">B</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord">2</span><span class="mord">5</span><span class="mopen">(</span><span class="mord">"</span><span class="mord cjk_fallback"> 苹</span><span class="mord cjk_fallback">果</span><span class="mord"> "</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord">2</span><span class="mclose">)</span></span></span></span></td><td>2.41</td></tr><tr><td><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>B</mi><mi>M</mi><mn>25</mn><mo stretchy="false">(</mo><mi mathvariant="normal">"</mi><mtext> 苹果</mtext><mi mathvariant="normal"> "</mi><mo separator="true">,</mo><mi>D</mi><mn>3</mn><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">BM25 ("苹果",D3)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:1em;vertical-align:-0.25em;"></span><span class="mord mathnormal" style="margin-right:0.05017em;">B</span><span class="mord mathnormal" style="margin-right:0.10903em;">M</span><span class="mord">2</span><span class="mord">5</span><span class="mopen">(</span><span class="mord">"</span><span class="mord cjk_fallback"> 苹</span><span class="mord cjk_fallback">果</span><span class="mord"> "</span><span class="mpunct">,</span><span class="mspace" style="margin-right:0.16666666666666666em;"></span><span class="mord mathnormal" style="margin-right:0.02778em;">D</span><span class="mord">3</span><span class="mclose">)</span></span></span></span></td><td>1.9</td></tr><tr><td>D3 中包含 “苹果” 和 “香蕉”，而查询只关心 “苹果”，因此 D3 的相关性不如 D1 和 D2。</td><td></td></tr></tbody></table><blockquote><p>注：实际情况，查询 Q 可能包含多个词汇</p></blockquote>]]></content>
    
    
    <summary type="html">&lt;p&gt;在 RAG 技术中，有使用 BM25 算法去检索相关文档，那么他的原理是怎样的？&lt;/p&gt;</summary>
    
    
    
    <category term="2-深度学习" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/"/>
    
    <category term="LLM开发工程师指南" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/"/>
    
    <category term="RAG" scheme="https://www.shaogui.life/categories/2-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0/LLM%E5%BC%80%E5%8F%91%E5%B7%A5%E7%A8%8B%E5%B8%88%E6%8C%87%E5%8D%97/RAG/"/>
    
    
  </entry>
  
</feed>
