Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@
[submodule "vendor/sam2"]
path = vendor/sam2
url = https://github.com/facebookresearch/sam2.git
[submodule "vendor/rawnind_jddc"]
path = vendor/rawnind_jddc
url = https://github.com/trougnouf/rawnind_jddc
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Currently targets the ONNX backend. The pipeline is designed to support addition
| [`mask-object-sam21-small`](models/mask-object-sam21-small/README.md) | mask | SAM 2.1 Hiera Small for interactive masking |
| [`mask-object-sam21-tiny`](models/mask-object-sam21-tiny/README.md) | mask | SAM 2.1 Hiera Tiny for interactive masking |
| [`mask-object-segnext-b2hq`](models/mask-object-segnext-b2hq/README.md) | mask | SegNext ViT-B SAx2 HQ for semantic masking |
| [`rawdenoise-nind`](models/rawdenoise-nind/README.md) | rawdenoise | UtNet2 raw denoiser trained on RawNIND (Bayer + linear Rec.2020 variants) |
| [`upscale-bsrgan`](models/upscale-bsrgan/README.md) | upscale | BSRGAN 2x and 4x blind super-resolution |

## Repository structure
Expand Down
47 changes: 37 additions & 10 deletions darktable_ai/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@
from darktable_ai.config import ModelConfig
from darktable_ai.convert import _import_script

_PROCESSED_IMAGE_EXTS = {".jpg", ".jpeg", ".png"}
_RAW_IMAGE_EXTS = {
".cr2", ".cr3", ".crw", # Canon
".nef", ".nrw", # Nikon
".arw", ".sr2", ".srf", # Sony
".raf", # Fuji
".rw2", # Panasonic
".pef", ".ptx", # Pentax
".orf", # Olympus
".rwl", # Leica
".srw", # Samsung
".dng", # Adobe generic
}
_SAMPLE_EXTS = _PROCESSED_IMAGE_EXTS | _RAW_IMAGE_EXTS

# Task → output file extension. Raw-domain tasks can't round-trip through PNG
# because they produce linear HDR or >8-bit data.
_OUTPUT_EXT_BY_TASK = {
"rawdenoise": ".tif",
}


def run_demo(config: ModelConfig) -> None:
"""Run the model's demo.py on all sample images for its task."""
Expand All @@ -26,16 +47,19 @@ def run_demo(config: ModelConfig) -> None:

module = _import_script(demo_script)
model_kwargs = _model_type_kwargs(config)
out_ext = _OUTPUT_EXT_BY_TASK.get(config.task, ".png")

for img in sorted(images_dir.iterdir()):
if img.suffix.lower() not in (".jpg", ".jpeg", ".png"):
continue
samples = sorted(p for p in images_dir.rglob("*")
if p.is_file() and p.suffix.lower() in _SAMPLE_EXTS)

for img in samples:
if img.stem.startswith("expected"):
continue

name = img.stem
output_path = demo_output_dir / f"{name}.png"
extra_kwargs = _image_kwargs(config, name)
rel = img.relative_to(images_dir).with_suffix("")
name = str(rel).replace("/", "_").replace("\\", "_")
output_path = demo_output_dir / f"{name}{out_ext}"
extra_kwargs = _image_kwargs(config, img, rel)

