diff --git a/docs/user_guide/data_scientist_guide/job_recipe.rst b/docs/user_guide/data_scientist_guide/job_recipe.rst index 7a2315739f..ea7ba7e121 100644 --- a/docs/user_guide/data_scientist_guide/job_recipe.rst +++ b/docs/user_guide/data_scientist_guide/job_recipe.rst @@ -105,6 +105,9 @@ Use ``initial_ckpt`` to specify a path to pre-trained model weights: the recipe. It only needs to exist on the **server** when the model is actually loaded during job execution. * **PyTorch requires model architecture**: For PyTorch, you must provide ``model`` (class instance or dict config) along with ``initial_ckpt``, because PyTorch checkpoints contain only weights, not architecture. + * **PyTorch update schema**: The server-side PyTorch model or checkpoint defines the accepted + ``state_dict()`` key schema for client updates. A client may return only the subset of keys it trained, + but every returned key must already exist in the server schema. New client-only keys are rejected. * **TensorFlow/Keras can use checkpoint alone**: Keras ``.h5`` or SavedModel formats contain both architecture and weights, so ``initial_ckpt`` can be used without ``model``. If ``model`` is provided, use a subclassed Keras class instance (or dict config). diff --git a/docs/user_guide/nvflare_cli/preflight_check.rst b/docs/user_guide/nvflare_cli/preflight_check.rst index 0c70e82bf9..db97931568 100644 --- a/docs/user_guide/nvflare_cli/preflight_check.rst +++ b/docs/user_guide/nvflare_cli/preflight_check.rst @@ -13,8 +13,8 @@ General Usage .. code-block:: - nvflare preflight_check -p PACKAGE_PATH - nvflare preflight_check --package_path PACKAGE_PATH + nvflare preflight-check -p PACKAGE_PATH + nvflare preflight-check --package_path PACKAGE_PATH This preflight check script should be run on each site's machine. The ``PACKAGE_PATH`` is the path to the folder that contains @@ -23,6 +23,9 @@ the package to be checked. After running the script, for the checks that pass, users will see "PASSED". The problem and how to fix it is reported for checks that fail. +Exit code ``0`` means all applicable checks passed. Exit code ``1`` means at least one applicable check failed. +Exit code ``4`` means the package path or package format is invalid. + Below are the scripts to run the preflight check on each type of site and the possible problems that may be reported. @@ -34,7 +37,7 @@ on the server site, a user should run: .. code-block:: - nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/server1 + nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/server1 The problems that may be reported: @@ -42,10 +45,10 @@ The problems that may be reported: :header: Checks,Problems,How to fix :widths: 15, 20, 25 - Check grpc port binding,Can't bind to address ({grpc_target_address}) for grpc service: {e},Please check the DNS and port. - Check admin port binding,Can't bind to address ({admin_host}:{admin_port}) for admin service: {e},Please check the DNS and port. - Check snapshot storage writable,Can't write to {self.snapshot_storage_root}: {e}.,Please check the user permission. - Check job storage writable, Can't write to {self.job_storage_root}: {e}.,Please check the user permission. + Check FL port binding,Can't bind to address ({host}:{port}): {e},Please check the DNS and port. + Check admin port binding,Can't bind to address ({host}:{port}): {e},Please check the DNS and port. + Check snapshot storage writable,Can't write to {snapshot_storage_root}: {e}.,Please check the user permission. + Check job storage writable,Can't write to {job_storage_root}: {e}.,Please check the user permission. Check dry run,Can't start successfully: {error},Please check the error message of dry run. @@ -59,7 +62,7 @@ So on the client site, a user will run: .. code-block:: - nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/site-1 + nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/site-1 The problems that may be reported: @@ -67,8 +70,8 @@ The problems that may be reported: :header: Checks,Problems,How to fix :widths: 15, 20, 25 - Check GRPC server available,Can't connect to grpc ({server_name}:{grpc_port}) server,Please check if server is up. - Check dry run, Can't start successfully: {error}, Please check the error message of dry run. + Check server available,Can't connect to {scheme} server ({host}:{port}),Please check if server is up. + Check dry run,Can't start successfully: {error},Please check the error message of dry run. Preflight check for admin consoles @@ -81,7 +84,7 @@ a user should run: .. code-block:: - nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/admin@nvidia.com + nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/admin@nvidia.com The problems that may be reported: @@ -89,5 +92,5 @@ The problems that may be reported: :header: Checks,Problems,How to fix :widths: 15, 20, 25 - Check GRPC server available,Can't connect to grpc ({server_name}:{grpc_port}) server,Please check if server is up. - Check dry run, Can't start successfully: {error}, Please check the error message of dry run. + Check server available,Can't connect to {scheme} server ({host}:{port}),Please check if server is up. + Check dry run,Can't start successfully: {error},Please check the error message of dry run. diff --git a/nvflare/app_opt/pt/model_persistence_format_manager.py b/nvflare/app_opt/pt/model_persistence_format_manager.py index 1a1e21eb60..3429fa22ee 100644 --- a/nvflare/app_opt/pt/model_persistence_format_manager.py +++ b/nvflare/app_opt/pt/model_persistence_format_manager.py @@ -140,9 +140,10 @@ def update(self, ml: ModelLearnable): introduce keys that do not already exist in the checkpoint. Notes: - Partial updates are supported: learned weights only need to cover the - subset of checkpoint keys that the client actually trained. The - original persisted weights for untouched keys are preserved. + The persisted checkpoint is the server schema for client updates. + Partial updates are supported: learned weights only need to cover a + subset of checkpoint keys that the client actually trained. New + client keys outside the server schema are rejected. """ err = validate_model_learnable(ml) if err: diff --git a/nvflare/app_opt/pt/utils.py b/nvflare/app_opt/pt/utils.py index 7438c7a748..fe47323e2e 100644 --- a/nvflare/app_opt/pt/utils.py +++ b/nvflare/app_opt/pt/utils.py @@ -219,7 +219,10 @@ def feed_vars(model: nn.Module, model_params): Notes: Empty payloads are treated as a no-op. Partial payloads are accepted as long as at least one key matches; unknown keys are ignored with a warning - instead of being applied to the local state dict. + instead of being applied to the local state dict. This is for loading a + received model into a local PyTorch module. Server-side validation of + learned client updates is handled by ``PTModelPersistenceFormatManager`` + and rejects keys outside the server checkpoint schema. """ _logger = get_module_logger(__name__, "AssignVariables") _logger.debug("AssignVariables...") diff --git a/nvflare/tool/package_checker/package_checker.py b/nvflare/tool/package_checker/package_checker.py index ed0298f605..e65b871cf1 100644 --- a/nvflare/tool/package_checker/package_checker.py +++ b/nvflare/tool/package_checker/package_checker.py @@ -16,12 +16,20 @@ import signal from abc import ABC, abstractmethod from collections import defaultdict +from enum import Enum, auto from subprocess import TimeoutExpired from nvflare.tool.package_checker.check_rule import CHECK_PASSED, CheckResult, CheckRule from nvflare.tool.package_checker.utils import run_command_in_subprocess, split_by_len +class CheckStatus(Enum): + PASS = auto() + PASS_WITH_CLEANUP = auto() + FAIL_WITH_CLEANUP = auto() + FAIL = auto() + + class PackageChecker(ABC): def __init__(self): self.report = defaultdict(list) @@ -67,15 +75,16 @@ def stop_dry_run(self, force: bool = True): print_human(f"killed dry run process output: {out}") print_human(f"killed dry run process err: {err}") - def check(self) -> int: + def check(self) -> CheckStatus: """Checks if the package is runnable on the current system. Returns: - 0: if no dry-run process started. - 1: if the dry-run process is started and return code is 0. - 2: if the dry-run process is started and return code is not 0. + CheckStatus.PASS: checks passed, no dry-run cleanup needed. + CheckStatus.PASS_WITH_CLEANUP: checks passed, dry-run process needs cleanup. + CheckStatus.FAIL_WITH_CLEANUP: checks failed, dry-run process needs cleanup. + CheckStatus.FAIL: checks failed, no dry-run cleanup needed. """ - ret_code = 0 + status = CheckStatus.PASS try: all_passed = True for rule in self.rules: @@ -96,23 +105,26 @@ def check(self) -> int: # check dry run if all_passed: - ret_code = self.check_dry_run() + status = self.check_dry_run() + else: + status = CheckStatus.FAIL except Exception as e: self.add_report( "Package Error", f"Exception happens in checking: {e}, this package is not in correct format.", "Please download a new package.", ) - finally: - return ret_code + status = CheckStatus.FAIL - def check_dry_run(self) -> int: + return status + + def check_dry_run(self) -> CheckStatus: """Runs dry run command. Returns: - 0: if no process started. - 1: if the process is started and return code is 0. - 2: if the process is started and return code is not 0. + CheckStatus.PASS_WITH_CLEANUP: dry run started successfully and needs cleanup. + CheckStatus.FAIL_WITH_CLEANUP: dry run started but failed and needs cleanup. + CheckStatus.FAIL: dry run could not be started. """ command = self.get_dry_run_command() dry_run_input = self.get_dry_run_inputs() @@ -130,12 +142,14 @@ def check_dry_run(self) -> int: CHECK_PASSED, "N/A", ) + return CheckStatus.PASS_WITH_CLEANUP else: self.add_report( "Check dry run", f"Can't start successfully: {out}", "Please check the error message of dry run.", ) + return CheckStatus.FAIL_WITH_CLEANUP except TimeoutExpired: os.killpg(process.pid, signal.SIGTERM) # Assumption, preflight check is focused on the connectivity, so we assume all sub-systems should @@ -149,15 +163,14 @@ def check_dry_run(self) -> int: CHECK_PASSED, "N/A", ) - - finally: - if process: - if process.returncode == 0: - return 1 - else: - return 2 - else: - return 0 + return CheckStatus.PASS_WITH_CLEANUP + except Exception as e: + self.add_report( + "Check dry run", + f"Can't start successfully: {e}", + "Please check the error message of dry run.", + ) + return CheckStatus.FAIL def add_report(self, check_name, problem_text: str, fix_text: str): self.report[self.package_path].append((check_name, problem_text, fix_text)) diff --git a/nvflare/tool/preflight_check.py b/nvflare/tool/preflight_check.py index a52defae92..2fef25331a 100644 --- a/nvflare/tool/preflight_check.py +++ b/nvflare/tool/preflight_check.py @@ -16,6 +16,7 @@ import os from nvflare.tool.package_checker import ClientPackageChecker, NVFlareConsolePackageChecker, ServerPackageChecker +from nvflare.tool.package_checker.package_checker import CheckStatus _preflight_parser = None @@ -69,13 +70,13 @@ def check_packages(args): for p in package_checkers: p.init(package_path=package_path) - ret_code = 0 + check_status = CheckStatus.PASS if p.should_be_checked(): - ret_code = p.check() + check_status = p.check() p.print_report() component_name = p.__class__.__name__.replace("PackageChecker", "").lower() - status = "pass" if ret_code == 0 else "fail" + status = "fail" if check_status in [CheckStatus.FAIL, CheckStatus.FAIL_WITH_CLEANUP] else "pass" if status == "fail": overall_pass = False check_result = {"component": component_name, "status": status} @@ -84,9 +85,9 @@ def check_packages(args): check_result["details"] = details checks.append(check_result) - if ret_code == 1: + if check_status == CheckStatus.PASS_WITH_CLEANUP: p.stop_dry_run(force=False) - elif ret_code == 2: + elif check_status == CheckStatus.FAIL_WITH_CLEANUP: p.stop_dry_run(force=True) overall = "pass" if overall_pass else "fail" diff --git a/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json b/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json index fe14726c8d..6777d27d4a 100644 --- a/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json +++ b/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json @@ -10,7 +10,12 @@ { "id": "persistor", "path": "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor", - "args": {} + "args": { + "model": { + "path": "simple_network.SimpleNetwork", + "args": {} + } + } }, { "id": "shareable_generator", diff --git a/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json b/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json index 49076bacbc..23506c4bf4 100644 --- a/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json +++ b/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json @@ -10,7 +10,12 @@ { "id": "persistor", "path": "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor", - "args": {} + "args": { + "model": { + "path": "simple_network.SimpleNetwork", + "args": {} + } + } }, { "id": "shareable_generator", diff --git a/tests/integration_test/data/jobs/hello-pt/app/custom/test_custom.py b/tests/integration_test/data/jobs/hello-pt/app/custom/test_custom.py deleted file mode 100644 index 6541949e39..0000000000 --- a/tests/integration_test/data/jobs/hello-pt/app/custom/test_custom.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from unittest.mock import patch - -import pytest -from cifar10trainer import Cifar10Trainer -from cifar10validator import Cifar10Validator - -from nvflare.apis.dxo import DXO, DataKind -from nvflare.apis.fl_constant import ReturnCode -from nvflare.apis.fl_context import FLContext -from nvflare.apis.signal import Signal - -TRAIN_TASK_NAME = "train" - - -@pytest.fixture() -def get_cifar_trainer(): - with patch.object(Cifar10Trainer, "_save_local_model") as mock_save: - with patch.object(Cifar10Trainer, "_load_local_model") as mock_load: - yield Cifar10Trainer(train_task_name=TRAIN_TASK_NAME, epochs=1) - - -class TestCifar10Trainer: - @pytest.mark.parametrize("num_rounds", [1, 3]) - def test_execute(self, get_cifar_trainer, num_rounds): - trainer = get_cifar_trainer - # just take first batch - iterator = iter(trainer._train_loader) - trainer._train_loader = [next(iterator)] - - dxo = DXO(data_kind=DataKind.WEIGHTS, data=trainer.model.state_dict()) - result = dxo.to_shareable() - for i in range(num_rounds): - result = trainer.execute(TRAIN_TASK_NAME, shareable=result, fl_ctx=FLContext(), abort_signal=Signal()) - assert result.get_return_code() == ReturnCode.OK - - @patch.object(Cifar10Trainer, "_save_local_model") - @patch.object(Cifar10Trainer, "_load_local_model") - def test_execute_rounds(self, mock_save, mock_load): - train_task_name = "train" - trainer = Cifar10Trainer(train_task_name=train_task_name, epochs=2) - # just take first batch - myitt = iter(trainer._train_loader) - trainer._train_loader = [next(myitt)] - - dxo = DXO(data_kind=DataKind.WEIGHTS, data=trainer.model.state_dict()) - result = dxo.to_shareable() - for i in range(3): - result = trainer.execute(train_task_name, shareable=result, fl_ctx=FLContext(), abort_signal=Signal()) - assert result.get_return_code() == ReturnCode.OK - - -class TestCifar10Validator: - def test_execute(self): - validate_task_name = "validate" - validator = Cifar10Validator(validate_task_name=validate_task_name) - # just take first batch - iterator = iter(validator._test_loader) - validator._test_loader = [next(iterator)] - - dxo = DXO(data_kind=DataKind.WEIGHTS, data=validator.model.state_dict()) - result = validator.execute( - validate_task_name, shareable=dxo.to_shareable(), fl_ctx=FLContext(), abort_signal=Signal() - ) - assert result.get_return_code() == ReturnCode.OK diff --git a/tests/integration_test/experiment_tracking_recipes_test.py b/tests/integration_test/experiment_tracking_recipes_test.py index 5bb7878c3e..f05511cad7 100644 --- a/tests/integration_test/experiment_tracking_recipes_test.py +++ b/tests/integration_test/experiment_tracking_recipes_test.py @@ -35,6 +35,7 @@ or run in a separate recipe test suite (takes ~1-2 minutes). """ +import importlib.util import os from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe @@ -57,6 +58,17 @@ def client_script_path(self): repo_root = os.path.dirname(os.path.dirname(test_dir)) return os.path.join(repo_root, "examples/advanced/experiment-tracking/tensorboard/client.py") + @property + def client_script_dir(self): + return os.path.dirname(self.client_script_path) + + def _make_model(self): + model_path = os.path.join(self.client_script_dir, "model.py") + spec = importlib.util.spec_from_file_location("_exp_tracking_model", model_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.SimpleNetwork() + def test_tensorboard_tracking_integration(self): """Test TensorBoard tracking can be added and job completes.""" import tempfile @@ -67,6 +79,7 @@ def test_tensorboard_tracking_integration(self): name="test_tensorboard", min_clients=2, num_rounds=1, + model=self._make_model(), train_script=self.client_script_path, ) @@ -89,6 +102,7 @@ def test_mlflow_tracking_integration(self): name="test_mlflow", min_clients=2, num_rounds=1, + model=self._make_model(), train_script=self.client_script_path, ) diff --git a/tests/integration_test/preflight_check_test.py b/tests/integration_test/preflight_check_test.py index 70bc578d85..3510e165fd 100644 --- a/tests/integration_test/preflight_check_test.py +++ b/tests/integration_test/preflight_check_test.py @@ -81,7 +81,10 @@ def _parse_preflight_output(output: bytes) -> dict[str, str]: def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str], check_type: str): - """Verify that actual checks match expected checks with detailed error messages. + """Verify that expected checks are present with detailed error messages. + + Preflight output may include additional checks, such as "Check dry run", when + earlier required checks pass. Those extra checks are allowed here. Args: actual_checks: Dictionary of actual check results @@ -97,15 +100,6 @@ def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str f"Actual checks: {list(actual_checks.keys())}" ) - # Check if there are unexpected checks - extra_checks = set(actual_checks.keys()) - set(expected_checks.keys()) - if extra_checks: - raise AssertionError( - f"{check_type} preflight check has unexpected checks: {extra_checks}\n" - f"Expected checks: {list(expected_checks.keys())}\n" - f"Actual checks: {list(actual_checks.keys())}" - ) - # Verify each check's status failed_checks = [] for check_name, expected_status in expected_checks.items(): @@ -120,14 +114,24 @@ def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str raise AssertionError(error_msg) -def _run_preflight_check_command_in_subprocess(package_path: str): +def _raise_preflight_command_error(command: str, returncode: int, output: bytes): + output_text = output.decode("utf-8", errors="replace") + raise AssertionError( + f"Preflight command failed with return code {returncode}: {command}\n" f"Output:\n{output_text}" + ) + + +def _run_preflight_check_command_in_subprocess(package_path: str, expect_success: bool = True): command = f"{sys.executable} -m {PREFLIGHT_CHECK_SCRIPT} -p {package_path}" print(f"Executing command {command} in subprocess") - output = subprocess.check_output(shlex.split(command)) - return output + process = subprocess.run(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False) + print(f"Preflight command return code: {process.returncode}") + if expect_success and process.returncode != 0: + _raise_preflight_command_error(command, process.returncode, process.stdout) + return process.stdout -def _run_preflight_check_command_in_pseudo_terminal(package_path: str): +def _run_preflight_check_command_in_pseudo_terminal(package_path: str, expect_success: bool = True): command = f"{sys.executable} -m {PREFLIGHT_CHECK_SCRIPT} -p {package_path}" print(f"Executing command {command} in pty") @@ -138,16 +142,20 @@ def read(fd): output.write(data) return data - pty.spawn(shlex.split(command), read) + status = pty.spawn(shlex.split(command), read) + returncode = os.waitstatus_to_exitcode(status) + print(f"Preflight command return code: {returncode}") + if expect_success and returncode != 0: + _raise_preflight_command_error(command, returncode, output.getvalue()) return output.getvalue() -def _run_preflight_check_command(package_path: str, method: str = "subprocess"): +def _run_preflight_check_command(package_path: str, method: str = "subprocess", expect_success: bool = True): if method == "subprocess": - return _run_preflight_check_command_in_subprocess(package_path) + return _run_preflight_check_command_in_subprocess(package_path, expect_success=expect_success) else: - return _run_preflight_check_command_in_pseudo_terminal(package_path) + return _run_preflight_check_command_in_pseudo_terminal(package_path, expect_success=expect_success) @pytest.fixture( @@ -183,7 +191,6 @@ def test_run_check_on_server(self, setup_system): "Check admin port binding": "PASSED", "Check snapshot storage writable": "PASSED", "Check job storage writable": "PASSED", - "Check dry run": "PASSED", } print(f"Server '{server_name}', expecting checks: {list(expected_checks.keys())}") @@ -205,7 +212,6 @@ def test_run_check_on_client(self, setup_system): expected_checks = { "Check server available": "PASSED", - "Check dry run": "PASSED", } print(f"Client '{client_name}', expecting checks: {list(expected_checks.keys())}") @@ -229,7 +235,6 @@ def test_run_check_on_admin_console(self, setup_system): expected_checks = { "Check server available": "PASSED", - "Check dry run": "PASSED", } print(f"Admin console, expecting checks: {list(expected_checks.keys())}") diff --git a/tests/integration_test/run_integration_tests.sh b/tests/integration_test/run_integration_tests.sh index 9f18eeab0c..b52c1d8274 100755 --- a/tests/integration_test/run_integration_tests.sh +++ b/tests/integration_test/run_integration_tests.sh @@ -48,10 +48,57 @@ while getopts ":m:dc" option; do done [[ "$no_args" == "true" ]] && usage cmd="$base_cmd" +hosts_backup="" + +has_localhost_aliases() +{ + python - <<'PY' +import socket +import sys + +for host in ("localhost0", "localhost1"): + try: + addresses = {info[4][0] for info in socket.getaddrinfo(host, None)} + except OSError: + sys.exit(1) + if "127.0.0.1" not in addresses: + sys.exit(1) +PY +} + +restore_localhost_aliases() +{ + if [[ -n "$hosts_backup" && -f "$hosts_backup" ]]; then + echo "Restoring original /etc/hosts file." + cp "$hosts_backup" /etc/hosts + rm -f "$hosts_backup" + fi +} + +ensure_localhost_aliases() +{ + if has_localhost_aliases; then + return + fi + + if [[ ! -w /etc/hosts ]]; then + echo "ERROR: localhost0 and localhost1 must resolve to 127.0.0.1 before running integration tests." + echo "Run ci/run_integration.sh, or add this line to /etc/hosts before running this script directly:" + echo "127.0.0.1 localhost0 localhost1" + exit 1 + fi + + echo "Adding DNS entries for integration test localhost aliases." + hosts_backup=$(mktemp) + cp /etc/hosts "$hosts_backup" + trap restore_localhost_aliases EXIT + echo "127.0.0.1 localhost0 localhost1" >> /etc/hosts +} run_preflight_check_test() { echo "Running preflight check integration tests." + ensure_localhost_aliases cmd="$cmd --junitxml=./integration_test.xml preflight_check_test.py" echo "$cmd" eval "$cmd" diff --git a/tests/integration_test/xgb_histogram_recipe_test.py b/tests/integration_test/xgb_histogram_recipe_test.py index efa435ea8f..489a74df8d 100644 --- a/tests/integration_test/xgb_histogram_recipe_test.py +++ b/tests/integration_test/xgb_histogram_recipe_test.py @@ -68,6 +68,13 @@ def load_data(self): return dtrain, dval +def _make_horizontal_per_site_config(num_clients, n_samples=50, n_features=5): + return { + f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=n_samples, n_features=n_features)} + for site_id in range(1, num_clients + 1) + } + + class TestXGBHorizontalRecipe: """Smoke tests for XGBHorizontalRecipe. @@ -78,13 +85,8 @@ class TestXGBHorizontalRecipe: def test_histogram_algorithm(self): """Test histogram algorithm completes successfully.""" with tempfile.TemporaryDirectory() as tmpdir: - env = SimEnv(num_clients=2, workspace_root=os.path.join(tmpdir, "test_histogram")) - - # Configure per-site data loaders - per_site_config = { - f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=50, n_features=5)} - for site_id in range(1, 3) - } + per_site_config = _make_horizontal_per_site_config(num_clients=2) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_histogram")) recipe = XGBHorizontalRecipe( name="test_histogram", @@ -108,8 +110,6 @@ def test_histogram_algorithm(self): def test_custom_xgb_params(self): """Test that custom XGBoost parameters are accepted.""" with tempfile.TemporaryDirectory() as tmpdir: - env = SimEnv(num_clients=2, workspace_root=os.path.join(tmpdir, "test_custom_params")) - custom_params = { "max_depth": 5, "eta": 0.05, @@ -119,11 +119,8 @@ def test_custom_xgb_params(self): "nthread": 4, } - # Configure per-site data loaders - per_site_config = { - f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=50, n_features=5)} - for site_id in range(1, 3) - } + per_site_config = _make_horizontal_per_site_config(num_clients=2) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_custom_params")) recipe = XGBHorizontalRecipe( name="test_custom_params", @@ -143,13 +140,8 @@ def test_multiple_clients(self): """Test recipe works with more than 2 clients.""" with tempfile.TemporaryDirectory() as tmpdir: num_clients = 5 - env = SimEnv(num_clients=num_clients, workspace_root=os.path.join(tmpdir, "test_multi_client")) - - # Configure per-site data loaders - per_site_config = { - f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=30, n_features=5)} - for site_id in range(1, num_clients + 1) - } + per_site_config = _make_horizontal_per_site_config(num_clients=num_clients, n_samples=30) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_multi_client")) recipe = XGBHorizontalRecipe( name="test_multi_client", diff --git a/tests/integration_test/xgb_vertical_recipe_test.py b/tests/integration_test/xgb_vertical_recipe_test.py index 83e0625193..0e20605ff6 100644 --- a/tests/integration_test/xgb_vertical_recipe_test.py +++ b/tests/integration_test/xgb_vertical_recipe_test.py @@ -76,6 +76,19 @@ def load_data(self): return dtrain, dval +def _make_vertical_per_site_config(num_clients=2, label_owner="site-1", n_samples=50, n_features=3): + return { + f"site-{site_id}": { + "data_loader": MockVerticalDataLoader( + has_labels=f"site-{site_id}" == label_owner, + n_samples=n_samples, + n_features=n_features, + ) + } + for site_id in range(1, num_clients + 1) + } + + class TestXGBVerticalRecipe: """Smoke tests for XGBVerticalRecipe. @@ -86,14 +99,8 @@ class TestXGBVerticalRecipe: def test_vertical_basic(self): """Test basic vertical XGBoost completes successfully.""" with tempfile.TemporaryDirectory() as tmpdir: - env = SimEnv(num_clients=2, workspace_root=os.path.join(tmpdir, "test_vertical")) - - # Configure per-site data loaders - # site-1 has labels, site-2 has features only - per_site_config = { - "site-1": {"data_loader": MockVerticalDataLoader(has_labels=True, n_samples=50, n_features=3)}, - "site-2": {"data_loader": MockVerticalDataLoader(has_labels=False, n_samples=50, n_features=3)}, - } + per_site_config = _make_vertical_per_site_config() + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_vertical")) recipe = XGBVerticalRecipe( name="test_vertical", @@ -123,6 +130,7 @@ def test_label_owner_validation(self): min_clients=2, num_rounds=1, label_owner="site-1", # Valid + per_site_config=_make_vertical_per_site_config(), ) assert recipe.label_owner == "site-1" @@ -133,6 +141,7 @@ def test_label_owner_validation(self): min_clients=2, num_rounds=1, label_owner="client1", # Invalid format + per_site_config=_make_vertical_per_site_config(), ) def test_custom_xgb_params(self): @@ -152,6 +161,7 @@ def test_custom_xgb_params(self): num_rounds=1, label_owner="site-1", xgb_params=custom_params, + per_site_config=_make_vertical_per_site_config(), ) # Verify params are stored @@ -161,15 +171,10 @@ def test_multiple_clients_vertical(self): """Test vertical recipe works with more than 2 clients.""" with tempfile.TemporaryDirectory() as tmpdir: num_clients = 3 - env = SimEnv(num_clients=num_clients, workspace_root=os.path.join(tmpdir, "test_multi_vertical")) - - # Configure per-site data loaders - only site-2 has labels - per_site_config = {} - for site_id in range(1, num_clients + 1): - has_labels = site_id == 2 # Only site-2 has labels - per_site_config[f"site-{site_id}"] = { - "data_loader": MockVerticalDataLoader(has_labels=has_labels, n_samples=30, n_features=2) - } + per_site_config = _make_vertical_per_site_config( + num_clients=num_clients, label_owner="site-2", n_samples=30, n_features=2 + ) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_multi_vertical")) recipe = XGBVerticalRecipe( name="test_multi_vertical", @@ -191,6 +196,7 @@ def test_in_process_parameter(self): num_rounds=1, label_owner="site-1", in_process=True, # Default + per_site_config=_make_vertical_per_site_config(), ) assert recipe.in_process is True @@ -200,6 +206,7 @@ def test_in_process_parameter(self): num_rounds=1, label_owner="site-1", in_process=False, + per_site_config=_make_vertical_per_site_config(), ) assert recipe2.in_process is False @@ -212,5 +219,6 @@ def test_model_file_name_parameter(self): num_rounds=1, label_owner="site-1", model_file_name=custom_name, + per_site_config=_make_vertical_per_site_config(), ) assert recipe.model_file_name == custom_name diff --git a/tests/unit_test/app_opt/pt/pt_param_validation_test.py b/tests/unit_test/app_opt/pt/pt_param_validation_test.py index 076c089f4b..7074e73188 100644 --- a/tests/unit_test/app_opt/pt/pt_param_validation_test.py +++ b/tests/unit_test/app_opt/pt/pt_param_validation_test.py @@ -83,7 +83,8 @@ def test_feed_vars_raises_on_shape_mismatch(): feed_vars(model, params) -def test_feed_vars_warns_on_unexpected_keys_when_some_match(caplog): +def test_feed_vars_filters_global_keys_not_in_local_model(caplog): + """Local model loading can ignore extra global keys in non-strict mode.""" model = SimpleNet() params = _clone_state_dict(model) params["model.fc.weight"] = torch.ones_like(model.state_dict()["fc.weight"]) @@ -97,6 +98,7 @@ def test_feed_vars_warns_on_unexpected_keys_when_some_match(caplog): def test_persistence_manager_accepts_partial_known_updates(): + """Client updates may contain any subset of the server checkpoint schema.""" model = SimpleNet() manager = PTModelPersistenceFormatManager(_clone_state_dict(model)) new_weight = torch.full_like(model.state_dict()["fc.weight"], 5.0) @@ -107,7 +109,8 @@ def test_persistence_manager_accepts_partial_known_updates(): assert torch.equal(manager.var_dict["fc.bias"], model.state_dict()["fc.bias"]) -def test_persistence_manager_rejects_unexpected_keys(): +def test_persistence_manager_rejects_client_keys_outside_server_schema(): + """Client updates may not introduce keys outside the server checkpoint schema.""" model = SimpleNet() manager = PTModelPersistenceFormatManager(_clone_state_dict(model)) weights = { diff --git a/tests/unit_test/tool/preflight_output_test.py b/tests/unit_test/tool/preflight_output_test.py index aa54c1c33e..e49038b2a4 100644 --- a/tests/unit_test/tool/preflight_output_test.py +++ b/tests/unit_test/tool/preflight_output_test.py @@ -18,6 +18,7 @@ import pytest from nvflare.tool import cli_output +from nvflare.tool.package_checker.package_checker import CheckStatus class TestPreflightOutput: @@ -66,7 +67,7 @@ def test_all_pass_json_envelope(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 0 + mock_checker.check.return_value = CheckStatus.PASS mock_checker.__class__.__name__ = "ServerPackageChecker" mock_checker.report = {str(pkg_path): []} # Make print_report() write a sentinel to stderr so we can assert routing @@ -110,7 +111,7 @@ def test_fail_exits_1(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 1 # fail + mock_checker.check.return_value = CheckStatus.FAIL_WITH_CLEANUP mock_checker.__class__.__name__ = "ServerPackageChecker" args = MagicMock() @@ -132,6 +133,70 @@ def test_fail_exits_1(self, capsys, tmp_path): assert data["exit_code"] == 1 assert data["data"]["overall"] == "fail" + def test_failed_check_without_dry_run_exits_1_without_cleanup(self, capsys, tmp_path): + """A required-rule failure exits 1 but does not try dry-run cleanup.""" + from nvflare.tool.preflight_check import check_packages + + pkg_path = tmp_path / "package2b" + pkg_path.mkdir() + (pkg_path / "startup").mkdir() + + mock_checker = MagicMock() + mock_checker.should_be_checked.return_value = True + mock_checker.check.return_value = CheckStatus.FAIL + mock_checker.__class__.__name__ = "ServerPackageChecker" + + args = MagicMock() + args.package_path = str(pkg_path) + args.output = "json" + + with ( + patch("nvflare.tool.preflight_check.ServerPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.ClientPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.NVFlareConsolePackageChecker", return_value=mock_checker), + ): + with pytest.raises(SystemExit) as exc_info: + check_packages(args) + + assert exc_info.value.code == 1 + captured = capsys.readouterr() + data = json.loads(captured.out) + assert data["exit_code"] == 1 + assert data["data"]["overall"] == "fail" + assert all(check["status"] == "fail" for check in data["data"]["checks"]) + assert mock_checker.stop_dry_run.call_count == 0 + + def test_successful_dry_run_cleanup_code_is_pass(self, capsys, tmp_path): + """Dry-run cleanup status means pass, not check failure.""" + from nvflare.tool.preflight_check import check_packages + + pkg_path = tmp_path / "package2" + pkg_path.mkdir() + (pkg_path / "startup").mkdir() + + mock_checker = MagicMock() + mock_checker.should_be_checked.return_value = True + mock_checker.check.return_value = CheckStatus.PASS_WITH_CLEANUP + mock_checker.__class__.__name__ = "ServerPackageChecker" + + args = MagicMock() + args.package_path = str(pkg_path) + args.output = "json" + + with ( + patch("nvflare.tool.preflight_check.ServerPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.ClientPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.NVFlareConsolePackageChecker", return_value=mock_checker), + ): + check_packages(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert data["exit_code"] == 0 + assert data["data"]["overall"] == "pass" + assert all(check["status"] == "pass" for check in data["data"]["checks"]) + assert mock_checker.stop_dry_run.call_count == 3 + def test_per_component_checks(self, capsys, tmp_path): """Each checker appears as a separate entry; stdout has only one JSON line.""" from nvflare.tool.preflight_check import check_packages @@ -142,7 +207,7 @@ def test_per_component_checks(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 0 + mock_checker.check.return_value = CheckStatus.PASS mock_checker.report = {str(pkg_path): []} args = MagicMock() @@ -195,7 +260,7 @@ def test_failed_checks_omit_empty_details(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 1 + mock_checker.check.return_value = CheckStatus.FAIL_WITH_CLEANUP mock_checker.__class__.__name__ = "ServerPackageChecker" args = MagicMock()