Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,10 @@ dmypy.json
# Pyre type checker
.pyre/

*.DS_Store
*.DS_Store

*.pth

*_DONE

data_custom
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,24 @@ gpu_ids: [0] # Our model can be trained using a single GPU with memory>20GB. You
2. Train the network.
```bash
python train.py --opt your_config_path

```

### Using
You can directly use our pre-trained model for low-light image enhancement in your own project. Here is an example code snippet to load the pre-trained model and enhance a low-light image.

**For a Single Image**

```python
python using.py --input path/to/image.jpg --output path/to/output.jpg --model path/to/config.yml
```

**For a Folder of Images**

```python
python using.py --input path/to/input_folder --output path/to/output_folder --model path/to/config.yml
```

## Citation
If you find our work useful for your research, please cite our paper
```
Expand Down
125 changes: 125 additions & 0 deletions code/confs/Custom_smallNet_custom.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#### general settings
name: train_custom_dataset_rebuttal_smallNet_ch32_blocks1
use_tb_logger: true
model: LLFlow
distortion: sr
scale: 1
gpu_ids: [0]
dataset: Custom
optimize_all_z: false
cond_encoder: ConEncoder1
train_gt_ratio: 0.2
avg_color_map: false

concat_histeq: true
histeq_as_input: false
concat_color_map: false
gray_map: false # concat 1-input.mean(dim=1) to the input

align_condition_feature: false
align_weight: 0.001
align_maxpool: true

to_yuv: false

encode_color_map: false
le_curve: false
# sigmoid_output: true

#### datasets
datasets:
train:
root: D:\codes\projects\rm\low-light-image-enhancement\data\rm_dataset_crops
quant: 32
use_shuffle: true
n_workers: 0 # per GPU
batch_size: 16
use_flip: true
color: RGB
use_crop: true
GT_size: 160 # 192
noise_prob: 0
noise_level: 5
log_low: true
gamma_aug: false

val:
root: D:\codes\projects\rm\low-light-image-enhancement\data\rm_dataset_crops
n_workers: 0
quant: 32
n_max: 20
batch_size: 1 # must be 1
log_low: true

#### Test Settings
# dataroot_GT: D:\LOLdataset\eval15\high
# dataroot_LR: D:\LOLdataset\eval15\low
dataroot_unpaired: D:\codes\projects\rm\LLFlow\code\using\using_crops\only_low_not_love

dataroot_GT: D:\codes\projects\rm\LLFlow\code\using\using_crops\high
dataroot_LR: D:\codes\projects\rm\LLFlow\code\using\using_crops\low
model_path: D:\codes\projects\rm\LLFlow\experiments\train_custom_dataset_rebuttal_smallNet_ch32_blocks1\models\latest_G.pth
heat: 0 # This is the standard deviation of the latent vectors

#### network structures
network_G:
which_model_G: LLFlow
in_nc: 3
out_nc: 3
nf: 32
nb: 4 # 12 for our low light encoder, 23 for LLFlow
train_RRDB: false
train_RRDB_delay: 0.5

flow:
K: 4 # 24.49 psnr用的12 # 16
L: 3 # 4
noInitialInj: true
coupling: CondAffineSeparatedAndCond
additionalFlowNoAffine: 2
conditionInFeaDim: 64
split:
enable: false
fea_up0: true
stackRRDB:
blocks: [1]
concat: true

#### path
path:
# pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
strict_load: true
resume_state: auto

#### training settings: learning rate scheme, loss
train:
manual_seed: 10
lr_G: !!float 5e-4 # normalizing flow 5e-4; l1 loss train 5e-5
weight_decay_G: 0 # 1e-5 # 5e-5 # 1e-5
beta1: 0.9
beta2: 0.99
lr_scheme: MultiStepLR
warmup_iter: -1 # no warm up
lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ] # [0.2, 0.35, 0.5, 0.65, 0.8, 0.95] # [ 0.5, 0.75, 0.9, 0.95 ]
lr_gamma: 0.5

weight_l1: 0
# flow_warm_up_iter: -1
weight_fl: 1

niter: 1000 #200000
val_freq: 100 # 200

#### validation settings
val:
# heats: [ 0.0, 0.5, 0.75, 1.0 ]
n_sample: 4

test:
heats: [ 0.0, 0.7, 0.8, 0.9 ]

#### logger
logger:
# Debug print_freq: 100
print_freq: 100
save_checkpoint_freq: !!float 1e3
6 changes: 3 additions & 3 deletions code/confs/LOL_smallNet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ datasets:
# dataroot_GT: D:\LOLdataset\eval15\high
# dataroot_LR: D:\LOLdataset\eval15\low
dataroot_unpaired: /home/data/Dataset/LOL_test/Fusion
dataroot_GT: D:\Dataset\LOL-v2\LOL-v2\IntegratedTest\Test\high
dataroot_LR: D:\Dataset\LOL-v2\LOL-v2\IntegratedTest\Test\low
model_path: C:\Users\Yufei\OneDrive - Nanyang Technological University (1)\Project\AAAI2022-code-release\pretrained_model\LOL_smallNet.pth
dataroot_GT: D:\codes\projects\rm\low-light-image-enhancement\data\rm_dataset_crops\val\high
dataroot_LR: D:\codes\projects\rm\low-light-image-enhancement\data\rm_dataset_crops\val\low
model_path: D:\codes\projects\rm\LLFlow\code\models\LOL_smallNet.pth
heat: 0 # This is the standard deviation of the latent vectors

#### network structures
Expand Down
82 changes: 82 additions & 0 deletions code/data/LoL_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,3 +284,85 @@ def center_crop_tensor(img, size):
assert border_double % 2 == 0, (img.shape, size)
border = border_double // 2
return img[:, :, border:-border, border:-border]


# cutsom dataset for custom use
class Custom_Dataset(data.Dataset):
def __init__(self, opt, train, all_opt):
self.root = opt["root"]
self.opt = opt
self.concat_histeq = all_opt["concat_histeq"] if "concat_histeq" in all_opt.keys() else False
self.histeq_as_input = all_opt["histeq_as_input"] if "histeq_as_input" in all_opt.keys() else False
self.log_low = opt["log_low"] if "log_low" in opt.keys() else False
self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
self.use_noise = opt['noise_prob'] if "noise_prob" in opt.keys() else False
self.noise_prob = opt['noise_prob'] if self.use_noise else None
self.noise_level = opt['noise_level'] if "noise_level" in opt.keys() else 0
self.center_crop_hr_size = opt.get("center_crop_hr_size", None)
self.crop_size = opt.get("GT_size", None)

# Direct low/high structure
self.pairs = self.load_pairs(self.root)
self.to_tensor = ToTensor()
self.gamma_aug = opt['gamma_aug'] if 'gamma_aug' in opt.keys() else False

def load_pairs(self, folder_path):
low_list = os.listdir(os.path.join(folder_path, 'low'))
low_list = sorted(list(filter(lambda x: 'png' in x or 'jpg' in x, low_list)))
pairs = []
for f_name in low_list:
pairs.append([
cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'low', f_name)), cv2.COLOR_BGR2RGB),
cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'high', f_name)), cv2.COLOR_BGR2RGB),
f_name.split('.')[0]
])
pairs[-1].append(self.hiseq_color_cv2_img(pairs[-1][0]))
return pairs

def __len__(self):
return len(self.pairs)

def hiseq_color_cv2_img(self, img):
(b, g, r) = cv2.split(img)
bH = cv2.equalizeHist(b)
gH = cv2.equalizeHist(g)
rH = cv2.equalizeHist(r)
result = cv2.merge((bH, gH, rH))
return result

def __getitem__(self, item):
lr, hr, f_name, his = self.pairs[item]
if self.histeq_as_input:
lr = his

if self.use_crop:
hr, lr, his = random_crop(hr, lr, his, self.crop_size)

if self.center_crop_hr_size:
hr, lr, his = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size), center_crop(his, self.center_crop_hr_size)

if self.use_flip:
hr, lr, his = random_flip(hr, lr, his)

if self.use_rot:
hr, lr, his = random_rotation(hr, lr, his)

if self.gamma_aug:
gamma = random.uniform(0.4, 2.8)
lr = gamma_aug(lr, gamma=gamma)

hr = self.to_tensor(hr)
lr = self.to_tensor(lr)

if self.use_noise and random.random() < self.noise_prob:
lr = torch.randn(lr.shape) * (self.noise_level / 255) + lr
if self.log_low:
lr = torch.log(torch.clamp(lr + 1e-3, min=1e-3))

if self.concat_histeq:
his = self.to_tensor(his)
lr = torch.cat([lr, his], dim=0)

return {'LQ': lr, 'GT': hr, 'LQ_path': f_name, 'GT_path': f_name}
2 changes: 2 additions & 0 deletions code/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def create_dataset(dataset_opt):
mode = dataset_opt['mode']
if mode == 'LoL':
from data.LoL_dataset import LoL_Dataset as D
elif mode == 'Custom':
from data.LoL_dataset import Custom_Dataset as D
else:
raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
dataset = D(dataset_opt)
Expand Down
Loading