-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathserver.py
More file actions
61 lines (42 loc) · 2.21 KB
/
server.py
File metadata and controls
61 lines (42 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from collections import OrderedDict
from omegaconf import DictConfig
import torch
from model import Net, test
def get_on_fit_config(config: DictConfig):
"""Return function that prepares config to send to clients."""
def fit_config_fn(server_round: int):
# This function will be executed by the strategy in its
# `configure_fit()` method.
# Here we are returning the same config on each round but
# here you might use the `server_round` input argument to
# adapt over time these settings so clients. For example, you
# might want clients to use a different learning rate at later
# stages in the FL process (e.g. smaller lr after N rounds)
return {
"lr": config.lr,
"momentum": config.momentum,
"local_epochs": config.local_epochs,
}
return fit_config_fn
def get_evaluate_fn(num_classes: int, testloader):
"""Define function for global evaluation on the server."""
def evaluate_fn(server_round: int, parameters, config):
# This function is called by the strategy's `evaluate()` method
# and receives as input arguments the current round number and the
# parameters of the global model.
# this function takes these parameters and evaluates the global model
# on a evaluation / test dataset.
model = Net(num_classes)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
params_dict = zip(model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
model.load_state_dict(state_dict, strict=True)
# Here we evaluate the global model on the test set. Recall that in more
# realistic settings you'd only do this at the end of your FL experiment
# you can use the `server_round` input argument to determine if this is the
# last round. If it's not, then preferably use a global validation set.
loss, accuracy = test(model, testloader, device)
# Report the loss and any other metric (inside a dictionary). In this case
# we report the global test accuracy.
return loss, {"accuracy": accuracy}
return evaluate_fn