print(f" {name}")
module.demo(
Expand All @@ -60,15 +84,18 @@ def _model_type_kwargs(config: ModelConfig) -> dict:
return {"model": str(output_dir / "model.onnx")}


def _image_kwargs(config: ModelConfig, image_name: str) -> dict:
def _image_kwargs(config: ModelConfig, img: Path, rel: Path) -> dict:
"""Get extra demo kwargs for a specific image.

Reads from a JSON sidecar file next to the sample image first
(e.g. ``samples/mask-object/example_01.json``), then falls back
to ``demo.image_args`` in model.yaml.
to ``demo.image_args`` in model.yaml, keyed by either the flattened
relative path or the bare filename stem.
"""
sidecar = config.root_dir / "samples" / config.task / f"{image_name}.json"
sidecar = img.with_suffix(".json")
if sidecar.is_file():
with open(sidecar) as f:
return json.load(f)
return config.demo.image_args.get(image_name, {})
flat = str(rel).replace("/", "_").replace("\\", "_")
return (config.demo.image_args.get(flat)
or config.demo.image_args.get(img.stem, {}))
2 changes: 1 addition & 1 deletion models/denoise-nind/model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ checkpoints:
model_card:
long_description: "Image denoiser trained on the Natural Image Noise Dataset (NIND) from Wikimedia Commons"
scope: "single-image denoising"
author: "Benoit Brummer (University of Louvain)"
author: "Benoit Brummer (UCLouvain)"
source: "https://github.com/trougnouf/nind-denoise"
paper: "https://arxiv.org/abs/1906.00270"
license: "GPL-3.0"
Expand Down
88 changes: 88 additions & 0 deletions models/rawdenoise-nind/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# RawNIND UtNet2 (Bayer + Linear Rec.2020 variants)

Two UtNet2 raw denoisers trained on the Raw Natural Image Noise Dataset
(RawNIND). Bundled into a single `type: multi` package with sensor-based
auto-dispatch.

| Variant | Input | Output | Use for |
|----------------|------------------------------|------------------------------------------|----------------------------------|
| `model_bayer` | 4ch packed Bayer [R,G1,G2,B] | 3ch camRGB at 2× spatial, arbitrary gain | Bayer sensors (pre-demosaic) |
| `model_linear` | 3ch linear Rec.2020 | 3ch linear Rec.2020, arbitrary gain | X-Trans, Foveon, post-demosaic |

Both models perform the same denoising task; the Bayer variant additionally
does the demosaic (via a PixelShuffle output head that 2× upsamples) and
emits its output in the camera's native RGB space — the ColorMatrix is not
baked into the graph, so consumers must apply it after inference to reach
linear Rec.2020. The linear variant is a pure 3→3 denoiser, in and out of
linear Rec.2020. Both variants output at an arbitrary learned gain and
require a scalar gain-match against the input mean before use.

## Source

- Repository: https://github.com/trougnouf/rawnind_jddc
- Paper: [Learning Joint Denoising, Demosaicing, and Compression from the Raw Natural Image Noise Dataset](https://arxiv.org/abs/2501.08924) (Brummer & De Vleeschouwer, 2025)
- License: GPL-3.0

## Architecture

UtNet2 — 4-pool U-Net encoder-decoder (input H,W must be divisible by 16):

- `funit=32`, activation `LeakyReLU` (package default for both variants)
- Bayer output head: `Conv2d(32 → 12, 1×1) → PixelShuffle(2)` (4 → 3 ch at 2× spatial)
- Linear output head: `Conv2d(32 → 3, 1×1)` (3 → 3 ch, same spatial)

## Checkpoints

- Bayer: `DenoiserTrainingBayerToProfiledRGB_4ch_2024-02-21-bayer_ms-ssim_mgout_notrans_valeither_-4` (iter 4350000)
- Linear: `DenoiserTrainingProfiledRGBToProfiledRGB_3ch_2024-10-09-prgb_ms-ssim_mgout_notrans_valeither_-1` (iter 3900000)

Both are the canonical base variants from the `graph_denoise_models_definitions.yaml`
config map (`in_channels: 4` and `in_channels: 3`, no other options set). Training
used `match_gain: output` — the raw network outputs are at an arbitrary learned
scale; the demo rescales against the input mean at inference.

## ONNX Models

| File | Input | Output |
|-------------------|----------------------------------|----------------------------------|
| `model_bayer.onnx` | `input` — float32 [1, 4, H, W] | `output` — float32 [1, 3, 2H, 2W] |
| `model_linear.onnx` | `input` — float32 [1, 3, H, W] | `output` — float32 [1, 3, H, W] |

H and W must be divisible by 16.

## Demo pipeline

`demo.py` auto-dispatches based on `rawpy.imread(image).raw_pattern.shape`:

- `(2, 2)` → Bayer variant:
1. Normalise per-channel black level → white level, clip to [0, 1]
2. Pack to [R, G1, G2, B] half-resolution tensor
3. Crop to mod-16
4. Inference → camRGB (arbitrary scale, 2× input spatial size)
5. Gain-match to input mean
6. camRGB → linear Rec.2020 via `inv(rgb_xyz_matrix[:3,:]) → XYZ → Rec.2020`
- anything else (X-Trans 6×6, Foveon, …) → Linear variant:
1. `rawpy.postprocess` with linear Rec.2020 output, camera WB, no gamma
2. Crop to mod-16
3. Inference → lin-Rec.2020 (arbitrary scale)
4. Gain-match to input mean

Output is a 16-bit linear Rec.2020 TIFF (or `.exr` if the output path has that
suffix). Linear Rec.2020 looks very dark in typical image viewers — open in
darktable / rawtherapee / a PQ-aware viewer.

## Selection Criteria

| Property | Value |
|--------------------------|-----------------------------------------------------------------------------------------------------------|
| Model license | GPL-3.0 |
| OSAID v1.0 | Open Source AI |
| MOF | Class I (Open Science) |
| Training data license | CC BY 4.0 / CC0 (per-image, Wikimedia Commons) |
| Training data provenance | [RawNIND](https://dataverse.uclouvain.be/dataverse/rawnind) – real-world raw noise/clean pairs captured by authors |
| Training code | [GPL-3.0](https://github.com/trougnouf/rawnind_jddc) |
| Known limitations | Authors flag the code as academic state; Bayer-only 2x output upsample baked into the Bayer variant |
| Published research | [arXiv:2501.08924](https://arxiv.org/abs/2501.08924) |
| Inference | Local only, no cloud dependencies |
| Scope | Raw and linear-RGB image denoising |
| Reproducibility | Full pipeline (setup, convert, clean, demo) |
183 changes: 183 additions & 0 deletions models/rawdenoise-nind/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Export RawNIND UtNet2 raw denoiser to ONNX.

Uses UtNet2 from the cloned rawnind_jddc repository:
https://github.com/trougnouf/rawnind_jddc

The bayer2prgb variant takes a 4-channel packed Bayer tensor and produces a
3-channel linear Rec.2020 RGB image at the same spatial resolution as the
packed Bayer input (i.e. half the sensor resolution on each axis).
"""

import argparse
import importlib.util
import os
import sys
import types

import torch

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
DTAI_ROOT = os.environ.get("DTAI_ROOT", os.path.join(SCRIPT_DIR, "../.."))
_RAWNIND_SRC = os.path.join(DTAI_ROOT, "vendor", "rawnind_jddc", "src")


def _load_utnet2():
"""Load UtNet2 from raw_denoiser.py without triggering rawnind's package __init__.

The upstream package's __init__.py chains in tools/libs/models, which pulls
psutil, configargparse and a long tail of research-pipeline deps we don't
need for ONNX export. UtNet2 itself only depends on torch; the only sibling
import it makes (rawnind.libs.rawproc) is used exclusively by the
Passthrough class, so we stub it out.
"""
# Stub the parent packages + the one real sibling module UtNet2's file imports.
for name in ("rawnind", "rawnind.libs", "rawnind.libs.rawproc"):
if name not in sys.modules:
sys.modules[name] = types.ModuleType(name)

path = os.path.join(_RAWNIND_SRC, "rawnind", "models", "raw_denoiser.py")
spec = importlib.util.spec_from_file_location("_rawnind_raw_denoiser", path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module.UtNet2


UtNet2 = _load_utnet2()

try:
import onnxconverter_common
HAS_ONNX_CONVERTER = True
except ImportError:
HAS_ONNX_CONVERTER = False


def load_model(checkpoint_path, in_channels=4, funit=32, activation="PReLU",
preupsample=False):
model = UtNet2(
in_channels=in_channels,
funit=funit,
activation=activation,
preupsample=preupsample,
)
# weights_only=False: the rawnind_jddc checkpoints are plain pickle dumps
# from torch.save on the nn.Module / training state, not pure tensor dicts.
loaded = torch.load(checkpoint_path, map_location="cpu", weights_only=False)

# Full-model pickle (author saved the nn.Module itself)
if isinstance(loaded, torch.nn.Module):
state_dict = loaded.state_dict()
elif isinstance(loaded, dict):
state_dict = loaded
for key in ("state_dict", "model_state_dict", "params", "params_ema", "model", "generator"):
if key in state_dict:
state_dict = state_dict[key]
if isinstance(state_dict, torch.nn.Module):
state_dict = state_dict.state_dict()
break
else:
raise TypeError(f"Unexpected checkpoint type: {type(loaded)}")

# Strip "module." prefix if present (DataParallel)
cleaned = {
(k[len("module."):] if k.startswith("module.") else k): v
for k, v in state_dict.items()
}
model.load_state_dict(cleaned, strict=True)
model.eval()
return model


def export_to_onnx(model, output_path, in_channels=4,
input_height=256, input_width=256,
dynamic_shapes=True, opset_version=17, fp16=False):
dummy_input = torch.randn(1, in_channels, input_height, input_width)

dynamic_axes = None
if dynamic_shapes:
dynamic_axes = {
"input": {0: "batch_size", 2: "height", 3: "width"},
"output": {0: "batch_size", 2: "height", 3: "width"},
}

torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=opset_version,
do_constant_folding=True,
input_names=["input"],
output_names=["output"],
dynamic_axes=dynamic_axes,
dynamo=False,
)
print(f"Model exported to {output_path}")

import onnx
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
print("ONNX model verification passed!")

try:
import onnxsim
print("Simplifying model...")
onnx_model, ok = onnxsim.simplify(onnx_model)
if ok:
onnx.save(onnx_model, output_path)
print("Model simplified successfully")
else:
print("Warning: simplification failed, using unsimplified model")
except ImportError:
print("onnx-simplifier not installed, skipping.")

if fp16:
if not HAS_ONNX_CONVERTER:
print("Warning: onnxconverter-common not installed. Skipping FP16 conversion.")
return
print("Converting to FP16...")
from onnxconverter_common import float16
fp16_model = float16.convert_float_to_float16(onnx_model)
onnx.save(fp16_model, output_path)
print(f"FP16 model saved to {output_path}")


def convert(checkpoint, output="model.onnx", in_channels=4, funit=32,
activation="PReLU", preupsample=False,
height=256, width=256, dynamic_shapes=True, opset=17, fp16=False):
"""Entry point for programmatic conversion."""
os.makedirs(os.path.dirname(output) or ".", exist_ok=True)

print("Loading RawNIND UtNet2 model...")
model = load_model(checkpoint, in_channels=in_channels, funit=funit,
activation=activation, preupsample=preupsample)

print("Exporting to ONNX...")
export_to_onnx(model, output, in_channels=in_channels,
input_height=height, input_width=width,
dynamic_shapes=dynamic_shapes,
opset_version=opset, fp16=fp16)


def main():
parser = argparse.ArgumentParser(description="Export RawNIND UtNet2 to ONNX")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--output", type=str, default="model.onnx")
parser.add_argument("--in-channels", type=int, default=4)
parser.add_argument("--funit", type=int, default=32)
parser.add_argument("--activation", type=str, default="PReLU")
parser.add_argument("--preupsample", action="store_true")
parser.add_argument("--height", type=int, default=256)
parser.add_argument("--width", type=int, default=256)
parser.add_argument("--opset", type=int, default=17)
parser.add_argument("--fp16", action="store_true")
args = parser.parse_args()

convert(args.checkpoint, args.output,
in_channels=args.in_channels, funit=args.funit,
activation=args.activation, preupsample=args.preupsample,
height=args.height, width=args.width,
opset=args.opset, fp16=args.fp16)


if __name__ == "__main__":
main()
Loading
Loading