-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdata_transform.py
More file actions
24 lines (19 loc) · 842 Bytes
/
data_transform.py
File metadata and controls
24 lines (19 loc) · 842 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torchvision.transforms as transforms
def get_transform(size=224):
# data augmentation
jitter = 0.4
train_transform = transforms.Compose([
transforms.RandomResizedCrop(size, (0.8, 1.0)),
transforms.RandomHorizontalFlip(0.5),
transforms.ColorJitter(brightness=jitter, contrast=jitter, saturation=jitter, hue=min(0.5, jitter)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return train_transform, val_transform
if __name__ == '__main__':
train_transform, val_transform = get_transform()