Source code for recommenders.models.deeprec.models.sequential.sum_cells

# Copyright (c) Recommenders contributors.
# Licensed under the MIT License.

import numpy as np
import tensorflow as tf
from keras.layers.legacy_rnn.rnn_cell_impl import LayerRNNCell
from tensorflow.python.eager import context
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.ops import init_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.util import nest


_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"


[docs]class SUMCell(LayerRNNCell): """Cell for Sequential User Matrix"""
[docs] def __init__( self, num_units, slots, attention_size, input_size, activation=None, reuse=None, kernel_initializer=None, bias_initializer=None, name=None, dtype=None, **kwargs ): super(SUMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype, **kwargs) _check_supported_dtypes(self.dtype) if context.executing_eagerly() and context.num_gpus() > 0: logging.warn( "%s: Note that this cell is not optimized for performance. " "Please use keras.layers.cudnn_recurrent.CuDNNGRU for better " "performance on GPU.", self, ) self._input_size = input_size self._slots = slots - 1 # the last channel is reserved for the highway slot self._num_units = num_units self._real_units = (self._num_units - input_size) // slots if activation: self._activation = activations.get(activation) else: self._activation = math_ops.tanh self._kernel_initializer = initializers.get(kernel_initializer) self._bias_initializer = initializers.get(bias_initializer)
@property def state_size(self): return self._num_units @property def output_size(self): return self._num_units def _basic_build(self, inputs_shape): """Common initialization operations for SUM cell and its variants. This function creates parameters for the cell. """ d = inputs_shape[-1] h = self._real_units s = self._slots self._erase_W = self.add_variable( name="_erase_W", shape=[d + h, h], initializer=self._kernel_initializer ) self._erase_b = self.add_variable( name="_erase_b", shape=[h], initializer=( self._bias_initializer if self._bias_initializer is not None else init_ops.constant_initializer(1.0, dtype=self.dtype) ), ) self._reset_W = self.add_variable( name="_reset_W", shape=[d + h, 1], initializer=self._kernel_initializer ) self._reset_b = self.add_variable( name="_reset_b", shape=[1], initializer=( self._bias_initializer if self._bias_initializer is not None else init_ops.constant_initializer(1.0, dtype=self.dtype) ), ) self._add_W = self.add_variable( name="_add_W", shape=[d + h, h], initializer=self._kernel_initializer ) self._add_b = self.add_variable( name="_add_b", shape=[h], initializer=( self._bias_initializer if self._bias_initializer is not None else init_ops.constant_initializer(1.0, dtype=self.dtype) ), ) self.heads = self.add_variable( name="_heads", shape=[s, d], initializer=self._kernel_initializer ) self._beta = self.add_variable( name="_beta_no_reg", shape=(), initializer=tf.compat.v1.constant_initializer( np.array([1.02]), dtype=np.float32 ), ) self._alpha = self.add_variable( name="_alpha_no_reg", shape=(), initializer=tf.compat.v1.constant_initializer( np.array([0.98]), dtype=np.float32 ), ) @tf_utils.shape_type_conversion def build(self, inputs_shape): """Initialization operations for SUM cell. this function creates all the parameters for the cell. """ if inputs_shape[-1] is None: raise ValueError( "Expected inputs.shape[-1] to be known, saw shape: %s" % str(inputs_shape) ) _check_supported_dtypes(self.dtype) d = inputs_shape[-1] # noqa: F841 h = self._real_units # noqa: F841 s = self._slots # noqa: F841 self._basic_build(inputs_shape) self.parameter_set = [ self._erase_W, self._erase_b, self._reset_W, self._reset_b, self._add_W, self._add_b, self.heads, ] self.built = True
[docs] def call(self, inputs, state): """The real operations for SUM cell to process user behaviors. params: inputs: (a batch of) user behaviors at time T state: (a batch of) user states at time T-1 returns: state, state: - after process the user behavior at time T, returns (a batch of) new user states at time T - after process the user behavior at time T, returns (a batch of) new user states at time T """ _check_rnn_cell_input_dtypes([inputs, state]) h = self._real_units s = self._slots + 1 state, last = state[:, : s * h], state[:, s * h :] state = tf.reshape(state, [-1, s, h]) att_logit_mat = tf.matmul(inputs, self.heads, transpose_b=True) att_weights = tf.nn.softmax(self._beta * att_logit_mat, axis=-1) att_weights = tf.expand_dims(att_weights, 2) h_hat = tf.reduce_sum( input_tensor=tf.multiply(state[:, : self._slots, :], att_weights), axis=1 ) h_hat = (h_hat + state[:, self._slots, :]) / 2 n_a, n_b = tf.nn.l2_normalize(last, 1), tf.nn.l2_normalize(inputs, 1) dist = tf.expand_dims(tf.reduce_sum(input_tensor=n_a * n_b, axis=1), 1) dist = tf.math.pow(self._alpha, dist) att_weights = att_weights * tf.expand_dims(dist, 1) reset = tf.sigmoid( tf.compat.v1.nn.xw_plus_b( tf.concat([inputs, h_hat], axis=-1), self._reset_W, self._reset_b ) ) erase = tf.sigmoid( tf.compat.v1.nn.xw_plus_b( tf.concat([inputs, h_hat], axis=-1), self._erase_W, self._erase_b ) ) add = tf.tanh( tf.compat.v1.nn.xw_plus_b( tf.concat([inputs, reset * h_hat], axis=-1), self._add_W, self._add_b ) ) start_part01 = state[:, : self._slots, :] state01 = start_part01 * ( tf.ones_like(start_part01) - att_weights * tf.expand_dims(erase, 1) ) state01 = state01 + att_weights * tf.expand_dims(erase, 1) * tf.expand_dims( add, 1 ) state01 = tf.reshape(state01, [-1, self._slots * self._real_units]) start_part02 = state[:, self._slots, :] state02 = start_part02 * (tf.ones_like(start_part02) - dist * erase) state02 = state02 + dist * erase * add state = tf.concat([state01, state02, inputs], axis=-1) return state, state
[docs] def get_config(self): config = { "num_units": self._num_units, "kernel_initializer": initializers.serialize(self._kernel_initializer), "bias_initializer": initializers.serialize(self._bias_initializer), "activation": activations.serialize(self._activation), "reuse": self._reuse, } base_config = super(SUMCell, self).get_config() return dict(list(base_config.items()) + list(config.items()))
[docs]class SUMV2Cell(SUMCell): """A variant of SUM cell, which upgrades the writing attention""" @tf_utils.shape_type_conversion def build(self, inputs_shape): """Initialization operations for SUMV2 cell. this function creates all the parameters for the cell. """ if inputs_shape[-1] is None: raise ValueError( "Expected inputs.shape[-1] to be known, saw shape: %s" % str(inputs_shape) ) _check_supported_dtypes(self.dtype) d = inputs_shape[-1] h = self._real_units s = self._slots self._basic_build(inputs_shape) self._writing_W = self.add_variable( name="_writing_W", shape=[d + h, h], initializer=self._kernel_initializer ) self._writing_b = self.add_variable( name="_writing_b", shape=[h], initializer=( self._bias_initializer if self._bias_initializer is not None else init_ops.constant_initializer(1.0, dtype=self.dtype) ), ) self._writing_W02 = self.add_variable( name="_writing_W02", shape=[h, s], initializer=self._kernel_initializer ) self.parameter_set = [ self._erase_W, self._erase_b, self._reset_W, self._reset_b, self._add_W, self._add_b, self.heads, self._writing_W, self._writing_W02, self._writing_b, ] self.built = True
[docs] def call(self, inputs, state): """The real operations for SUMV2 cell to process user behaviors. Args: inputs: (a batch of) user behaviors at time T state: (a batch of) user states at time T-1 Returns: state: after process the user behavior at time T, returns (a batch of) new user states at time T state: after process the user behavior at time T, returns (a batch of) new user states at time T """ _check_rnn_cell_input_dtypes([inputs, state]) h = self._real_units s = self._slots + 1 state, last = state[:, : s * h], state[:, s * h :] state = tf.reshape(state, [-1, s, h]) att_logit_mat = tf.matmul(inputs, self.heads, transpose_b=True) att_weights = tf.nn.softmax(self._beta * att_logit_mat, axis=-1) att_weights = tf.expand_dims(att_weights, 2) h_hat = tf.reduce_sum( input_tensor=tf.multiply(state[:, : self._slots, :], att_weights), axis=1 ) h_hat = (h_hat + state[:, self._slots, :]) / 2 # get the true writing attentions writing_input = tf.concat([inputs, h_hat], axis=1) att_weights = tf.compat.v1.nn.xw_plus_b( writing_input, self._writing_W, self._writing_b ) att_weights = tf.nn.relu(att_weights) att_weights = tf.matmul(att_weights, self._writing_W02) att_weights = tf.nn.softmax(att_weights, axis=-1) att_weights = tf.expand_dims(att_weights, 2) n_a, n_b = tf.nn.l2_normalize(last, 1), tf.nn.l2_normalize(inputs, 1) dist = tf.expand_dims(tf.reduce_sum(input_tensor=n_a * n_b, axis=1), 1) dist = tf.math.pow(self._alpha, dist) att_weights = att_weights * tf.expand_dims(dist, 1) reset = tf.sigmoid( tf.compat.v1.nn.xw_plus_b( tf.concat([inputs, h_hat], axis=-1), self._reset_W, self._reset_b ) ) erase = tf.sigmoid( tf.compat.v1.nn.xw_plus_b( tf.concat([inputs, h_hat], axis=-1), self._erase_W, self._erase_b ) ) add = tf.tanh( tf.compat.v1.nn.xw_plus_b( tf.concat([inputs, reset * h_hat], axis=-1), self._add_W, self._add_b ) ) start_part01 = state[:, : self._slots, :] state01 = start_part01 * ( tf.ones_like(start_part01) - att_weights * tf.expand_dims(erase, 1) ) state01 = state01 + att_weights * tf.expand_dims(erase, 1) * tf.expand_dims( add, 1 ) state01 = tf.reshape(state01, [-1, self._slots * self._real_units]) start_part02 = state[:, self._slots, :] state02 = start_part02 * (tf.ones_like(start_part02) - dist * erase) state02 = state02 + dist * erase * add state = tf.concat([state01, state02, inputs], axis=-1) return state, state
def _check_rnn_cell_input_dtypes(inputs): for t in nest.flatten(inputs): _check_supported_dtypes(t.dtype) def _check_supported_dtypes(dtype): if dtype is None: return dtype = dtypes.as_dtype(dtype) if not (dtype.is_floating or dtype.is_complex): raise ValueError( "RNN cell only supports floating point inputs, " "but saw dtype: %s" % dtype )