一文教你在MindSpore中实现A2C算法训练

mindspore,a2c · 浏览次数 : 0

小编点评

本文介绍了MindSpore A2C算法的基本原理和伪代码实现。A2C算法是一种强化学习算法,通过结合策略梯度和价值函数的方法,利用神经网络对状态和动作进行建模。本文主要阐述了: 1. A2C算法结构:A2C算法包括两个主要部分,分别是策略网络(Actor)和价值网络(Critic)。Actor负责选择一个动作,Critic负责评估状态-动作对的值。 2. 训练过程:文章提供了A2C算法的伪代码,详细说明了算法的整体流程,包括初始化、环境交互、状态值计算、TD误差计算、价值网络更新、策略网络更新等关键步骤。 3. 算法特点:文章还提到了A2C算法中的重要概念,如TD误差和优势函数(Advantage Function),这些概念对于理解算法的性能优化至关重要。 4. 参数设置:最后,文章列出了训练A2C算法所需的配置参数,如学习率、状态空间维数、动作空间维数、隐藏层大小以及折扣因子等,这些参数对于模型训练的效果有着直接的影响。 总之,通过对MindSpore A2C算法的结构、训练过程和参数设置的详细介绍,本文为读者提供了一个清晰的理解路径,帮助其更好地掌握和应用这一强化学习算法。

正文

本文分享自华为云社区《MindSpore A2C 强化学习》,作者:irrational。

Advantage Actor-Critic (A2C)算法是一个强化学习算法,它结合了策略梯度(Actor)和价值函数(Critic)的方法。A2C算法在许多强化学习任务中表现优越,因为它能够利用价值函数来减少策略梯度的方差,同时直接优化策略。

A2C算法的核心思想

  • Actor:根据当前策略选择动作。
  • Critic:评估一个状态-动作对的值(通常是使用状态值函数或动作值函数)。
  • 优势函数(Advantage Function):用来衡量某个动作相对于平均水平的好坏,通常定义为A(s,a)=Q(s,a)−V(s)。

A2C算法的伪代码

以下是A2C算法的伪代码:

Initialize policy network (actor) π with parameters θ
Initialize value network (critic) V with parameters w
Initialize learning rates α_θ for policy network and α_w for value network

for each episode do
    Initialize state s
    while state s is not terminal do
        # Actor: select action a according to the current policy π(a|s; θ)
        a = select_action(s, θ)
        
        # Execute action a in the environment, observe reward r and next state s'
        r, s' = environment.step(a)
        
        # Critic: compute the value of the current state V(s; w)
        V_s = V(s, w)
        
        # Critic: compute the value of the next state V(s'; w)
        V_s_prime = V(s', w)
        
        # Compute the TD error (δ)
        δ = r + γ * V_s_prime - V_s
        
        # Critic: update the value network parameters w
        w = w + α_w * δ * ∇_w V(s; w)
        
        # Compute the advantage function A(s, a)
        A = δ
        
        # Actor: update the policy network parameters θ
        θ = θ + α_θ * A * ∇_θ log π(a|s; θ)
        
        # Move to the next state
        s = s'
    end while
end for

解释

  1. 初始化:初始化策略网络(Actor)和价值网络(Critic)的参数,以及它们的学习率。
  2. 循环每个Episode:在每个Episode开始时,初始化状态。
  3. 选择动作:根据当前策略从Actor中选择动作。
  4. 执行动作:在环境中执行动作,并观察奖励和下一个状态。
  5. 计算状态值:用Critic评估当前状态和下一个状态的值。
  6. 计算TD误差:计算时序差分误差(Temporal Difference Error),它是当前奖励加上下一个状态的折扣值与当前状态值的差。
  7. 更新Critic:根据TD误差更新价值网络的参数。
  8. 计算优势函数:使用TD误差计算优势函数。
  9. 更新Actor:根据优势函数更新策略网络的参数。
  10. 更新状态:移动到下一个状态,重复上述步骤,直到Episode结束。

这个伪代码展示了A2C算法的核心步骤,实际实现中可能会有更多细节,如使用折扣因子γ、多个并行环境等。

代码如下:

import argparse

from mindspore import context
from mindspore import dtype as mstype
from mindspore.communication import init

from mindspore_rl.algorithm.a2c import config
from mindspore_rl.algorithm.a2c.a2c_session import A2CSession
from mindspore_rl.algorithm.a2c.a2c_trainer import A2CTrainer

parser = argparse.ArgumentParser(description="MindSpore Reinforcement A2C")
parser.add_argument("--episode", type=int, default=10000, help="total episode numbers.")
parser.add_argument(
    "--device_target",
    type=str,
    default="CPU",
    choices=["CPU", "GPU", "Ascend", "Auto"],
    help="Choose a devioptions.device_targece to run the ac example(Default: Auto).",
)
parser.add_argument(
    "--precision_mode",
    type=str,
    default="fp32",
    choices=["fp32", "fp16"],
    help="Precision mode",
)
parser.add_argument(
    "--env_yaml",
    type=str,
    default="../env_yaml/CartPole-v0.yaml",
    help="Choose an environment yaml to update the a2c example(Default: CartPole-v0.yaml).",
)
parser.add_argument(
    "--algo_yaml",
    type=str,
    default=None,
    help="Choose an algo yaml to update the a2c example(Default: None).",
)
parser.add_argument(
    "--enable_distribute",
    type=bool,
    default=False,
    help="Train in distribute mode (Default: False).",
)
parser.add_argument(
    "--worker_num",
    type=int,
    default=2,
    help="Worker num (Default: 2).",
)
options, _ = parser.parse_known_args()

首先初始化参数,然后我这里用cpu运行:options.device_targe = “CPU”

episode=options.episode
"""Train a2c"""
if options.device_target != "Auto":
    context.set_context(device_target=options.device_target)
if context.get_context("device_target") in ["CPU", "GPU"]:
    context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE)
