基于 Torch-TensorRT 量化模型

Torch TensorRT 是 PyTorch 与 NVIDIA TensorRT 的新集成,它用一行代码加速推理

基于Torch-TensorRT量化模型-20250123172119

Torch-TensorRT 是 PyTorch/TorchScript/FX 的编译器。与 PyTorch 的即时 (JIT) 编译器不同,Torch-TensorRT 是一个提前 (AOT) 编译器,这意味着在部署 TorchScript 代码之前,您需要执行一个明确的编译步骤,将标准 TorchScript 或 FX 程序转换为面向 TensorRT 引擎的模块

使用 Torch-TensorRT 编译 ResNet50

1. 加载并测试模型

1
2
3
4
5
6
7
8
9
10
11
import torch
import torchvision

torch.hub._validate_not_a_forked_repo=lambda a,b,c: True

resnet50_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
resnet50_model.eval()

# Model benchmark without Torch-TensorRT
model = resnet50_model.eval().to("cuda")
benchmark(model, input_shape=(128, 3, 224, 224), nruns=100)

基于Torch-TensorRT量化模型-20250123172120

2. 编译模型 - fp32

1
2
3
4
5
6
7
8
9
10
11
import torch_tensorrt

# The compiled module will have precision as specified by "op_precision".
# Here, it will have FP32 precision.
trt_model_fp32 = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input((128, 3, 224, 224), dtype=torch.float32)],
enabled_precisions = torch.float32, # Run with FP32
workspace_size = 1 << 22
)

# Obtain the average time taken by a batch of input
benchmark(trt_model_fp32, input_shape=(128, 3, 224, 224), nruns=100)

基于Torch-TensorRT量化模型-20250123172120-1

3. 编译模型 - fp16

1
2
3
4
5
6
7
8
9
10
11
import torch_tensorrt

# The compiled module will have precision as specified by "op_precision".
# Here, it will have FP32 precision.
trt_model_fp32 = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input((128, 3, 224, 224), dtype=torch.float32)],
enabled_precisions = , dtype=torch.half, # Run with FP32
workspace_size = 1 << 22
)

# Obtain the average time taken by a batch of input
benchmark(trt_model_fp32, input_shape=(128, 3, 224, 224), dtype='fp16', nruns=100)

基于Torch-TensorRT量化模型-20250123172121

使用 Torch-TensorRT 编译 TorchScript 模型

1. 加载并测试模型

1
2
3
model = LeNet()
model.to("cuda").eval()
benchmark(model)

基于Torch-TensorRT量化模型-20250123172121-1

2. 生成 trace 模型,并测试速度

1
2
traced_model = torch.jit.trace(model, torch.empty([1,1,32,32]).to("cuda"))
benchmark(traced_model)

基于Torch-TensorRT量化模型-20250123172121-2

3. 生成 script 模型,并测试速度

1
2
3
model = LeNet().to("cuda").eval()
script_model = torch.jit.script(model)
benchmark(script_model)

基于Torch-TensorRT量化模型-20250123172122

4. 使用 Troch-TensoRT 编译 trace 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch_tensorrt

# We use a batch-size of 1024, and half precision
trt_ts_module = torch_tensorrt.compile(traced_model, inputs=[torch_tensorrt.Input(
min_shape=[1024, 1, 32, 32],
opt_shape=[1024, 1, 33, 33],
max_shape=[1024, 1, 34, 34],
dtype=torch.half
)],
enabled_precisions = {torch.half})

input_data = torch.randn((1024, 1, 32, 32))
input_data = input_data.half().to("cuda")

input_data = input_data.half()
result = trt_ts_module(input_data)
torch.jit.save(trt_ts_module, "trt_ts_module.ts")

benchmark(trt_ts_module, input_shape=(1024, 1, 32, 32), dtype="fp16")

基于Torch-TensorRT量化模型-20250123172122-1

5. 使用 Troch-TensoRT 编译 script 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch_tensorrt

trt_script_module = torch_tensorrt.compile(script_model, inputs = [torch_tensorrt.Input(
min_shape=[1024, 1, 32, 32],
opt_shape=[1024, 1, 33, 33],
max_shape=[1024, 1, 34, 34],
dtype=torch.half
)],
enabled_precisions={torch.half})

input_data = torch.randn((1024, 1, 32, 32))
input_data = input_data.half().to("cuda")

input_data = input_data.half()
result = trt_script_module(input_data)
torch.jit.save(trt_script_module, "trt_script_module.ts")

benchmark(trt_script_module, input_shape=(1024, 1, 32, 32), dtype="fp16")

基于Torch-TensorRT量化模型-20250123172123

使用 Torch-TensorRT 进行 PTQ 量化

1. 构建模型,并使用 Torch-TensorRT 转换,但是未进行量化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Exporting to TorchScript
with torch.no_grad():
data = iter(val_dataloader)
images, _ = data.next()
jit_model = torch.jit.trace(model, images.to("cuda"))
torch.jit.save(jit_model, "mobilenetv2_base.jit.pt")

