使用 Torch_TensorRT 量化分割模型

本文使用 Torch_TensorRT 量化 deeplapv3 + 模型,

Torch-TensorRT 是 PyTorch 的推理编译器,通过 NVIDIA 的 TensorRT 深度学习优化器,运行时以 NVIDIA GPU 为目标。 它通过界面支持即时 (JIT) 编译工作流以及预先 (AOT) 工作流。 Torch-TensorRT 无缝集成到 PyTorch 生态系统中,支持将优化的 TensorRT 代码与标准 PyTorch 代码混合执行。

由于 Torch-TensorRT 接受 torchScript 输入,优化后输出 ts 模型,所以下文将从 pytorch\ptq\qat 三个方向测试 Torch-TensorRT

为了评估量化水平,我们定义一个评估函数

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Helper function to benchmark the model
def benchmark(model, input_shape=(1024, 1, 32, 32), dtype='fp32', nwarmup=50, nruns=800):
input_data = torch.randn(input_shape)
input_data = input_data.to("cuda")
if dtype=='fp16':
input_data = input_data.half()

with torch.no_grad():
for _ in range(nwarmup):
features = model(input_data)
torch.cuda.synchronize()
timings = []
with torch.no_grad():
for i in range(1, nruns+1):
start_time = time.time()
output = model(input_data)
torch.cuda.synchronize()
end_time = time.time()
timings.append(end_time - start_time)
print('Average batch time: %.2f ms,median:%2f-%2f:'%(np.mean(timings)*1000,

base->TorchScript->Torch_tensorRT

基于原始的 pth,编译为 torchscript,然后再使用 Torch_tensorRT 优化

首先导入 pth 模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from modelingv2.deeplab import DeepLab
model=DeepLab(in_channels=3,num_classes=2,pretrained=False)
model = model.cuda()
# mobilenetv2_base_ckpt is the checkpoint generated from Step 2 : Training a baseline Mobilenetv2 model.
ckpt = torch.load("./models/deeplabv3plus_base.pth")
modified_state_dict={}
for key, val in ckpt.items():
# Remove 'module.' from the key names
if key.startswith('module'):
modified_state_dict[key[7:]] = val
else:
modified_state_dict[key] = val
# Load the pre-trained checkpoint
model.load_state_dict(modified_state_dict)
model = model.cuda()

其次,导出为 torchscript

1
2
3
4
5
6
7
# Exporting to TorchScript
with torch.no_grad():
data = iter(train_dataloader)
images, _ = data.next()
jit_model = torch.jit.trace(model, images.to("cuda"))
torch.jit.save(jit_model, "models/deeplabv3plus_base.jit.pt")
benchmark(jit_model, input_shape=(16, 3, 512, 512), nruns=100)

最后使用 Torch_tensorRT 优化

1
2
3
4
5
6
7
8
9
10
#Loading the Torchscript model and compiling it into a TensorRT model
baseline_model = torch.jit.load("models/deeplabv3plus_base.jit.pt").eval()

compile_spec = {"inputs": [torch_tensorrt.Input([4, 3, 512, 512])],
"enabled_precisions": torch.float,
"truncate_long_and_double": True
}
trt_base = torch_tensorrt.compile(baseline_model, **compile_spec)
torch.jit.save(trt_base, "models/deeplabv3plus_base_trt.ts")
benchmark(trt_base, input_shape=(16, 3, 512, 512), nruns=100)

base->ptq->TorchScript->Torch_tensorRT

首先导入 pth 模型,并使用 pytorch_quantization 进行量化

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
quant_modules.initialize()
q_model=DeepLab(in_channels=3,num_classes=2,pretrained=False)
q_model = q_model.cuda()
ckpt = torch.load("./models/deeplabv3plus_base.pth")
modified_state_dict={}
for key, val in ckpt.items():
# Remove 'module.' from the key names
if key.startswith('module'):
modified_state_dict[key[7:]] = val
else:
modified_state_dict[key] = val
# Load the pre-trained checkpoint
q_model.load_state_dict(modified_state_dict)

def compute_amax(model, **kwargs):
# Load calib result
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
if isinstance(module._calibrator, calib.MaxCalibrator):
module.load_calib_amax()
else:
module.load_calib_amax(**kwargs)
model.cuda()

def collect_stats(model, data_loader, num_batches):
"""Feed data to the network and collect statistics"""
# Enable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.disable_quant()
module.enable_calib()
else:
module.disable()
# Feed data to the network for collecting stats
for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches):
model(image.cuda())
if i >= num_batches:
break
# Disable calibrators
for name, module in model.named_modules():
if isinstance(module, quant_nn.TensorQuantizer):
if module._calibrator is not None:
module.enable_quant()
module.disable_calib()
else:
module.enable()
# Calibrate the model using max calibration technique.
with torch.no_grad():
collect_stats(q_model, train_dataloader, num_batches=20)
compute_amax(q_model, method="max")
# compute_amax(q_model, method="entropy")
# compute_amax(q_model, method="percentile")
# compute_amax(q_model, method="mse")

with torch.no_grad():
collect_stats(q_model, val_dataloader, num_batches=20)
compute_amax(q_model, method="max")

其次,导出为 torchscript

1
2
3
4
5
6
7
8
9
quant_nn.TensorQuantizer.use_fb_fake_quant = True

# Exporting to TorchScript
with torch.no_grad():
data = iter(train_dataloader)
images, _ = data.next()
jit_model = torch.jit.trace(q_model, images.to("cuda"))
torch.jit.save(jit_model, "models/deeplabv3plus_ptq.jit.pt")
benchmark(jit_model, input_shape=(16, 3, 512, 512), nruns=100)

最后使用 Torch_tensorRT 优化

1
2
3
4
5
6
7
8
9
#Loading the Torchscript model and compiling it into a TensorRT model
ptq_model = torch.jit.load("models/deeplabv3plus_ptq.jit.pt").eval()

compile_spec = {"inputs": [torch_tensorrt.Input([4, 3, 512, 512])],
"enabled_precisions": torch.int8,
}
trt_ptq = torch_tensorrt.compile(ptq_model, **compile_spec)
torch.jit.save(trt_base, "models/deeplabv3plus_ptq_trt.ts")
benchmark(trt_ptq, input_shape=(16, 3, 512, 512), nruns=100)

ptq->qat->TorchScript->Torch_tensorRT

首先使用 qat 优化模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def train(model, dataloader, crit, opt):
model.train()
for batch, (data, labels) in enumerate(dataloader):
data, labels = data.cuda(), labels.cuda()
opt.zero_grad()
outputs = model(data)
loss = Focal_Loss(outputs, labels)+Dice_loss(outputs, labels)
loss.backward()
opt.step()

crit=torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-4,weight_decay=0.9)
q_model=q_model.train()
# Finetune the QAT model for 2 epochs
num_epochs=10
for epoch in range(num_epochs):
print('Epoch: [%5d / %5d]' % (epoch + 1, num_epochs))
train(q_model, train_dataloader, crit, optimizer)
test_loss,acc = evaluate(q_model, val_dataloader, crit)
print("Test Loss: {:.5f} Test acc {:.2f}%".format(test_loss,acc*100))

