-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscripting.py
More file actions
107 lines (94 loc) · 3.9 KB
/
scripting.py
File metadata and controls
107 lines (94 loc) · 3.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import argparse
import torch
import importlib
from pathlib import Path
import sys
import os
def parse_args():
parser = argparse.ArgumentParser('Convert DGCNN model to TorchScript format')
parser.add_argument('--model_path', type=str,
default='log/sem_seg/dgcnn/logs/best_model.pth',
help='Path to best_model.pth')
parser.add_argument('--output_path', type=str,
default='deploy/dgcnn_segmentation.pt',
help='Output path for TorchScript model')
parser.add_argument('--emb_dims', type=int, default=1024,
help='Dimension of embeddings')
parser.add_argument('--k', type=int, default=20,
help='Num of nearest neighbors to use')
parser.add_argument('--num_classes', type=int, default=3,
help='Number of segmentation classes')
parser.add_argument('--gpu', type=str, default='0',
help='GPU device ID')
parser.add_argument('--dropout', type=float, default=0.5,
help='Dropout rate')
return parser.parse_args()
def main():
args = parse_args()
print(f"使用配置: {args}")
# 添加models目录到系统路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))
# 创建输出目录
output_dir = Path(args.output_path).parent
output_dir.mkdir(parents=True, exist_ok=True)
# 加载模型定义
try:
MODEL = importlib.import_module('dgcnn')
except ImportError as e:
print(f"无法导入dgcnn模型: {e}")
print("请确保dgcnn.py文件在models/目录下")
return
# 创建模型实例
device = torch.device(f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu")
model = MODEL.get_model(args).to(device)
# 加载训练好的权重
try:
checkpoint = torch.load(args.model_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f"成功加载模型权重: {args.model_path}")
except Exception as e:
print(f"加载模型失败: {e}")
return
# 创建正确格式的示例输入
# 根据训练代码,输入应为[batch_size, 9, num_points]
# 9个通道包括: xyz坐标, rgb颜色, 归一化坐标
example_input = torch.randn(1, 9, 4096).to(device)
# 验证模型推理
with torch.no_grad():
try:
output = model(example_input)
print(f"模型验证通过,输出形状: {output.shape}")
except Exception as e:
print(f"模型推理失败: {e}")
print("提示: 检查模型期望的输入通道数是否为9")
return
# 使用TorchScript转换模型
try:
scripted_model = torch.jit.trace(model, example_input)
print("模型已成功转换为TorchScript格式")
except Exception as e:
print(f"转换失败,尝试使用script方法: {e}")
try:
scripted_model = torch.jit.script(model)
print("使用script方法转换成功")
except Exception as script_e:
print(f"script方法也失败: {script_e}")
print("提示: 检查模型中是否有动态控制流或不支持的Python特性")
return
# 保存模型
scripted_model.save(args.output_path)
print(f"模型已保存至: {args.output_path}")
# 验证保存的模型
try:
loaded_model = torch.jit.load(args.output_path)
with torch.no_grad():
loaded_output = loaded_model(example_input)
assert torch.allclose(output, loaded_output, atol=1e-6), "输出不匹配"
print("保存的模型验证通过")
except Exception as e:
print(f"验证失败: {e}")
if __name__ == '__main__':
main()