跳到主要内容

SAC 算法实战

在实战中,我们主要讲解 算法的第二个版本,即自动调节温度因子的版本。该版本的如图 所示,整个训练过程相对来说还是比较简洁的,只是需要额外定义一些网络,比如用来调节温度因子等。

算法伪代码

定义模型

首先我们定义 ,即值网络和策略网络,跟 算法其实是一样的,如代码清单 所示。

代码清单 网络
class ValueNet(nn.Module):
def __init__(self, state_dim, hidden_dim, init_w=3e-3):
super(ValueNet, self).__init__()
'''定义值网络
'''
self.linear1 = nn.Linear(state_dim, hidden_dim) # 输入层
self.linear2 = nn.Linear(hidden_dim, hidden_dim) # 隐藏层
self.linear3 = nn.Linear(hidden_dim, 1)

self.linear3.weight.data.uniform_(-init_w, init_w) # 初始化权重
self.linear3.bias.data.uniform_(-init_w, init_w)

def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x
class PolicyNet(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3, log_std_min=-20, log_std_max=2):
super(PolicyNet, self).__init__()
self.log_std_min = log_std_min
self.log_std_max = log_std_max

self.linear1 = nn.Linear(state_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)

# 初始化权重
self.mean_linear = nn.Linear(hidden_dim, action_dim)
self.mean_linear.weight.data.uniform_(-init_w, init_w)
self.mean_linear.bias.data.uniform_(-init_w, init_w)

self.log_std_linear = nn.Linear(hidden_dim, action_dim)
self.log_std_linear.weight.data.uniform_(-init_w, init_w)
self.log_std_linear.bias.data.uniform_(-init_w, init_w)

def forward(self, state):
x = F.relu(self.linear1(state))
x = F.relu(self.linear2(x))

mean = self.mean_linear(x)
log_std = self.log_std_linear(x)
log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max)

return mean, log_std

def evaluate(self, state, epsilon=1e-6):
mean, log_std = self.forward(state)
std = log_std.exp()
# 计算动作
normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)
# 计算动作概率
log_prob = normal.log_prob(z) - torch.log(1 - action.pow(2) + epsilon)
log_prob = log_prob.sum(-1, keepdim=True)

return action, log_prob, z, mean, log_std

def get_action(self, state):
state = torch.FloatTensor(state).unsqueeze(0)
mean, log_std = self.forward(state)
std = log_std.exp()

normal = Normal(mean, std)
z = normal.sample()
action = torch.tanh(z)

action = action.detach().cpu().numpy()
return action[0]

然后再额外定义一个 网络,如代码清单 所示。

代码清单 网络
class SoftQNet(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim, init_w=3e-3):
super(SoftQNet, self).__init__()
'''定义Q网络,state_dim, action_dim, hidden_dim, init_w分别为状态维度、动作维度隐藏层维度和初始化权重
'''
self.linear1 = nn.Linear(state_dim + action_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.linear3 = nn.Linear(hidden_dim, 1)

self.linear3.weight.data.uniform_(-init_w, init_w)
self.linear3.bias.data.uniform_(-init_w, init_w)

def forward(self, state, action):
x = torch.cat([state, action], 1)
x = F.relu(self.linear1(x))
x = F.relu(self.linear2(x))
x = self.linear3(x)
return x