强化学习rllib简明教程ray
强化学习rllib简明教程 ray
之前说到强化 学习的库,推荐了tianshou,但是tianshou实现的功能还不够多,于是转向rllib,个⼈还是很期待tianshou的发展。
回到rllib,rllib是基于ray的⼀个⼯具(不知道这么说是不是合适),ray和rllib的关系就像,mllib之于spark,ray是个分布式的计算框架。
,。进⼊官⽹,可以看到,蚂蚁⾦服也在使⽤这个框架,⼤⼚使⽤,不过本⼈只是为了快速实现⼀个强化学习的实验。
不过其⽂档存在着⼀些问题,⽐如:官⽅案例运⾏出错,⽂档长久未更新等。给我这种为了快速完成强化学习的菜鸟造成了⼀定的困难,本⼈亲⾃采坑,通过阅读源码等⽅式,把⼀些坑给踩了,现在记录在此,为后来的⼈少踩点坑。本⽂主要介绍rllib的⼀些基本功能和使⽤,在基于⽂档的基础上进⾏⼀些问题的解决和修补。最多会带⼀些tune的案例,对其他功能有需要的,请⾃⾏看⽂档采坑。
私以为⼀个完善的强化学习库,应该完成以下功能:
1. 训练参数
2. 训练结果
3. 测试过程 (强化学习有时候是为了找到最优⽅案或是ai认为期望最好的⽅案,所以在我们需要获得测试的过程)
4. 模型存储
5. 模型读取
6. ⾃定义环境
7. 结果复现ed
主要是以上7个功能,为了快速⼊门与简化过程,接下来会根据新的顺序来对以上七个功能进⾏实现。
ray版本1.2.0
1. ⾃定义环境
⾸先⼀个⾃定义环境必须继承⾃gym. Env,并实现ret和step⽅法,其他⽅法可实现可不实现,具体
可以参照gym的标准,我这⾥是根据tianshou的标准去写,但是和tianshou不同的是,在⽅法__init__中,必须带第⼆个参数,⽤于传递envconfig,否则会报错。
在这⾥我实现了⼀个简单的游戏,⽤于简化之后的实验,规则为
长度为10的线段,每次只能左右移动,节点标为0…9,起点为0,终点为9,超过100步则死亡-100。 到达9则胜利+100
myenv1.py
import random
import gym
import gym.spaces
import numpy as np
import traceback
import pprint
class GridEnv1(gym.Env):
'''
长度为10的线段,每次只能左右移动,节点标为0..9,
起点为0,终点为9,超过100步则死亡-100
到达9则胜利+100
'''
def__init__(lf,env_config):
lf.action_space=gym.spaces.Discrete(2)
lf.obrvation_space=gym.spaces.Box(np.array([0]),np.array([9]))
<()
def ret(lf):
'''
:return: state
'''
lf.obrvation =[0]
#lf.reward = 10
lf.done=Fal
lf.step_num=0
return[0]
def step(lf, action)->tuple:
'''
:param action:
:
return: tuple ->[obrvation,reward,done,info]
'''
#pprint.act_stack())
if action==0:
action=-1
lf.obrvation[0]+=action
lf.step_num+=1
reward=-1.0
if lf.step_num>100or lf.obrvation[0]<0:
reward=-100.0
lf.done=True
#print('last %d action %d now %d' % (lf.obrvation[0] - action, action, lf.obrvation[0])) return lf.obrvation,reward,lf.done,{}
if lf.obrvation[0]==9:
reward=100.0
lf.done=True
作文补习班#print('last %d action %d now %d'%(lf.obrvation[0]-action,action,lf.obrvation[0]))
return lf.obrvation,reward,lf.done,{}
def render(lf, mode='human'):
pass
2.训练参数
众所周知,深度学习⼜被称作炼丹,超参数很多,rllib的实验有两种启动⽅法,⼀种是rllib的底层api进⾏组合调⽤,另⼀种是tune.run 进⾏调⽤。以dqn为例
rllib-api
import ray
from ray.rllib.agents.dqn import DQNTrainer
from myenv1 import GridEnv1
ray.init()
trainer=DQNTrainer(
env=GridEnv1,
config={'framework':'tfe',
}
)
for i in range(10):
其中train⽅法调⽤⼀次即为训练⼀个世代,这是底层api,⽆法快速控制结束条件等其他参数,所以官⽅更推荐tune.run。
framework参数代表你要⽤什么框架,
tf:tensorflow,tfe: TensorFlow eager, torch: PyTorch。
其中tfe是⼯程模式,即刻计算张量,如果是tf,则会在构建图完成之后才计算,调试解阶段tfe可以看到过程。tf速度更快。
在过去,可以设置config中的config[“eager”] = True,完成模式的更改,现在这个设置已被弃⽤,想⽤⼯程模式的请使⽤framework
tune.run
from ray import tune
import ray
from ray.rllib.agents.dqn import DQNTrainer
from myenv1 import GridEnv1
ray.init()
t=tune.run(
无可奈何DQNTrainer,#此处可以⽤字符串,请⾃⾏进⼊⽂档查阅对应字符串
config={
'env':GridEnv1,
},
stop={
'episode_reward_max':91
}
)
tune会⾃动⽣成报告,并以stop为结束条件,上⾯为当⼀个世代的最⼤得分超过91时,停⽌训练。同时tune可以进⾏超参数寻优,但这不是本篇的主要内容。
上⾯是开始训练的两种⽅法,那config中有什么可以设置呢,config中的设置主要来源于两个地⽅,⼀个是基本的默认设置,另⼀个是根据你选定的trainer的默认设置,⽐如dqn就有⼀些其他算法没有的设置。第⼀种的设置如下
,算法配置请在算法列表中⾃⾏查找。
3.训练结果
如果是为了查看每⼀个世代的训练情况按照以下操作即可
rllib-API
ain()
print(t)
tune.run
运⾏tune.run之后会⾃⾏打印结果
同时有时候还会有获得过程中最优值的需求
这样的需求则需要调⽤回调类,回调类必须继承⾃DefaultCallbacks。因为需要传⼊⼀个类,所以我⾃⾏完成了⼀下动态类,供⼤家参考,主要是回调与全局锁变量。代码
record.py
import ray
@
class BestRecord:
人与自然关系def__init__(lf):
lf.bestVal =0
lf.bestAction =[]
lf.poolAction ={}
# eps_id:list[action]
lf.poolVal ={}
# eps_id:reward
def add(lf, sample):
'''
sample key
obs new_obs actions rewards dones
agent_index eps_id unroll_id weights
一霎时的意思'''
for index,item in enumerate(sample['eps_id']):
if not item in lf.poolVal:
lf.poolVal[item]=0
lf.poolAction[item]=[]
lf.poolAction[item].append(sample['actions'][index])
lf.poolVal[item]+=sample['rewards'][index]
if lf.poolVal[item]>lf.bestVal:
lf.bestVal=lf.poolVal[item]
lf.bestAction=lf.poolAction[item]
if sample['dones'][index]:
del lf.poolVal[item]
del lf.poolAction[item]
def getBest(lf):
return lf.bestVal,lf.bestAction
def getAll(lf):
return lf.poolVal,lf.poolAction
mycallback.py
长安旅夜from typing import Dict
import numpy as np
import ray
from ray.rllib.agents.callbacks import DefaultCallbacks
from v import BaEnv
坏事变好事的例子from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
def initDefaultCallbacks(logPrint=Fal,isRecord=Fal):
#MyCallbacks = type('MyCallbacks', (DefaultCallbacks))
class MyCallbacks(DefaultCallbacks):
pass
if logPrint:
def on_episode_start(lf,*, worker: RolloutWorker, ba_env: BaEnv,
policies: Dict[str, Policy],
policies: Dict[str, Policy],
episode: MultiAgentEpisode, env_index:int,**kwargs): print("episode {} (env-idx={}) started.".format(
episode.episode_id, env_index))
episode.ur_data["pole_angles"]=[]
episode.hist_data["pole_angles"]=[]
def on_episode_step(lf,*, worker: RolloutWorker, ba_env: BaEnv,
道路标志牌episode: MultiAgentEpisode, env_index:int,**kwargs): # pole_angle = abs(episode.last_obrvation_for()[2])
# #print(episode.last_obrvation_for())
# #返回最后⼀次观察
# raw_angle = abs(episode.last_raw_obs_for()[2])
# #print(episode.last_raw_obs_for())
# #返回指定代理的最后⼀个未预处理的对象存储服务
# asrt pole_angle == raw_angle
# episode.ur_data["pole_angles"].append(pole_angle)
# #print(episode.)
# print('episode are running')
# print(episode.last_obrvation_for())
asrt episode.last_obrvation_for()== episode.last_raw_obs_for()
def on_episode_end(lf,*, worker: RolloutWorker, ba_env: BaEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
env_index:int,**kwargs):
pole_angle = np.mean(episode.ur_data["pole_angles"])
print("episode {} (env-idx={}) ended with length {} and pole "
"angles {}".format(episode.episode_id, env_index, episode.length,
pole_angle))
episode.custom_metrics["pole_angle"]= pole_angle
episode.hist_data["pole_angles"]= episode.ur_data["pole_angles"]
def on_train_result(lf,*, trainer, result:dict,**kwargs):
print("ain() result: {} -> {} episodes".format(
trainer, result["episodes_this_iter"]))
行政人事专员
# you can mutate the result dict to add new fields to return
result["callback_ok"]=True
def on_postprocess_trajectory(
lf,*, worker: RolloutWorker, episode: MultiAgentEpisode,
agent_id:str, policy_id:str, policies: Dict[str, Policy],
postprocesd_batch: SampleBatch,
original_batches: Dict[str, SampleBatch],**kwargs):
print("postprocesd {} steps".format(unt))
if"num_batches"not in episode.custom_metrics:
episode.custom_metrics["num_batches"]=0
episode.custom_metrics["num_batches"]+=1
<_train_result=on_train_result
<_episode_start=on_episode_start
<_episode_step=on_episode_step
<_episode_end=on_episode_end
<_postprocess_trajectory=on_postprocess_trajectory
if isRecord:
def on_sample_end(lf,*, worker:"RolloutWorker", samples: SampleBatch, **kwargs)->None:
'''
sample key
obs new_obs actions rewards dones
agent_index eps_id unroll_id weights
:param worker:
:param samples:
:param kwargs:
:return:
'''