diff --git a/.vscode/launch.json b/.vscode/launch.json index b50b37a..27c9a00 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,6 +9,51 @@ "type": "debugpy", "request": "launch", "program": "${file}" + }, + { + "name": "Python: pytest all tests", + "type": "python", + "request": "launch", + "module": "pytest", + "args": ["tests"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal" + }, + { + "name": "Python: pytest test_data_list_nodes.py", + "type": "python", + "request": "launch", + "module": "pytest", + "args": ["tests/test_data_list_nodes.py"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal" + }, + { + "name": "Python: pytest tests.test_data_list_nodes.test_create_from_int", + "type": "python", + "request": "launch", + "module": "pytest", + "args": ["-k", "test_create_from_int", "tests/test_data_list_nodes.py"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal" + }, + { + "name": "Python: pytest tests.test_data_list_nodes.test_shuffle", + "type": "python", + "request": "launch", + "module": "pytest", + "args": ["-k", "test_shuffle", "tests/test_data_list_nodes.py"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal" + }, + { + "name": "Python: pytest tests.test_list_nodes.test_shuffle", + "type": "python", + "request": "launch", + "module": "pytest", + "args": ["-k", "test_shuffle", "tests/test_list_nodes.py"], + "cwd": "${workspaceFolder}", + "console": "integratedTerminal" } ] } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f0fd4d2..86a55ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "basic_data_handling" -version = "1.4.0" +version = "1.5.0" description = """Basic Python functions for manipulating data that every programmer is used to, lightweight with no additional dependencies. Supported data types: diff --git a/src/basic_data_handling/path_nodes.py b/src/basic_data_handling/path_nodes.py index ea6dbad..9274bea 100644 --- a/src/basic_data_handling/path_nodes.py +++ b/src/basic_data_handling/path_nodes.py @@ -25,11 +25,44 @@ def get_input_directory(): get_output_directory = get_input_directory +def _require_numpy(): + try: + import numpy as np + return np + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "basic_data_handling: Missing dependency 'numpy'. It seems your ComfyUI installation is faulty." + "Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install numpy`." + ) from e + + +def _require_torch(): + try: + import torch + return torch + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "basic_data_handling: Missing dependency 'torch'. It seems your ComfyUI installation is faulty." + "Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install torch`." + ) from e + + +def _require_pillow(): + try: + from PIL import Image, ImageOps + return Image, ImageOps + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "basic_data_handling: Missing dependency 'pillow'. It seems your ComfyUI installation is faulty." + "Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install pillow`." + ) from e + + # helper functions: def load_image_helper(path: str): """Helper function to load an image from a path""" - from PIL import Image, ImageOps + Image, ImageOps = _require_pillow() try: import pillow_jxl # noqa: F401 - imported but unused, kept for JPEG XL support except ModuleNotFoundError: @@ -49,8 +82,8 @@ def load_image_helper(path: str): def extract_mask_from_alpha(img): """Extract a mask from the alpha channel of an image""" - import numpy as np - import torch + np = _require_numpy() + torch = _require_torch() if 'A' in img.getbands(): alpha = np.array(img.getchannel('A')).astype(np.float32) / 255.0 @@ -70,8 +103,8 @@ def extract_mask_from_alpha(img): def extract_mask_from_greyscale(img): """Extract a mask from a greyscale image or the red channel of an RGB image""" - import numpy as np - import torch + np = _require_numpy() + torch = _require_torch() if img.mode == 'L': # Image is already greyscale @@ -1141,8 +1174,8 @@ def save_image_with_mask(self, images, mask, path: str, format: str = "png", # Create PIL image (RGB) pil_img = Image.fromarray(img_np) - # Create alpha channel image - alpha_img = Image.fromarray(alpha_np, mode='L') + # Create alpha channel image (avoid deprecated 'mode' kwarg in Pillow 13+) + alpha_img = Image.fromarray(alpha_np).convert("L") # Convert to RGBA and add alpha channel pil_img_rgba = pil_img.convert("RGBA") diff --git a/src/basic_data_handling/tensor_nodes.py b/src/basic_data_handling/tensor_nodes.py index 9b825d6..52b4683 100644 --- a/src/basic_data_handling/tensor_nodes.py +++ b/src/basic_data_handling/tensor_nodes.py @@ -1,4 +1,11 @@ -import torch +try: + import torch +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "basic_data_handling: Missing dependency 'torch'. It seems your ComfyUI installation is faulty." + "Only for development purposes: Install it with `pip install .[dev] numpy torch pillow` or `pip install torch`." + ) from e + from inspect import cleandoc from typing import Any diff --git a/tests/test_dict_nodes.py b/tests/test_dict_nodes.py index 81e2417..a47b7a2 100644 --- a/tests/test_dict_nodes.py +++ b/tests/test_dict_nodes.py @@ -7,6 +7,8 @@ DictCreateFromBoolean, DictCreateFromFloat, DictCreateFromInt, + DictCreateFromItemsDataList, + DictCreateFromItemsList, DictCreateFromLists, DictCreateFromString, DictExcludeKeys, @@ -110,6 +112,20 @@ def test_dict_create_from_string(): assert node.create() == ({},) +def test_dict_create_from_items_datalist(): + node = DictCreateFromItemsDataList() + assert node.create_from_items(item=[("key1", "value1"), ("key2", "value2")]) == (_dict_x2,) + with pytest.raises(ValueError): + node.create_from_items(item=[("key1", "value1", "extra")]) + + +def test_dict_create_from_items_list(): + node = DictCreateFromItemsList() + assert node.create_from_items(items=[("key1", "value1"), ("key2", "value2")]) == (_dict_x2,) + with pytest.raises(ValueError): + node.create_from_items(items=[("key1", "value1", "extra")]) + + @pytest.mark.parametrize("dict_type", _tested_dict_types) def test_dict_pop_random(dict_type): node = DictPopRandom() @@ -379,6 +395,14 @@ def test_dict_invert(dict_type, in_dict, out_dict, success, message): assert type(result[0]) == dict_type, f"Wrong type: {message}" +def test_dict_invert_unhashable_values(): + node = DictInvert() + my_dict = {"key1": [1], "key2": [2]} + result, success = node.invert(my_dict) + assert result == my_dict + assert success is False + + def test_dict_create_from_lists(): node = DictCreateFromLists() keys = ["key1", "key2", "key3"] diff --git a/tests/test_dynamic_input.py b/tests/test_dynamic_input.py new file mode 100644 index 0000000..ba7d90c --- /dev/null +++ b/tests/test_dynamic_input.py @@ -0,0 +1,41 @@ +import pytest + +from src.basic_data_handling._dynamic_input import ContainsDynamicDict + + +def test_contains_dynamic_dict_basic_lookup(): + d = ContainsDynamicDict({ + 'value': ('x', {'_dynamic': 'number'}), + 'fixed': 'y', + }) + + # direct behavior + assert 'value' in d + assert d['value'] == ('x', {'_dynamic': 'number'}) + + # dynamic numeric key lookup + assert 'value1' in d + assert d['value1'] == ('x', {'_dynamic': 'number'}) + assert 'value999' in d + assert d['value999'] == ('x', {'_dynamic': 'number'}) + + # non-dynamic key and fallback + assert 'fixed' in d + assert d['fixed'] == 'y' + + # non-matching key should not be present + assert 'novalue' not in d + with pytest.raises(KeyError): + _ = d['novalue'] + + +def test_contains_dynamic_dict_partial_prefix_not_numeric(): + d = ContainsDynamicDict({'val': ('z', {'_dynamic': 'number'})}) + + assert 'val' in d + assert d['val'] == ('z', {'_dynamic': 'number'}) + assert 'val1' in d + assert d['val1'] == ('z', {'_dynamic': 'number'}) + assert 'valx' not in d + with pytest.raises(KeyError): + _ = d['valx'] diff --git a/tests/test_path_nodes.py b/tests/test_path_nodes.py index 66c5396..7516d70 100644 --- a/tests/test_path_nodes.py +++ b/tests/test_path_nodes.py @@ -11,7 +11,7 @@ PathSetExtension, PathNormalize, PathRelative, PathGlob, PathExpandVars, PathGetCwd, PathListDir, PathIsAbsolute, PathCommonPrefix, PathLoadStringFile, PathSaveStringFile, PathLoadImageRGB, PathSaveImageRGB, PathLoadImageRGBA, PathSaveImageRGBA, - PathLoadMaskFromAlpha, PathLoadMaskFromGreyscale, + PathLoadMaskFromAlpha, PathLoadMaskFromGreyscale, PathInputDir, PathOutputDir, ) @@ -519,6 +519,25 @@ def test_path_glob(tmp_path): assert len(no_match_result[0]) == 0 +def test_path_input_output_dir(monkeypatch): + monkeypatch.setattr("src.basic_data_handling.path_nodes.get_input_directory", lambda: "/tmp/input") + monkeypatch.setattr("src.basic_data_handling.path_nodes.get_output_directory", lambda: "/tmp/output") + + input_node = PathInputDir() + output_node = PathOutputDir() + + assert input_node.execute() == ("/tmp/input",) + assert output_node.execute() == ("/tmp/output",) + + # fallback to defaults if functions are not patched + monkeypatch.setattr("src.basic_data_handling.path_nodes.get_input_directory", lambda: "./") + monkeypatch.setattr("src.basic_data_handling.path_nodes.get_output_directory", lambda: "./") + + assert PathInputDir().execute() == ("./",) + assert PathOutputDir().execute() == ("./",) + + + def test_path_expand_vars(monkeypatch): node = PathExpandVars() monkeypatch.setenv("TEST_VAR", "test_value")