From b1552196440dd5bcb14362a2f793b814438a8e94 Mon Sep 17 00:00:00 2001 From: He Date: Tue, 1 Aug 2023 14:54:14 -0500 Subject: [PATCH 1/2] use loss function --- src/appfl/funcx/client_utils.py | 6 ++++++ src/appfl/funcx/funcx_client.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/appfl/funcx/client_utils.py b/src/appfl/funcx/client_utils.py index 6be0dd3..777b1d7 100644 --- a/src/appfl/funcx/client_utils.py +++ b/src/appfl/funcx/client_utils.py @@ -21,6 +21,12 @@ def get_model(cfg): model = ModelClass(*cfg.model_args, **cfg.model_kwargs) return model +def get_loss(cfg): + print("===== get loss class =====") + LossClass = get_executable_func(cfg.get_loss)() + loss_fn = LossClass() + return loss_fn + def load_global_state(cfg, global_state, temp_dir): if CloudStorage.is_cloud_storage_object(global_state): CloudStorage.init(cfg, temp_dir) diff --git a/src/appfl/funcx/funcx_client.py b/src/appfl/funcx/funcx_client.py index 0c21405..77d86c1 100644 --- a/src/appfl/funcx/funcx_client.py +++ b/src/appfl/funcx/funcx_client.py @@ -43,7 +43,7 @@ def client_training( from appfl.algorithm.client_optimizer import ClientOptim from appfl.algorithm.funcx_client_optimizer import FuncxClientOptim from appfl.misc import client_log, get_dataloader - from appfl.funcx.client_utils import get_dataset, load_global_state, send_client_state, get_model + from appfl.funcx.client_utils import get_dataset, load_global_state, send_client_state, get_model, get_loss ## Load client configs cfg.device = cfg.clients[client_idx].device @@ -65,6 +65,10 @@ def client_training( ## Prepare output directory output_filename = cfg.output_filename + "_client_%s" % (client_idx) outfile = client_log(cfg.output_dirname, output_filename) + + # Get custom loss function + if loss_fn == None: + loss_fn = get_loss(cfg) ## Instantiate training agent at client client= eval(cfg.fed.clientname)( @@ -123,7 +127,7 @@ def client_testing( import os.path as osp from appfl.misc import client_log, get_dataloader from appfl.algorithm.client_optimizer import ClientOptim - from appfl.funcx.client_utils import get_dataset, load_global_state, get_model + from appfl.funcx.client_utils import get_dataset, load_global_state, get_model, get_loss ## Load client configs cfg.device = cfg.clients[client_idx].device cfg.output_dirname = cfg.clients[client_idx].output_dir @@ -138,6 +142,11 @@ def client_testing( ## Prepare output directory output_filename = cfg.output_filename + "_client_test_%s" % (client_idx) outfile = client_log(cfg.output_dirname, output_filename) + + # Get custom loss function + if loss_fn == None: + loss_fn = get_loss(cfg) + ## Instantiate training agent at client client= eval(cfg.fed.clientname)( client_idx, From 567a05d4e4a239be62fc98f370b8072165d076fc Mon Sep 17 00:00:00 2001 From: He Date: Tue, 1 Aug 2023 15:29:08 -0500 Subject: [PATCH 2/2] clean up --- src/appfl/funcx/client_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/appfl/funcx/client_utils.py b/src/appfl/funcx/client_utils.py index 777b1d7..41c3d13 100644 --- a/src/appfl/funcx/client_utils.py +++ b/src/appfl/funcx/client_utils.py @@ -22,7 +22,6 @@ def get_model(cfg): return model def get_loss(cfg): - print("===== get loss class =====") LossClass = get_executable_func(cfg.get_loss)() loss_fn = LossClass() return loss_fn