pytorch学习:STGCN

更新时间:2023-07-30 14:58:53 阅读: 评论:0

pytorch学习:STGCN 1 main.ipynb
1.1 导⼊库
import random
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
春到沂河from load_data import *
from utils import *
from stgcn import *
多想
1.2 随机种⼦
torch.manual_ed(2021)
torch.cuda.manual_ed(2021)
np.random.ed(2021)
random.ed(2021)
torch.backends.cudnn.deterministic = True
1.3 cpu or gpu
if torch.cuda.is_available():
device = torch.device("cuda")
el:
device=torch.device("cpu")
1.4 file path
matrix_path = "datat/W_228.csv"
#邻接矩阵
#228*228,228是观测点的数量
data_path = "datat/V_228.csv"
#数据矩阵
#12672*228
#12672=288*44,288是⼀天中有⼏个5分钟,44是我数据集⼀共44天
save_path = "save/model.pt"
#模型保存路径
1.5 参数
day_slot = 288
#24⼩时*12(12是⼀⼩时有⼏个5分钟的时间⽚)
#⼀天有⼏个5分钟
n_train, n_val, n_test = 34, 5, 5
# 训练集(前34天)评估集(中间5天)测试集(最后5天)
n_his = 12
#⽤过去12个时间⽚段的交通数据
n_pred = 3
#预测未来的第3个时间⽚段的交通数据
n_route = 228
#⼦路段数量
Ks, Kt = 3, 3
#空间和时间卷积核⼤⼩
blocks = [[1, 32, 64], [64, 32, 128]]
##两个ST块各隐藏层⼤⼩
drop_prob = 0
#dropout概率
batch_size = 50
epochs = 50生意经
lr = 1e-3
1.6 图的⼀些操作
W = load_matrix(matrix_path)
#load_data⾥⾯的函数
芙蓉蛋
#邻接矩阵,是⼀个ndarray
L = scaled_laplacian(W)
#utils.py⾥⾯的函数
#标准化拉普拉斯矩阵,是⼀个ndarray
Lk = cheb_poly(L, Ks)
#L的切⽐雪夫多项式
#[Ks,n,n]⼤⼩的list(n是L的size)
Lk = torch.Tensor(Lk.astype(np.float32)).to(device) #转换成Tensor
1.7 归⼀化
train, val, test = load_data(
data_path,
n_train * day_slot,
怎么查询公司n_val * day_slot)
#训练集,测试集,验证集
#load_data load_data.py的函数
scaler = StandardScaler()
train = scaler.fit_transform(train)
val = ansform(val)
test = ansform(test)
#数据归⼀化(每⼀个点的数⼗天的数据归⼀化成N(0,1))
1.8 x,y的构造高发明
x_train, y_train = data_transform(train, n_his, n_pred, day_slot, device) #在load_data.py中
x_val, y_val = data_transform(val, n_his, n_pred, day_slot, device)
x_test, y_test = data_transform(test, n_his, n_pred, day_slot, device)
#分别是测试集、验证集和测试集的数据集和标签值
1.9 DataLoader
dataLoader部分见:
train_data = torch.utils.data.TensorDatat(x_train, y_train)
#先转化成pytorch可以识别的Datat格式
train_iter = torch.utils.data.DataLoader(
train_data,
batch_size,
shuffle=True)
#把datat导⼊dataloader,并设置batch_size和shuffle
val_data = torch.utils.data.TensorDatat(x_val, y_val)
val_iter = torch.utils.data.DataLoader(
val_data,
batch_size)
test_data = torch.utils.data.TensorDatat(x_test, y_test)
test_iter = torch.utils.data.DataLoader(
test_data,
batch_size)
'''
for x, y in train_iter:
print(x.size())
返回的结果都是:torch.Size([50, 1, 12, 228])
print(x.size())
返回的结果都是:torch.Size([50, 228])
'''
1.10 损失函数
loss = nn.MSELoss()
#均⽅误差
1.11 模型部分
model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device) #模型
1.12 优化函数
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
1.13 LRScheduler
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=5,
gamma=0.7)
#每经过5步,学习率乘0.7
1.14 模型的训练和保存
min_val_loss = np.inf
for epoch in range(1, epochs + 1):
l_sum, n = 0.0, 0
for x, y in train_iter:
y_pred = model(x).view(len(x), -1)
#x_size:50, 1, 12, 228]
l = loss(y_pred, y)
#计算误差
<_grad()
l.backward()
optimizer.step()
#pytorch三部曲
l_sum += l.item() * y.shape[0]
#y.shape[0]是50(⼀个batch 的数据量)
#因为我们的LOSS是MSELOSS,所以在计算loss的时候除了m(即50),这边就需要乘回去        n += y.shape[0]
#n表⽰⼀个epoch中总的数据量(其实就是34*288=9732)
scheduler.step()
#更新学习率
val_loss = evaluate_model(model, loss, val_iter)
关于春的古诗#在utils.py⾥⾯
#做⽤是求得验证集在当前这⼀组参属下的误差
if val_loss < min_val_loss:
欢度国庆简笔画min_val_loss = val_loss
torch.save(model.state_dict(), save_path)
#如果验证集得到的误差⼩,那么将验证集的参数保存
print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)
'''
epoch 1 , train loss: 0.2372948690231597 , validation loss: 0.17270135993722582
epoch 2 , train loss: 0.16071674468762734 , validation loss: 0.1874343464626883
epoch 3 , train loss: 0.15448020929178746 , validation loss: 0.15503579677238952
epoch 4 , train loss: 0.14851808142814324 , validation loss: 0.1571340094572001
epoch 5 , train loss: 0.14439846146427904 , validation loss: 0.1607034688638727
epoch 6 , train loss: 0.13501421282825268 , validation loss: 0.15179621507107777
epoch 7 , train loss: 0.13397674925686107 , validation loss: 0.1501583637547319
epoch 8 , train loss: 0.13199909963433504 , validation loss: 0.15549336293589894
epoch 9 , train loss: 0.13083163166267517 , validation loss: 0.1436274949678757
epoch 10 , train loss: 0.12860295229930127 , validation loss: 0.1711318050069313
epoch 11 , train loss: 0.12468195441724815 , validation loss: 0.14502346818845202
epoch 12 , train loss: 0.12422825037287816 , validation loss: 0.1424633072294893
epoch 13 , train loss: 0.12274483556448518 , validation loss: 0.14821374778003588
epoch 14 , train loss: 0.12206453774660224 , validation loss: 0.14754791510203025
epoch 15 , train loss: 0.12099895425406379 , validation loss: 0.14229175160183524
epoch 16 , train loss: 0.11788094088358396 , validation loss: 0.14172261148473642
epoch 17 , train loss: 0.11743906428081737 , validation loss: 0.14362958854023558
epoch 18 , train loss: 0.11658749032162606 , validation loss: 0.14289248521256187
epoch 19 , train loss: 0.11578559385394271 , validation loss: 0.14577691240684829
epoch 20 , train loss: 0.11517422387001339 , validation loss: 0.14248750845554972
epoch 21 , train loss: 0.11292880779622501 , validation loss: 0.14378667825384298
epoch 22 , train loss: 0.11236149522433111 , validation loss: 0.1418098776064215
epoch 23 , train loss: 0.11190123393005597 , validation loss: 0.14487336483532495
epoch 24 , train loss: 0.11122141592764404 , validation loss: 0.14256540075433952
epoch 24 , train loss: 0.11122141592764404 , validation loss: 0.14256540075433952 epoch 25 , train loss: 0.11055498759427415 , validation loss: 0.1417213207804156 epoch 26 , train loss: 0.10926588731084119 , validation loss: 0.14354881562673263 epoch 27 , train loss: 0.10878032141678218 , validation loss: 0.14406675109843703 epoch 28 , train loss: 0.10831604593266689 , validation loss: 0.14266293554356063 epoch 29 , train loss: 0.10783299739592932 , validation loss: 0.14181039777387233 epoch 30 , train loss: 0.10746425136239193 , validation loss: 0.14267496105256308 epoch 31 , train loss: 0.10646289705865472 , validation loss: 0.14362520976060064 epoch 32 , train loss: 0.10611696387435193 , validation loss: 0.1432999167183455 epoch 33 , train loss: 0.10574598974
132804 , validation loss: 0.14397347505020835 epoch 34 , train loss: 0.10544157493979493 , validation loss: 0.14419378039773798 epoch 35 , train loss: 0.1051575989090946 , validation loss: 0.1453490975537222
epoch 36 , train loss: 0.10441591932940965 , validation loss: 0.14409059120246964 epoch 37 , train loss: 0.10416163295225915 , validation loss: 0.1449487895915543 epoch 38 , train loss: 0.10386519186668972 , validation loss: 0.14444787363882047 epoch 39 , train loss: 0.10369502502373996 , validation loss: 0.14437076065988436 epoch 40 , train loss: 0.10344708665002564 , validation loss: 0.14485514112306339 epoch 41 , train loss: 0.10296985521567077 , validation loss: 0.1442400562801283 epoch 42 , train loss: 0.10274617794937922 , validation loss: 0.14564144609999047 epoch 43 , train loss: 0.10261664642584892 , validation loss: 0.14551366431924112 epoch 44 , train loss: 0.102446699424612 , validation loss: 0.14577252360699822
epoch 45 , train loss: 0.10227145068907287 , validation loss: 0.1455480455536477 epoch 46 , train loss: 0.10193707958222101 , validation loss: 0.1456132891050873 epoch 47 , train loss: 0.1017713555406352 , validation loss: 0.14567107602573223 epoch 48 , train loss: 0.10164602311826305 , validation loss: 0.14578005224194404 epoch 49 , train loss: 0.1015352703
7844785 , validation loss: 0.14653010304718123 epoch 50 , train loss: 0.10142039881116231 , validation loss: 0.1462976201607363
'''
1.15 加载最佳模型对应参数
best_model = STGCN(Ks, Kt, blocks, n_his, n_route, Lk, drop_prob).to(device)
best_model.load_state_dict(torch.load(save_path))
1.16  测评
l = evaluate_model(best_model, loss, test_iter)
MAE, MAPE, RMSE = evaluate_metric(best_model, test_iter, scaler)
print("test loss:", l, "\nMAE:", MAE, ", MAPE:", MAPE, ", RMSE:", RMSE)
'''
test loss: 0.13690029052052186
MAE: 2.2246220055150383 , MAPE: 0.051902304533065484 , RMSE: 3.995202803143325 '''
2 load_data.py
2.1 库函数导⼊
import torch
import numpy as np
import pandas as pd
2.2 load_matrix
def load_matrix(file_path):
ad_csv(file_path, header=None).values.astype(float)

本文发布于:2023-07-30 14:58:53,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/89/1102086.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:验证   数据   模型   保存   函数   数量   时间   测试
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图