强化学习实战-使用Sarsa算法解决悬崖问题

更新时间:2023-06-02 18:55:05 阅读: 评论:0

源程序
# Step1 导⼊依赖
import gym
import numpy as np
import time
import matplotlib.pyplot as plt
# Step2 定义Agent
class SarsaAgent(object):
风诗def__init__(lf, obs_n, act_n, lr, gamma, epsilon):        lf.obs_n = obs_n
lf.act_n = act_n
lf.lr = lr
lf.Q_table = np.zeros((obs_n, act_n))
def sample(lf, obs):
梦见牙出血"""
根据输⼊观察值,采样输出的动作值,带探索
:param obs:当前state
:return: 下⼀个动作
"""
action =0
if np.random.uniform(0,1)<(1.0- lf.epsilon):# 根据table的Q值选动作
action = lf.predict(obs)
el:
action = np.random.choice(lf.act_n)# 有⼀定概率随机探索选取⼀个动作return action
def predict(lf, obs):
'''
根据输⼊观察值,预测输出的动作值
:param obs:当前state
:return:预测的动作
'''
Q_list = lf.Q_table[obs,:]
maxQ = np.max(Q_list)
action_list = np.where(Q_list == maxQ)[0]# maxQ可能对应多个action
action = np.random.choice(action_list)
return action
def learn(lf, obs, act, reward, next_obs, next_act, done):
'''
on-policy
:param obs:交互前的obs, s_t禁毒标语
:param act:本次交互选择的action, a_t
:param reward:本次动作获得的奖励r
:param next_obs:本次交互后的obs, s_t+1
:param next_act:根据当前Q表格, 针对next_obs会选择的动作, a_t+1
:param done:episode是否结束
:return:null
恢复硬盘
'''
predict_Q = lf.Q_table[obs, act]
if done:
target_Q = reward  # 没有下⼀个状态了
el:
target_Q = reward + lf.gamma * lf.Q_table[next_obs, next_act]# Sarsa        lf.Q_table[obs, act]+= lf.lr *(target_Q - predict_Q)# 修正q
# 保存Q表格数据到⽂件
def save(lf):
npy_file ='./q_table.npy'
np.save(npy_file, lf.Q_table)有你才有温暖的家
print(npy_file +' saved.')
# 从⽂件中读取Q值到Q表格中
def restore(lf, npy_file='./q_table.npy'):
lf.Q_table = np.load(npy_file)
print(npy_file +' loaded.')
# Step3 Training && Test(训练&&测试)
def train_episode(env, agent, render=Fal):
total_reward =0
total_steps =0# 记录每个episode⾛了多少step
obs = ()
act = agent.sample(obs)
while True:
next_obs, reward, done, _ = env.step(act)# 与环境进⾏⼀个交互
next_act = agent.sample(next_obs)# 根据算法选择⼀个动作
# 训练Sarsa算法
# 训练Sarsa算法
agent.learn(obs, act, reward, next_obs, next_act, done)
act = next_act
obs = next_obs  # 存储上⼀个观察值
total_reward += reward
total_steps +=1
if render:
if done:
break
return total_reward, total_steps
def test_episode(env, agent):
total_reward =0
total_steps =0# 记录每个episode⾛了多少step
obs = ()
while True:
action = agent.predict(obs)# greedy
做用英语怎么说next_obs, reward, done, _ = env.step(action)
total_reward += reward
total_steps +=1
obs = next_obs
# time.sleep(0.5)
# der()
if done:
break
购房意向书return total_reward, total_steps
# Step4 创建环境和Agent,启动训练
# 使⽤gym创建悬崖环境
env = gym.make("CliffWalking-v0")# 0 up, 1 right, 2 down, 3 left
# 创建⼀个agent实例,输⼊超参数
agent = SarsaAgent(
obs_n=env.obrvation_space.n,
act_n=env.action_space.n,
lr=0.001,
gamma=0.99,
epsilon=0.1
)
print("Start training ...")
total_reward_list =[]
# 训练1000个episode,打印每个episode的分数
for episode in range(1000):
a型血的男人ep_reward, ep_steps = train_episode(env, agent,Fal)
total_reward_list.append(ep_reward)
if episode %50==0:
print('Episode %s: steps = %s , reward = %.1f'%(episode, ep_steps, ep_reward))
print("Train end")
def show_reward(total_reward):
N =len(total_reward)
x = np.linspace(0, N,1000)
plt.plot(x, total_reward,'b-', lw=1, ms=5)
plt.show()
show_reward(total_reward_list)
# 全部训练结束,查看算法效果
test_reward, test_steps = test_episode(env, agent)
Start training ...
Episode 0: steps = 1160 , reward = -3239.0 Episode 50: steps = 138 , reward = -237.0

本文发布于:2023-06-02 18:55:05,感谢您对本站的认可!

本文链接:https://www.wtabcd.cn/fanwen/fan/82/836146.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:动作   算法   表格   训练   结束
相关文章
留言与评论(共有 0 条评论)
   
验证码:
推荐文章
排行榜
Copyright ©2019-2022 Comsenz Inc.Powered by © 专利检索| 网站地图