#Loading the Torchscript model and compiling it into a TensorRT model
baseline_model = torch.jit.load("mobilenetv2_base.jit.pt").eval()
compile_spec = {"inputs": [torch_tensorrt.Input([64, 3, 224, 224])]
, "enabled_precisions": torch.float
}
trt_base = torch_tensorrt.compile(baseline_model, **compile_spec)

# Evaluate and benchmark the performance of the baseline TRT model (TRT FP32 Model)
test_loss, test_acc = evaluate(trt_base, val_dataloader, criterion, 0)
print("Mobilenetv2 TRT Baseline accuracy: {:.2f}%".format(100 * test_acc))

benchmark(trt_base, input_shape=(64, 3, 224, 224))

基于Torch-TensorRT量化模型-20250123172123-1

2. 使用 torch-tensorRT 进行 PTQ 量化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
calibrator = torch_tensorrt.ptq.DataLoaderCalibrator(calib_dataloader,
use_cache=False,
algo_type=torch_tensorrt.ptq.CalibrationAlgo.MINMAX_CALIBRATION,
device=torch.device('cuda:0'))

compile_spec = {
"inputs": [torch_tensorrt.Input([64, 3, 224, 224])],
"enabled_precisions": torch.int8,
"calibrator": calibrator,
"truncate_long_and_double": True

}
trt_ptq = torch_tensorrt.compile(baseline_model, **compile_spec)

# Evaluate the PTQ model
test_loss, test_acc = evaluate(trt_ptq, val_dataloader, criterion, 0)
print("Mobilenetv2 PTQ accuracy: {:.2f}%".format(100 * test_acc))

benchmark(trt_ptq, input_shape=(64, 3, 224, 224))

基于Torch-TensorRT量化模型-20250123172123-2

使用 Torch-TensorRT 进行 QAT 量化

1. 定义并加载模型权重

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
quant_modules.initialize()

# 定义并加载模型权重
# All the regular conv, FC layers will be converted to their quantized counterparts due to quant_modules.initialize()
feature_extract = False
q_model = models.mobilenet_v2(pretrained=True)
set_parameter_requires_grad(q_model, feature_extract)
q_model.classifier[1] = nn.Linear(1280, 10)
q_model = q_model.cuda()

# mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.
ckpt = torch.load("./mobilenetv2_base_ckpt")
modified_state_dict={}
for key, val in ckpt["model_state_dict"].items():
# Remove 'module.' from the key names
if key.startswith('module'):
modified_state_dict[key[7:]] = val
else:
modified_state_dict[key] = val

# Load the pre-trained checkpoint
q_model.load_state_dict(modified_state_dict)
optimizer.load_state_dict(ckpt["opt_state_dict"])

2. 定义校准规则并校准,这里使用 max 校准

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# 定义校准规则并校准,这里使用max校准
def compute_amax(model, **kwargs):
# Load calib result
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
model.cuda()

def collect_stats(model, data_loader, num_batches):
"""Feed data to the network and collect statistics"""
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()

# Feed data to the network for collecting stats
for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
model(image.cuda())
if i >= num_batches:
break

# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()


#Calibrate the model using percentile calibration technique.
with torch.no_grad():
collect_stats(q_model, train_dataloader, num_batches=32)
compute_amax(q_model, method="max")

3. 微调 QAT 模型 2 个 epoch

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 微调QAT模型2个epoch
num_epochs=2
lr = 0.001
for epoch in range(num_epochs):
print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, num_epochs, lr))

train(q_model, train_dataloader, criterion, optimizer, epoch)
test_loss, test_acc = evaluate(q_model, val_dataloader, criterion, epoch)

print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

save_checkpoint({'epoch': epoch + 1,
'model_state_dict': q_model.state_dict(),
'acc': test_acc,
'opt_state_dict': optimizer.state_dict()
},
ckpt_path="mobilenetv2_qat_ckpt")

4. 导出 QAT 模型,得到 Torchscript 模型

1
2
3
4
5
6
7
# 导出QAT模型,得到Torchscript模型
quant_nn.TensorQuantizer.use_fb_fake_quant = True
with torch.no_grad():
data = iter(val_dataloader)
images, _ = data.next()
jit_model = torch.jit.trace(q_model, images.to("cuda"))
torch.jit.save(jit_model, "mobilenetv2_qat.jit.pt")

5. 加载 Torchscript 模型并编译为 TensorRT 模型

1
2
3
4
5
6
7
8
9
10
11
#加载Torchscript模型并编译为TensorRT模型
qat_model = torch.jit.load("mobilenetv2_qat.jit.pt").eval()
compile_spec = {"inputs": [torch_tensorrt.Input([64, 3, 224, 224])],
"enabled_precisions": torch.int8
}
trt_mod = torch_tensorrt.compile(qat_model, **compile_spec)

#Evaluate and benchmark the performance of the QAT-TRT model (TRT INT8)
test_loss, test_acc = evaluate(trt_mod, val_dataloader, criterion, 0)
print("Mobilenetv2 QAT accuracy using TensorRT: {:.2f}%".format(100 * test_acc))
benchmark(trt_mod, input_shape=(64, 3, 224, 224))

基于Torch-TensorRT量化模型-20250123172124