compute_type = (
    mstype.float32 if options.precision_mode == "fp32" else mstype.float16
)
config.algorithm_config["policy_and_network"]["params"][
    "compute_type"
] = compute_type
if compute_type == mstype.float16 and options.device_target != "Ascend":
    raise ValueError("Fp16 mode is supported by Ascend backend.")
is_distribte = options.enable_distribute
if is_distribte:
    init()
    context.set_context(enable_graph_kernel=False)
    config.deploy_config["worker_num"] = options.worker_num
a2c_session = A2CSession(options.env_yaml, options.algo_yaml, is_distribte)

设置上下文管理器

import sys
import time
from io import StringIO

class RealTimeCaptureAndDisplayOutput(object):
    def __init__(self):
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        self.captured_output = StringIO()

    def write(self, text):
        self._original_stdout.write(text)  # 实时打印
        self.captured_output.write(text)   # 保存到缓冲区

    def flush(self):
        self._original_stdout.flush()
        self.captured_output.flush()

    def __enter__(self):
        sys.stdout = self
        sys.stderr = self
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr
episode=10
# dqn_session.run(class_type=DQNTrainer, episode=episode)
with RealTimeCaptureAndDisplayOutput() as captured_new:
    a2c_session.run(class_type=A2CTrainer, episode=episode)
import re
import matplotlib.pyplot as plt

# 原始输出
raw_output = captured_new.captured_output.getvalue()

# 使用正则表达式从输出中提取loss和rewards
loss_pattern = r"loss=(\d+\.\d+)"
reward_pattern = r"running_reward=(\d+\.\d+)"
loss_values = [float(match.group(1)) for match in re.finditer(loss_pattern, raw_output)]
reward_values = [float(match.group(1)) for match in re.finditer(reward_pattern, raw_output)]

# 绘制loss曲线
plt.plot(loss_values, label='Loss')
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show()

# 绘制reward曲线
plt.plot(reward_values, label='Rewards')
plt.xlabel('Episode')
plt.ylabel('Rewards')
plt.title('Rewards Curve')
plt.legend()
plt.show()

展示结果:
image.png

image.png

下面我将详细解释你提供的 MindSpore A2C 算法训练配置参数的含义:

Actor 配置

'actor': {
  'number': 1,
  'type': mindspore_rl.algorithm.a2c.a2c.A2CActor,
  'params': {
    'collect_environment': PyFuncWrapper<
       (_envs): GymEnvironment<>
     >,
   'eval_environment': PyFuncWrapper<
     (_envs): GymEnvironment<>
     >,
   'replay_buffer': None,
   'a2c_net': ActorCriticNet<
     (common): Dense<input_channels=4, output_channels=128, has_bias=True>
     (actor): Dense<input_channels=128, output_channels=2, has_bias=True>
     (critic): Dense<input_channels=128, output_channels=1, has_bias=True>
     (relu): LeakyReLU<>
     >},
  'policies': [],
  'networks': ['a2c_net']
}
  • number: Actor 的实例数量,这里设置为1,表示使用一个 Actor 实例。
  • type: Actor 的类型,这里使用 mindspore_rl.algorithm.a2c.a2c.A2CActor
  • params: Actor 的参数配置。
    • collect_environment 和 eval_environment: 使用 PyFuncWrapper 包装的 GymEnvironment,用于数据收集和评估环境。
    • replay_buffer: 设置为 None,表示不使用经验回放缓冲区。
    • a2c_net: Actor-Critic 网络,包含一个公共层、一个 Actor 层和一个 Critic 层,以及一个 Leaky ReLU 激活函数。
  • policies 和 networks: Actor 关联的策略和网络,这里主要是 a2c_net

