-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtoOnnx.py
More file actions
64 lines (52 loc) · 2.58 KB
/
Copy pathtoOnnx.py
File metadata and controls
64 lines (52 loc) · 2.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
import os
# 모델 구조가 정의된 파일 import
from drivingModel import DribingResNet
def convert_to_onnx():
# ==========================================
# 1. 파일 경로 설정
# ==========================================
# 변환할 .pth 파일 이름 (오타 주의: 저장된 실제 파일명과 일치해야 함)
pth_file_path = "project/DrivingProcess/best_dirving_model.pth"
# 생성될 .onnx 파일 이름
onnx_file_path = "project/DrivingProcess/DrivingResNet.onnx"
# 파일 존재 여부 확인
if not os.path.exists(pth_file_path):
print(f"❌ Error: '{pth_file_path}' 파일을 찾을 수 없습니다.")
return
print(f"Loading model from {pth_file_path}...")
# ==========================================
# 2. 모델 초기화 및 가중치 로드
# ==========================================
device = torch.device("cpu") # ONNX 변환은 호환성을 위해 CPU에서 수행 권장
# 모델 구조 생성 (학습할 때와 동일한 설정이어야 함)
model = DribingResNet(num_classes=3)
# 가중치(.pth) 로드
# map_location='cpu'를 사용하여 GPU에서 학습된 모델도 CPU에서 로드 가능하게 함
model.load_state_dict(torch.load(pth_file_path, map_location=device))
# 평가 모드로 전환 (Dropout, BatchNorm 등을 고정 - 필수!)
model.eval()
# ==========================================
# 3. ONNX 변환 (Export)
# ==========================================
# 모델 입력 크기에 맞는 더미 데이터 생성 (Batch=1, Ch=3, Height=224, Width=224)
dummy_input = torch.randn(1, 3, 224, 224).to(device)
print(f"Exporting to {onnx_file_path}...")
torch.onnx.export(
model, # 실행할 모델
dummy_input, # 더미 입력
onnx_file_path, # 저장 경로
export_params=True, # 가중치 포함 여부
opset_version=12, # ONNX 버전 (Unity Sentis 호환성 좋음)
do_constant_folding=True, # 최적화 수행
input_names=['input_image'], # Unity에서 참조할 입력 이름
output_names=['control_values'], # Unity에서 참조할 출력 이름
dynamic_axes={
'input_image': {0: 'batch_size'}, # 배치 크기 가변 허용
'control_values': {0: 'batch_size'}
}
)
print("✅ 변환 완료! 유니티 Assets 폴더에 넣으세요.")
print(f" -> {os.path.abspath(onnx_file_path)}")
if __name__ == "__main__":
convert_to_onnx()