From 2efe529adba14c5d811bc0adb4d34b5440eb19fe Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 30 Apr 2026 14:25:33 -0700 Subject: [PATCH 1/8] Refresh stale integration tests (cherry picked from commit 0a303f808a1a132bb57e74753fabd9399c3bd4b9) --- .../jobs/hello-pt/app/custom/test_custom.py | 79 ------------------- .../experiment_tracking_recipes_test.py | 37 +++++++++ .../integration_test/preflight_check_test.py | 73 ++++++++++++----- .../integration_test/run_integration_tests.sh | 47 +++++++++++ .../xgb_histogram_recipe_test.py | 34 +++----- .../xgb_vertical_recipe_test.py | 42 ++++++---- 6 files changed, 174 insertions(+), 138 deletions(-) delete mode 100644 tests/integration_test/data/jobs/hello-pt/app/custom/test_custom.py diff --git a/tests/integration_test/data/jobs/hello-pt/app/custom/test_custom.py b/tests/integration_test/data/jobs/hello-pt/app/custom/test_custom.py deleted file mode 100644 index 6541949e39..0000000000 --- a/tests/integration_test/data/jobs/hello-pt/app/custom/test_custom.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from unittest.mock import patch - -import pytest -from cifar10trainer import Cifar10Trainer -from cifar10validator import Cifar10Validator - -from nvflare.apis.dxo import DXO, DataKind -from nvflare.apis.fl_constant import ReturnCode -from nvflare.apis.fl_context import FLContext -from nvflare.apis.signal import Signal - -TRAIN_TASK_NAME = "train" - - -@pytest.fixture() -def get_cifar_trainer(): - with patch.object(Cifar10Trainer, "_save_local_model") as mock_save: - with patch.object(Cifar10Trainer, "_load_local_model") as mock_load: - yield Cifar10Trainer(train_task_name=TRAIN_TASK_NAME, epochs=1) - - -class TestCifar10Trainer: - @pytest.mark.parametrize("num_rounds", [1, 3]) - def test_execute(self, get_cifar_trainer, num_rounds): - trainer = get_cifar_trainer - # just take first batch - iterator = iter(trainer._train_loader) - trainer._train_loader = [next(iterator)] - - dxo = DXO(data_kind=DataKind.WEIGHTS, data=trainer.model.state_dict()) - result = dxo.to_shareable() - for i in range(num_rounds): - result = trainer.execute(TRAIN_TASK_NAME, shareable=result, fl_ctx=FLContext(), abort_signal=Signal()) - assert result.get_return_code() == ReturnCode.OK - - @patch.object(Cifar10Trainer, "_save_local_model") - @patch.object(Cifar10Trainer, "_load_local_model") - def test_execute_rounds(self, mock_save, mock_load): - train_task_name = "train" - trainer = Cifar10Trainer(train_task_name=train_task_name, epochs=2) - # just take first batch - myitt = iter(trainer._train_loader) - trainer._train_loader = [next(myitt)] - - dxo = DXO(data_kind=DataKind.WEIGHTS, data=trainer.model.state_dict()) - result = dxo.to_shareable() - for i in range(3): - result = trainer.execute(train_task_name, shareable=result, fl_ctx=FLContext(), abort_signal=Signal()) - assert result.get_return_code() == ReturnCode.OK - - -class TestCifar10Validator: - def test_execute(self): - validate_task_name = "validate" - validator = Cifar10Validator(validate_task_name=validate_task_name) - # just take first batch - iterator = iter(validator._test_loader) - validator._test_loader = [next(iterator)] - - dxo = DXO(data_kind=DataKind.WEIGHTS, data=validator.model.state_dict()) - result = validator.execute( - validate_task_name, shareable=dxo.to_shareable(), fl_ctx=FLContext(), abort_signal=Signal() - ) - assert result.get_return_code() == ReturnCode.OK diff --git a/tests/integration_test/experiment_tracking_recipes_test.py b/tests/integration_test/experiment_tracking_recipes_test.py index 5bb7878c3e..2d6aa283ac 100644 --- a/tests/integration_test/experiment_tracking_recipes_test.py +++ b/tests/integration_test/experiment_tracking_recipes_test.py @@ -36,6 +36,7 @@ """ import os +import sys from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe from nvflare.recipe import SimEnv @@ -57,6 +58,40 @@ def client_script_path(self): repo_root = os.path.dirname(os.path.dirname(test_dir)) return os.path.join(repo_root, "examples/advanced/experiment-tracking/tensorboard/client.py") + @property + def client_script_dir(self): + return os.path.dirname(self.client_script_path) + + def _is_client_script_module(self, module): + spec = getattr(module, "__spec__", None) + origin = getattr(spec, "origin", None) + if not origin or origin in {"built-in", "frozen"}: + return False + + client_script_dir = os.path.realpath(self.client_script_dir) + module_origin = os.path.realpath(origin) + return module_origin.startswith(client_script_dir + os.sep) + + def _make_model(self): + original_sys_path = list(sys.path) + previous_modules = { + name: module + for name, module in list(sys.modules.items()) + if name == "model" or self._is_client_script_module(module) + } + sys.modules.pop("model", None) + sys.path.insert(0, self.client_script_dir) + try: + from model import SimpleNetwork + + return SimpleNetwork() + finally: + for name, module in list(sys.modules.items()): + if name == "model" or self._is_client_script_module(module): + sys.modules.pop(name, None) + sys.modules.update(previous_modules) + sys.path[:] = original_sys_path + def test_tensorboard_tracking_integration(self): """Test TensorBoard tracking can be added and job completes.""" import tempfile @@ -67,6 +102,7 @@ def test_tensorboard_tracking_integration(self): name="test_tensorboard", min_clients=2, num_rounds=1, + model=self._make_model(), train_script=self.client_script_path, ) @@ -89,6 +125,7 @@ def test_mlflow_tracking_integration(self): name="test_mlflow", min_clients=2, num_rounds=1, + model=self._make_model(), train_script=self.client_script_path, ) diff --git a/tests/integration_test/preflight_check_test.py b/tests/integration_test/preflight_check_test.py index 70bc578d85..905b99fc0d 100644 --- a/tests/integration_test/preflight_check_test.py +++ b/tests/integration_test/preflight_check_test.py @@ -27,7 +27,7 @@ from tests.integration_test.src.constants import PREFLIGHT_CHECK_SCRIPT TEST_CASES = [ - {"project_yaml": "data/projects/dummy.yml", "admin_name": "super@test.org"}, + {"project_yaml": "data/projects/dummy.yml", "admin_name": "super@test.org", "is_dummy_overseer": True}, ] @@ -81,7 +81,10 @@ def _parse_preflight_output(output: bytes) -> dict[str, str]: def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str], check_type: str): - """Verify that actual checks match expected checks with detailed error messages. + """Verify that expected checks are present with detailed error messages. + + Preflight output may include additional checks, such as "Check dry run", when + earlier required checks pass. Those extra checks are allowed here. Args: actual_checks: Dictionary of actual check results @@ -97,15 +100,6 @@ def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str f"Actual checks: {list(actual_checks.keys())}" ) - # Check if there are unexpected checks - extra_checks = set(actual_checks.keys()) - set(expected_checks.keys()) - if extra_checks: - raise AssertionError( - f"{check_type} preflight check has unexpected checks: {extra_checks}\n" - f"Expected checks: {list(expected_checks.keys())}\n" - f"Actual checks: {list(actual_checks.keys())}" - ) - # Verify each check's status failed_checks = [] for check_name, expected_status in expected_checks.items(): @@ -123,8 +117,8 @@ def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str def _run_preflight_check_command_in_subprocess(package_path: str): command = f"{sys.executable} -m {PREFLIGHT_CHECK_SCRIPT} -p {package_path}" print(f"Executing command {command} in subprocess") - output = subprocess.check_output(shlex.split(command)) - return output + process = subprocess.run(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False) + return process.stdout def _run_preflight_check_command_in_pseudo_terminal(package_path: str): @@ -156,6 +150,7 @@ def _run_preflight_check_command(package_path: str, method: str = "subprocess"): def setup_system(request): test_config = request.param project_yaml_path = test_config["project_yaml"] + is_dummy_overseer = test_config["is_dummy_overseer"] admin_name = test_config["admin_name"] if not os.path.isfile(project_yaml_path): @@ -166,24 +161,27 @@ def setup_system(request): admin_folder_root = os.path.abspath(os.path.join(workspace_root, admin_name)) - return site_launcher, admin_folder_root + return site_launcher, is_dummy_overseer, admin_folder_root @pytest.mark.xdist_group(name="preflight_tests_group") class TestPreflightCheck: - def test_run_check_on_server(self, setup_system): - site_launcher, _ = setup_system + def test_run_check_on_server_after_overseer_start(self, setup_system): + site_launcher, is_dummy_overseer, _ = setup_system try: + if not is_dummy_overseer: + site_launcher.start_overseer() + # preflight-check on server for server_name, server_props in site_launcher.server_properties.items(): output = _run_preflight_check_command(package_path=server_props.root_dir) actual_checks = _parse_preflight_output(output) + # Get expected checks based on communication scheme expected_checks = { "Check FL port binding": "PASSED", "Check admin port binding": "PASSED", "Check snapshot storage writable": "PASSED", "Check job storage writable": "PASSED", - "Check dry run": "PASSED", } print(f"Server '{server_name}', expecting checks: {list(expected_checks.keys())}") @@ -193,23 +191,55 @@ def test_run_check_on_server(self, setup_system): site_launcher.stop_all_sites() site_launcher.cleanup() + def test_run_check_on_server_before_overseer_start(self, setup_system): + site_launcher, is_dummy_overseer, _ = setup_system + try: + # preflight-check on server + for server_name, server_props in site_launcher.server_properties.items(): + output = _run_preflight_check_command(package_path=server_props.root_dir) + actual_checks = _parse_preflight_output(output) + + # Get expected checks based on communication scheme + expected_checks = { + "Check FL port binding": "PASSED", + "Check admin port binding": "PASSED", + "Check snapshot storage writable": "PASSED", + "Check job storage writable": "PASSED", + } + + if is_dummy_overseer: + print(f"Server '{server_name}', expecting checks: {list(expected_checks.keys())}") + _verify_checks(actual_checks, expected_checks, f"Server '{server_name}'") + else: + assert any(status != "PASSED" for status in actual_checks.values()), ( + f"Server '{server_name}' preflight check expected some failures before overseer start, " + f"Actual checks: {actual_checks}" + ) + finally: + site_launcher.stop_all_sites() + site_launcher.cleanup() + def test_run_check_on_client(self, setup_system): - site_launcher, _ = setup_system + site_launcher, is_dummy_overseer, _ = setup_system try: + if not is_dummy_overseer: + site_launcher.start_overseer() site_launcher.start_servers() time.sleep(SERVER_START_TIME) + # preflight-check on clients for client_name, client_props in site_launcher.client_properties.items(): output = _run_preflight_check_command(package_path=client_props.root_dir) actual_checks = _parse_preflight_output(output) + # Get expected checks based on communication scheme expected_checks = { "Check server available": "PASSED", - "Check dry run": "PASSED", } print(f"Client '{client_name}', expecting checks: {list(expected_checks.keys())}") + # Verify checks match expectations based on communication scheme _verify_checks(actual_checks, expected_checks, f"Client '{client_name}'") except Exception: raise @@ -218,8 +248,10 @@ def test_run_check_on_client(self, setup_system): site_launcher.cleanup() def test_run_check_on_admin_console(self, setup_system): - site_launcher, admin_folder_root = setup_system + site_launcher, is_dummy_overseer, admin_folder_root = setup_system try: + if not is_dummy_overseer: + site_launcher.start_overseer() site_launcher.start_servers() time.sleep(SERVER_START_TIME) @@ -229,7 +261,6 @@ def test_run_check_on_admin_console(self, setup_system): expected_checks = { "Check server available": "PASSED", - "Check dry run": "PASSED", } print(f"Admin console, expecting checks: {list(expected_checks.keys())}") diff --git a/tests/integration_test/run_integration_tests.sh b/tests/integration_test/run_integration_tests.sh index 9f18eeab0c..b52c1d8274 100755 --- a/tests/integration_test/run_integration_tests.sh +++ b/tests/integration_test/run_integration_tests.sh @@ -48,10 +48,57 @@ while getopts ":m:dc" option; do done [[ "$no_args" == "true" ]] && usage cmd="$base_cmd" +hosts_backup="" + +has_localhost_aliases() +{ + python - <<'PY' +import socket +import sys + +for host in ("localhost0", "localhost1"): + try: + addresses = {info[4][0] for info in socket.getaddrinfo(host, None)} + except OSError: + sys.exit(1) + if "127.0.0.1" not in addresses: + sys.exit(1) +PY +} + +restore_localhost_aliases() +{ + if [[ -n "$hosts_backup" && -f "$hosts_backup" ]]; then + echo "Restoring original /etc/hosts file." + cp "$hosts_backup" /etc/hosts + rm -f "$hosts_backup" + fi +} + +ensure_localhost_aliases() +{ + if has_localhost_aliases; then + return + fi + + if [[ ! -w /etc/hosts ]]; then + echo "ERROR: localhost0 and localhost1 must resolve to 127.0.0.1 before running integration tests." + echo "Run ci/run_integration.sh, or add this line to /etc/hosts before running this script directly:" + echo "127.0.0.1 localhost0 localhost1" + exit 1 + fi + + echo "Adding DNS entries for integration test localhost aliases." + hosts_backup=$(mktemp) + cp /etc/hosts "$hosts_backup" + trap restore_localhost_aliases EXIT + echo "127.0.0.1 localhost0 localhost1" >> /etc/hosts +} run_preflight_check_test() { echo "Running preflight check integration tests." + ensure_localhost_aliases cmd="$cmd --junitxml=./integration_test.xml preflight_check_test.py" echo "$cmd" eval "$cmd" diff --git a/tests/integration_test/xgb_histogram_recipe_test.py b/tests/integration_test/xgb_histogram_recipe_test.py index efa435ea8f..489a74df8d 100644 --- a/tests/integration_test/xgb_histogram_recipe_test.py +++ b/tests/integration_test/xgb_histogram_recipe_test.py @@ -68,6 +68,13 @@ def load_data(self): return dtrain, dval +def _make_horizontal_per_site_config(num_clients, n_samples=50, n_features=5): + return { + f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=n_samples, n_features=n_features)} + for site_id in range(1, num_clients + 1) + } + + class TestXGBHorizontalRecipe: """Smoke tests for XGBHorizontalRecipe. @@ -78,13 +85,8 @@ class TestXGBHorizontalRecipe: def test_histogram_algorithm(self): """Test histogram algorithm completes successfully.""" with tempfile.TemporaryDirectory() as tmpdir: - env = SimEnv(num_clients=2, workspace_root=os.path.join(tmpdir, "test_histogram")) - - # Configure per-site data loaders - per_site_config = { - f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=50, n_features=5)} - for site_id in range(1, 3) - } + per_site_config = _make_horizontal_per_site_config(num_clients=2) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_histogram")) recipe = XGBHorizontalRecipe( name="test_histogram", @@ -108,8 +110,6 @@ def test_histogram_algorithm(self): def test_custom_xgb_params(self): """Test that custom XGBoost parameters are accepted.""" with tempfile.TemporaryDirectory() as tmpdir: - env = SimEnv(num_clients=2, workspace_root=os.path.join(tmpdir, "test_custom_params")) - custom_params = { "max_depth": 5, "eta": 0.05, @@ -119,11 +119,8 @@ def test_custom_xgb_params(self): "nthread": 4, } - # Configure per-site data loaders - per_site_config = { - f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=50, n_features=5)} - for site_id in range(1, 3) - } + per_site_config = _make_horizontal_per_site_config(num_clients=2) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_custom_params")) recipe = XGBHorizontalRecipe( name="test_custom_params", @@ -143,13 +140,8 @@ def test_multiple_clients(self): """Test recipe works with more than 2 clients.""" with tempfile.TemporaryDirectory() as tmpdir: num_clients = 5 - env = SimEnv(num_clients=num_clients, workspace_root=os.path.join(tmpdir, "test_multi_client")) - - # Configure per-site data loaders - per_site_config = { - f"site-{site_id}": {"data_loader": MockXGBDataLoader(n_samples=30, n_features=5)} - for site_id in range(1, num_clients + 1) - } + per_site_config = _make_horizontal_per_site_config(num_clients=num_clients, n_samples=30) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_multi_client")) recipe = XGBHorizontalRecipe( name="test_multi_client", diff --git a/tests/integration_test/xgb_vertical_recipe_test.py b/tests/integration_test/xgb_vertical_recipe_test.py index 83e0625193..0e20605ff6 100644 --- a/tests/integration_test/xgb_vertical_recipe_test.py +++ b/tests/integration_test/xgb_vertical_recipe_test.py @@ -76,6 +76,19 @@ def load_data(self): return dtrain, dval +def _make_vertical_per_site_config(num_clients=2, label_owner="site-1", n_samples=50, n_features=3): + return { + f"site-{site_id}": { + "data_loader": MockVerticalDataLoader( + has_labels=f"site-{site_id}" == label_owner, + n_samples=n_samples, + n_features=n_features, + ) + } + for site_id in range(1, num_clients + 1) + } + + class TestXGBVerticalRecipe: """Smoke tests for XGBVerticalRecipe. @@ -86,14 +99,8 @@ class TestXGBVerticalRecipe: def test_vertical_basic(self): """Test basic vertical XGBoost completes successfully.""" with tempfile.TemporaryDirectory() as tmpdir: - env = SimEnv(num_clients=2, workspace_root=os.path.join(tmpdir, "test_vertical")) - - # Configure per-site data loaders - # site-1 has labels, site-2 has features only - per_site_config = { - "site-1": {"data_loader": MockVerticalDataLoader(has_labels=True, n_samples=50, n_features=3)}, - "site-2": {"data_loader": MockVerticalDataLoader(has_labels=False, n_samples=50, n_features=3)}, - } + per_site_config = _make_vertical_per_site_config() + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_vertical")) recipe = XGBVerticalRecipe( name="test_vertical", @@ -123,6 +130,7 @@ def test_label_owner_validation(self): min_clients=2, num_rounds=1, label_owner="site-1", # Valid + per_site_config=_make_vertical_per_site_config(), ) assert recipe.label_owner == "site-1" @@ -133,6 +141,7 @@ def test_label_owner_validation(self): min_clients=2, num_rounds=1, label_owner="client1", # Invalid format + per_site_config=_make_vertical_per_site_config(), ) def test_custom_xgb_params(self): @@ -152,6 +161,7 @@ def test_custom_xgb_params(self): num_rounds=1, label_owner="site-1", xgb_params=custom_params, + per_site_config=_make_vertical_per_site_config(), ) # Verify params are stored @@ -161,15 +171,10 @@ def test_multiple_clients_vertical(self): """Test vertical recipe works with more than 2 clients.""" with tempfile.TemporaryDirectory() as tmpdir: num_clients = 3 - env = SimEnv(num_clients=num_clients, workspace_root=os.path.join(tmpdir, "test_multi_vertical")) - - # Configure per-site data loaders - only site-2 has labels - per_site_config = {} - for site_id in range(1, num_clients + 1): - has_labels = site_id == 2 # Only site-2 has labels - per_site_config[f"site-{site_id}"] = { - "data_loader": MockVerticalDataLoader(has_labels=has_labels, n_samples=30, n_features=2) - } + per_site_config = _make_vertical_per_site_config( + num_clients=num_clients, label_owner="site-2", n_samples=30, n_features=2 + ) + env = SimEnv(clients=list(per_site_config), workspace_root=os.path.join(tmpdir, "test_multi_vertical")) recipe = XGBVerticalRecipe( name="test_multi_vertical", @@ -191,6 +196,7 @@ def test_in_process_parameter(self): num_rounds=1, label_owner="site-1", in_process=True, # Default + per_site_config=_make_vertical_per_site_config(), ) assert recipe.in_process is True @@ -200,6 +206,7 @@ def test_in_process_parameter(self): num_rounds=1, label_owner="site-1", in_process=False, + per_site_config=_make_vertical_per_site_config(), ) assert recipe2.in_process is False @@ -212,5 +219,6 @@ def test_model_file_name_parameter(self): num_rounds=1, label_owner="site-1", model_file_name=custom_name, + per_site_config=_make_vertical_per_site_config(), ) assert recipe.model_file_name == custom_name From 4bc9e60fc928fd126c76d2c89d7d10b7cf492f3a Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Fri, 1 May 2026 12:01:55 -0700 Subject: [PATCH 2/8] Fix hello-pt PyTorch persistor config (cherry picked from commit 1f6fe68cc074c6329eace83cc39ee191396df65f) --- .../data/jobs/hello-pt/app/config/config_fed_server.json | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json b/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json index 49076bacbc..23506c4bf4 100644 --- a/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json +++ b/tests/integration_test/data/jobs/hello-pt/app/config/config_fed_server.json @@ -10,7 +10,12 @@ { "id": "persistor", "path": "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor", - "args": {} + "args": { + "model": { + "path": "simple_network.SimpleNetwork", + "args": {} + } + } }, { "id": "shareable_generator", From ac391bf4dc4488e35edaee77ea8610ccd7b83c4c Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Fri, 1 May 2026 12:46:23 -0700 Subject: [PATCH 3/8] Fix pt_init_client PyTorch persistor config (cherry picked from commit 16c81410fb9082cf49a90a0b53937b90cd3d861c) --- .../apps/pt_init_client/app/config/config_fed_server.json | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json b/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json index fe14726c8d..6777d27d4a 100644 --- a/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json +++ b/tests/integration_test/data/apps/pt_init_client/app/config/config_fed_server.json @@ -10,7 +10,12 @@ { "id": "persistor", "path": "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor", - "args": {} + "args": { + "model": { + "path": "simple_network.SimpleNetwork", + "args": {} + } + } }, { "id": "shareable_generator", From 56452fe5a75aa0196b7b0bc2035f52f160a1f3fe Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Fri, 1 May 2026 13:31:20 -0700 Subject: [PATCH 4/8] Address integration test review feedback (cherry picked from commit 2642d7db7e0544a0289b97504386f033fed9ee11) --- .../experiment_tracking_recipes_test.py | 35 ++++--------------- .../integration_test/preflight_check_test.py | 31 ++++++++++++---- 2 files changed, 30 insertions(+), 36 deletions(-) diff --git a/tests/integration_test/experiment_tracking_recipes_test.py b/tests/integration_test/experiment_tracking_recipes_test.py index 2d6aa283ac..f05511cad7 100644 --- a/tests/integration_test/experiment_tracking_recipes_test.py +++ b/tests/integration_test/experiment_tracking_recipes_test.py @@ -35,8 +35,8 @@ or run in a separate recipe test suite (takes ~1-2 minutes). """ +import importlib.util import os -import sys from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe from nvflare.recipe import SimEnv @@ -62,35 +62,12 @@ def client_script_path(self): def client_script_dir(self): return os.path.dirname(self.client_script_path) - def _is_client_script_module(self, module): - spec = getattr(module, "__spec__", None) - origin = getattr(spec, "origin", None) - if not origin or origin in {"built-in", "frozen"}: - return False - - client_script_dir = os.path.realpath(self.client_script_dir) - module_origin = os.path.realpath(origin) - return module_origin.startswith(client_script_dir + os.sep) - def _make_model(self): - original_sys_path = list(sys.path) - previous_modules = { - name: module - for name, module in list(sys.modules.items()) - if name == "model" or self._is_client_script_module(module) - } - sys.modules.pop("model", None) - sys.path.insert(0, self.client_script_dir) - try: - from model import SimpleNetwork - - return SimpleNetwork() - finally: - for name, module in list(sys.modules.items()): - if name == "model" or self._is_client_script_module(module): - sys.modules.pop(name, None) - sys.modules.update(previous_modules) - sys.path[:] = original_sys_path + model_path = os.path.join(self.client_script_dir, "model.py") + spec = importlib.util.spec_from_file_location("_exp_tracking_model", model_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.SimpleNetwork() def test_tensorboard_tracking_integration(self): """Test TensorBoard tracking can be added and job completes.""" diff --git a/tests/integration_test/preflight_check_test.py b/tests/integration_test/preflight_check_test.py index 905b99fc0d..5ce68217de 100644 --- a/tests/integration_test/preflight_check_test.py +++ b/tests/integration_test/preflight_check_test.py @@ -114,14 +114,25 @@ def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str raise AssertionError(error_msg) -def _run_preflight_check_command_in_subprocess(package_path: str): +def _raise_preflight_command_error(command: str, returncode: int, output: bytes): + output_text = output.decode("utf-8", errors="replace") + raise AssertionError( + f"Preflight command failed with return code {returncode}: {command}\n" + f"Output:\n{output_text}" + ) + + +def _run_preflight_check_command_in_subprocess(package_path: str, expect_success: bool = True): command = f"{sys.executable} -m {PREFLIGHT_CHECK_SCRIPT} -p {package_path}" print(f"Executing command {command} in subprocess") process = subprocess.run(shlex.split(command), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, check=False) + print(f"Preflight command return code: {process.returncode}") + if expect_success and process.returncode != 0: + _raise_preflight_command_error(command, process.returncode, process.stdout) return process.stdout -def _run_preflight_check_command_in_pseudo_terminal(package_path: str): +def _run_preflight_check_command_in_pseudo_terminal(package_path: str, expect_success: bool = True): command = f"{sys.executable} -m {PREFLIGHT_CHECK_SCRIPT} -p {package_path}" print(f"Executing command {command} in pty") @@ -132,16 +143,20 @@ def read(fd): output.write(data) return data - pty.spawn(shlex.split(command), read) + status = pty.spawn(shlex.split(command), read) + returncode = os.waitstatus_to_exitcode(status) + print(f"Preflight command return code: {returncode}") + if expect_success and returncode != 0: + _raise_preflight_command_error(command, returncode, output.getvalue()) return output.getvalue() -def _run_preflight_check_command(package_path: str, method: str = "subprocess"): +def _run_preflight_check_command(package_path: str, method: str = "subprocess", expect_success: bool = True): if method == "subprocess": - return _run_preflight_check_command_in_subprocess(package_path) + return _run_preflight_check_command_in_subprocess(package_path, expect_success=expect_success) else: - return _run_preflight_check_command_in_pseudo_terminal(package_path) + return _run_preflight_check_command_in_pseudo_terminal(package_path, expect_success=expect_success) @pytest.fixture( @@ -173,7 +188,9 @@ def test_run_check_on_server_after_overseer_start(self, setup_system): site_launcher.start_overseer() # preflight-check on server for server_name, server_props in site_launcher.server_properties.items(): - output = _run_preflight_check_command(package_path=server_props.root_dir) + output = _run_preflight_check_command( + package_path=server_props.root_dir, expect_success=is_dummy_overseer + ) actual_checks = _parse_preflight_output(output) # Get expected checks based on communication scheme From 13f7f9e992f39d0c911cfb68d61db9e90ac210ad Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Fri, 1 May 2026 13:45:58 -0700 Subject: [PATCH 5/8] Format preflight integration test (cherry picked from commit 0b31d98af299754e824eb23537032efe237c697f) --- tests/integration_test/preflight_check_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/integration_test/preflight_check_test.py b/tests/integration_test/preflight_check_test.py index 5ce68217de..4015c50352 100644 --- a/tests/integration_test/preflight_check_test.py +++ b/tests/integration_test/preflight_check_test.py @@ -117,8 +117,7 @@ def _verify_checks(actual_checks: dict[str, str], expected_checks: dict[str, str def _raise_preflight_command_error(command: str, returncode: int, output: bytes): output_text = output.decode("utf-8", errors="replace") raise AssertionError( - f"Preflight command failed with return code {returncode}: {command}\n" - f"Output:\n{output_text}" + f"Preflight command failed with return code {returncode}: {command}\n" f"Output:\n{output_text}" ) From fbe29f2fe57680c43dd8cde9db23ccb9e5cbacd9 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Tue, 5 May 2026 21:15:16 -0700 Subject: [PATCH 6/8] Clarify preflight check status handling (cherry picked from commit a0db392039592b79507c67874610396ac491a138) --- .../nvflare_cli/preflight_check.rst | 13 ++-- .../tool/package_checker/package_checker.py | 55 ++++++++------ nvflare/tool/preflight_check.py | 11 +-- tests/unit_test/tool/preflight_output_test.py | 73 ++++++++++++++++++- 4 files changed, 117 insertions(+), 35 deletions(-) diff --git a/docs/user_guide/nvflare_cli/preflight_check.rst b/docs/user_guide/nvflare_cli/preflight_check.rst index 0c70e82bf9..b8376dd3ef 100644 --- a/docs/user_guide/nvflare_cli/preflight_check.rst +++ b/docs/user_guide/nvflare_cli/preflight_check.rst @@ -13,8 +13,8 @@ General Usage .. code-block:: - nvflare preflight_check -p PACKAGE_PATH - nvflare preflight_check --package_path PACKAGE_PATH + nvflare preflight-check -p PACKAGE_PATH + nvflare preflight-check --package_path PACKAGE_PATH This preflight check script should be run on each site's machine. The ``PACKAGE_PATH`` is the path to the folder that contains @@ -23,6 +23,9 @@ the package to be checked. After running the script, for the checks that pass, users will see "PASSED". The problem and how to fix it is reported for checks that fail. +Exit code ``0`` means all applicable checks passed. Exit code ``1`` means at least one applicable check failed. +Exit code ``4`` means the package path or package format is invalid. + Below are the scripts to run the preflight check on each type of site and the possible problems that may be reported. @@ -34,7 +37,7 @@ on the server site, a user should run: .. code-block:: - nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/server1 + nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/server1 The problems that may be reported: @@ -59,7 +62,7 @@ So on the client site, a user will run: .. code-block:: - nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/site-1 + nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/site-1 The problems that may be reported: @@ -81,7 +84,7 @@ a user should run: .. code-block:: - nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/admin@nvidia.com + nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/admin@nvidia.com The problems that may be reported: diff --git a/nvflare/tool/package_checker/package_checker.py b/nvflare/tool/package_checker/package_checker.py index ed0298f605..e65b871cf1 100644 --- a/nvflare/tool/package_checker/package_checker.py +++ b/nvflare/tool/package_checker/package_checker.py @@ -16,12 +16,20 @@ import signal from abc import ABC, abstractmethod from collections import defaultdict +from enum import Enum, auto from subprocess import TimeoutExpired from nvflare.tool.package_checker.check_rule import CHECK_PASSED, CheckResult, CheckRule from nvflare.tool.package_checker.utils import run_command_in_subprocess, split_by_len +class CheckStatus(Enum): + PASS = auto() + PASS_WITH_CLEANUP = auto() + FAIL_WITH_CLEANUP = auto() + FAIL = auto() + + class PackageChecker(ABC): def __init__(self): self.report = defaultdict(list) @@ -67,15 +75,16 @@ def stop_dry_run(self, force: bool = True): print_human(f"killed dry run process output: {out}") print_human(f"killed dry run process err: {err}") - def check(self) -> int: + def check(self) -> CheckStatus: """Checks if the package is runnable on the current system. Returns: - 0: if no dry-run process started. - 1: if the dry-run process is started and return code is 0. - 2: if the dry-run process is started and return code is not 0. + CheckStatus.PASS: checks passed, no dry-run cleanup needed. + CheckStatus.PASS_WITH_CLEANUP: checks passed, dry-run process needs cleanup. + CheckStatus.FAIL_WITH_CLEANUP: checks failed, dry-run process needs cleanup. + CheckStatus.FAIL: checks failed, no dry-run cleanup needed. """ - ret_code = 0 + status = CheckStatus.PASS try: all_passed = True for rule in self.rules: @@ -96,23 +105,26 @@ def check(self) -> int: # check dry run if all_passed: - ret_code = self.check_dry_run() + status = self.check_dry_run() + else: + status = CheckStatus.FAIL except Exception as e: self.add_report( "Package Error", f"Exception happens in checking: {e}, this package is not in correct format.", "Please download a new package.", ) - finally: - return ret_code + status = CheckStatus.FAIL - def check_dry_run(self) -> int: + return status + + def check_dry_run(self) -> CheckStatus: """Runs dry run command. Returns: - 0: if no process started. - 1: if the process is started and return code is 0. - 2: if the process is started and return code is not 0. + CheckStatus.PASS_WITH_CLEANUP: dry run started successfully and needs cleanup. + CheckStatus.FAIL_WITH_CLEANUP: dry run started but failed and needs cleanup. + CheckStatus.FAIL: dry run could not be started. """ command = self.get_dry_run_command() dry_run_input = self.get_dry_run_inputs() @@ -130,12 +142,14 @@ def check_dry_run(self) -> int: CHECK_PASSED, "N/A", ) + return CheckStatus.PASS_WITH_CLEANUP else: self.add_report( "Check dry run", f"Can't start successfully: {out}", "Please check the error message of dry run.", ) + return CheckStatus.FAIL_WITH_CLEANUP except TimeoutExpired: os.killpg(process.pid, signal.SIGTERM) # Assumption, preflight check is focused on the connectivity, so we assume all sub-systems should @@ -149,15 +163,14 @@ def check_dry_run(self) -> int: CHECK_PASSED, "N/A", ) - - finally: - if process: - if process.returncode == 0: - return 1 - else: - return 2 - else: - return 0 + return CheckStatus.PASS_WITH_CLEANUP + except Exception as e: + self.add_report( + "Check dry run", + f"Can't start successfully: {e}", + "Please check the error message of dry run.", + ) + return CheckStatus.FAIL def add_report(self, check_name, problem_text: str, fix_text: str): self.report[self.package_path].append((check_name, problem_text, fix_text)) diff --git a/nvflare/tool/preflight_check.py b/nvflare/tool/preflight_check.py index a52defae92..2fef25331a 100644 --- a/nvflare/tool/preflight_check.py +++ b/nvflare/tool/preflight_check.py @@ -16,6 +16,7 @@ import os from nvflare.tool.package_checker import ClientPackageChecker, NVFlareConsolePackageChecker, ServerPackageChecker +from nvflare.tool.package_checker.package_checker import CheckStatus _preflight_parser = None @@ -69,13 +70,13 @@ def check_packages(args): for p in package_checkers: p.init(package_path=package_path) - ret_code = 0 + check_status = CheckStatus.PASS if p.should_be_checked(): - ret_code = p.check() + check_status = p.check() p.print_report() component_name = p.__class__.__name__.replace("PackageChecker", "").lower() - status = "pass" if ret_code == 0 else "fail" + status = "fail" if check_status in [CheckStatus.FAIL, CheckStatus.FAIL_WITH_CLEANUP] else "pass" if status == "fail": overall_pass = False check_result = {"component": component_name, "status": status} @@ -84,9 +85,9 @@ def check_packages(args): check_result["details"] = details checks.append(check_result) - if ret_code == 1: + if check_status == CheckStatus.PASS_WITH_CLEANUP: p.stop_dry_run(force=False) - elif ret_code == 2: + elif check_status == CheckStatus.FAIL_WITH_CLEANUP: p.stop_dry_run(force=True) overall = "pass" if overall_pass else "fail" diff --git a/tests/unit_test/tool/preflight_output_test.py b/tests/unit_test/tool/preflight_output_test.py index aa54c1c33e..e49038b2a4 100644 --- a/tests/unit_test/tool/preflight_output_test.py +++ b/tests/unit_test/tool/preflight_output_test.py @@ -18,6 +18,7 @@ import pytest from nvflare.tool import cli_output +from nvflare.tool.package_checker.package_checker import CheckStatus class TestPreflightOutput: @@ -66,7 +67,7 @@ def test_all_pass_json_envelope(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 0 + mock_checker.check.return_value = CheckStatus.PASS mock_checker.__class__.__name__ = "ServerPackageChecker" mock_checker.report = {str(pkg_path): []} # Make print_report() write a sentinel to stderr so we can assert routing @@ -110,7 +111,7 @@ def test_fail_exits_1(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 1 # fail + mock_checker.check.return_value = CheckStatus.FAIL_WITH_CLEANUP mock_checker.__class__.__name__ = "ServerPackageChecker" args = MagicMock() @@ -132,6 +133,70 @@ def test_fail_exits_1(self, capsys, tmp_path): assert data["exit_code"] == 1 assert data["data"]["overall"] == "fail" + def test_failed_check_without_dry_run_exits_1_without_cleanup(self, capsys, tmp_path): + """A required-rule failure exits 1 but does not try dry-run cleanup.""" + from nvflare.tool.preflight_check import check_packages + + pkg_path = tmp_path / "package2b" + pkg_path.mkdir() + (pkg_path / "startup").mkdir() + + mock_checker = MagicMock() + mock_checker.should_be_checked.return_value = True + mock_checker.check.return_value = CheckStatus.FAIL + mock_checker.__class__.__name__ = "ServerPackageChecker" + + args = MagicMock() + args.package_path = str(pkg_path) + args.output = "json" + + with ( + patch("nvflare.tool.preflight_check.ServerPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.ClientPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.NVFlareConsolePackageChecker", return_value=mock_checker), + ): + with pytest.raises(SystemExit) as exc_info: + check_packages(args) + + assert exc_info.value.code == 1 + captured = capsys.readouterr() + data = json.loads(captured.out) + assert data["exit_code"] == 1 + assert data["data"]["overall"] == "fail" + assert all(check["status"] == "fail" for check in data["data"]["checks"]) + assert mock_checker.stop_dry_run.call_count == 0 + + def test_successful_dry_run_cleanup_code_is_pass(self, capsys, tmp_path): + """Dry-run cleanup status means pass, not check failure.""" + from nvflare.tool.preflight_check import check_packages + + pkg_path = tmp_path / "package2" + pkg_path.mkdir() + (pkg_path / "startup").mkdir() + + mock_checker = MagicMock() + mock_checker.should_be_checked.return_value = True + mock_checker.check.return_value = CheckStatus.PASS_WITH_CLEANUP + mock_checker.__class__.__name__ = "ServerPackageChecker" + + args = MagicMock() + args.package_path = str(pkg_path) + args.output = "json" + + with ( + patch("nvflare.tool.preflight_check.ServerPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.ClientPackageChecker", return_value=mock_checker), + patch("nvflare.tool.preflight_check.NVFlareConsolePackageChecker", return_value=mock_checker), + ): + check_packages(args) + + captured = capsys.readouterr() + data = json.loads(captured.out) + assert data["exit_code"] == 0 + assert data["data"]["overall"] == "pass" + assert all(check["status"] == "pass" for check in data["data"]["checks"]) + assert mock_checker.stop_dry_run.call_count == 3 + def test_per_component_checks(self, capsys, tmp_path): """Each checker appears as a separate entry; stdout has only one JSON line.""" from nvflare.tool.preflight_check import check_packages @@ -142,7 +207,7 @@ def test_per_component_checks(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 0 + mock_checker.check.return_value = CheckStatus.PASS mock_checker.report = {str(pkg_path): []} args = MagicMock() @@ -195,7 +260,7 @@ def test_failed_checks_omit_empty_details(self, capsys, tmp_path): mock_checker = MagicMock() mock_checker.should_be_checked.return_value = True - mock_checker.check.return_value = 1 + mock_checker.check.return_value = CheckStatus.FAIL_WITH_CLEANUP mock_checker.__class__.__name__ = "ServerPackageChecker" args = MagicMock() From 7f72aa8c68e2ac9ca18a5f646d1bec2d89570954 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Tue, 5 May 2026 21:25:53 -0700 Subject: [PATCH 7/8] Clarify preflight docs and PyTorch update schema (cherry picked from commit b61b739f79aadbf36b98e2e64f4adb7321195106) --- .../data_scientist_guide/job_recipe.rst | 3 +++ docs/user_guide/nvflare_cli/preflight_check.rst | 16 ++++++++-------- .../pt/model_persistence_format_manager.py | 7 ++++--- nvflare/app_opt/pt/utils.py | 5 ++++- .../app_opt/pt/pt_param_validation_test.py | 7 +++++-- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/docs/user_guide/data_scientist_guide/job_recipe.rst b/docs/user_guide/data_scientist_guide/job_recipe.rst index 7a2315739f..ea7ba7e121 100644 --- a/docs/user_guide/data_scientist_guide/job_recipe.rst +++ b/docs/user_guide/data_scientist_guide/job_recipe.rst @@ -105,6 +105,9 @@ Use ``initial_ckpt`` to specify a path to pre-trained model weights: the recipe. It only needs to exist on the **server** when the model is actually loaded during job execution. * **PyTorch requires model architecture**: For PyTorch, you must provide ``model`` (class instance or dict config) along with ``initial_ckpt``, because PyTorch checkpoints contain only weights, not architecture. + * **PyTorch update schema**: The server-side PyTorch model or checkpoint defines the accepted + ``state_dict()`` key schema for client updates. A client may return only the subset of keys it trained, + but every returned key must already exist in the server schema. New client-only keys are rejected. * **TensorFlow/Keras can use checkpoint alone**: Keras ``.h5`` or SavedModel formats contain both architecture and weights, so ``initial_ckpt`` can be used without ``model``. If ``model`` is provided, use a subclassed Keras class instance (or dict config). diff --git a/docs/user_guide/nvflare_cli/preflight_check.rst b/docs/user_guide/nvflare_cli/preflight_check.rst index b8376dd3ef..db97931568 100644 --- a/docs/user_guide/nvflare_cli/preflight_check.rst +++ b/docs/user_guide/nvflare_cli/preflight_check.rst @@ -45,10 +45,10 @@ The problems that may be reported: :header: Checks,Problems,How to fix :widths: 15, 20, 25 - Check grpc port binding,Can't bind to address ({grpc_target_address}) for grpc service: {e},Please check the DNS and port. - Check admin port binding,Can't bind to address ({admin_host}:{admin_port}) for admin service: {e},Please check the DNS and port. - Check snapshot storage writable,Can't write to {self.snapshot_storage_root}: {e}.,Please check the user permission. - Check job storage writable, Can't write to {self.job_storage_root}: {e}.,Please check the user permission. + Check FL port binding,Can't bind to address ({host}:{port}): {e},Please check the DNS and port. + Check admin port binding,Can't bind to address ({host}:{port}): {e},Please check the DNS and port. + Check snapshot storage writable,Can't write to {snapshot_storage_root}: {e}.,Please check the user permission. + Check job storage writable,Can't write to {job_storage_root}: {e}.,Please check the user permission. Check dry run,Can't start successfully: {error},Please check the error message of dry run. @@ -70,8 +70,8 @@ The problems that may be reported: :header: Checks,Problems,How to fix :widths: 15, 20, 25 - Check GRPC server available,Can't connect to grpc ({server_name}:{grpc_port}) server,Please check if server is up. - Check dry run, Can't start successfully: {error}, Please check the error message of dry run. + Check server available,Can't connect to {scheme} server ({host}:{port}),Please check if server is up. + Check dry run,Can't start successfully: {error},Please check the error message of dry run. Preflight check for admin consoles @@ -92,5 +92,5 @@ The problems that may be reported: :header: Checks,Problems,How to fix :widths: 15, 20, 25 - Check GRPC server available,Can't connect to grpc ({server_name}:{grpc_port}) server,Please check if server is up. - Check dry run, Can't start successfully: {error}, Please check the error message of dry run. + Check server available,Can't connect to {scheme} server ({host}:{port}),Please check if server is up. + Check dry run,Can't start successfully: {error},Please check the error message of dry run. diff --git a/nvflare/app_opt/pt/model_persistence_format_manager.py b/nvflare/app_opt/pt/model_persistence_format_manager.py index 1a1e21eb60..3429fa22ee 100644 --- a/nvflare/app_opt/pt/model_persistence_format_manager.py +++ b/nvflare/app_opt/pt/model_persistence_format_manager.py @@ -140,9 +140,10 @@ def update(self, ml: ModelLearnable): introduce keys that do not already exist in the checkpoint. Notes: - Partial updates are supported: learned weights only need to cover the - subset of checkpoint keys that the client actually trained. The - original persisted weights for untouched keys are preserved. + The persisted checkpoint is the server schema for client updates. + Partial updates are supported: learned weights only need to cover a + subset of checkpoint keys that the client actually trained. New + client keys outside the server schema are rejected. """ err = validate_model_learnable(ml) if err: diff --git a/nvflare/app_opt/pt/utils.py b/nvflare/app_opt/pt/utils.py index 7438c7a748..fe47323e2e 100644 --- a/nvflare/app_opt/pt/utils.py +++ b/nvflare/app_opt/pt/utils.py @@ -219,7 +219,10 @@ def feed_vars(model: nn.Module, model_params): Notes: Empty payloads are treated as a no-op. Partial payloads are accepted as long as at least one key matches; unknown keys are ignored with a warning - instead of being applied to the local state dict. + instead of being applied to the local state dict. This is for loading a + received model into a local PyTorch module. Server-side validation of + learned client updates is handled by ``PTModelPersistenceFormatManager`` + and rejects keys outside the server checkpoint schema. """ _logger = get_module_logger(__name__, "AssignVariables") _logger.debug("AssignVariables...") diff --git a/tests/unit_test/app_opt/pt/pt_param_validation_test.py b/tests/unit_test/app_opt/pt/pt_param_validation_test.py index 076c089f4b..7074e73188 100644 --- a/tests/unit_test/app_opt/pt/pt_param_validation_test.py +++ b/tests/unit_test/app_opt/pt/pt_param_validation_test.py @@ -83,7 +83,8 @@ def test_feed_vars_raises_on_shape_mismatch(): feed_vars(model, params) -def test_feed_vars_warns_on_unexpected_keys_when_some_match(caplog): +def test_feed_vars_filters_global_keys_not_in_local_model(caplog): + """Local model loading can ignore extra global keys in non-strict mode.""" model = SimpleNet() params = _clone_state_dict(model) params["model.fc.weight"] = torch.ones_like(model.state_dict()["fc.weight"]) @@ -97,6 +98,7 @@ def test_feed_vars_warns_on_unexpected_keys_when_some_match(caplog): def test_persistence_manager_accepts_partial_known_updates(): + """Client updates may contain any subset of the server checkpoint schema.""" model = SimpleNet() manager = PTModelPersistenceFormatManager(_clone_state_dict(model)) new_weight = torch.full_like(model.state_dict()["fc.weight"], 5.0) @@ -107,7 +109,8 @@ def test_persistence_manager_accepts_partial_known_updates(): assert torch.equal(manager.var_dict["fc.bias"], model.state_dict()["fc.bias"]) -def test_persistence_manager_rejects_unexpected_keys(): +def test_persistence_manager_rejects_client_keys_outside_server_schema(): + """Client updates may not introduce keys outside the server checkpoint schema.""" model = SimpleNet() manager = PTModelPersistenceFormatManager(_clone_state_dict(model)) weights = { From 6aec921540f0ebffe085e21882376c538c3fdba2 Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Thu, 7 May 2026 14:03:53 -0700 Subject: [PATCH 8/8] Align preflight backport with source PR --- .../integration_test/preflight_check_test.py | 56 +++---------------- 1 file changed, 7 insertions(+), 49 deletions(-) diff --git a/tests/integration_test/preflight_check_test.py b/tests/integration_test/preflight_check_test.py index 4015c50352..3510e165fd 100644 --- a/tests/integration_test/preflight_check_test.py +++ b/tests/integration_test/preflight_check_test.py @@ -27,7 +27,7 @@ from tests.integration_test.src.constants import PREFLIGHT_CHECK_SCRIPT TEST_CASES = [ - {"project_yaml": "data/projects/dummy.yml", "admin_name": "super@test.org", "is_dummy_overseer": True}, + {"project_yaml": "data/projects/dummy.yml", "admin_name": "super@test.org"}, ] @@ -164,7 +164,6 @@ def _run_preflight_check_command(package_path: str, method: str = "subprocess", def setup_system(request): test_config = request.param project_yaml_path = test_config["project_yaml"] - is_dummy_overseer = test_config["is_dummy_overseer"] admin_name = test_config["admin_name"] if not os.path.isfile(project_yaml_path): @@ -175,24 +174,18 @@ def setup_system(request): admin_folder_root = os.path.abspath(os.path.join(workspace_root, admin_name)) - return site_launcher, is_dummy_overseer, admin_folder_root + return site_launcher, admin_folder_root @pytest.mark.xdist_group(name="preflight_tests_group") class TestPreflightCheck: - def test_run_check_on_server_after_overseer_start(self, setup_system): - site_launcher, is_dummy_overseer, _ = setup_system + def test_run_check_on_server(self, setup_system): + site_launcher, _ = setup_system try: - if not is_dummy_overseer: - site_launcher.start_overseer() - # preflight-check on server for server_name, server_props in site_launcher.server_properties.items(): - output = _run_preflight_check_command( - package_path=server_props.root_dir, expect_success=is_dummy_overseer - ) + output = _run_preflight_check_command(package_path=server_props.root_dir) actual_checks = _parse_preflight_output(output) - # Get expected checks based on communication scheme expected_checks = { "Check FL port binding": "PASSED", "Check admin port binding": "PASSED", @@ -207,55 +200,22 @@ def test_run_check_on_server_after_overseer_start(self, setup_system): site_launcher.stop_all_sites() site_launcher.cleanup() - def test_run_check_on_server_before_overseer_start(self, setup_system): - site_launcher, is_dummy_overseer, _ = setup_system - try: - # preflight-check on server - for server_name, server_props in site_launcher.server_properties.items(): - output = _run_preflight_check_command(package_path=server_props.root_dir) - actual_checks = _parse_preflight_output(output) - - # Get expected checks based on communication scheme - expected_checks = { - "Check FL port binding": "PASSED", - "Check admin port binding": "PASSED", - "Check snapshot storage writable": "PASSED", - "Check job storage writable": "PASSED", - } - - if is_dummy_overseer: - print(f"Server '{server_name}', expecting checks: {list(expected_checks.keys())}") - _verify_checks(actual_checks, expected_checks, f"Server '{server_name}'") - else: - assert any(status != "PASSED" for status in actual_checks.values()), ( - f"Server '{server_name}' preflight check expected some failures before overseer start, " - f"Actual checks: {actual_checks}" - ) - finally: - site_launcher.stop_all_sites() - site_launcher.cleanup() - def test_run_check_on_client(self, setup_system): - site_launcher, is_dummy_overseer, _ = setup_system + site_launcher, _ = setup_system try: - if not is_dummy_overseer: - site_launcher.start_overseer() site_launcher.start_servers() time.sleep(SERVER_START_TIME) - # preflight-check on clients for client_name, client_props in site_launcher.client_properties.items(): output = _run_preflight_check_command(package_path=client_props.root_dir) actual_checks = _parse_preflight_output(output) - # Get expected checks based on communication scheme expected_checks = { "Check server available": "PASSED", } print(f"Client '{client_name}', expecting checks: {list(expected_checks.keys())}") - # Verify checks match expectations based on communication scheme _verify_checks(actual_checks, expected_checks, f"Client '{client_name}'") except Exception: raise @@ -264,10 +224,8 @@ def test_run_check_on_client(self, setup_system): site_launcher.cleanup() def test_run_check_on_admin_console(self, setup_system): - site_launcher, is_dummy_overseer, admin_folder_root = setup_system + site_launcher, admin_folder_root = setup_system try: - if not is_dummy_overseer: - site_launcher.start_overseer() site_launcher.start_servers() time.sleep(SERVER_START_TIME)