# Copyright (c) Recommenders contributors.
# Licensed under the MIT License.
import tensorflow as tf
from recommenders.models.sasrec.model import SASREC, Encoder, LayerNormalization
[docs]class SSEPT(SASREC):
"""
SSE-PT Model
:Citation:
Wu L., Li S., Hsieh C-J., Sharpnack J., SSE-PT: Sequential Recommendation
Via Personalized Transformer, RecSys, 2020.
TF 1.x codebase: https://github.com/SSE-PT/SSE-PT
TF 2.x codebase (SASREc): https://github.com/nnkkmto/SASRec-tf2
"""
[docs] def __init__(self, **kwargs):
"""Model initialization.
Args:
item_num (int): Number of items in the dataset.
seq_max_len (int): Maximum number of items in user history.
num_blocks (int): Number of Transformer blocks to be used.
embedding_dim (int): Item embedding dimension.
attention_dim (int): Transformer attention dimension.
conv_dims (list): List of the dimensions of the Feedforward layer.
dropout_rate (float): Dropout rate.
l2_reg (float): Coefficient of the L2 regularization.
num_neg_test (int): Number of negative examples used in testing.
user_num (int): Number of users in the dataset.
user_embedding_dim (int): User embedding dimension.
item_embedding_dim (int): Item embedding dimension.
"""
super().__init__(**kwargs)
self.user_num = kwargs.get("user_num", None) # New
self.conv_dims = kwargs.get("conv_dims", [200, 200]) # modified
self.user_embedding_dim = kwargs.get(
"user_embedding_dim", self.embedding_dim
) # extra
self.item_embedding_dim = kwargs.get("item_embedding_dim", self.embedding_dim)
self.hidden_units = self.item_embedding_dim + self.user_embedding_dim
# New, user embedding
self.user_embedding_layer = tf.keras.layers.Embedding(
input_dim=self.user_num + 1,
output_dim=self.user_embedding_dim,
name="user_embeddings",
mask_zero=True,
input_length=1,
embeddings_regularizer=tf.keras.regularizers.L2(self.l2_reg),
)
self.positional_embedding_layer = tf.keras.layers.Embedding(
self.seq_max_len,
self.user_embedding_dim + self.item_embedding_dim, # difference
name="positional_embeddings",
mask_zero=False,
embeddings_regularizer=tf.keras.regularizers.L2(self.l2_reg),
)
self.dropout_layer = tf.keras.layers.Dropout(self.dropout_rate)
self.encoder = Encoder(
self.num_blocks,
self.seq_max_len,
self.hidden_units,
self.hidden_units,
self.attention_num_heads,
self.conv_dims,
self.dropout_rate,
)
self.mask_layer = tf.keras.layers.Masking(mask_value=0)
self.layer_normalization = LayerNormalization(
self.seq_max_len, self.hidden_units, 1e-08
)
[docs] def call(self, x, training):
"""Model forward pass.
Args:
x (tf.Tensor): Input tensor.
training (tf.Tensor): Training tensor.
Returns:
tf.Tensor, tf.Tensor, tf.Tensor:
- Logits of the positive examples.
- Logits of the negative examples.
- Mask for nonzero targets
"""
users = x["users"]
input_seq = x["input_seq"]
pos = x["positive"]
neg = x["negative"]
mask = tf.expand_dims(tf.cast(tf.not_equal(input_seq, 0), tf.float32), -1)
seq_embeddings, positional_embeddings = self.embedding(input_seq)
# User Encoding
# u0_latent = self.user_embedding_layer(users[0])
# u0_latent = u0_latent * (self.embedding_dim ** 0.5)
u_latent = self.user_embedding_layer(users)
u_latent = u_latent * (self.user_embedding_dim**0.5) # (b, 1, h)
# return users
# replicate the user embedding for all the items
u_latent = tf.tile(u_latent, [1, tf.shape(input_seq)[1], 1]) # (b, s, h)
seq_embeddings = tf.reshape(
tf.concat([seq_embeddings, u_latent], 2),
[tf.shape(input_seq)[0], -1, self.hidden_units],
)
seq_embeddings += positional_embeddings
# dropout
seq_embeddings = self.dropout_layer(seq_embeddings, training=training)
# masking
seq_embeddings *= mask
# --- ATTENTION BLOCKS ---
seq_attention = seq_embeddings # (b, s, h1 + h2)
seq_attention = self.encoder(seq_attention, training, mask)
seq_attention = self.layer_normalization(seq_attention) # (b, s, h1+h2)
# --- PREDICTION LAYER ---
# user's sequence embedding
pos = self.mask_layer(pos)
neg = self.mask_layer(neg)
user_emb = tf.reshape(
u_latent,
[tf.shape(input_seq)[0] * self.seq_max_len, self.user_embedding_dim],
)
pos = tf.reshape(pos, [tf.shape(input_seq)[0] * self.seq_max_len])
neg = tf.reshape(neg, [tf.shape(input_seq)[0] * self.seq_max_len])
pos_emb = self.item_embedding_layer(pos)
neg_emb = self.item_embedding_layer(neg)
# Add user embeddings
pos_emb = tf.reshape(tf.concat([pos_emb, user_emb], 1), [-1, self.hidden_units])
neg_emb = tf.reshape(tf.concat([neg_emb, user_emb], 1), [-1, self.hidden_units])
seq_emb = tf.reshape(
seq_attention,
[tf.shape(input_seq)[0] * self.seq_max_len, self.hidden_units],
) # (b*s, d)
pos_logits = tf.reduce_sum(pos_emb * seq_emb, -1)
neg_logits = tf.reduce_sum(neg_emb * seq_emb, -1)
pos_logits = tf.expand_dims(pos_logits, axis=-1) # (bs, 1)
# pos_prob = tf.keras.layers.Dense(1, activation='sigmoid')(pos_logits) # (bs, 1)
neg_logits = tf.expand_dims(neg_logits, axis=-1) # (bs, 1)
# neg_prob = tf.keras.layers.Dense(1, activation='sigmoid')(neg_logits) # (bs, 1)
# output = tf.concat([pos_logits, neg_logits], axis=0)
# masking for loss calculation
istarget = tf.reshape(
tf.cast(tf.not_equal(pos, 0), dtype=tf.float32),
[tf.shape(input_seq)[0] * self.seq_max_len],
)
return pos_logits, neg_logits, istarget
[docs] def predict(self, inputs):
"""
Model prediction for candidate (negative) items
"""
training = False
user = inputs["user"]
input_seq = inputs["input_seq"]
candidate = inputs["candidate"]
mask = tf.expand_dims(tf.cast(tf.not_equal(input_seq, 0), tf.float32), -1)
seq_embeddings, positional_embeddings = self.embedding(input_seq) # (1, s, h)
u0_latent = self.user_embedding_layer(user)
u0_latent = u0_latent * (self.user_embedding_dim**0.5) # (1, 1, h)
u0_latent = tf.squeeze(u0_latent, axis=0) # (1, h)
test_user_emb = tf.tile(u0_latent, [1 + self.num_neg_test, 1]) # (101, h)
u_latent = self.user_embedding_layer(user)
u_latent = u_latent * (self.user_embedding_dim**0.5) # (b, 1, h)
u_latent = tf.tile(u_latent, [1, tf.shape(input_seq)[1], 1]) # (b, s, h)
seq_embeddings = tf.reshape(
tf.concat([seq_embeddings, u_latent], 2),
[tf.shape(input_seq)[0], -1, self.hidden_units],
)
seq_embeddings += positional_embeddings # (b, s, h1 + h2)
seq_embeddings *= mask
seq_attention = seq_embeddings
seq_attention = self.encoder(seq_attention, training, mask)
seq_attention = self.layer_normalization(seq_attention) # (b, s, h1+h2)
seq_emb = tf.reshape(
seq_attention,
[tf.shape(input_seq)[0] * self.seq_max_len, self.hidden_units],
) # (b*s1, h1+h2)
candidate_emb = self.item_embedding_layer(candidate) # (b, s2, h2)
candidate_emb = tf.squeeze(candidate_emb, axis=0) # (s2, h2)
candidate_emb = tf.reshape(
tf.concat([candidate_emb, test_user_emb], 1), [-1, self.hidden_units]
) # (b*s2, h1+h2)
candidate_emb = tf.transpose(candidate_emb, perm=[1, 0]) # (h1+h2, b*s2)
test_logits = tf.matmul(seq_emb, candidate_emb) # (b*s1, b*s2)
test_logits = tf.reshape(
test_logits,
[tf.shape(input_seq)[0], self.seq_max_len, 1 + self.num_neg_test],
) # (1, s, 101)
test_logits = test_logits[:, -1, :] # (1, 101)
return test_logits
[docs] def loss_function(self, pos_logits, neg_logits, istarget):
"""Losses are calculated separately for the positive and negative
items based on the corresponding logits. A mask is included to
take care of the zero items (added for padding).
Args:
pos_logits (tf.Tensor): Logits of the positive examples.
neg_logits (tf.Tensor): Logits of the negative examples.
istarget (tf.Tensor): Mask for nonzero targets.
Returns:
float: Loss.
"""
pos_logits = pos_logits[:, 0]
neg_logits = neg_logits[:, 0]
# ignore padding items (0)
# istarget = tf.reshape(
# tf.cast(tf.not_equal(self.pos, 0), dtype=tf.float32),
# [tf.shape(self.input_seq)[0] * self.seq_max_len],
# )
# for logits
loss = tf.reduce_sum(
-tf.math.log(tf.math.sigmoid(pos_logits) + 1e-24) * istarget
- tf.math.log(1 - tf.math.sigmoid(neg_logits) + 1e-24) * istarget
) / tf.reduce_sum(istarget)
# for probabilities
# loss = tf.reduce_sum(
# - tf.math.log(pos_logits + 1e-24) * istarget -
# tf.math.log(1 - neg_logits + 1e-24) * istarget
# ) / tf.reduce_sum(istarget)
reg_loss = tf.compat.v1.losses.get_regularization_loss()
# reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
# loss += sum(reg_losses)
loss += reg_loss
return loss