Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,51 @@
"type": "debugpy",
"request": "launch",
"program": "${file}"
},
{
"name": "Python: pytest all tests",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["tests"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal"
},
{
"name": "Python: pytest test_data_list_nodes.py",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["tests/test_data_list_nodes.py"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal"
},
{
"name": "Python: pytest tests.test_data_list_nodes.test_create_from_int",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["-k", "test_create_from_int", "tests/test_data_list_nodes.py"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal"
},
{
"name": "Python: pytest tests.test_data_list_nodes.test_shuffle",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["-k", "test_shuffle", "tests/test_data_list_nodes.py"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal"
},
{
"name": "Python: pytest tests.test_list_nodes.test_shuffle",
"type": "python",
"request": "launch",
"module": "pytest",
"args": ["-k", "test_shuffle", "tests/test_list_nodes.py"],
"cwd": "${workspaceFolder}",
"console": "integratedTerminal"
}
]
}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "basic_data_handling"
version = "1.4.0"
version = "1.5.0"
description = """Basic Python functions for manipulating data that every programmer is used to, lightweight with no additional dependencies.

Supported data types:
Expand Down
47 changes: 40 additions & 7 deletions src/basic_data_handling/path_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,44 @@ def get_input_directory():
get_output_directory = get_input_directory


def _require_numpy():
try:
import numpy as np
return np
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"basic_data_handling: Missing dependency 'numpy'. It seems your ComfyUI installation is faulty."
"Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install numpy`."
) from e


def _require_torch():
try:
import torch
return torch
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"basic_data_handling: Missing dependency 'torch'. It seems your ComfyUI installation is faulty."
"Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install torch`."
) from e


def _require_pillow():
try:
from PIL import Image, ImageOps
return Image, ImageOps
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"basic_data_handling: Missing dependency 'pillow'. It seems your ComfyUI installation is faulty."
"Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install pillow`."
) from e


# helper functions:

def load_image_helper(path: str):
"""Helper function to load an image from a path"""
from PIL import Image, ImageOps
Image, ImageOps = _require_pillow()
try:
import pillow_jxl # noqa: F401 - imported but unused, kept for JPEG XL support
except ModuleNotFoundError:
Expand All @@ -49,8 +82,8 @@ def load_image_helper(path: str):

def extract_mask_from_alpha(img):
"""Extract a mask from the alpha channel of an image"""
import numpy as np
import torch
np = _require_numpy()
torch = _require_torch()

if 'A' in img.getbands():
alpha = np.array(img.getchannel('A')).astype(np.float32) / 255.0
Expand All @@ -70,8 +103,8 @@ def extract_mask_from_alpha(img):

def extract_mask_from_greyscale(img):
"""Extract a mask from a greyscale image or the red channel of an RGB image"""
import numpy as np
import torch
np = _require_numpy()
torch = _require_torch()

if img.mode == 'L':
# Image is already greyscale
Expand Down Expand Up @@ -1141,8 +1174,8 @@ def save_image_with_mask(self, images, mask, path: str, format: str = "png",
# Create PIL image (RGB)
pil_img = Image.fromarray(img_np)

# Create alpha channel image
alpha_img = Image.fromarray(alpha_np, mode='L')
# Create alpha channel image (avoid deprecated 'mode' kwarg in Pillow 13+)
alpha_img = Image.fromarray(alpha_np).convert("L")

# Convert to RGBA and add alpha channel
pil_img_rgba = pil_img.convert("RGBA")
Expand Down
9 changes: 8 additions & 1 deletion src/basic_data_handling/tensor_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
import torch
try:
import torch
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"basic_data_handling: Missing dependency 'torch'. It seems your ComfyUI installation is faulty."
"Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install torch`."
) from e

from inspect import cleandoc
from typing import Any

Expand Down
24 changes: 24 additions & 0 deletions tests/test_dict_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
DictCreateFromBoolean,
DictCreateFromFloat,
DictCreateFromInt,
DictCreateFromItemsDataList,
DictCreateFromItemsList,
DictCreateFromLists,
DictCreateFromString,
DictExcludeKeys,
Expand Down Expand Up @@ -110,6 +112,20 @@ def test_dict_create_from_string():
assert node.create() == ({},)


def test_dict_create_from_items_datalist():
node = DictCreateFromItemsDataList()
assert node.create_from_items(item=[("key1", "value1"), ("key2", "value2")]) == (_dict_x2,)
with pytest.raises(ValueError):
node.create_from_items(item=[("key1", "value1", "extra")])


def test_dict_create_from_items_list():
node = DictCreateFromItemsList()
assert node.create_from_items(items=[("key1", "value1"), ("key2", "value2")]) == (_dict_x2,)
with pytest.raises(ValueError):
node.create_from_items(items=[("key1", "value1", "extra")])


@pytest.mark.parametrize("dict_type", _tested_dict_types)
def test_dict_pop_random(dict_type):
node = DictPopRandom()
Expand Down Expand Up @@ -379,6 +395,14 @@ def test_dict_invert(dict_type, in_dict, out_dict, success, message):
assert type(result[0]) == dict_type, f"Wrong type: {message}"


def test_dict_invert_unhashable_values():
node = DictInvert()
my_dict = {"key1": [1], "key2": [2]}
result, success = node.invert(my_dict)
assert result == my_dict
assert success is False


def test_dict_create_from_lists():
node = DictCreateFromLists()
keys = ["key1", "key2", "key3"]
Expand Down
41 changes: 41 additions & 0 deletions tests/test_dynamic_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from src.basic_data_handling._dynamic_input import ContainsDynamicDict


def test_contains_dynamic_dict_basic_lookup():
d = ContainsDynamicDict({
'value': ('x', {'_dynamic': 'number'}),
'fixed': 'y',
})

# direct behavior
assert 'value' in d
assert d['value'] == ('x', {'_dynamic': 'number'})

# dynamic numeric key lookup
assert 'value1' in d
assert d['value1'] == ('x', {'_dynamic': 'number'})
assert 'value999' in d
assert d['value999'] == ('x', {'_dynamic': 'number'})

# non-dynamic key and fallback
assert 'fixed' in d
assert d['fixed'] == 'y'

# non-matching key should not be present
assert 'novalue' not in d
with pytest.raises(KeyError):
_ = d['novalue']


def test_contains_dynamic_dict_partial_prefix_not_numeric():
d = ContainsDynamicDict({'val': ('z', {'_dynamic': 'number'})})

assert 'val' in d
assert d['val'] == ('z', {'_dynamic': 'number'})
assert 'val1' in d
assert d['val1'] == ('z', {'_dynamic': 'number'})
assert 'valx' not in d
with pytest.raises(KeyError):
_ = d['valx']
21 changes: 20 additions & 1 deletion tests/test_path_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
PathSetExtension, PathNormalize, PathRelative, PathGlob, PathExpandVars, PathGetCwd,
PathListDir, PathIsAbsolute, PathCommonPrefix, PathLoadStringFile, PathSaveStringFile,
PathLoadImageRGB, PathSaveImageRGB, PathLoadImageRGBA, PathSaveImageRGBA,
PathLoadMaskFromAlpha, PathLoadMaskFromGreyscale,
PathLoadMaskFromAlpha, PathLoadMaskFromGreyscale, PathInputDir, PathOutputDir,
)


Expand Down Expand Up @@ -519,6 +519,25 @@ def test_path_glob(tmp_path):
assert len(no_match_result[0]) == 0


def test_path_input_output_dir(monkeypatch):
monkeypatch.setattr("src.basic_data_handling.path_nodes.get_input_directory", lambda: "/tmp/input")
monkeypatch.setattr("src.basic_data_handling.path_nodes.get_output_directory", lambda: "/tmp/output")

input_node = PathInputDir()
output_node = PathOutputDir()

assert input_node.execute() == ("/tmp/input",)
assert output_node.execute() == ("/tmp/output",)

# fallback to defaults if functions are not patched
monkeypatch.setattr("src.basic_data_handling.path_nodes.get_input_directory", lambda: "./")
monkeypatch.setattr("src.basic_data_handling.path_nodes.get_output_directory", lambda: "./")

assert PathInputDir().execute() == ("./",)
assert PathOutputDir().execute() == ("./",)



def test_path_expand_vars(monkeypatch):
node = PathExpandVars()
monkeypatch.setenv("TEST_VAR", "test_value")
Expand Down
Loading