强化学习视角下的杆平衡问题

发布于 2024-12-25
更新于 2025-01-22

本文摘要:首先回顾一下贝尔曼方程,然后介绍CartPole问题,为便于理解CartPole问题,我们先从经典力学的角度分析,然后再通过强化学习方法解决。

贝尔曼方程

在强化学习中,我们希望找到一个最优策略,使得智能体在环境中能够获得最大的长期累积奖励。

一旦智能体确定了某个策略,那么该策略的价值函数就可以对每个状态或“状态—动作”二元组给出对应的期望回报值。最优价值函数对每个状态或“状态—动作”二元组给出了所有策略中最大的期望回报值。对于给定的MDP(马尔可夫决策过程),尽管状态或“状态—动作”二元组对应的最优价值函数是唯一的,但最优策略可能会有好多个。在最优价值函数的基础上,通过贪心算法得到的策略肯定是一个最优策略。

贝尔曼方程(Bellman Equation)是强化学习中的核心概念之一,用于描述最优策略下的状态价值(或状态-动作价值)。它建立在马尔可夫过程和动态规划的思想上推导得出的。

对于不同的问题,贝尔曼方程有两种常见的形式,分别是状态价值函数状态-动作价值函数 。在实际的强化学习算法应用中,通常会基于贝尔曼方程进行迭代计算来求解价值函数。

状态价值函数 $V(s)$: 某一状态的价值等于在该状态采取某一动作所获得的即时奖励,加上进入下一个状态后未来所有奖励的折扣总和。即从某状态 $s$ 开始的预期总奖励可以分解为:

  • 当前的即时奖励 $ R_t$。
  • 从下一状态 $ S_{t+1}$ 开始的未来奖励的折扣总和。

因此,有:

$$ V^\pi(s) = \mathbb{E}_\pi \left[ R_t + \gamma V^\pi(S_{t+1}) \mid S_t = s \right] $$

其中:

  • $\pi$:策略,定义了在每个状态下采取动作的概率分布。
  • $R_t$:当前时间步$t$的即时奖励。
  • $ \gamma$:折扣因子($ 0 \leq \gamma \leq 1$),表示未来奖励的重要程度,对后续的价值进行折扣加权,如果 $\gamma = 0$,智能体就只关心当前的即时奖励;而 $\gamma$ 接近 $1$ 时,智能体更看重长远的奖励情况。
  • $ V^\pi(s)$:表示在策略 $ \pi$ 下,从状态 $ s$ 开始的预期累计奖励。

利用期望的线性性质,这可以写成:

$$ V^\pi(s) = \sum_a \pi(a \mid s) \sum_{s'} P(s' \mid s, a) \left[ R(s, a, s') + \gamma V^\pi(s') \right] $$