Learner 配置

'learner': {
  'number': 1,
  'type': mindspore_rl.algorithm.a2c.a2c.A2CLearner,
  'params': {
    'gamma': 0.99,
    'state_space_dim': 4,
    'action_space_dim': 2,
    'a2c_net': ActorCriticNet<
      (common): Dense<input_channels=4, output_channels=128, has_bias=True>
      (actor): Dense<input_channels=128, output_channels=2, has_bias=True>
      (critic): Dense<input_channels=128, output_channels=1, has_bias=True>
      (relu): LeakyReLU<>
    >,
    'a2c_net_train': TrainOneStepCell<
      (network): Loss<
        (a2c_net): ActorCriticNet<
          (common): Dense<input_channels=4, output_channels=128, has_bias=True>
          (actor): Dense<input_channels=128, output_channels=2, has_bias=True>
          (critic): Dense<input_channels=128, output_channels=1, has_bias=True>
          (relu): LeakyReLU<>
        >
        (smoothl1_loss): SmoothL1Loss<>
      >
      (optimizer): Adam<>
      (grad_reducer): Identity<>
    >
  },
  'networks': ['a2c_net_train', 'a2c_net']
}
  • number: Learner 的实例数量,这里设置为1,表示使用一个 Learner 实例。
  • type: Learner 的类型,这里使用 mindspore_rl.algorithm.a2c.a2c.A2CLearner
  • params: Learner 的参数配置。
    • gamma: 折扣因子,用于未来奖励的折扣计算。
    • state_space_dim: 状态空间的维度,这里为4。
    • action_space_dim: 动作空间的维度,这里为2。
    • a2c_net: Actor-Critic 网络定义,与 Actor 中相同。
    • a2c_net_train: 用于训练的网络,包含损失函数(SmoothL1Loss)、优化器(Adam)和梯度缩减器(Identity)。
  • networks: Learner 关联的网络,包括 a2c_net_train 和 a2c_net

Policy and Network 配置

'policy_and_network': {
  'type': mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork,
  'params': {
    'lr': 0.01,
    'state_space_dim': 4,
    'action_space_dim': 2,
    'hidden_size': 128,
    'gamma': 0.99,
    'compute_type': mindspore.float32,
    'environment_config': {
      'id': 'CartPole-v0',
      'entry_point': 'gym.envs.classic_control:CartPoleEnv',
      'reward_threshold': 195.0,
      'nondeterministic': False,
      'max_episode_steps': 200,
      '_kwargs': {},
      '_env_name': 'CartPole'
    }
  }
}
  • type: 策略和网络的类型,这里使用 mindspore_rl.algorithm.a2c.a2c.A2CPolicyAndNetwork
  • params: 策略和网络的参数配置。
    • lr: 学习率,这里为0.01。
    • state_space_dim 和 action_space_dim: 状态和动作空间的维度。
    • hidden_size: 隐藏层的大小,这里为128。
    • gamma: 折扣因子。
    • compute_type: 计算类型,这里为 mindspore.float32
    • environment_config: 环境配置,包括环境 ID、入口、奖励阈值、最大步数等。

Collect Environment 配置

'collect_environment': {
  'number': 1,
  'type': mindspore_rl.environment.gym_environment.GymEnvironment,
  'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],
  'params': {
    'GymEnvironment': {
      'name': 'CartPole-v0',
      'seed': 42
    },
    'name': 'CartPole-v0'
  }
}
  • number: 环境实例数量,这里为1。
  • type: 环境的类型,这里使用 mindspore_rl.environment.gym_environment.GymEnvironment
  • wrappers: 环境使用的包装器,这里是 PyFuncWrapper
  • params: 环境的参数配置,包括环境名称 CartPole-v0 和随机种子 42

Eval Environment 配置

'eval_environment': {
  'number': 1,
  'type': mindspore_rl.environment.gym_environment.GymEnvironment,
  'wrappers': [mindspore_rl.environment.pyfunc_wrapper.PyFuncWrapper],
  'params': {
    'GymEnvironment': {
      'name': 'CartPole-v0',
      'seed': 42
    },
    'name': 'CartPole-v0'
  }
}
  • 配置与 collect_environment 类似,用于评估模型性能。

