-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdemo.py
78 lines (66 loc) · 1.9 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import flight as fl
from flight.learning import federated_split
from flight.learning.torch import TorchModule
from flight.learning.torch.types import TensorLoss
NUM_LABELS = 10
class MyModule(TorchModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 28 * 28 * 3),
nn.ReLU(),
nn.Linear(28 * 28 * 3, 28 * 28),
nn.ReLU(),
nn.Linear(28 * 28, 28),
nn.ReLU(),
nn.Linear(28, NUM_LABELS),
)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx) -> TensorLoss:
x, y = batch
y_hat = self(x)
return nn.functional.cross_entropy(y_hat, y)
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=0.02)
def main():
data = MNIST(
root="~/Research/Data/Torch-Data/",
download=False,
train=False,
transform=ToTensor(),
)
data = Subset(data, indices=list(range(200)))
topo = fl.flat_topology(10)
module = MyModule()
fed_data = federated_split(
topo=topo,
data=data,
num_labels=NUM_LABELS,
label_alpha=100.0,
sample_alpha=100.0,
)
trained_module, records = fl.federated_fit(
topo, module, fed_data, strategy="fedavg", rounds=10
)
df = pd.DataFrame.from_records(records)
print(df.head())
sns.lineplot(
df,
x="train/time",
y="train/loss",
hue="node/idx",
# errorbar=None,
).set(yscale="linear")
plt.show()
if __name__ == "__main__":
main()