-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_scripting.py
More file actions
206 lines (175 loc) · 7.92 KB
/
test_scripting.py
File metadata and controls
206 lines (175 loc) · 7.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import time
def load_point_cloud(file_path):
"""
加载点云数据(CSV 文件,无表头,仅 XYZ 三列)。
将 XYZ 坐标除以 1000 转换为米。
"""
try:
data = pd.read_csv(file_path, header=None).values
if data.shape[1] != 3:
raise ValueError(f"预期输入 CSV 有 3 列(X, Y, Z),实际得到 {data.shape[1]} 列")
points = data[:, :3].astype(np.float32)
points /= 1000.0 # 从毫米转换为米
return points
except Exception as e:
raise Exception(f"无法加载点云数据 {file_path}:{e}")
def preprocess_block(points, num_points=16384):
"""
预处理一个点云块:
- 添加 RGB(255, 255, 255)
- 计算归一化坐标
- 采样或填充到 num_points
- 以点云中心进行居中
"""
num_input_points = points.shape[0]
if num_input_points < 1024: # 最小点数要求
return None, None
# 计算块中心(X, Y 居中,Z 不变)
block_center = np.mean(points, axis=0)
block_center[2] = 0 # Z 不居中
coord_max = np.max(np.abs(points), axis=0)
rgb = np.full((num_input_points, 3), 255, dtype=np.float32)
points_with_rgb = np.hstack([points, rgb])
normalized_xyz = points / (coord_max + 1e-6)
points_with_rgb_normalized = np.hstack([points_with_rgb, normalized_xyz])
points_with_rgb_normalized[:, 0] -= block_center[0]
points_with_rgb_normalized[:, 1] -= block_center[1]
points_with_rgb_normalized[:, 3:6] /= 255.0
if num_input_points >= num_points:
indices = np.random.choice(num_input_points, num_points, replace=False)
selected_points = points_with_rgb_normalized[indices]
else:
indices = np.random.choice(num_input_points, num_points - num_input_points, replace=True)
indices = np.concatenate([np.arange(num_input_points), indices])
selected_points = points_with_rgb_normalized[indices]
input_tensor = torch.from_numpy(selected_points).float().unsqueeze(0)
input_tensor = input_tensor.transpose(1, 2) # [1, 9, num_points]
return input_tensor, indices
def process_point_cloud_without_replacement(points, num_points=16384, batch_size=32, num_votes=3):
"""
通过无放回采样处理点云,确保所有点被推理。
"""
total_points = points.shape[0]
print(f"总点数:{total_points}")
vote_label_pool = np.zeros((total_points, 3))
remaining_indices = np.arange(total_points) # 跟踪剩余未采样点
for vote in range(num_votes):
print(f"处理投票 {vote + 1}/{num_votes}")
np.random.seed(42 + vote) # 每轮投票不同种子
remaining_indices = np.random.permutation(remaining_indices) # 打乱剩余点
sampled_count = 0
while sampled_count < total_points:
# 从剩余点中选择 num_points 个点
num_to_sample = min(num_points, total_points - sampled_count)
if num_to_sample < 1024: # 避免点数过少
break
batch_inputs = []
batch_point_idxs = []
batch_block_indices = []
batch_count = 0
for i in range(0, total_points - sampled_count, num_points):
end_idx = min(i + num_points, total_points - sampled_count)
block_indices = remaining_indices[i:end_idx]
if len(block_indices) < 1024:
continue
block_points = points[block_indices]
input_tensor, indices = preprocess_block(block_points, num_points)
if input_tensor is None:
continue
batch_inputs.append(input_tensor)
batch_point_idxs.append(block_indices)
batch_block_indices.append(indices)
batch_count += 1
sampled_count += len(block_indices)
if batch_count >= batch_size or sampled_count >= total_points:
if batch_inputs:
batch_tensor = torch.cat(batch_inputs, dim=0)
yield batch_tensor, batch_point_idxs, batch_block_indices, vote_label_pool
batch_inputs = []
batch_point_idxs = []
batch_block_indices = []
batch_count = 0
# 更新剩余点(实际上已通过切片控制,无需移除)
if sampled_count >= total_points:
break
print(f"投票 {vote + 1} 完成,已采样 {sampled_count} 个点")
return vote_label_pool
def main():
# 设置随机种子
np.random.seed(42)
torch.manual_seed(42)
# 配置参数
model_path = r'D:\company\0125\gitnet\DGCNN0\DGCNN\deploy\dgcnn_segmentationcir.pt'
input_csv = r'D:\company\data\2024-12-16-14-13-59.csv'
output_csv = r'D:\company\data\labeled_point_cloud16.csv'
num_points = 16384 # 需与训练时一致
num_classes = 3
num_votes = 3
batch_size = 32
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"使用设备:{device}")
# 创建输出目录
output_dir = Path(output_csv).parent
output_dir.mkdir(parents=True, exist_ok=True)
# 加载 TorchScript 模型
try:
model = torch.jit.load(model_path, map_location=device)
model.eval()
print(f"成功加载模型:{model_path}")
except Exception as e:
print(f"加载模型失败:{e}")
return
# 加载点云数据
try:
points = load_point_cloud(input_csv)
print(f"加载点云数据,点数:{points.shape[0]}(已缩放为米单位)")
except Exception as e:
print(f"加载或预处理点云失败:{e}")
return
# 无放回采样处理点云
vote_label_pool = np.zeros((points.shape[0], num_classes))
for batch_tensor, batch_point_idxs, batch_block_indices, vote_label_pool in tqdm(
process_point_cloud_without_replacement(points, num_points, batch_size, num_votes),
desc="处理批次"
):
start_time = time.time()
batch_tensor = batch_tensor.to(device)
with torch.no_grad():
seg_pred = model(batch_tensor) # [batch_size, num_classes, num_points]
seg_pred = seg_pred.permute(0, 2, 1) # [batch_size, num_points, num_classes]
pred_labels = seg_pred.argmax(dim=2).cpu().numpy() # [batch_size, num_points]
for b in range(len(batch_point_idxs)):
point_idxs = batch_point_idxs[b]
block_indices = batch_block_indices[b]
for i, idx in enumerate(point_idxs[block_indices]):
vote_label_pool[idx, pred_labels[b, i]] += 1
print(f"批次处理时间:{time.time() - start_time:.2f} 秒")
pred_labels = np.argmax(vote_label_pool, axis=1)
print(f"推理完成,预测标签形状:{pred_labels.shape}")
# 检查未预测的点
unpredicted_points = np.sum(vote_label_pool.sum(axis=1) == 0)
print(f"未预测的点数:{unpredicted_points}")
# 仅保存已预测的点
predicted_mask = vote_label_pool.sum(axis=1) > 0
predicted_points = points[predicted_mask]
predicted_labels = pred_labels[predicted_mask]
print(f"已预测点数:{len(predicted_points)}")
# 统计标签分布
unique, counts = np.unique(predicted_labels, return_counts=True)
print("标签分布:")
for label, count in zip(unique, counts):
print(f" 标签 {label} ({['liaodui', 'dabi', 'madao'][label]}): {count} 个点")
# 保存已预测点到 CSV
output_data = np.hstack([predicted_points * 1000.0, predicted_labels.reshape(-1, 1)])
try:
np.savetxt(output_csv, output_data, delimiter=',', fmt='%.6f,%.6f,%.6f,%d')
print(f"结果已保存至:{output_csv},包含 {len(output_data)} 个点")
except Exception as e:
print(f"保存结果失败:{e}")
if __name__ == '__main__':
main()