# Copyright (c) Recommenders contributors.
# Licensed under the MIT License.

import tensorflow as tf
from tensorflow.compat.v1.nn import dynamic_rnn
from recommenders.models.deeprec.models.sequential.sequential_base_model import (
from recommenders.models.deeprec.models.sequential.sum_cells import (

[docs]class SUMModel(SequentialBaseModel): """Sequential User Matrix Model :Citation: Lian, J., Batal, I., Liu, Z., Soni, A., Kang, E. Y., Wang, Y., & Xie, X., "Multi-Interest-Aware User Modeling for Large-Scale Sequential Recommendations", arXiv preprint arXiv:2102.09211, 2021. """ def _build_seq_graph(self): """The main function to create SUM model. Returns: object: The output of SUM section, which is a concatenation of user vector and target item vector. """ hparams = self.hparams # noqa: F841 with tf.compat.v1.variable_scope("sum"): self.history_embedding = tf.concat( [self.item_history_embedding, self.cate_history_embedding], 2 ) cell = self._create_sumcell() self.cell = cell cell.model = self final_state = self._build_sum(cell) for _p in cell.parameter_set: tf.compat.v1.summary.histogram(, _p) if hasattr(cell, "_alpha") and hasattr(cell._alpha, "name"): tf.compat.v1.summary.histogram(, cell._alpha) if hasattr(cell, "_beta") and hasattr(cell._beta, "name"): tf.compat.v1.summary.histogram(, cell._beta) final_state, att_weights = self._attention_query_by_state( final_state, self.target_item_embedding ) model_output = tf.concat([final_state, self.target_item_embedding], 1) tf.compat.v1.summary.histogram("model_output", model_output) return model_output def _attention_query_by_state(self, seq_output, query): """Merge a user's memory states conditioned by a query item. Params: seq_output: A flatten representation of SUM memory states for (a batch of) users query: (a batch of) target item candidates Returns: tf.Tensor, tf.Tensor: Merged user representation. Attention weights of each memory channel. """ dim_q = query.shape[-1] att_weights = tf.constant(1.0, dtype=tf.float32) with tf.compat.v1.variable_scope("query_att"): if self.hparams.slots > 1: query_att_W = tf.compat.v1.get_variable( name="query_att_W", shape=[self.hidden_size, dim_q], initializer=self.initializer, ) # reshape the memory states to (BatchSize, Slots, HiddenSize) memory_state = tf.reshape( seq_output, [-1, self.hparams.slots, self.hidden_size] ) att_weights = tf.nn.softmax( tf.squeeze( tf.matmul( tf.tensordot(memory_state, query_att_W, axes=1), tf.expand_dims(query, -1), ), -1, ), -1, ) # merge the memory states, the final shape is (BatchSize, HiddenSize) att_res = tf.reduce_sum( input_tensor=memory_state * tf.expand_dims(att_weights, -1), axis=1 ) else: att_res = seq_output return att_res, att_weights def _create_sumcell(self): """Create a SUM cell Returns: object: An initialized SUM cell """ hparams = self.hparams input_embedding_dim = self.history_embedding.shape[-1] input_params = [ hparams.hidden_size * hparams.slots + input_embedding_dim, hparams.slots, hparams.attention_size, input_embedding_dim, ] sumcells = {"SUM": SUMCell, "SUMV2": SUMV2Cell} sumCell = sumcells[hparams.cell] res = None if hparams.cell in ["SUM", "SUMV2"]: res = sumCell(*input_params) else: raise ValueError("ERROR! Cell type not support: {0}".format(hparams.cell)) return res def _build_sum(self, cell): """Generate user memory states from behavior sequence Args: object: An initialied SUM cell. Returns: object: A flatten representation of user memory states, in the shape of (BatchSize, SlotsNum x HiddenSize) """ hparams = self.hparams with tf.compat.v1.variable_scope("sum"): self.mask = self.iterator.mask self.sequence_length = tf.reduce_sum(input_tensor=self.mask, axis=1) rum_outputs, final_state = dynamic_rnn( cell, inputs=self.history_embedding, dtype=tf.float32, sequence_length=self.sequence_length, scope="sum", initial_state=cell.zero_state( tf.shape(input=self.history_embedding)[0], tf.float32 ), ) final_state = final_state[:, : hparams.slots * hparams.hidden_size] self.heads = cell.heads self.alpha = cell._alpha self.beta = cell._beta tf.compat.v1.summary.histogram("SUM_outputs", rum_outputs) return final_state