.. _multi_tower: 多塔结构 ======== 在多目标建模领域,如 MMoE 所展现的那样,专家网络(Expert)承担着挖掘不同任务之间共享的底层特征表示的重任,而门控网络(Gate)则灵活地动态分配专家权重,依据不同任务的特性需求进行精准适配。这种由 “共享专家 + 任务专属门控” 构成的架构,与生俱来地具备处理共性(共享专家所提取的通用特征)与特性(门控网络赋予的特定权重)的卓越能力。 多场景建模与多任务学习类似但关注点不同:多任务学习处理相同场景/分布下的不同任务(如单样本同时预估CTR、CVR),而多场景建模处理不同场景/分布下的相同任务(如不同场景预估相同CTR)。前者是对于一条样本预估多个不同的目标值,后者是对于不同的样本预估相同的目标值。多场景建模若采用独立模型,会忽视场景共性,导致小场景效果差且资源消耗剧增;若混合样本训练单一模型,则会忽视场景差异,降低预测精度。 .. figure:: ../../img/star_1.png 多目标与多场景建模的差异(图片来自阿里妈妈博客) 本小节将会介绍基于多塔结构建模时,在利用多场景共性的前提下,显示的使用不同场景的信号来捕捉场景的特性。 HMoE ---- 在多任务建模小节中,介绍了MMoE(Mixture-of-Experts)底层通过多专家网络作为多任务的共享特征,顶层对于不同的任务使用门控机制融合专家特征实现不同任务差异化的学习。在多场景建模中HMoE借鉴了MMoE的思路,底层同样适用多专家网络提取提取多个场景的特征作为共享特征,只不过顶层的多个塔不再是多个任务的输出,而是多个场景的输出,HMoE模型结构如下: .. figure:: ../../img/hmoe.png HMoE模型结构 模型的底层使用多个专家抽取多个场景的特征,并通过一组门控网络将多个专家的输出结果进行融合,最后输入给上层不同的场景塔。 .. math:: M(x) = \sum_{i=1}^{K} G_i(x) E_i(x) 原论文中是对于所有场景的塔都使用同一组门控融合后的专家特征,这种方式可以看成是多任务建模中的Shared-Bottom式的特征共享,只不过以多个FCN的融合输出替代了单个FCN的输出。从MMoE的经验来看,如果多个任务之间的相关性较差,底层这种特征硬共享可能会出现负迁移的现象。所以这种方式也不一定就是多场景建模的最优方案,也可以尝试对于不同的场景,使用不同门控融合后的专家特征。如第\ :math:`t`\ 个场景的输入特征表示为\ :math:`M_t(x) = \sum_{i=1}^{K} G_i^t(x) E_i(x)`\ ,最终哪种效果更好可以根据自己的场景做实验得到。 在得到了底层多场景特征之后,模型单场景的最终预估值不是简单的直接使用对应场景Tower打分,而是将多个场景输出打分融合为单个场景的打分。第\ :math:`t`\ 个场景的模型打分表示如下: .. math:: out_t = \sum_{i=1}^{T} W_i(x) S_i(x) 其中\ :math:`W_i(x)`\ 是场景\ :math:`i`\ 的融合权重,原论文中对于不同的场景下打分融合的\ :math:`W`\ 是否共享也未明确说明,但可以根据MMoE的思路,给每个场景都学习一个融合的权重,即第\ :math:`t`\ 个场景的预估值可以表示为:\ :math:`out_t = \sum_{i=1}^{T} W_i^t(x) S_i(x)` 从最终单场景由多个场景打分融合可以看出,对于某个场景\ :math:`t`\ 的样本,HMoE不仅需要计算它在场景\ :math:`t`\ 下的打分,还需要计算它在场景下的打分,计算场景\ :math:`t`\ 最终的打分时,其他场景的打分对\ :math:`t`\ 场景也是有参考价值的。 虽然在前向推理时可以将一条样本预估出不同场景的打分,但是对于某个场景\ :math:`t`\ 的样本来说应该只影响当前场景的参数(主要是场景塔的参数),否则\ :math:`a`\ 场景下的样本直接影响\ :math:`b`\ 场景的参数,很容易导致模型对于场景的感知下降,进而让整个多场景的模型效果变差。因此在计算融合打分时候,需要抑制其他场景打分的梯度回传,最终场景\ :math:`t`\ 的打分表示如下 .. math:: out_t(x) = W_t(x) S_t(x) + \sum_{j=1, j \neq t}^{T} W_j(x) \underbrace{S_j(x)}_{\text{stop gradient}} 不共享融合权重的打分公式为:\ :math:`out_t(x) = W_t^t(x) S_t(x) + \sum_{j=1, j \neq t}^{T} W_j^t(x) \underbrace{S_j(x)}_{\text{stop gradient}}` HMoE核心代码如下,其中包括了是否共享门控和融合打分权重的部分。 .. raw:: latex \diilbookstyleinputcell .. code:: python def build_hmoe_model(feature_columns, num_domains, domain_feature_name, share_gate=False, share_domain_w=False, shared_expert_nums=5, shared_expert_dnn_units=[256,128], gate_dnn_units=[256,128], domain_tower_units=[128,64], domain_weight_units=[128,64], linear_logits=False ): # 构建输入层字典 input_layer_dict = build_input_layer(feature_columns) domain_input = input_layer_dict[domain_feature_name] # 构建特征嵌入表字典 group_embedding_feature_dict = build_group_feature_embedding_table_dict(feature_columns, input_layer_dict, prefix="embedding/") # 连接不同组的嵌入向量作为各个网络的输入 dnn_inputs = concat_group_embedding(group_embedding_feature_dict, 'dnn') # 创建多个专家 expert_output_list = [] for i in range(shared_expert_nums): expert_output = DNNs(shared_expert_dnn_units, name=f"expert_{str(i)}")(dnn_inputs) expert_output_list.append(expert_output) expert_concat = tf.keras.layers.Lambda(lambda x: tf.stack(x, axis=1))(expert_output_list) # (None, expert_num, dims) if share_gate: # 共享Gate domain_tower_input_list = [] gate_output = DNNs(gate_dnn_units, name=f"shared_gates")(dnn_inputs) gate_output = tf.keras.layers.Dense(shared_expert_nums, use_bias=False, activation='softmax', name=f"domain_{i}_softmax")(gate_output) gate_output = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(gate_output) # (None,expert_num, 1) gate_expert_output = tf.keras.layers.Lambda(lambda x: x[0] * x[1])([gate_output, expert_concat]) gate_expert_output = tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=1, keepdims=False))(gate_expert_output) for _ in range(num_domains): domain_tower_input_list.append(gate_expert_output) else: domain_tower_input_list = [] for i in range(num_domains): gate_output = DNNs(gate_dnn_units, name=f"domain_{str(i)}_gates")(dnn_inputs) gate_output = tf.keras.layers.Dense(shared_expert_nums, use_bias=False, activation='softmax', name=f"domain_{i}_softmax")(gate_output) gate_output = tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, axis=-1))(gate_output) # (None,expert_num, 1) gate_expert_output = tf.keras.layers.Lambda(lambda x: x[0] * x[1])([gate_output, expert_concat]) gate_expert_output = tf.keras.layers.Lambda(lambda x: tf.reduce_sum(x, axis=1, keepdims=False))(gate_expert_output) for _ in range(num_domains): domain_tower_input_list.append(gate_expert_output) # 定义domain tower domain_tower_output_list = [] for i in range(num_domains): domain_dnn_input = domain_tower_input_list[i] task_output = DNNs(domain_tower_units)(domain_dnn_input) domain_tower_output_list.append(task_output) # 定义domain权重 domain_weight_list = [] if share_domain_w: # 共享domain权重 domain_weight = DNNs(domain_weight_units)(dnn_inputs) for i in range(num_domains): domain_weight_list.append(domain_weight) else: for i in range(num_domains): domain_weight = DNNs(domain_weight_units)(dnn_inputs) domain_weight = tf.keras.layers.Lambda(lambda x: tf.nn.softmax(x, axis=1))(domain_weight) domain_weight_list.append(domain_weight) # 融合domain信息 domain_output_list = [] for i in range(num_domains): domain_weight = domain_weight_list[i] domain_tower_output = domain_tower_output_list[i] weighted_output = tf.keras.layers.Lambda(lambda x: x[0] * x[1])([domain_weight, domain_tower_output]) for j in range(num_domains): if i == j: continue grad_output = tf.keras.layers.Lambda(lambda x: tf.stop_gradient(x))(domain_tower_output_list[j]) weighted_output = tf.keras.layers.Add()([ weighted_output, tf.keras.layers.Multiply()([domain_weight_list[i][:, j:j+1], grad_output]) ]) dummy_domain = tf.keras.layers.Lambda(lambda x: tf.ones_like(x[0]) * tf.cast(x[1], tf.int32))([domain_input, i]) domain_mask = tf.keras.layers.Lambda(lambda x: tf.squeeze(tf.equal(x[0], x[1]), axis=-1))([domain_input, dummy_domain]) domain_output = tf.keras.layers.Lambda(lambda x: tf.boolean_mask(x[0], x[1]))([weighted_output, domain_mask]) domain_output_list.append(domain_output) # 将所有domain的数据拼接成batch final_domain_output = tf.keras.layers.Concatenate(axis=0)(domain_output_list) dnn_logits = PredictLayer(activation=None, name="dnn_logits")(final_domain_output) if linear_logits: linear_logits = get_linear_logits(input_layer_dict, feature_columns) final_logits = dnn_logits + linear_logits else: final_logits = dnn_logits # 构建模型 model = tf.keras.Model(inputs=list(input_layer_dict.values()), outputs=final_logits) return model 完整的实践流程: **1. 导入相关代码包** .. raw:: latex \diilbookstyleinputcell .. code:: python import sys import funrec from funrec.utils import build_metrics_table **2. 特征处理** .. raw:: latex \diilbookstyleinputcell .. code:: python config = funrec.load_config('hmoe') train_data, test_data = funrec.load_data(config.data) feature_columns, processed_data = funrec.prepare_features(config.features, train_data, test_data) **3. 模型定义及训练** .. raw:: latex \diilbookstyleinputcell .. code:: python model = funrec.train_model(config.training, feature_columns, processed_data) **4. 模型效果评估** .. raw:: latex \diilbookstyleinputcell .. code:: python metrics = funrec.evaluate_model(model, processed_data, config.evaluation, feature_columns) print(build_metrics_table(metrics)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output +--------+--------+---------------+ | auc | gauc | valid_users | +========+========+===============+ | 0.5865 | 0.5491 | 217 | +--------+--------+---------------+ STAR ---- STAR(Star Topology Adaptive Recommender)模型采用星型拓扑结构,实现场景私有参数和场景共享参数同时建模场景差异性和共性。场景私有参数以及场景共享参数最终聚合得到每个场景的模型。STAR结构如下图所示。 .. figure:: ../../img/star_2.png STAR模型结构 相比于单场景的模型,STAR有三个针对多场景建模的创新思路值得学习,分别是星型拓扑结构的全连接网络(STAR Topology Fully-Connected Network),Partitioned Normalization 以及辅助网络,下面将以此进行介绍。 **STAR Topology Fully-Connected Network** 星形拓扑全连接结构的核心思想是对于每一个全连接网络(FCN)都有场景共享和场景独占的部分,每个场景最终的参数由共享和独占参数通过element-wise product融合计算得到。 .. figure:: ../../img/star_fcn.png STAR FCN结构 具体而言,对于第\ :math:`p`\ 个场景的FCN的最终参数\ :math:`W_p^{\star},b_p^{\star}`\ 表示如下: .. math:: W_p^{\star} = W_p \otimes W \\ b_p^{\star} = b_p + b 其中\ :math:`W_p,W`\ 分别表示第\ :math:`p`\ 个场景独有和全场景共享的参数,\ :math:`b_p,b`\ 也一样。 如果用\ :math:`in_p`\ 表示第\ :math:`p`\ 个场景FCN的输入,则该层星形FCN的输出\ :math:`out_p`\ 表示为: .. math:: out_p = \phi((W_p^\star)^\top in_p + b_p^\star), 其中\ :math:`\phi`\ 是激活函数。 STAR Topology Fully-Connected Network的具体实现如下: .. raw:: latex \diilbookstyleinputcell .. code:: python class StarTopologyFCN(tf.keras.layers.Layer): def __init__(self, num_domain, hidden_units, activation="relu", dropout=0., l2_reg=0., **kwargs): self.num_domain = num_domain self.hidden_units = hidden_units self.activation_list = [tf.keras.layers.Activation(activation) for _ in hidden_units] self.dropout_list = [tf.keras.layers.Dropout(dropout) for _ in hidden_units] self.l2_reg = l2_reg super(StarTopologyFCN, self).__init__( **kwargs) def build(self, input_shape): input_shape = input_shape[0] self.shared_bias = [ self.add_weight( name=f"shared_bias_{i}", shape=[1, i], initializer=tf.keras.initializers.Zeros(), trainable=True ) for i in self.hidden_units ] self.domain_bias_list = [ tf.keras.layers.Embedding( self.num_domain, output_dim=i, embeddings_initializer=tf.keras.initializers.Zeros() ) for i in self.hidden_units ] hidden_units = self.hidden_units.copy() hidden_units.insert(0, input_shape[-1]) self.shared_weights = [ self.add_weight( name=f"shared_weight_{i}", shape=[1, hidden_units[i], hidden_units[i+1]], initializer="glorot_uniform", regularizer=tf.keras.regularizers.l2(self.l2_reg), trainable=True ) for i in range(len(hidden_units) - 1) ] self.domain_weights_list = [ tf.keras.layers.Embedding( self.num_domain, hidden_units[i] * hidden_units[i + 1], embeddings_initializer="glorot_uniform", embeddings_regularizer=tf.keras.regularizers.l2(self.l2_reg) ) for i in range(len(hidden_units) - 1) ] def call(self, inputs, training=None): inputs, domain_index = inputs output = tf.expand_dims(inputs, axis=1) for i in range(len(self.hidden_units)): domain_weight = tf.reshape(self.domain_weights_list[i](domain_index), [-1] + self.shared_weights[i].shape.as_list()[1:]) weight = self.shared_weights[i] * domain_weight domain_bias = tf.reshape(self.domain_bias_list[i](domain_index), [-1] + self.shared_bias[i].shape.as_list()[1:]) bias = self.shared_bias[i] + domain_bias fc = tf.matmul(output, weight) + tf.expand_dims(bias, 1) output = self.activation_list[i](fc, training=training) output = self.dropout_list[i](output, training=training) return tf.squeeze(output, axis=1) **Partitioned Normalization** 在神经网络训练时,为了加快模型的收敛常会在模型中加入BN(Batch Normalization)。但是在多场景建模中,样本只在相同的场景内才满足独立同分布,多个场景混合的样本得到的统计量会忽略了不同场景独有的分布差异。为此应该让多场景中不同的场景独享统计量,这就是PN(Partitioned Normalization)提出的主要动机。 在介绍PN之前,先简单回顾一下经典的BN的原理: .. math:: \mathbf{z'} = \gamma \frac{\mathbf{z} - \mathbf{E}}{\sqrt{\mathrm{Var} + \epsilon}} + \beta 其中\ :math:`\mathbf{E},\mathrm{Var}`\ 分别是移动的均值和方差,\ :math:`\gamma,\beta`\ 是可学习的参数用来对数据进行缩放和平移。 PN相比BN来说,不仅可学习的缩放和平移参数包括场景共享和独占两部分的参数,统计的移动均值和方差也是在不同场景样本上得到的,具体表示如下: .. math:: \mathbf{z'} = (\gamma * \gamma_p) \frac{\mathbf{z} - \mathbf{E_p}}{\sqrt{\mathrm{Var_p} + \epsilon}} + (\beta + \beta_p) 其中\ :math:`\gamma,\beta`\ 和\ :math:`\gamma_p,\beta_p`\ 分别表示场景共享和独占的参数,\ :math:`\mathbf{E_p},\mathrm{Var_p}`\ 表示在场景\ :math:`p`\ 的样本中统计得到的移动均值和方差。由于PN是基于Batch样本计算的,为了得到不同场景下更稳定的均值和方差,训练时的Batch Size可以调的稍微大一些。 Partitioned Normalization的具体实现如下: .. raw:: latex \diilbookstyleinputcell .. code:: python class PartitionedNormalization(tf.keras.layers.Layer): def __init__(self, num_domain, name=None, **kwargs): self.bn_list = [tf.keras.layers.BatchNormalization(center=False, scale=False, name=f"bn_{i}") for i in range(num_domain)] super(PartitionedNormalization, self).__init__(name=name) def build(self, input_shape): assert len(input_shape) == 2 and len(input_shape[1]) <= 2 dim = input_shape[0][-1] self.global_gamma = self.add_weight( name="global_gamma", shape=[dim], initializer=tf.keras.initializers.Constant(0.5), trainable=True ) self.global_beta = self.add_weight( name="global_beta", shape=[dim], initializer=tf.keras.initializers.Zeros(), trainable=True ) self.domain_gamma = self.add_weight( name="domain_gamma", shape=[len(self.bn_list), dim], initializer=tf.keras.initializers.Constant(0.5), trainable=True ) self.domain_beta = self.add_weight( name="domain_beta", shape=[len(self.bn_list), dim], initializer=tf.keras.initializers.Zeros(), trainable=True ) def generate_grid_tensor(self, indices, dim): y = tf.range(dim) x_grid, y_grid = tf.meshgrid(indices, y) return tf.transpose(tf.stack([x_grid, y_grid], axis=-1), [1, 0, 2]) def call(self, inputs, training=None): inputs, domain_index = inputs domain_index = tf.cast(tf.reshape(domain_index, [-1]), "int32") dim = inputs.shape.as_list()[-1] output = inputs # compute each domain's BN individually for i, bn in enumerate(self.bn_list): mask = tf.equal(domain_index, i) single_bn = self.bn_list[i](tf.boolean_mask(inputs, mask), training=training) single_bn = (self.global_gamma + self.domain_gamma[i]) * single_bn + (self.global_beta + self.domain_beta[i]) # get current domain samples' indices indices = tf.boolean_mask(tf.range(tf.shape(inputs)[0]), mask) indices = self.generate_grid_tensor(indices, dim) output = tf.cond( tf.reduce_any(mask), lambda: tf.reshape(tf.tensor_scatter_nd_update(output, indices, single_bn), [-1, dim]), lambda: output ) return output 为了进一步加强场景特征对模型输出的影响,在STAR中还会单独构建一个场景的辅助网络(Auxiliary Network),辅助网络将场景特征和其他特征共同输入到浅层网络中得到一个辅助的Logits,最终和主网络的Logits相加计算得到最终的CTR预估值: .. math:: pCTR = Sigmoid(Logits_{main} + Logits_{aux}) STAR模型的实现代码如下: .. raw:: latex \diilbookstyleinputcell .. code:: python def build_star_model( feature_columns, num_domains, domain_feature_name, star_dnn_units=[128, 64], aux_dnn_units=[128, 64], star_fcn_activation='relu', dropout=0.2, l2_reg=1e-5, linear_logits=False): # 构建输入层字典 input_layer_dict = build_input_layer(feature_columns) domain_input = input_layer_dict[domain_feature_name] # 构建特征嵌入表字典 group_embedding_feature_dict = build_group_feature_embedding_table_dict(feature_columns, input_layer_dict, prefix="embedding/") # 连接不同组的嵌入向量作为各个网络的输入 domain_embeddings = concat_group_embedding(group_embedding_feature_dict, 'domain') dnn_inputs = concat_group_embedding(group_embedding_feature_dict, 'dnn') fcn_inputs = PartitionedNormalization(num_domain=num_domains, name="fcn_pn_layer")([dnn_inputs, domain_input]) fcn_output = StarTopologyFCN(num_domains, star_dnn_units, star_fcn_activation, dropout, l2_reg, name="star_fcn_layer")([fcn_inputs, domain_input]) fcn_logit = PredictLayer(activation=None, name='fcn_logits')(fcn_output) aux_inputs = concat_func([domain_embeddings, dnn_inputs], axis=-1) aux_inputs = PartitionedNormalization(num_domain=num_domains, name="aux_pn_layer")([aux_inputs, domain_input]) aux_output = DNNs(aux_dnn_units, dropout_rate=dropout)(aux_inputs) aux_logit = PredictLayer(activation=None, name='aux_logits')(aux_output) if linear_logits: linear_logits = get_linear_logits(input_layer_dict, feature_columns) final_logits = add_tensor_func([linear_logits, fcn_logit, aux_logit]) else: final_logits = add_tensor_func([fcn_logit, aux_logit]) final_prediction = PredictLayer(activation=None, name='final_prediction')(final_logits) model = tf.keras.Model(inputs=list(input_layer_dict.values()), outputs=final_prediction) return model 完整的实践流程: **1. 导入相关代码包** .. raw:: latex \diilbookstyleinputcell .. code:: python import sys import funrec from funrec.utils import build_metrics_table **2. 特征处理** .. raw:: latex \diilbookstyleinputcell .. code:: python config = funrec.load_config('star') train_data, test_data = funrec.load_data(config.data) feature_columns, processed_data = funrec.prepare_features(config.features, train_data, test_data) **3. 模型定义及训练** .. raw:: latex \diilbookstyleinputcell .. code:: python model = funrec.train_model(config.training, feature_columns, processed_data) **4. 模型效果评估** .. raw:: latex \diilbookstyleinputcell .. code:: python metrics = funrec.evaluate_model(model, processed_data, config.evaluation, feature_columns) print(build_metrics_table(metrics)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output +--------+--------+---------------+ | auc | gauc | valid_users | +========+========+===============+ | 0.6391 | 0.6144 | 693 | +--------+--------+---------------+