diff --git a/src/appfl/funcx/client_utils.py b/src/appfl/funcx/client_utils.py index 6be0dd3..41c3d13 100644 --- a/src/appfl/funcx/client_utils.py +++ b/src/appfl/funcx/client_utils.py @@ -21,6 +21,11 @@ def get_model(cfg): model = ModelClass(*cfg.model_args, **cfg.model_kwargs) return model +def get_loss(cfg): + LossClass = get_executable_func(cfg.get_loss)() + loss_fn = LossClass() + return loss_fn + def load_global_state(cfg, global_state, temp_dir): if CloudStorage.is_cloud_storage_object(global_state): CloudStorage.init(cfg, temp_dir) diff --git a/src/appfl/funcx/funcx_client.py b/src/appfl/funcx/funcx_client.py index 0c21405..77d86c1 100644 --- a/src/appfl/funcx/funcx_client.py +++ b/src/appfl/funcx/funcx_client.py @@ -43,7 +43,7 @@ def client_training( from appfl.algorithm.client_optimizer import ClientOptim from appfl.algorithm.funcx_client_optimizer import FuncxClientOptim from appfl.misc import client_log, get_dataloader - from appfl.funcx.client_utils import get_dataset, load_global_state, send_client_state, get_model + from appfl.funcx.client_utils import get_dataset, load_global_state, send_client_state, get_model, get_loss ## Load client configs cfg.device = cfg.clients[client_idx].device @@ -65,6 +65,10 @@ def client_training( ## Prepare output directory output_filename = cfg.output_filename + "_client_%s" % (client_idx) outfile = client_log(cfg.output_dirname, output_filename) + + # Get custom loss function + if loss_fn == None: + loss_fn = get_loss(cfg) ## Instantiate training agent at client client= eval(cfg.fed.clientname)( @@ -123,7 +127,7 @@ def client_testing( import os.path as osp from appfl.misc import client_log, get_dataloader from appfl.algorithm.client_optimizer import ClientOptim - from appfl.funcx.client_utils import get_dataset, load_global_state, get_model + from appfl.funcx.client_utils import get_dataset, load_global_state, get_model, get_loss ## Load client configs cfg.device = cfg.clients[client_idx].device cfg.output_dirname = cfg.clients[client_idx].output_dir @@ -138,6 +142,11 @@ def client_testing( ## Prepare output directory output_filename = cfg.output_filename + "_client_test_%s" % (client_idx) outfile = client_log(cfg.output_dirname, output_filename) + + # Get custom loss function + if loss_fn == None: + loss_fn = get_loss(cfg) + ## Instantiate training agent at client client= eval(cfg.fed.clientname)( client_idx,