Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 87 additions & 7 deletions ai/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,98 @@
# CataScan AI

AI 모델 학습 및 추론 코드입니다.
카메라로 촬영한 눈 이미지를 이용하여 백내장 유무를 분류하는 AI 모델 관련 코드입니다. 데이터 전처리, 모델 학습, 평가 및 ONNX 변환과 관련된 코드가 포함되어 있습니다.

## Download Dataset
자세한 내용은 [AI 모델 개요](../docs/ai/00_introduction.md)를 참고해주세요.

## 1. 기술 스택
- **프로그래밍 언어**: Python
- **프레임워크**: PyTorch
- **모델 변환**: ONNX


## 2. 폴더 구조
```
A-Eye-Lab-Research/
├── dataset/ # 데이터셋 관련 코드
├── src/ # 학습 관련 코드
│ ├── config/ # 학습 관련 config
│ ├── models/ # 모델 코드
│ ├── modules/ # 모델 학습 관련 코드 (optimizer, scheduler 등)
│ ├── eval.py # 평가 코드
│ ├── train.py # 학습 코드
├── requirements.txt # 필수 패키지 목록
├── README.md # 프로젝트 설명 파일
```

## 3. Setup
```bash
pip install -r requirements.txt
```

## 4. Usage
### Download Dataset
```bash
./dataset/download_dataset.bash
chmod +x ./dataset/download_dataset.bash

./dataset/download_dataset.bash dataset/data/
```
## Model Train

### Model Train
```bash
python src/train.py --cfg src/config/train.yaml
```
## Model Evaluation (with extern dataset)
- config 각 파라미터에 대한 설명은 하단의 "5. config 설정"을 참고해주세요.

### Model Evaluation (with external dataset)
```bash
python src/eval.py \
--dataset_path /dataset/data/kaggle_cataract_nand \
--ckpt best_checkpoint_path \
--cfg src/config/train.yaml
```

### Export ONNX Model
```bash
python src/eval.py --dataset_path /dataset/{your_dataset or kaggle_cataract_nand(default)}
python src/export.py \
--ckpt best_checkpoint_path \
--cfg src/config/train.yaml \
--onnx_path models/

# ONNX 모델 평가
python src/onnx_eval.py \
--onnx_path models/model_quantized.onnx \
--dataset_path dataset/data/kaggle_cataract_nand
```

자세한 내용은 [AI 모델 개요](../docs/ai/00_introduction.md)를 참고해주세요.
## 5. config 설정
- `src/config/train.yaml` 파일을 수정하여 학습에 필요한 설정을 변경할 수 있습니다.
- config 주요 항목에 대한 설명은 다음과 같습니다.
- `DEVICE`: 사용할 디바이스 (cpu/cuda/mps 등)
- `RANDOM_SEED`: 랜덤 시드
- `MODEL`
- `NAME`: 사용할 모델 이름
- `NUM_CLASSES`: 클래스 개수 (binary classification이므로 2)
- `PRETRAINED`: 사전 학습된 모델 사용 여부
- `DATASET`
- `TRAIN_DATA_DIR`: 학습에 사용할 데이터셋 경로 목록
- `NUM_WORKERS`: 데이터 로드에 사용할 worker 수
- `N_FOLDS`: K-Fold Cross Validation을 위한 Fold 수. 1로 설정시 사용하지 않음
- `TRAIN`
- `BATCH_SIZE`: 배치 사이즈
- `EPOCHS`: 학습 에폭 수
- `PATIENCE`: Early Stopping을 위한 patience 값
- `AMP`: Automatic Mixed Precision 사용 여부
- `LOSS`
- `NAME`: 사용할 Loss 함수 이름
- `OPTIMIZER`
- `NAME`: 사용할 Optimizer 이름
- `lr`: Learning Rate
- `weight_decay`: Weight Decay
- `SCHEDULER`
- `scheduler_name`: 사용할 Scheduler 이름
- `mode`: Scheduler mode
- `factor`: Scheduler factor
- `patience`: Scheduler patience
- `min_lr`: Scheduler 최소 Learning Rate
- `SAVE_DIR`: 학습 결과 저장 경로
- `WANDB_PROJECT`: Weights & Biases 프로젝트 이름
7 changes: 6 additions & 1 deletion ai/dataset/download_dataset.bash
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#!/bin/bash

download_path="dataset/data"
if [ -z "$1" ]; then
echo "Usage: $0 <download_path>"
exit 1
fi

download_path="$1"