其中:

  • $ \pi(a \mid s)$:在状态 $ s$ 下选择动作 $ a$ 的概率。
  • $ P(s' \mid s, a)$:从状态 $ s$ 通过动作 $ a$ 转移到状态 $ s'$ 的概率。
  • $ R(s, a, s')$:即时奖励。

对于最优策略,我们选择使得预期总奖励最大的动作 $a$,因此得到最优状态价值函数:

$$ V^*(s) = \max_a \sum_{s'} P(s' \mid s, a) \left[ R(s, a, s') + \gamma V^*(s') \right] $$

最优状态价值函数 $ V^*(s)$:
当选择使得预期总奖励最大的动作 $a$,策略 $ \pi$ 最优时,状态价值函数成为:

$$ V^*(s) = \max_a \mathbb{E} \left[ R_t + \gamma V^*(S_{t+1}) \mid S_t = s, A_t = a \right] $$

状态价值函数的直观解释是:在给定策略 $\pi$ 下,一个状态 $s$ 的价值等于按照该策略采取一个动作后,立即获得的奖励加上在下一状态 $s_{t + 1}$ 按照相同策略持续行动下去所能获得的价值(经过折扣因子 $\gamma$ 调整后的期望价值)的总和的期望。也就是说,一个状态的价值取决于当下能拿到的奖励以及后续状态的价值,它们共同构成了对当前状态整体价值的衡量。

状态-动作价值函数 $ Q(s, a)$:
Q函数,用 $Q(s,a)$ 表示,它表示在状态 $s$ 下采取动作 $a$ 后,按照最优策略持续行动下去未来能够获得的累积奖励的期望。

$$ Q^\pi(s, a) = \mathbb{E} \left[ R_t + \gamma V^\pi(S_{t+1}) \mid S_t = s, A_t = a \right] $$

最优情况下:

$$ Q^*(s, a) = \mathbb{E} \left[ R_t + \gamma \max_{a'} Q^*(S_{t+1}, a') \mid S_t = s, A_t = a \right] $$

状态-动作价值函数和状态价值函数的贝尔曼方程不同之处:

  • 状态-动作价值函数关注的是特定状态 $s$ 下采取具体动作 $a$ 的价值情况,也就是 $Q(s,a)$。
  • 方程右边在计算下一个状态的价值贡献时,是取 $\max_{a'} Q(s_{t + 1}, a')$,这是因为我们希望找到从下一个状态 $s_{t + 1}$ 出发采取最优动作(能使得Q值最大的动作)后对应的价值,毕竟我们的目标是要找到最优策略,所以要考虑按照最优方式行动所能带来的价值,而不是基于某个既定策略(像前面 $V^{\pi}(s)$ 那样)。

状态-动作价值函数含义是:在状态 $s$ 下执行动作 $a$ 的Q值,等于立即获得的奖励加上按照最优策略在下一状态 $s_{t + 1}$ 采取最优动作所能获得的最大Q值(经过折扣因子 $\gamma$ 调整后的)的期望。它建立了当前状态 - 动作对的价值与下一个状态最优动作价值之间的联系,为通过迭代方式寻找最优Q值从而确定最优策略提供了理论依据。

贝尔曼方程提供了强化学习的递归结构:

  1. 当前的最优值依赖于下一步的最优值。
  2. 智能体通过学习和迭代更新值函数或策略,最终找到最优解。

CartPole问题

前面的文章中我们讨论过双足机器人的线性倒立摆问题,那时是通过纯粹的动力学模型进行解算的,此外还有强化学习的相关入门,有多臂赌博机井字棋等问题,涉及马尔可夫决策过程动态规划、蒙特卡洛方法等。

CartPole问题是强化学习中一个经典的入门任务,是指在环境中有一个小车,有一个小车(Cart),上面竖立着一个杆(Pole),杆的一端铰接在小车的固定点上,杆可以自由摆动,我们需要控制小车的左右移动,使得杆保持竖直,并避免它倒下。小车的运动是一维的,仅x轴运动,环境的描述包括:小车在水平轴上的位置(x)、小车沿水平方向的速度(v)、杆与垂直方向的夹角(θ)、杆旋转的角速度(ω)。智能体的目标是找到一个策略(Policy),控制小车的运动,使得杆始终保持竖立。小车可以做出的动作包括向左推车或向右推车。

传统动力学模型

为便于理解,我们可以先用传统的力学模型来分析 CartPole 问题。问题的核心是通过控制小车的运动来保持杆的平衡,其物理系统可以通过牛顿第二定律建模,设:

  • m₁:小车质量
  • m₂:杆的质量
  • l:杆质心到与小车连接处的距离
  • g:重力加速度
  • x:小车的位置
  • θ:杆的角度(相对于竖直方向)
  • F:作用在小车上的水平力(智能体控制的输出)

为了便于理解,我们首先通过经典力学来分析一下这个问题。分析小车在水平方向上的加速度 $\ddot{x}$: 小车在水平方向上受到智能体施加的外力$F$和杆在固定点对小车施加的水平反作用力 $T_x$,根据牛顿第二定律,小车受到合力是:

$$ \text{合力} = \text{质量} \times \text{加速度} \\ F + T_x = m_1 \ddot{x} \tag{1} $$

再对杆进行分析,杆受到垂直方向的重力 $m_2 g$ 和小车在连接处的拉力,将拉力沿水平和垂直方向分解,水平分量 $T_x$ ,垂直分量 $T_y$。
杆的运动可以分解为绕质心的旋转和质心随小车的移动。
质心在水平方向上的位置为$x + l \sin\theta$,求二阶导,注意到$\theta$本身是时间$t$的导数,加速度为 $\ddot{x} + l \ddot{\theta} \cos\theta - l (\dot{\theta})^2 \sin\theta$,同理杆在垂直方向上的位置为 $ l \cos\theta$ ,加速度为 $-l \ddot{\theta} \sin\theta - l (\dot{\theta})^2 \cos\theta$

根据牛顿第二定律,杆在水平方向上

$$ T_x = m_2 \left(\ddot{x} + l \ddot{\theta} \cos\theta - l (\dot{\theta})^2 \sin\theta\right) \tag{2} $$

垂直方向上

$$ T_y - m_2 g = m_2 \left(-l \ddot{\theta} \sin\theta - l (\dot{\theta})^2 \cos\theta\right) \tag{3} $$

杆绕固定点的旋转满足角动量定理:

$$ \text{力矩} = \text{惯性矩} \times \text{角加速度} $$

杆的惯性矩是 $\frac{1}{3} m_2 l^2$。杆受到的力矩是 $T_x l$。所以:

$$ T_x l = \frac{1}{3} m_2 l^2 \ddot{\theta} \tag{4} $$

联立上面的四个关键方程,消去系统内力 $T_x$、$T_y$,可以得到小车的加速度 $\ddot{x}$ 和杆的角加速度 $\ddot{\theta}$ 的表达式。 小车的水平运动方程为:

$$ (m_1 + m_2) \ddot{x} + m_2 l \ddot{\theta} \cos\theta - m_2 l (\dot{\theta})^2 \sin\theta = F $$

杆的旋转运动方程为:

$$ l \ddot{\theta} - g \sin\theta + \ddot{x} \cos\theta = 0 $$

通过代数推导并消去 $\ddot{x}$ 和 $\ddot{\theta}$,可以得到小车系统的完整动力学描述。

上面的推导对下面的强化学习内容没什么用,只是利用旧的知识便于理解。

Q-learning解决CartPole问题

CartPole 问题可以被建模为一个马尔可夫决策过程(Markov Decision Process, MDP),状态由四个连续变量描述:

小车位置$x$,小车速度$\dot{x}$,杆的角度$\theta$,杆的杆的角速度 $\dot{\theta}$,有$s = [x, \dot{x}, \theta, \dot{\theta}]$。

动作空间是离散的:$a \in {0, 1}$,分别对应于:向左施加力和向右施加力。

奖励函数:

  • 如果杆仍然保持竖直,奖励 $r = 1$。
  • 如果杆倒下(超出一定角度 $\theta$ 或小车超出轨道范围),回合结束,且没有额外奖励。

Q-learning 是一种基于值函数的强化学习算法,其目标是学习一个最优的Q值函数 $Q(s, a)$,表示在状态 $s$ 下采取动作 $a$ 后,未来累积奖励的期望值。
Q-learning 的核心是 Bellman 方程,Q值的更新公式如下:

$$ Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right] $$
  • $s$:当前状态。
  • $a$:当前动作。
  • $r$:执行动作 $a$ 后获得的即时奖励。
  • $s'$:执行动作 $a$ 后的下一个状态。
  • $\alpha$:学习率,控制更新步长。
  • $\gamma$:折扣因子,衡量未来奖励的重要性。
  • $\max_{a'} Q(s', a')$:在下一个状态 $s'$ 下,选择使Q值最大的动作 $a'$。

在训练过程中,动作选择使用 $\epsilon$-贪婪策略:以概率 $\epsilon$ 随机选择动作(探索),以概率 $1-\epsilon$ 选择当前 Q值最大的动作(利用)。
在 Q-learning 中,智能体通过一个 Q-table 来记录每个状态-动作对的价值,而在 DQN 中,智能体会用深度神经网络来逼近 Q 值,从而处理更复杂的环境。

  1. 初始化Q值表 $Q(s, a)$ 为任意值(通常为0)。
  2. 在每个时间步:

    • 根据当前状态 $s$,使用策略(如 $\epsilon$-贪婪策略)选择动作 $a$。
    • 执行动作 $a$,观察奖励 $r$ 和下一个状态 $s'$。
    • 更新Q值:$Q(s, a) \leftarrow Q(s, a) + \alpha \left[ r + \gamma \max_{a'} Q(s', a') - Q(s, a) \right]$。
    • 更新状态:$s \leftarrow s'$。
  3. 重复上述步骤,直到Q值收敛。

我们利用gym库来解决CartPole问题,文档地址
CartPole 的状态是连续的(小车位置、速度、杆子角度、角速度),需要将其离散化为有限的区间。

bash
pip install gym numpy matplotlib

Python 代码实现:

python
import gym
import numpy as np
import matplotlib.pyplot as plt

# 创建 CartPole 环境
env = gym.make('CartPole-v1')

# 超参数
state_size = env.observation_space.shape[0]  # 状态空间维度
action_size = env.action_space.n            # 动作空间维度
alpha = 0.1                                 # 学习率
gamma = 0.99                                # 折扣因子
epsilon = 1.0                               # 初始探索率
epsilon_min = 0.01                          # 最小探索率
epsilon_decay = 0.995                       # 探索率衰减率
n_episodes = 1000                           # 训练回合数

# 离散化状态空间
def discretize_state(state, bins):
    """
    将连续状态离散化为离散区间。
    :param state: 当前状态(NumPy 数组)
    :param bins: 每个状态维度的离散化区间
    :return: 离散化后的状态(元组)
    """
    return tuple(np.digitize(state[i], bins[i]) for i in range(len(state)))

# 初始化 Q 值表
q_table = np.zeros((10, 10, 10, 10, action_size))  # 假设每个状态维度离散化为 10 个区间

# 定义状态离散化的区间
state_bins = [
    np.linspace(-4.8, 4.8, 10),  # 小车位置
    np.linspace(-5, 5, 10),      # 小车速度
    np.linspace(-0.418, 0.418, 10),  # 杆子角度
    np.linspace(-5, 5, 10)       # 杆子角速度
]

# 存储每个回合的总奖励,用于可视化
episode_rewards = []

# 训练过程
for episode in range(n_episodes):
    # 重置环境并获取初始状态
    state, _ = env.reset()  # 忽略第二个返回值(空字典)
    # 离散化初始状态
    state = discretize_state(state, state_bins)

    done = False
    total_reward = 0

    while not done:
        # ε-贪婪策略选择动作
        if np.random.rand() < epsilon:
            action = env.action_space.sample()  # 随机探索
        else:
            action = np.argmax(q_table[state])  # 选择最优动作

        # 执行动作
        next_state, reward, done, _, _ = env.step(action)  # 忽略额外的返回值
        # 离散化下一个状态
        next_state = discretize_state(next_state, state_bins)

        # 更新 Q 值
        old_value = q_table[state][action]
        next_max = np.max(q_table[next_state])
        new_value = old_value + alpha * (reward + gamma * next_max - old_value)
        q_table[state][action] = new_value

        # 更新状态
        state = next_state
        total_reward += reward

    # 衰减探索率
    if epsilon > epsilon_min:
        epsilon *= epsilon_decay

    # 记录当前回合的总奖励
    episode_rewards.append(total_reward)

    # 打印训练进度
    if (episode + 1) % 100 == 0:
        print(f"Episode: {episode + 1}, Total Reward: {total_reward}, Epsilon: {epsilon:.2f}")

# 训练结束后,绘制奖励曲线
plt.plot(episode_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Q-learning: Reward per Episode')
plt.show()

# 测试训练结果
state, _ = env.reset()  # 忽略第二个返回值(空字典)
state = discretize_state(state, state_bins)
done = False
total_reward = 0

while not done:
    action = np.argmax(q_table[state])
    next_state, reward, done, _, _ = env.step(action)  # 忽略额外的返回值
    next_state = discretize_state(next_state, state_bins)
    state = next_state
    total_reward += reward
    env.render()

print(f"Test Total Reward: {total_reward}")
env.close()

代码详解

环境重置env.reset()

将环境恢复到初始状态,并返回初始状态的信息。返回值通常是一个元组,具体格式取决于环境的实现。

对于 CartPole-v1 环境,env.reset() 的返回值如下:

python
(state, info)
  • state:环境的初始状态,通常是一个 NumPy 数组。对于 CartPole-v1,状态数组包含 4 个浮点数,分别表示:
    • 小车的位置(position
    • 小车的速度(velocity
    • 杆子的角度(angle
    • 杆子的角速度(angular velocity
  • info:额外的环境信息,通常是一个字典。对于 CartPole-v1info 是一个空字典 {}

以下是一个简单的示例,展示如何使用 env.reset()

python
import gym

# 创建 CartPole 环境
env = gym.make('CartPole-v1')

# 重置环境,获取初始状态
initial_state, info = env.reset()

print("Initial State:", initial_state)
print("Info:", info)

# 关闭环境
env.close()

输出

plaintext
Initial State: [ 0.01622422 -0.03802749 -0.03126878 -0.00195737]
Info: {}
  • 状态格式:不同环境的 state 格式可能不同。例如,CartPole-v1 的状态是一个包含 4 个浮点数的数组,而其他环境可能返回更复杂的数据结构。
  • 信息字典info 字典通常包含环境的额外信息,例如调试信息或统计信息。对于 CartPole-v1info 是一个空字典,但其他环境可能会返回有用的信息。
  • 随机性env.reset() 的初始状态通常是随机的,以确保智能体能够学习到在不同初始条件下的策略。

离散化discretize_state(state, bins)函数

作用是将连续的状态空间离散化为有限的离散区间。Q-learning 依赖于一个离散的 Q 值表来存储和更新状态-动作值。

  • 输入:连续状态 state 和离散化区间 bins
  • 输出:离散化后的状态,表示为一个元组,每个元素对应状态的一个维度所在的区间索引。

参数说明

  • state:当前的状态,通常是一个 NumPy 数组。对于 CartPole-v1,状态数组包含 4 个浮点数:

    • 小车的位置(position
    • 小车的速度(velocity
    • 杆子的角度(angle
    • 杆子的角速度(angular velocity
    • bins:一个列表,包含每个状态维度的离散化区间。例如:
python
bins = [
        np.linspace(-4.8, 4.8, 10),  # 小车位置的区间
        np.linspace(-5, 5, 10),      # 小车速度的区间
        np.linspace(-0.418, 0.418, 10),  # 杆子角度的区间
        np.linspace(-5, 5, 10)       # 杆子角速度的区间
    ]

每个维度的区间被均匀划分为若干个子区间(例如 10 个)。

实现原理

  • 对于状态数组中的每个值,使用 np.digitize 函数将其映射到对应的离散区间。
  • np.digitize 的作用是找到一个值在给定区间中的索引。例如:
python
np.digitize(0.5, [0, 1, 2])  # 返回 1,因为 0.5 在 [0, 1) 区间
  • 最终,离散化后的状态是一个元组,表示每个状态维度所在的区间索引。

示例
假设状态和离散化区间如下:

python
state = np.array([0.1, -0.2, 0.05, 0.3])
bins = [
    np.linspace(-1, 1, 3),  # 划分为 [-1, 0), [0, 1)
    np.linspace(-1, 1, 3),  # 划分为 [-1, 0), [0, 1)
    np.linspace(-1, 1, 3),  # 划分为 [-1, 0), [0, 1)
    np.linspace(-1, 1, 3)   # 划分为 [-1, 0), [0, 1)
]

调用 discretize_state(state, bins)

python
discretized_state = discretize_state(state, bins)
print(discretized_state)

输出

plaintext
(2, 1, 2, 2)
  • state[0] = 0.1 落在 [0, 1) 区间,索引为 2
  • state[1] = -0.2 落在 [-1, 0) 区间,索引为 1
  • state[2] = 0.05 落在 [0, 1) 区间,索引为 2
  • state[3] = 0.3 落在 [0, 1) 区间,索引为 2

注意边界处理np.digitize 默认将值映射到左闭右开区间。如果需要处理边界值,可以调整区间范围或使用 right=True 参数。

action = np.argmax(q_table[state])

代码action = np.argmax(q_table[state])作用是从 Q 值表中选择当前状态下具有最高 Q 值的动作。
Q 值表 (q_table) 是一个多维数组,存储了每个状态-动作对的 Q 值,对于 CartPole-v1,Q 值表的形状为 (10, 10, 10, 10, 2),其中前 4 个维度表示离散化后的状态空间(每个维度划分为 10 个区间),最后一个维度表示动作空间(2 个动作:向左或向右)。

state 是当前状态的离散化表示,通常是一个元组,例如 (2, 1, 2, 2),这个元组表示当前状态在每个维度上的区间索引。
q_table[state] 是从 Q 值表中提取的当前状态对应的 Q 值向量。对于 CartPole-v1q_table[state] 是一个长度为 2 的数组,表示在当前状态下,每个动作的 Q 值。例如:

python
q_table[state] = [0.5, 0.8]

这表示:动作 0(向左)的 Q 值为 0.5。动作 1(向右)的 Q 值为 0.8

np.argmax 是一个 NumPy 函数,用于返回数组中最大值所在的索引。

next_state, reward, done, _, _ = env.step(action)

作用是执行智能体选择的动作,并获取环境的反馈信息。env.step(action)gym 库中用于执行动作的方法。它接受一个动作 action 作为输入,并返回以下信息:

  • next_state
    • 执行动作后的下一个状态,通常是一个 NumPy 数组。
    • 对于 CartPole-v1next_state 包含 4 个浮点数:
      • 小车的位置(position
      • 小车的速度(velocity
      • 杆子的角度(angle
      • 杆子的角速度(angular velocity
  • reward
    • 执行动作后获得的即时奖励。
    • 对于 CartPole-v1,每保持杆子直立一个时间步,奖励为 +1
  • done
    • 一个布尔值,表示当前回合是否结束。
    • 如果 doneTrue,表示回合结束(例如,杆子倒下或小车超出边界)。
  • info
    • 一个字典,包含额外的调试信息。
    • 对于 CartPole-v1info 通常是一个空字典 {}
  • _
    • 在某些 gym 版本中,env.step() 可能返回第五个值,通常忽略。

在 Q-learning 中,env.step(action) 是智能体与环境交互的关键步骤。它的作用包括:

  1. 执行动作:智能体根据当前状态选择一个动作 action,并通过 env.step(action) 执行该动作。
  2. 获取反馈:环境返回执行动作后的下一个状态 next_state、即时奖励 reward 以及回合是否结束的标志 done
  3. 更新 Q 值表:使用 next_statereward 更新 Q 值表。
  4. 判断回合结束:如果 doneTrue,则结束当前回合,并调用 env.reset() 开始下一个回合。

深度Q网络(DQN)

DQN 是一种基于神经网络的强化学习算法,其核心是使用神经网络逼近 Q函数。

如果状态空间较大或连续(如 CartPole),我们无法用表格存储 $Q(s, a)$。DQN 使用神经网络近似 Q函数,其目标是最小化以下损失函数:

$$ L(\theta) = \mathbb{E}\left[ \left( y - Q(s, a; \theta) \right)^2 \right] $$

其中目标值 $y$ 是:

$$ y = r + \gamma \max_{a'} Q(s', a'; \theta^-) $$
  • $\theta$ 是当前网络的参数。
  • $\theta^-$ 是目标网络的参数(固定一段时间再更新)。

通过梯度下降更新网络参数,可以逐步逼近最优的 Q函数。

python
import numpy as np
import tensorflow as tf
import gym
import random
from collections import deque
import matplotlib.pyplot as plt

# 初始化环境
env = gym.make('CartPole-v1')

# 超参数
state_size = env.observation_space.shape[0]  # 状态空间维度
action_size = env.action_space.n  # 动作空间维度
gamma = 0.99  # 折扣因子
learning_rate = 0.001  # 学习率
batch_size = 64  # 经验回放的批量大小
max_episodes = 500  # 最大训练回合数
max_steps = 200  # 每回合最大步数
memory_size = 2000  # 经验回放池大小
epsilon = 1.0  # 初始探索概率
epsilon_decay = 0.995  # 探索概率衰减因子
epsilon_min = 0.01  # 最低探索概率
update_target_frequency = 5  # 更新目标网络的频率

# 经验回放池
memory = deque(maxlen=memory_size)

# 创建 Q 网络
def build_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Dense(24, activation='relu', input_shape=(state_size,)),
        tf.keras.layers.Dense(24, activation='relu'),
        tf.keras.layers.Dense(action_size, activation='linear')
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss='mse')
    return model

# 主网络和目标网络
q_network = build_model()
target_network = build_model()
target_network.set_weights(q_network.get_weights())

# 存储奖励
rewards_per_episode = []

def replay_experience():
    """经验回放,用于训练 Q 网络。"""
    if len(memory) < batch_size:
        return

    batch = random.sample(memory, batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)

    states = np.array(states)
    next_states = np.array(next_states)

    # 预测当前和下一个状态的 Q 值
    q_values = q_network.predict(states)
    q_next_values = target_network.predict(next_states)

    for i in range(batch_size):
        if dones[i]:
            q_values[i][actions[i]] = rewards[i]
        else:
            q_values[i][actions[i]] = rewards[i] + gamma * np.max(q_next_values[i])

    # 更新 Q 网络
    q_network.fit(states, q_values, epochs=1, verbose=0)

# 主循环
for episode in range(max_episodes):
    state = env.reset()[0]  # 获取初始状态
    total_reward = 0

    for step in range(max_steps):
        # 根据 epsilon-greedy 策略选择动作
        if np.random.rand() < epsilon:
            action = env.action_space.sample()
        else:
            q_values = q_network.predict(np.expand_dims(state, axis=0))
            action = np.argmax(q_values[0])

        # 执行动作
        next_state, reward, done, _ = env.step(action)
        next_state = next_state

        # 存储经验
        memory.append((state, action, reward, next_state, done))

        state = next_state
        total_reward += reward

        if done:
            break

        # 经验回放
        replay_experience()

    # 每隔一定频率更新目标网络
    if episode % update_target_frequency == 0:
        target_network.set_weights(q_network.get_weights())

    # 减少 epsilon
    epsilon = max(epsilon_min, epsilon * epsilon_decay)

    # 记录奖励
    rewards_per_episode.append(total_reward)

    if (episode + 1) % 10 == 0:
        print(f"Episode {episode + 1}: Average Reward: {np.mean(rewards_per_episode[-10:])}")

# 绘制学习曲线
plt.plot(rewards_per_episode)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Learning Curve for CartPole using DQN')
plt.show()

# 关闭环境
env.close()

最好在有较好GPU配置的环境下运行该程序。