联邦学习FedAvg自编写代码

更新时间:2023-07-19 05:56:01 阅读: 评论:0

联邦学习FedAvg⾃编写代码
联邦学习中,联邦平均算法获得了很⼤的使⽤空间,因此常常被⽤于进⾏同步训练操作
不多废话了,以下为Fedavg代码
由于使⽤场景为NonIID场景,因此我使⽤了别⼈的⼀个MNIST数据集⾃定义的代码(见附录)
FedAvg代码如下,功能具体看注释
⼯作环境:python3.8.5 + pytorch(⽆cuda)
divergence模块可直接删除
# coding: utf-8
# In[1]:
import argpar
import torch
import os
as nn
functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datats, transforms
from torch.autograd import Variable
from PIL import Image
import torch
import copy
import pandas as pd
import random
import time
import sys
import re
import matplotlib.pyplot as plt
#import divergence
name =str(sys.argv[0])
# In[2]:
home_path ="./"
class MyDatat(torch.utils.data.Datat):#创建⾃⼰的类:MyDatat,这个类是继承的torch.utils.data.Datat
def__init__(lf,root,data,label,transform=None, target_transform=None):#初始化⼀些需要传⼊的参数
super(MyDatat,lf).__init__()
imgs =[]#创建⼀个名为img的空列表,⼀会⼉⽤来装东西
lf.img_route = root
for i in range(len(data)):
imgs.append((data[i],int(label[i])))
lf.imgs = imgs
lf.target_transform = target_transform
def__getitem__(lf, index):#这个⽅法是必须要有的,⽤于按照索引读取每个元素的具体内容
fn, label = lf.imgs[index]#fn是图⽚path #fn和label分别获得imgs[index]也即是刚才每⾏中word[0]和word[1]的信息        route = lf.img_route +str(label)+"/"+ fn
img = Image.open(route)#按照path读⼊图⽚from PIL import Image # 按照路径读取图⽚
ansform is not None:
img = lf.transform(img)#是否进⾏transform
return img,label  #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容def__len__(lf):#这个函数也必须要写,它返回的是数据集的长度,也就是多少张图⽚,要和loader的长度作区分return len(lf.imgs)
filePath = home_path +'data/MNIST/image_turn/'
train_data =[]
train_label =[]
for i in range(10):
train_data.append(os.listdir(filePath+str(i)))
train_label.append([i]*len(train_data[i]))
filePath = home_path +'data/MNIST/image_test_turn/'
test_data =[]
test_label =[]
for i in range(10):
test_data.append(os.listdir(filePath+str(i)))
test_label.append([i]*len(test_data[i]))
test_ori =[]
test_label_ori =[]
for x in range(10):
test_ori += test_data[x]
test_label_ori += test_label[x]
春到沂河test_data=MyDatat(home_path +"data/MNIST/image_test_turn/",test_ori,test_label_ori, transform=transforms.ToTensor()) test_loader = DataLoader(datat=test_data, batch_size=64)
# In[4]:
class MyConvNet(nn.Module):
def__init__(lf):
super(MyConvNet,lf).__init__()
nn.Conv2d(
in_channels =1,
out_channels =16,
kernel_size =3,
stride =1,
padding =1,
),
nn.ReLU(),
nn.AvgPool2d(
kernel_size =2,
stride =2)
)
nn.Conv2d(16,32,3,1,0),
nn.ReLU(),
nn.MaxPool2d(2,2),
)
lf.classifier = nn.Sequential(
nn.Linear(32*6*6,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,10)
)
def forward(lf,x):
x = lf.conv1(x)
x = lf.conv2(x)
x = x.view(x.size(0),-1)
output = lf.classifier(x)
return  output
def train_model(model,traindataloader,criterion,optimizer,batch_max,num_epochs):
train_loss_all =[]
train_acc_all =[]
for epoch in range(num_epochs):
train_loss =0.0
train_corrects =0
train_num =0
temp = random.sample(traindataloader, batch_max)
for(b_x,b_y)in temp:
if(torch.cuda.is_available()):
b_x = b_x.cuda()
b_y = b_y.cuda()
output = model(b_x)
pre_lab = torch.argmax(output,1)
loss =  criterion(output, b_y)
<_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
train_corrects += torch.sum(pre_lab == b_y.data)
train_num += b_x.size(0)
train_loss_all.append(train_loss / train_num)
train_acc_all.append(train_corrects.double().item()/train_num)
print("Train Loss:{:.4f}  Train Acc: {:.4f}".format(train_loss_all[-1],train_acc_all[-1]))
return model
# In[6]:
def local_train(local_convnet_dict,traindataloader,epochs,batch_max):
if(torch.cuda.is_available()):
local_convnet = MyConvNet().cuda()
el:
local_convnet = MyConvNet()
local_convnet.load_state_dict(local_convnet_dict)
optimizer = optim.Adam(local_convnet.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
local_convnet = train_model(local_convnet,traindataloader, criterion, optimizer, batch_max, epochs)    minus_convnet_dict = MyConvNet().state_dict()
for name in local_convnet.state_dict():
minus_convnet_dict[name]= local_convnet_dict[name]- local_convnet.state_dict()[name] return minus_convnet_dict
# In[7]:
def Central_model_update(Central_model, minus_convnet_client):
weight =1
model_dict = Central_model.state_dict()
for name in Central_model.state_dict():
for local_dict in minus_convnet_client:
model_dict[name]= model_dict[name]- weight * local_dict[name]/len(minus_convnet_client)    Central_model.load_state_dict(model_dict)
return Central_model
# In[12]:
def train_data_loader(client_num, ClientSort1):
global train_data
global train_label
train_loaders =[]
for i in range(client_num):
御茶水train_ori =[]
label_ori =[]
for j in range(10):
train_ori += train_data[j]
label_ori += train_label[j]
train_datas=MyDatat(home_path +"data/MNIST/image_turn/",train_ori,label_ori, transform=transforms.ToTensor())        train_loader = DataLoader(datat=train_datas, batch_size=100, shuffle=True)
train_list =[]
for step,(b_x,b_y)in enumerate(train_loader):
初中数学小报train_list.append((b_x,b_y))
train_loaders.append(train_list)磨耗层
return train_loaders
'''
def train_data_loader(client_num, ClientSort1):
global train_data
global train_label
train_loaders = []
for i in range(client_num):
if i < ClientSort1: #同步部分
train_ori = []
label_ori = []
耳鸣的治疗方法for j in range(10):
if(i!=j):
train_ori += train_data[j]
label_ori += train_label[j]
el: #异步部分
train_ori = []
label_ori = []
for j in range(10):
if(j == i-10 ):
train_ori += train_data[j]
绘画的英语label_ori += train_label[j]
train_datas = MyDatat(home_path + "data/MNIST/image_turn/",train_ori,label_ori, transform=transforms.ToTensor())        train_loader = DataLoader(datat=train_datas, batch_size=128, shuffle=True)
train_list = []
for step,(b_x,b_y) in enumerate(train_loader):
train_list.append((b_x,b_y))
train_loaders.append(train_list)
return train_loaders
'''
# In[13]:
def test_accuracy(Central_model):
global test_loader
test_correct =0
for data in test_loader:
Central_model.eval()
inputs, lables = data
if(torch.cuda.is_available()):
inputs = inputs.cuda()
inputs, lables = Variable(inputs), Variable(lables)
outputs = Central_model(inputs)绒毛活检
if(torch.cuda.is_available()):
outputs = outputs.cpu()
_,id= torch.max(outputs.data,1)
test_correct += torch.sum(id== lables.data)
test_correct = test_correct
print("correct:%.3f%%"%(100* test_correct /len(test_ori)))
return100* test_correct /len(test_ori)
# In[14]:
>>>>>###
>>>>>###
if(torch.cuda.is_available()):
Central_model = MyConvNet().cuda()
el:
Central_model = MyConvNet()
local_client_num =10#局部客户端数量
ClientSort1 =10
#Central_model.load_state_dict(torch.load('F:/params.pkl'))
global_epoch =1000
#print(test_accuracy(Central_model))
# In[ ]:
train_loaders = train_data_loader(local_client_num, ClientSort1)
result =[]
手工树叶
count =0
for i in range(global_epoch):
count +=1
minus_model =[]
for j in range(local_client_num):
minus_model.append(local_train(Central_model.state_dict(),train_loaders[j],1,1))    Central_model = Central_model_update(Central_model, minus_model)
print("epoch: ",count,"\naccuracy:")
result.append(float(test_accuracy(Central_model)))
#divergence.cal_divergence(Central_model.state_dict(), i, "F:/同步Fig/")
# In[ ]:
plt.xlabel('round')
plt.ylabel('accuracy')
plt.plot(range(0,len(result)), result, color='r',linewidth='1.0', label='同步FedAvg')
plt.savefig(home_path + name +".jpg")
filename =open(home_path + name +".txt", mode='w')
for namet in result:
filename.write(str(namet))
filename.write('\n')
filename.clo()
torch.save(Central_model.state_dict(),filePath + name +'.pkl')
附录:
import numpy as np
import struct
from PIL import Image
import os
data_file ='./data/MNIST/raw/train-images-idx3-ubyte'#需要修改的路径
data_file_size =47040016
data_file_size =str(data_file_size -16)+'B'
data_buf =open(data_file,'rb').read()
magic, numImages, numRows, numColumns = struct.unpack_from(

本文发布于:2023-07-19 05:56:01,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/1104502.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:联邦   读取   代码
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图