python dataset/download_datasets_1.py --download_path "$download_path"
python dataset/download_datasets_2.py --download_path "$download_path"
Expand Down
58 changes: 58 additions & 0 deletions ai/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
annotated-types==0.7.0
certifi==2025.1.31
charset-normalizer==3.4.1
click==8.1.8
coloredlogs==15.0.1
contourpy==1.3.1
cycler==0.12.1
docker-pycreds==0.4.0
filelock==3.18.0
flatbuffers==25.2.10
fonttools==4.56.0
fsspec==2025.3.0
gitdb==4.0.12
GitPython==3.1.44
huggingface-hub==0.29.3
humanfriendly==10.0
idna==3.10
Jinja2==3.1.6
joblib==1.4.2
kagglehub==0.3.10
kiwisolver==1.4.8
lightning-utilities==0.14.0
MarkupSafe==3.0.2
matplotlib==3.10.1
mpmath==1.3.0
networkx==3.4.2
numpy==2.2.3
onnx==1.17.0
onnxruntime==1.21.0
packaging==24.2
pillow==11.1.0
platformdirs==4.3.6
protobuf==5.29.3
psutil==7.0.0
pydantic==2.10.6
pydantic_core==2.27.2
pyparsing==3.2.1
python-dateutil==2.9.0.post0
PyYAML==6.0.2
requests==2.32.3
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.15.2
sentry-sdk==2.22.0
setproctitle==1.3.5
six==1.17.0
smmap==5.0.2
sympy==1.13.1
tabulate==0.9.0
threadpoolctl==3.6.0
timm==1.0.15
torch==2.6.0
torchmetrics==1.6.3
torchvision==0.21.0
tqdm==4.67.1
typing_extensions==4.12.2
urllib3==2.3.0
wandb==0.19.8
21 changes: 6 additions & 15 deletions ai/src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,13 @@

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg", type=str, required=True, help="Configuration file to use"
)
parser.add_argument(
"--ckpt", type=str, required=True, help="Checkpoint file to load model weights"
)
parser.add_argument("--cfg", type=str, required=True, help="Configuration file to use")
parser.add_argument("--ckpt", type=str, required=True, help="Checkpoint file to load model weights")
parser.add_argument(
"--dataset_path",
type=str,
default="dataset/kaggle_cataract_nand",
required=True,
help="Configuration file to use",
default="dataset/data/kaggle_cataract_nand",
help="test dataset path",
)

args = parser.parse_args()
Expand Down Expand Up @@ -81,13 +76,9 @@ def __getitem__(self, idx):

def load_model(model_name, num_classes, pretrained):
if hasattr(models, model_name):
return getattr(models, model_name)(
num_classes=num_classes, pretrained=pretrained
)
return getattr(models, model_name)(num_classes=num_classes, pretrained=pretrained)
else:
return getattr(torchvision_models, model_name)(
num_classes=num_classes, pretrained=pretrained
)
return getattr(torchvision_models, model_name)(num_classes=num_classes, pretrained=pretrained)


def main(args):
Expand Down
95 changes: 95 additions & 0 deletions ai/src/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import argparse
import os

import models
import onnx
import torch
import torch.onnx
import torchvision.models as torchvision_models
from modules.utils import load_yaml
from onnxruntime.quantization import QuantType, quantize_dynamic


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--cfg", type=str, required=True, help="Configuration file to use")
parser.add_argument("--ckpt", type=str, required=True, help="Checkpoint file to load model weights")
parser.add_argument("--onnx_path", type=str, default="models/", help="Path to save ONNX model")

args = parser.parse_args()

return args


def load_model(model_name, num_classes, pretrained):
if hasattr(models, model_name):
return getattr(models, model_name)(num_classes=num_classes, pretrained=pretrained)
else:
return getattr(torchvision_models, model_name)(num_classes=num_classes, pretrained=pretrained)


def export_to_onnx(model, input_shape, save_path, device):
"""PyTorch 모델을 ONNX로 변환"""
model.eval()
dummy_input = torch.randn(input_shape).to(device)

torch.onnx.export(
model,
dummy_input,
save_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=["input"],
output_names=["output1", "output2"],
dynamic_axes={"input": {0: "batch_size"}, "output1": {0: "batch_size"}, "output2": {0: "batch_size"}},
training=torch.onnx.TrainingMode.EVAL,
keep_initializers_as_inputs=True,
)
print(f"ONNX 모델이 {save_path}에 저장되었습니다.")


def quantize_onnx_model(onnx_path, quantized_path):
"""ONNX 모델 동적 양자화"""
quantized_model = quantize_dynamic(
model_input=onnx_path,
model_output=quantized_path,
weight_type=QuantType.QUInt8,
per_channel=False,
)
print(f"양자화된 모델이 {quantized_path}에 저장되었습니다.")


def verify_onnx_model(onnx_path):
"""ONNX 모델 검증"""
model = onnx.load(onnx_path)
onnx.checker.check_model(model)
print("ONNX 모델 검증이 완료되었습니다.")


if __name__ == "__main__":
args = parse_args()

cfg = load_yaml(args.cfg)
device = torch.device(cfg["DEVICE"])
model = load_model(
cfg["MODEL"]["NAME"],
cfg["MODEL"]["NUM_CLASSES"],
cfg["MODEL"]["PRETRAINED"],
)
model.load(args.ckpt)
model = model.to(device)

model.eval()
input_shape = (1, 3, 224, 224) # 예시 입력 shape

# ONNX 변환
onnx_path = os.path.join(args.onnx_path, "./model.onnx")
export_to_onnx(model, input_shape, onnx_path, device)

# 모델 검증
verify_onnx_model(onnx_path)

# 양자화
quantized_path = os.path.join(args.onnx_path, "./model_quantized.onnx")
quantize_onnx_model(onnx_path, quantized_path)
7 changes: 4 additions & 3 deletions ai/src/models/mobilenetv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ def __init__(self, num_classes, pretrained=False):
super(MobileNet_V3_Large, self).__init__()

# From timm
self.model = timm.create_model(
"mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes
)
self.model = timm.create_model("mobilenetv3_large_100", pretrained=pretrained, num_classes=num_classes)

def forward(self, x):
return self.model(x)

def load(self, path):
self.load_state_dict(torch.load(path))


if __name__ == "__main__":
model = MobileNet_V3_Large(num_classes=2, pretrained=False)
Expand Down
Loading