Technology & AI

Step-by-Step Guide to Building and Comparing FedAvg and FedProx Federated Learning on Non-IID CIFAR-10 and NVIDIA FLARE

CLIENT_SCRIPT += r'''
def main():
   p = argparse.ArgumentParser()
   p.add_argument("--num_sites", type=int, default=3)
   p.add_argument("--alpha", type=float, default=0.3)
   p.add_argument("--local_epochs", type=int, default=1)
   p.add_argument("--mu", type=float, default=0.0)
   p.add_argument("--max_samples", type=int, default=4000)
   p.add_argument("--batch_size", type=int, default=64)
   p.add_argument("--lr", type=float, default=0.01)
   p.add_argument("--data_root", type=str, default="/tmp/nvflare/data")
   p.add_argument("--results_dir", type=str, default="/tmp/nvflare/results")
   p.add_argument("--tag", type=str, default="fedavg")
   args = p.parse_args()
   device = "cuda" if torch.cuda.is_available() else "cpu"
   tf = T.Compose([T.ToTensor(),
                   T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
   train_set = torchvision.datasets.CIFAR10(args.data_root, train=True,  download=False, transform=tf)
   test_set  = torchvision.datasets.CIFAR10(args.data_root, train=False, download=False, transform=tf)
   flare.init()
   site_name = flare.get_site_name()
   site_id   = int(site_name.split("-")[-1]) - 1
   labels   = np.array(train_set.targets)
   my_idx   = dirichlet_partition(labels, args.num_sites, args.alpha)[site_id]
   if len(my_idx) > args.max_samples:
       my_idx = my_idx[:args.max_samples]
   train_loader = DataLoader(Subset(train_set, my_idx), batch_size=args.batch_size, shuffle=True)
   test_loader  = DataLoader(test_set, batch_size=512, shuffle=False)
   print(f"[{site_name}] mu={args.mu}  local samples={len(my_idx)}", flush=True)
   model     = Net().to(device)
   optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
   criterion = nn.CrossEntropyLoss()
   while flare.is_running():
       input_model = flare.receive()
       rnd = input_model.current_round
       global_state = {k: torch.as_tensor(v) for k, v in input_model.params.items()}
       model.load_state_dict(global_state)
       acc = evaluate(model, test_loader, device)
       if site_id == 0:
           with open(os.path.join(args.results_dir, args.tag + ".csv"), "a", newline="") as f:
               csv.writer(f).writerow([rnd, acc])
       print(f"[{site_name}] round {rnd}: global test acc = {acc:.4f}", flush=True)
       global_w = [w.detach().clone() for w in model.parameters()]
       model.train()
       steps = 0
       for _ in range(args.local_epochs):
           for x, y in train_loader:
               x, y = x.to(device), y.to(device)
               optimizer.zero_grad()
               loss = criterion(model(x), y)
               if args.mu > 0:
                   prox = sum(((w - g) ** 2).sum() for w, g in zip(model.parameters(), global_w))
                   loss = loss + (args.mu / 2.0) * prox
               loss.backward()
               optimizer.step()
               steps += 1
       out = flare.FLModel(
           params={k: v.cpu().numpy() for k, v in model.state_dict().items()},
           metrics={"test_accuracy": acc},
           meta={"NUM_STEPS_CURRENT_ROUND": steps},
       )
       flare.send(out)
if __name__ == "__main__":
   main()
'''
with open("client_train.py", "w") as f:
   f.write(CLIENT_SCRIPT)
sys.path.insert(0, os.getcwd())
from client_train import Net

Related Articles

Leave a Reply

Your email address will not be published. Required fields are marked *

Back to top button