Source code for recommenders.models.deeprec.models.sequential.gru4rec

# Copyright (c) Recommenders contributors.
# Licensed under the MIT License.

import tensorflow as tf
from keras.layers.legacy_rnn.rnn_cell_impl import GRUCell, LSTMCell
from recommenders.models.deeprec.models.sequential.sequential_base_model import (
    SequentialBaseModel,
)
from tensorflow.compat.v1.nn import dynamic_rnn

__all__ = ["GRU4RecModel"]


[docs]class GRU4RecModel(SequentialBaseModel): """GRU4Rec Model :Citation: B. Hidasi, A. Karatzoglou, L. Baltrunas, D. Tikk, "Session-based Recommendations with Recurrent Neural Networks", ICLR (Poster), 2016. """ def _build_seq_graph(self): """The main function to create GRU4Rec model. Returns: object:the output of GRU4Rec section. """ with tf.compat.v1.variable_scope("gru4rec"): # final_state = self._build_lstm() final_state = self._build_gru() model_output = tf.concat([final_state, self.target_item_embedding], 1) tf.compat.v1.summary.histogram("model_output", model_output) return model_output def _build_lstm(self): """Apply an LSTM for modeling. Returns: object: The output of LSTM section. """ with tf.compat.v1.name_scope("lstm"): self.mask = self.iterator.mask self.sequence_length = tf.reduce_sum(input_tensor=self.mask, axis=1) self.history_embedding = tf.concat( [self.item_history_embedding, self.cate_history_embedding], 2 ) rnn_outputs, final_state = dynamic_rnn( LSTMCell(self.hidden_size), inputs=self.history_embedding, sequence_length=self.sequence_length, dtype=tf.float32, scope="lstm", ) tf.compat.v1.summary.histogram("LSTM_outputs", rnn_outputs) return final_state[1] def _build_gru(self): """Apply a GRU for modeling. Returns: object: The output of GRU section. """ with tf.compat.v1.name_scope("gru"): self.mask = self.iterator.mask self.sequence_length = tf.reduce_sum(input_tensor=self.mask, axis=1) self.history_embedding = tf.concat( [self.item_history_embedding, self.cate_history_embedding], 2 ) rnn_outputs, final_state = dynamic_rnn( GRUCell(self.hidden_size), inputs=self.history_embedding, sequence_length=self.sequence_length, dtype=tf.float32, scope="gru", ) tf.compat.v1.summary.histogram("GRU_outputs", rnn_outputs) return final_state