Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/user_guide/data_scientist_guide/job_recipe.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ Use ``initial_ckpt`` to specify a path to pre-trained model weights:
the recipe. It only needs to exist on the **server** when the model is actually loaded during job execution.
* **PyTorch requires model architecture**: For PyTorch, you must provide ``model`` (class instance or
dict config) along with ``initial_ckpt``, because PyTorch checkpoints contain only weights, not architecture.
* **PyTorch update schema**: The server-side PyTorch model or checkpoint defines the accepted
``state_dict()`` key schema for client updates. A client may return only the subset of keys it trained,
but every returned key must already exist in the server schema. New client-only keys are rejected.
* **TensorFlow/Keras can use checkpoint alone**: Keras ``.h5`` or SavedModel formats contain both architecture
and weights, so ``initial_ckpt`` can be used without ``model``. If ``model`` is provided, use a subclassed
Keras class instance (or dict config).
Expand Down
29 changes: 16 additions & 13 deletions docs/user_guide/nvflare_cli/preflight_check.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ General Usage

.. code-block::

nvflare preflight_check -p PACKAGE_PATH
nvflare preflight_check --package_path PACKAGE_PATH
nvflare preflight-check -p PACKAGE_PATH
nvflare preflight-check --package_path PACKAGE_PATH


This preflight check script should be run on each site's machine. The ``PACKAGE_PATH`` is the path to the folder that contains
Expand All @@ -23,6 +23,9 @@ the package to be checked.
After running the script, for the checks that pass, users will see "PASSED". The problem and how
to fix it is reported for checks that fail.

Exit code ``0`` means all applicable checks passed. Exit code ``1`` means at least one applicable check failed.
Exit code ``4`` means the package path or package format is invalid.

Below are the scripts to run the preflight check on each type of site and the possible problems that may be reported.


Expand All @@ -34,18 +37,18 @@ on the server site, a user should run:

.. code-block::

nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/server1
nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/server1

The problems that may be reported:

.. csv-table::
:header: Checks,Problems,How to fix
:widths: 15, 20, 25

Check grpc port binding,Can't bind to address ({grpc_target_address}) for grpc service: {e},Please check the DNS and port.
Check admin port binding,Can't bind to address ({admin_host}:{admin_port}) for admin service: {e},Please check the DNS and port.
Check snapshot storage writable,Can't write to {self.snapshot_storage_root}: {e}.,Please check the user permission.
Check job storage writable, Can't write to {self.job_storage_root}: {e}.,Please check the user permission.
Check FL port binding,Can't bind to address ({host}:{port}): {e},Please check the DNS and port.
Check admin port binding,Can't bind to address ({host}:{port}): {e},Please check the DNS and port.
Check snapshot storage writable,Can't write to {snapshot_storage_root}: {e}.,Please check the user permission.
Check job storage writable,Can't write to {job_storage_root}: {e}.,Please check the user permission.
Check dry run,Can't start successfully: {error},Please check the error message of dry run.


Expand All @@ -59,16 +62,16 @@ So on the client site, a user will run:

.. code-block::

nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/site-1
nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/site-1

The problems that may be reported:

.. csv-table::
:header: Checks,Problems,How to fix
:widths: 15, 20, 25

Check GRPC server available,Can't connect to grpc ({server_name}:{grpc_port}) server,Please check if server is up.
Check dry run, Can't start successfully: {error}, Please check the error message of dry run.
Check server available,Can't connect to {scheme} server ({host}:{port}),Please check if server is up.
Check dry run,Can't start successfully: {error},Please check the error message of dry run.


Preflight check for admin consoles
Expand All @@ -81,13 +84,13 @@ a user should run:

.. code-block::

nvflare preflight_check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/admin@nvidia.com
nvflare preflight-check -p /path_to_NVFlare/NVFlare/workspace/example_project/prod_00/admin@nvidia.com

The problems that may be reported:

.. csv-table::
:header: Checks,Problems,How to fix
:widths: 15, 20, 25

Check GRPC server available,Can't connect to grpc ({server_name}:{grpc_port}) server,Please check if server is up.
Check dry run, Can't start successfully: {error}, Please check the error message of dry run.
Check server available,Can't connect to {scheme} server ({host}:{port}),Please check if server is up.
Check dry run,Can't start successfully: {error},Please check the error message of dry run.
7 changes: 4 additions & 3 deletions nvflare/app_opt/pt/model_persistence_format_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,10 @@ def update(self, ml: ModelLearnable):
introduce keys that do not already exist in the checkpoint.

Notes:
Partial updates are supported: learned weights only need to cover the
subset of checkpoint keys that the client actually trained. The
original persisted weights for untouched keys are preserved.
The persisted checkpoint is the server schema for client updates.
Partial updates are supported: learned weights only need to cover a
subset of checkpoint keys that the client actually trained. New
client keys outside the server schema are rejected.
"""
err = validate_model_learnable(ml)
if err:
Expand Down
5 changes: 4 additions & 1 deletion nvflare/app_opt/pt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ def feed_vars(model: nn.Module, model_params):
Notes:
Empty payloads are treated as a no-op. Partial payloads are accepted as
long as at least one key matches; unknown keys are ignored with a warning
instead of being applied to the local state dict.
instead of being applied to the local state dict. This is for loading a
received model into a local PyTorch module. Server-side validation of
learned client updates is handled by ``PTModelPersistenceFormatManager``
and rejects keys outside the server checkpoint schema.
"""
_logger = get_module_logger(__name__, "AssignVariables")
_logger.debug("AssignVariables...")
Expand Down
55 changes: 34 additions & 21 deletions nvflare/tool/package_checker/package_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@
import signal
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum, auto
from subprocess import TimeoutExpired

