基于 Pytorch 量化模型

Pytorch 原生量化之 FX Graph Mode Quantization?

  • FX Graph Mode Quantization 是 PyTorch 中一个新的自动量化框架,目前它是一个原型功能。它通过添加对函数的支持和自动化量化过程来改进 Eager Mode Quantization
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx
float_model.eval() # 因为是PTQ,所以就推理模式就够了
qconfig = get_default_qconfig("fbgemm") # 指定量化细节配置
qconfig_dict = {"": qconfig} # 指定量化选项
def calibrate(model, data_loader): # 校准功能函数
model.eval()
with torch.no_grad():
for image, target in data_loader:
model(image)
prepared_model = prepare_fx(float_model, qconfig_dict) # 准备量化模型,比如融合CONV+BN+RELU,然后插入量化观察节点
calibrate(prepared_model, data_loader_test) # 校准数据集进行标准
quantized_model = convert_fx(prepared_model) # 把校准后的模型转化为量化版本模型

Pytorch 原生量化之 Eager Mode Quantization?

  • Eager Mode Quantization 是一项测试功能。用户需要进行融合并指定手动进行量化和去量化的位置,而且它只支持模块而不支持功能
  • 1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
     import torch

    # define a floating point model where some layers could be statically quantized
    class M(torch.nn.Module):
    def __init__(self):
    super(M, self).__init__()
    # QuantStub converts tensors from floating point to quantized
    self.quant = torch.quantization.QuantStub()
    self.conv = torch.nn.Conv2d(1, 1, 1)
    self.relu = torch.nn.ReLU()
    # DeQuantStub converts tensors from quantized to floating point
    self.dequant = torch.quantization.DeQuantStub()
    def forward(self, x):
    # 自己指定开始量化的层
    x = self.quant(x)
    x = self.conv(x)
    x = self.relu(x)
    # 指定结束量化的层
    x = self.dequant(x)
    return x
    # create a model instance
    model_fp32 = M()
    # model must be set to eval mode for static quantization logic to work
    model_fp32.eval()
    model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
    # 指定融合的层
    model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
    model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)
    input_fp32 = torch.randn(4, 1, 4, 4)
    model_fp32_prepared(input_fp32)
    model_int8 = torch.quantization.convert(model_fp32_prepared)
    res = model_int8(input_fp32)

参考:

  1. TORCH.FX 第二篇 ——PTQ 量化实操 - Oldpan 的个人博客
  2. Site Unreachable