Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/appfl/funcx/client_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def get_model(cfg):
model = ModelClass(*cfg.model_args, **cfg.model_kwargs)
return model

def get_loss(cfg):
LossClass = get_executable_func(cfg.get_loss)()
loss_fn = LossClass()
return loss_fn

def load_global_state(cfg, global_state, temp_dir):
if CloudStorage.is_cloud_storage_object(global_state):
CloudStorage.init(cfg, temp_dir)
Expand Down
13 changes: 11 additions & 2 deletions src/appfl/funcx/funcx_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def client_training(
from appfl.algorithm.client_optimizer import ClientOptim
from appfl.algorithm.funcx_client_optimizer import FuncxClientOptim
from appfl.misc import client_log, get_dataloader
from appfl.funcx.client_utils import get_dataset, load_global_state, send_client_state, get_model
from appfl.funcx.client_utils import get_dataset, load_global_state, send_client_state, get_model, get_loss

## Load client configs
cfg.device = cfg.clients[client_idx].device
Expand All @@ -65,6 +65,10 @@ def client_training(
## Prepare output directory
output_filename = cfg.output_filename + "_client_%s" % (client_idx)
outfile = client_log(cfg.output_dirname, output_filename)

# Get custom loss function
if loss_fn == None:
loss_fn = get_loss(cfg)

## Instantiate training agent at client
client= eval(cfg.fed.clientname)(
Expand Down Expand Up @@ -123,7 +127,7 @@ def client_testing(
import os.path as osp
from appfl.misc import client_log, get_dataloader
from appfl.algorithm.client_optimizer import ClientOptim
from appfl.funcx.client_utils import get_dataset, load_global_state, get_model
from appfl.funcx.client_utils import get_dataset, load_global_state, get_model, get_loss
## Load client configs
cfg.device = cfg.clients[client_idx].device
cfg.output_dirname = cfg.clients[client_idx].output_dir
Expand All @@ -138,6 +142,11 @@ def client_testing(
## Prepare output directory
output_filename = cfg.output_filename + "_client_test_%s" % (client_idx)
outfile = client_log(cfg.output_dirname, output_filename)

# Get custom loss function
if loss_fn == None:
loss_fn = get_loss(cfg)

## Instantiate training agent at client
client= eval(cfg.fed.clientname)(
client_idx,
Expand Down