跳到主要内容

Dueling DQN 算法实战

算法主要是改了网络结构,其他地方跟 是一模一样的,如代码清单 1 所示。

代码清单 1 网络结构
class DuelingQNetwork(nn.Module):
def __init__(self, state_dim, action_dim,hidden_dim=128):
super(DuelingQNetwork, self).__init__()
# 隐藏层
self.hidden_layer = nn.Sequential(
nn.Linear(state_dim, hidden_dim),
nn.ReLU()
)
# 优势层
self.advantage_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim)
)
# 价值层
self.value_layer = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1)
)

def forward(self, state):
x = self.hidden_layer(state)
advantage = self.advantage_layer(x)
value = self.value_layer(x)
return value + advantage - advantage.mean() # Q(s,a) = V(s) + A(s,a) - mean(A(s,a))

最后我们展示一下它在 环境下的训练结果,如图 1 所示,完整的代码同样可以参考本书的代码仓库。

图 1 环境 算法训练曲线

由于环境比较简单,暂时还看不出来 算法的优势,但是在复杂的环境下,比如 游戏中, 算法的效果就会比 算法好很多,读者可以在 仓库中找到更复杂环境下的训练结果便于更好地进行对比。