其次,导出为 torchscript

1
2
3
4
5
6
7
8
9
10
quant_nn.TensorQuantizer.use_fb_fake_quant = True
q_model=q_model.eval()

# Exporting to TorchScript
with torch.no_grad():
data = iter(train_dataloader)
images, _ = data.next()
jit_model = torch.jit.trace(q_model, images.to("cuda"))
torch.jit.save(jit_model, "models/deeplabv3plus_qat.jit.pt")
benchmark(jit_model, input_shape=(16, 3, 512, 512), nruns=100)

最后使用 Torch_tensorRT 优化

1
2
3
4
5
6
7
8
9
#Loading the Torchscript model and compiling it into a TensorRT model
qat_model = torch.jit.load("models/deeplabv3plus_qat.jit.pt").eval()

compile_spec = {"inputs": [torch_tensorrt.Input([4, 3, 512, 512])],
"enabled_precisions": torch.int8
}
trt_qat = torch_tensorrt.compile(qat_model, **compile_spec)
torch.jit.save(trt_base, "models/deeplabv3plus_qat_trt.ts")
benchmark(trt_base, input_shape=(16, 3, 512, 512), nruns=100)

耗时汇总

总结以上 3 个流程的耗时如下

实验jit 结果情况jit 耗时Torch_TensorRT 结果Torch_TensorRT 耗时
base正确81.43正确
ptq正确94.3正确
qat正确94.5