from nvflare.tool.package_checker.check_rule import CHECK_PASSED, CheckResult, CheckRule
from nvflare.tool.package_checker.utils import run_command_in_subprocess, split_by_len


class CheckStatus(Enum):
PASS = auto()
PASS_WITH_CLEANUP = auto()
FAIL_WITH_CLEANUP = auto()
FAIL = auto()


class PackageChecker(ABC):
def __init__(self):
self.report = defaultdict(list)
Expand Down Expand Up @@ -67,15 +75,16 @@ def stop_dry_run(self, force: bool = True):
print_human(f"killed dry run process output: {out}")
print_human(f"killed dry run process err: {err}")

def check(self) -> int:
def check(self) -> CheckStatus:
"""Checks if the package is runnable on the current system.

Returns:
0: if no dry-run process started.
1: if the dry-run process is started and return code is 0.
2: if the dry-run process is started and return code is not 0.
CheckStatus.PASS: checks passed, no dry-run cleanup needed.
CheckStatus.PASS_WITH_CLEANUP: checks passed, dry-run process needs cleanup.
CheckStatus.FAIL_WITH_CLEANUP: checks failed, dry-run process needs cleanup.
CheckStatus.FAIL: checks failed, no dry-run cleanup needed.
"""
ret_code = 0
status = CheckStatus.PASS
try:
all_passed = True
for rule in self.rules:
Expand All @@ -96,23 +105,26 @@ def check(self) -> int:

# check dry run
if all_passed:
ret_code = self.check_dry_run()
status = self.check_dry_run()
else:
status = CheckStatus.FAIL
except Exception as e:
self.add_report(
"Package Error",
f"Exception happens in checking: {e}, this package is not in correct format.",
"Please download a new package.",
)
finally:
return ret_code
status = CheckStatus.FAIL

def check_dry_run(self) -> int:
return status

def check_dry_run(self) -> CheckStatus:
"""Runs dry run command.

Returns:
0: if no process started.
1: if the process is started and return code is 0.
2: if the process is started and return code is not 0.
CheckStatus.PASS_WITH_CLEANUP: dry run started successfully and needs cleanup.
CheckStatus.FAIL_WITH_CLEANUP: dry run started but failed and needs cleanup.
CheckStatus.FAIL: dry run could not be started.
"""
command = self.get_dry_run_command()
dry_run_input = self.get_dry_run_inputs()
Expand All @@ -130,12 +142,14 @@ def check_dry_run(self) -> int:
CHECK_PASSED,
"N/A",
)
return CheckStatus.PASS_WITH_CLEANUP
else:
self.add_report(
"Check dry run",
f"Can't start successfully: {out}",
"Please check the error message of dry run.",
)
return CheckStatus.FAIL_WITH_CLEANUP
except TimeoutExpired:
os.killpg(process.pid, signal.SIGTERM)
# Assumption, preflight check is focused on the connectivity, so we assume all sub-systems should
Expand All @@ -149,15 +163,14 @@ def check_dry_run(self) -> int:
CHECK_PASSED,
"N/A",
)

finally:
if process:
if process.returncode == 0:
return 1
else:
return 2
else:
return 0
return CheckStatus.PASS_WITH_CLEANUP
except Exception as e:
self.add_report(
"Check dry run",
f"Can't start successfully: {e}",
"Please check the error message of dry run.",
)
return CheckStatus.FAIL
Comment thread
YuanTingHsieh marked this conversation as resolved.
Comment thread
YuanTingHsieh marked this conversation as resolved.

def add_report(self, check_name, problem_text: str, fix_text: str):
self.report[self.package_path].append((check_name, problem_text, fix_text))
Expand Down
11 changes: 6 additions & 5 deletions nvflare/tool/preflight_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import os

from nvflare.tool.package_checker import ClientPackageChecker, NVFlareConsolePackageChecker, ServerPackageChecker
from nvflare.tool.package_checker.package_checker import CheckStatus

_preflight_parser = None

Expand Down Expand Up @@ -69,13 +70,13 @@ def check_packages(args):

for p in package_checkers:
p.init(package_path=package_path)
ret_code = 0
check_status = CheckStatus.PASS
if p.should_be_checked():
ret_code = p.check()
check_status = p.check()
p.print_report()

component_name = p.__class__.__name__.replace("PackageChecker", "").lower()
status = "pass" if ret_code == 0 else "fail"
status = "fail" if check_status in [CheckStatus.FAIL, CheckStatus.FAIL_WITH_CLEANUP] else "pass"
if status == "fail":
overall_pass = False
check_result = {"component": component_name, "status": status}
Expand All @@ -84,9 +85,9 @@ def check_packages(args):
check_result["details"] = details
checks.append(check_result)

if ret_code == 1:
if check_status == CheckStatus.PASS_WITH_CLEANUP:
p.stop_dry_run(force=False)
elif ret_code == 2:
elif check_status == CheckStatus.FAIL_WITH_CLEANUP:
p.stop_dry_run(force=True)

overall = "pass" if overall_pass else "fail"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
{
"id": "persistor",
"path": "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor",
"args": {}
"args": {
"model": {
"path": "simple_network.SimpleNetwork",
"args": {}
}
}
},
{
"id": "shareable_generator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
{
"id": "persistor",
"path": "nvflare.app_opt.pt.file_model_persistor.PTFileModelPersistor",
"args": {}
"args": {
"model": {
"path": "simple_network.SimpleNetwork",
"args": {}
}
}
},
{
"id": "shareable_generator",
Expand Down

This file was deleted.

Loading
Loading