From f5f18fcbc6f694d7ab33e2c055d845a90eda7240 Mon Sep 17 00:00:00 2001 From: bo-10000 Date: Fri, 14 Mar 2025 22:53:54 +0900 Subject: [PATCH] update readme & refactor --- ai/README.md | 94 +++++++++++++- ai/dataset/download_dataset.bash | 7 +- ai/requirements.txt | 58 +++++++++ ai/src/eval.py | 21 +-- ai/src/export.py | 95 ++++++++++++++ ai/src/models/mobilenetv3.py | 7 +- ai/src/onnx_eval.py | 215 +++++++++++++++++++++++++++++++ 7 files changed, 471 insertions(+), 26 deletions(-) mode change 100644 => 100755 ai/dataset/download_dataset.bash create mode 100644 ai/requirements.txt create mode 100644 ai/src/export.py create mode 100644 ai/src/onnx_eval.py diff --git a/ai/README.md b/ai/README.md index 81429a2..543f408 100644 --- a/ai/README.md +++ b/ai/README.md @@ -1,18 +1,98 @@ # CataScan AI -AI 모델 학습 및 추론 코드입니다. +카메라로 촬영한 눈 이미지를 이용하여 백내장 유무를 분류하는 AI 모델 관련 코드입니다. 데이터 전처리, 모델 학습, 평가 및 ONNX 변환과 관련된 코드가 포함되어 있습니다. -## Download Dataset +자세한 내용은 [AI 모델 개요](../docs/ai/00_introduction.md)를 참고해주세요. + +## 1. 기술 스택 +- **프로그래밍 언어**: Python +- **프레임워크**: PyTorch +- **모델 변환**: ONNX + + +## 2. 폴더 구조 +``` +A-Eye-Lab-Research/ +├── dataset/ # 데이터셋 관련 코드 +├── src/ # 학습 관련 코드 +│ ├── config/ # 학습 관련 config +│ ├── models/ # 모델 코드 +│ ├── modules/ # 모델 학습 관련 코드 (optimizer, scheduler 등) +│ ├── eval.py # 평가 코드 +│ ├── train.py # 학습 코드 +├── requirements.txt # 필수 패키지 목록 +├── README.md # 프로젝트 설명 파일 +``` + +## 3. Setup +```bash +pip install -r requirements.txt +``` + +## 4. Usage +### Download Dataset ```bash -./dataset/download_dataset.bash +chmod +x ./dataset/download_dataset.bash + +./dataset/download_dataset.bash dataset/data/ ``` -## Model Train + +### Model Train ```bash python src/train.py --cfg src/config/train.yaml ``` -## Model Evaluation (with extern dataset) +- config 각 파라미터에 대한 설명은 하단의 "5. config 설정"을 참고해주세요. + +### Model Evaluation (with external dataset) +```bash +python src/eval.py \ + --dataset_path /dataset/data/kaggle_cataract_nand \ + --ckpt best_checkpoint_path \ + --cfg src/config/train.yaml +``` + +### Export ONNX Model ```bash -python src/eval.py --dataset_path /dataset/{your_dataset or kaggle_cataract_nand(default)} +python src/export.py \ + --ckpt best_checkpoint_path \ + --cfg src/config/train.yaml \ + --onnx_path models/ + +# ONNX 모델 평가 +python src/onnx_eval.py \ + --onnx_path models/model_quantized.onnx \ + --dataset_path dataset/data/kaggle_cataract_nand ``` -자세한 내용은 [AI 모델 개요](../docs/ai/00_introduction.md)를 참고해주세요. \ No newline at end of file +## 5. config 설정 +- `src/config/train.yaml` 파일을 수정하여 학습에 필요한 설정을 변경할 수 있습니다. +- config 주요 항목에 대한 설명은 다음과 같습니다. + - `DEVICE`: 사용할 디바이스 (cpu/cuda/mps 등) + - `RANDOM_SEED`: 랜덤 시드 + - `MODEL` + - `NAME`: 사용할 모델 이름 + - `NUM_CLASSES`: 클래스 개수 (binary classification이므로 2) + - `PRETRAINED`: 사전 학습된 모델 사용 여부 + - `DATASET` + - `TRAIN_DATA_DIR`: 학습에 사용할 데이터셋 경로 목록 + - `NUM_WORKERS`: 데이터 로드에 사용할 worker 수 + - `N_FOLDS`: K-Fold Cross Validation을 위한 Fold 수. 1로 설정시 사용하지 않음 + - `TRAIN` + - `BATCH_SIZE`: 배치 사이즈 + - `EPOCHS`: 학습 에폭 수 + - `PATIENCE`: Early Stopping을 위한 patience 값 + - `AMP`: Automatic Mixed Precision 사용 여부 + - `LOSS` + - `NAME`: 사용할 Loss 함수 이름 + - `OPTIMIZER` + - `NAME`: 사용할 Optimizer 이름 + - `lr`: Learning Rate + - `weight_decay`: Weight Decay + - `SCHEDULER` + - `scheduler_name`: 사용할 Scheduler 이름 + - `mode`: Scheduler mode + - `factor`: Scheduler factor + - `patience`: Scheduler patience + - `min_lr`: Scheduler 최소 Learning Rate + - `SAVE_DIR`: 학습 결과 저장 경로 + - `WANDB_PROJECT`: Weights & Biases 프로젝트 이름 \ No newline at end of file diff --git a/ai/dataset/download_dataset.bash b/ai/dataset/download_dataset.bash old mode 100644 new mode 100755 index eaa9ea4..54549d4 --- a/ai/dataset/download_dataset.bash +++ b/ai/dataset/download_dataset.bash @@ -1,6 +1,11 @@ #!/bin/bash -download_path="dataset/data" +if [ -z "$1" ]; then + echo "Usage: $0 " + exit 1 +fi + +download_path="$1" python dataset/download_datasets_1.py --download_path "$download_path" python dataset/download_datasets_2.py --download_path "$download_path" diff --git a/ai/requirements.txt b/ai/requirements.txt new file mode 100644 index 0000000..ce03e6b --- /dev/null +++ b/ai/requirements.txt @@ -0,0 +1,58 @@ +annotated-types==0.7.0 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +coloredlogs==15.0.1 +contourpy==1.3.1 +cycler==0.12.1 +docker-pycreds==0.4.0 +filelock==3.18.0 +flatbuffers==25.2.10 +fonttools==4.56.0 +fsspec==2025.3.0 +gitdb==4.0.12 +GitPython==3.1.44 +huggingface-hub==0.29.3 +humanfriendly==10.0 +idna==3.10 +Jinja2==3.1.6 +joblib==1.4.2 +kagglehub==0.3.10 +kiwisolver==1.4.8 +lightning-utilities==0.14.0 +MarkupSafe==3.0.2 +matplotlib==3.10.1 +mpmath==1.3.0 +networkx==3.4.2 +numpy==2.2.3 +onnx==1.17.0 +onnxruntime==1.21.0 +packaging==24.2 +pillow==11.1.0 +platformdirs==4.3.6 +protobuf==5.29.3 +psutil==7.0.0 +pydantic==2.10.6 +pydantic_core==2.27.2 +pyparsing==3.2.1 +python-dateutil==2.9.0.post0 +PyYAML==6.0.2 +requests==2.32.3 +safetensors==0.5.3 +scikit-learn==1.6.1 +scipy==1.15.2 +sentry-sdk==2.22.0 +setproctitle==1.3.5 +six==1.17.0 +smmap==5.0.2 +sympy==1.13.1 +tabulate==0.9.0 +threadpoolctl==3.6.0 +timm==1.0.15 +torch==2.6.0 +torchmetrics==1.6.3 +torchvision==0.21.0 +tqdm==4.67.1 +typing_extensions==4.12.2 +urllib3==2.3.0 +wandb==0.19.8 diff --git a/ai/src/eval.py b/ai/src/eval.py index d386981..686909b 100644 --- a/ai/src/eval.py +++ b/ai/src/eval.py @@ -15,18 +15,13 @@ def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument( - "--cfg", type=str, required=True, help="Configuration file to use" - ) - parser.add_argument( - "--ckpt", type=str, required=True, help="Checkpoint file to load model weights" - ) + parser.add_argument("--cfg", type=str, required=True, help="Configuration file to use") + parser.add_argument("--ckpt", type=str, required=True, help="Checkpoint file to load model weights") parser.add_argument( "--dataset_path", type=str, - default="dataset/kaggle_cataract_nand", - required=True, - help="Configuration file to use", + default="dataset/data/kaggle_cataract_nand", + help="test dataset path", ) args = parser.parse_args() @@ -81,13 +76,9 @@ def __getitem__(self, idx): def load_model(model_name, num_classes, pretrained): if hasattr(models, model_name): - return getattr(models, model_name)( - num_classes=num_classes, pretrained=pretrained - ) + return getattr(models, model_name)(num_classes=num_classes, pretrained=pretrained) else: - return getattr(torchvision_models, model_name)( - num_classes=num_classes, pretrained=pretrained - ) + return getattr(torchvision_models, model_name)(num_classes=num_classes, pretrained=pretrained) def main(args): diff --git a/ai/src/export.py b/ai/src/export.py new file mode 100644 index 0000000..72dfb81 --- /dev/null +++ b/ai/src/export.py @@ -0,0 +1,95 @@ +import argparse +import os + +import models +import onnx +import torch +import torch.onnx +import torchvision.models as torchvision_models +from modules.utils import load_yaml +from onnxruntime.quantization import QuantType, quantize_dynamic + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--cfg", type=str, required=True, help="Configuration file to use") + parser.add_argument("--ckpt", type=str, required=True, help="Checkpoint file to load model weights") + parser.add_argument("--onnx_path", type=str, default="models/", help="Path to save ONNX model") + + args = parser.parse_args() + + return args + + +def load_model(model_name, num_classes, pretrained): + if hasattr(models, model_name): + return getattr(models, model_name)(num_classes=num_classes, pretrained=pretrained) + else: + return getattr(torchvision_models, model_name)(num_classes=num_classes, pretrained=pretrained) + + +def export_to_onnx(model, input_shape, save_path, device): + """PyTorch 모델을 ONNX로 변환""" + model.eval() + dummy_input = torch.randn(input_shape).to(device) + + torch.onnx.export( + model, + dummy_input, + save_path, + export_params=True, + opset_version=18, + do_constant_folding=True, + input_names=["input"], + output_names=["output1", "output2"], + dynamic_axes={"input": {0: "batch_size"}, "output1": {0: "batch_size"}, "output2": {0: "batch_size"}}, + training=torch.onnx.TrainingMode.EVAL, + keep_initializers_as_inputs=True, + ) + print(f"ONNX 모델이 {save_path}에 저장되었습니다.") + + +def quantize_onnx_model(onnx_path, quantized_path): + """ONNX 모델 동적 양자화""" + quantized_model = quantize_dynamic( + model_input=onnx_path, + model_output=quantized_path, + weight_type=QuantType.QUInt8, + per_channel=False, + ) + print(f"양자화된 모델이 {quantized_path}에 저장되었습니다.") + + +def verify_onnx_model(onnx_path): + """ONNX 모델 검증""" + model = onnx.load(onnx_path) + onnx.checker.check_model(model) + print("ONNX 모델 검증이 완료되었습니다.") + + +if __name__ == "__main__": + args = parse_args() + + cfg = load_yaml(args.cfg) + device = torch.device(cfg["DEVICE"]) + model = load_model( + cfg["MODEL"]["NAME"], + cfg["MODEL"]["NUM_CLASSES"], + cfg["MODEL"]["PRETRAINED"], + ) + model.load(args.ckpt) + model = model.to(device) + + model.eval() + input_shape = (1, 3, 224, 224) # 예시 입력 shape + + # ONNX 변환 + onnx_path = os.path.join(args.onnx_path, "./model.onnx") + export_to_onnx(model, input_shape, onnx_path, device) + + # 모델 검증 + verify_onnx_model(onnx_path) + + # 양자화 + quantized_path = os.path.join(args.onnx_path, "./model_quantized.onnx") + quantize_onnx_model(onnx_path, quantized_path) diff --git a/ai/src/models/mobilenetv3.py b/ai/src/models/mobilenetv3.py index 8ef679b..259a90b 100644 --- a/ai/src/models/mobilenetv3.py +++ b/ai/src/models/mobilenetv3.py @@ -8,13 +8,14 @@ def __init__(self, num_classes, pretrained=False): super(MobileNet_V3_Large, self).__init__() # From timm - self.model = timm.create_model( - "mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes - ) + self.model = timm.create_model("mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes) def forward(self, x): return self.model(x) + def load(self, path): + self.load_state_dict(torch.load(path)) + if __name__ == "__main__": model = MobileNet_V3_Large(num_classes=2, pretrained=False) diff --git a/ai/src/onnx_eval.py b/ai/src/onnx_eval.py new file mode 100644 index 0000000..a72f0fe --- /dev/null +++ b/ai/src/onnx_eval.py @@ -0,0 +1,215 @@ +import argparse +import os + +import matplotlib.pyplot as plt +import numpy as np +import onnxruntime as ort +from PIL import Image +from sklearn.metrics import ( + accuracy_score, + auc, + confusion_matrix, + f1_score, + precision_score, + recall_score, + roc_curve, +) +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from tqdm import tqdm + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--onnx_path", type=str, default="models/model_quantized.onnx", help="Path to save ONNX model") + parser.add_argument( + "--dataset_path", + type=str, + default="dataset/data/kaggle_cataract_nand", + help="test dataset path", + ) + + args = parser.parse_args() + + return args + + +class CustomImageDataset(Dataset): + def __init__(self, dataset_path, transform=None): + """ + 커스텀 데이터셋 클래스 (이미지와 레이블을 로드) + """ + self.dataset_path = dataset_path + self.transform = transform + self.data = [] + self.labels = [] + self._load_data() + + def _load_data(self): + """ + 디렉토리에서 라벨이 된 이미지 로드 + 디렉토리 구조 + dataset + ㄴ 0 : Nomal + ㄴ 1 : Cataract + """ + print(f"Loading images from: {self.dataset_path}") + classes = {"0": 0, "1": 1} + + for label_dir, label in classes.items(): + class_path = os.path.join(self.dataset_path, label_dir) + if not os.path.isdir(class_path): + raise ValueError(f"Directory '{class_path}' does not exist.") + + for file_name in os.listdir(class_path): + file_path = os.path.join(class_path, file_name) + try: + image = Image.open(file_path).convert("RGB") + if self.transform: + image = self.transform(image) + self.data.append(image) + self.labels.append(label) + except Exception as e: + print(f"Failed to process image: {file_path}. Error: {e}") + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx], self.labels[idx] + + +def specificity_score(y_true, y_pred, zero_division=0): + """ + 특이도(Specificity) 계산 함수 + TN / (TN + FP) + """ + cm = confusion_matrix(y_true, y_pred) + if len(cm) <= 1: + return zero_division + tn, fp = cm[0][0], cm[0][1] + if tn + fp == 0: + return zero_division + return tn / (tn + fp) + + +class ONNXImageTestEvaluator: + def __init__(self, onnx_path, dataset_path, batch_size=16, image_size=(224, 224)): + self.onnx_path = onnx_path + self.dataset_path = dataset_path + self.image_size = image_size + self.batch_size = batch_size + + # CUDA를 사용할 수 없는 경우를 위해 providers 목록 수정 + providers = ["CPUExecutionProvider"] + try: + # CUDA 사용 가능성 확인 + ort_session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + except Exception as e: + print("CUDA를 사용할 수 없습니다. CPU를 사용합니다.") + print(f"에러 메시지: {str(e)}") + + # 수정된 providers로 세션 초기화 + self.session = ort.InferenceSession(onnx_path, providers=providers) + + # 입력 이름 가져오기 + self.input_name = self.session.get_inputs()[0].name + + self.norm = {"mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)} + + # Image transformation + self.transform = transforms.Compose( + [ + # transforms.Resize(image_size), + transforms.Resize(image_size[0]), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + transforms.Normalize(**self.norm), + ] + ) + + def load_data(self): + """ + DataLoader를 사용하여 데이터 로드 + """ + dataset = CustomImageDataset(self.dataset_path, transform=self.transform) + self.dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False) + print(f"Loaded {len(dataset)} images.") + + def evaluate(self): + """ + Test dataset에서 ONNX 모델 Evaluation 진행하고 ROC 커브 생성 + """ + all_preds = [] + all_labels = [] + all_probs = [] # 확률값 저장을 위한 리스트 + + for data, label in tqdm(self.dataloader, desc="Evaluating", unit="batch"): + input_data = data.numpy() + outputs = self.session.run(None, {self.input_name: input_data}) + + # 클래스별 확률값 + probabilities = outputs[0] + positive_probs = probabilities[:, 1] # 양성 클래스(1)의 확률 + + preds = np.argmax(outputs[0], axis=1) + + all_probs.extend(positive_probs) + all_preds.extend(preds) + all_labels.extend(label.numpy()) + + y_true = all_labels + y_pred = all_preds + y_prob = all_probs + + # ROC 커브 계산 + fpr, tpr, _ = roc_curve(y_true, y_prob) + roc_auc = auc(fpr, tpr) + + # ROC 커브 그리기 + plt.figure(figsize=(8, 6)) + plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})") + plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--") + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + plt.title("Receiver Operating Characteristic (ROC) Curve") + plt.legend(loc="lower right") + + # ROC 커브 저장 + plt.savefig("roc_curve.png") + plt.close() + + # Calculate metrics + metrics = { + "Confusion Matrix": confusion_matrix(y_true, y_pred), + "Precision": precision_score(y_true, y_pred, zero_division=0), + "Recall": recall_score(y_true, y_pred, zero_division=0), + "F1 Score": f1_score(y_true, y_pred), + "Accuracy": accuracy_score(y_true, y_pred), + "Specificity": specificity_score(y_true, y_pred, zero_division=0), + "AUC-ROC": roc_auc, + } + + return metrics + + +if __name__ == "__main__": + args = parse_args() + + # ONNX 평가기 초기화 + evaluator = ONNXImageTestEvaluator(onnx_path=args.onnx_path, dataset_path=args.dataset_path, image_size=(224, 224)) + + # 데이터 로드 및 평가 + evaluator.load_data() + results = evaluator.evaluate() + + # 결과 출력 + for metric, value in results.items(): + if metric == "Confusion Matrix": + print(f"\n{metric}") + print(value, "\n") + else: + print(f"{metric:<15} : {value:.4f}")