总结一下,这些配置定义了 Actor-Critic 算法在 MindSpore 框架中的具体实现,包括 Actor 和 Learner 的设置、策略和网络的参数,以及训练和评估环境的配置。这个还是比较基础的。

点击关注,第一时间了解华为云新鲜技术~

 

与一文教你在MindSpore中实现A2C算法训练相似的内容:

一文教你在MindSpore中实现A2C算法训练

文中的配置定义了 Actor-Critic 算法在 MindSpore 框架中的具体实现,包括 Actor 和 Learner 的设置、策略和网络的参数,以及训练和评估环境的配置。

这是你没见过的MindSpore 2.0.0 for Windows GPU版

摘要:一文带你看看MindSpore 2.0.0 for Windows GPU版。 本文分享自华为云社区《MindSpore 2.0.0 for Windows GPU泄漏版尝鲜》,作者:张辉 。 在看了MindSpore架构师王磊老师的帖子( https://zhuanlan.zhihu.com

一文教你理解Kafka offset

日常开发中,相信大家都对 Kafka 有所耳闻,Kafka 作为一个分布式的流处理平台,一般用来存储和传输大量的消息数据。在 Kafka 中有三个重要概念,分别是 topic、partition 和 offset。 topic 是 kafka 中的消息以主题为单位进行归类的逻辑概念,生产者负责将消息

一文教你如何调用Ascend C算子

本文分享自华为云社区《一文教你如何调用Ascend C算子》,作者: 昇腾CANN。 Ascend C是CANN针对算子开发场景推出的编程语言,原生支持C和C++标准规范,兼具开发效率和运行性能。基于Ascend C编写的算子程序,通过编译器编译和运行时调度,运行在昇腾AI处理器上。使用Ascend

netty系列之: 在netty中使用 tls 协议请求 DNS 服务器

简介 在前面的文章中我们讲过了如何在netty中构造客户端分别使用tcp和udp协议向DNS服务器请求消息。在请求的过程中并没有进行消息的加密,所以这种请求是不安全的。 那么有同学会问了,就是请求解析一个域名的IP地址而已,还需要安全通讯吗? 事实上,不加密的DNS查询消息是很危险的,如果你在访问一

Docker 中的 .NET 异常了怎么抓 Dump

## 一:背景 ### 1. 讲故事 有很多朋友跟我说,在 Windows 上看过你文章知道了怎么抓 Crash, CPU爆高,内存暴涨 等各种Dump,为什么你没有写在 Docker 中如何抓的相关文章呢?瞧不上吗? 哈哈,在DUMP的分析旅程中,跑在 Docker 中的 .NET 占比真的不多,

4A 安全之授权:编程的门禁,你能解开吗?

概述 在安全管理系统里面,授权(Authorization)的概念常常是和认证(Authentication)、账号(Account)和审计(Audit)一起出现的,并称之为 4A。就像上一文章提到的,对于安全模块的实现,最好都遵循行业标准和最佳实践,授权也不例外。 作为安全系统的一部分,授权的职责

[转帖]怎么查看Linux服务器硬件信息,这些命令告诉你

https://zhuanlan.zhihu.com/p/144368206 Linux服务器配置文档找不到,你还在为查询Linux服务器硬件信息发愁吗?学会这些命令,让你轻松查看Linux服务器的CPU,内存,硬盘,SN序列号等信息,根本就不用去机房。 一、查看CPU信息 CPU信息常常包括查看C

文心一言,通营销之学,成一家之言,百度人工智能AI大数据模型文心一言Python3.10接入

“文心”取自《文心雕龙》一书的开篇,作者刘勰在书中引述了一个古代典故:春秋时期,鲁国有一位名叫孔文子的大夫,他在学问上非常有造诣,但是他的儿子却不学无术,孔文子非常痛心。 一天,孔文子在山上遇到了一位神仙,神仙告诉他:“你的儿子之所以不学无术,是因为你没有给他灌输文心,让他懂得文学的魅力和意义。”孔

.NET性能系列文章一:.NET7的性能改进

这些方法在.NET7中变得更快 照片来自 CHUTTERSNAP 的 Unsplash 欢迎阅读.NET性能系列的第一章。这一系列的特点是对.NET世界中许多不同的主题进行研究、比较性能。正如标题所说的那样,本章节在于.NET7中的性能改进。你将看到哪种方法是实现特定功能最快的方法,以及大量的技巧和