Source code for recommenders.models.sasrec.util

# Copyright (c) Recommenders contributors.
# Licensed under the MIT License.

from collections import defaultdict


[docs]class SASRecDataSet: """ A class for creating SASRec specific dataset used during train, validation and testing. Attributes: usernum: integer, total number of users itemnum: integer, total number of items User: dict, all the users (keys) with items as values Items: set of all the items user_train: dict, subset of User that are used for training user_valid: dict, subset of User that are used for validation user_test: dict, subset of User that are used for testing col_sep: column separator in the data file filename: data filename """
[docs] def __init__(self, **kwargs): self.usernum = 0 self.itemnum = 0 self.User = defaultdict(list) self.Items = set() self.user_train = {} self.user_valid = {} self.user_test = {} self.col_sep = kwargs.get("col_sep", " ") self.filename = kwargs.get("filename", None) if self.filename: with open(self.filename, "r") as fr: sample = fr.readline() ncols = sample.strip().split(self.col_sep) if ncols == 3: self.with_time = True else: self.with_time = False
def split(self, **kwargs): self.filename = kwargs.get("filename", self.filename) if not self.filename: raise ValueError("Filename is required") if self.with_time: self.data_partition_with_time() else: self.data_partition() def data_partition(self): # assume user/item index starting from 1 f = open(self.filename, "r") for line in f: u, i = line.rstrip().split(self.col_sep) u = int(u) i = int(i) self.usernum = max(u, self.usernum) self.itemnum = max(i, self.itemnum) self.User[u].append(i) for user in self.User: nfeedback = len(self.User[user]) if nfeedback < 3: self.user_train[user] = self.User[user] self.user_valid[user] = [] self.user_test[user] = [] else: self.user_train[user] = self.User[user][:-2] self.user_valid[user] = [] self.user_valid[user].append(self.User[user][-2]) self.user_test[user] = [] self.user_test[user].append(self.User[user][-1]) def data_partition_with_time(self): # assume user/item index starting from 1 f = open(self.filename, "r") for line in f: u, i, t = line.rstrip().split(self.col_sep) u = int(u) i = int(i) t = float(t) self.usernum = max(u, self.usernum) self.itemnum = max(i, self.itemnum) self.User[u].append((i, t)) self.Items.add(i) for user in self.User.keys(): # sort by time items = sorted(self.User[user], key=lambda x: x[1]) # keep only the items items = [x[0] for x in items] self.User[user] = items nfeedback = len(self.User[user]) if nfeedback < 3: self.user_train[user] = self.User[user] self.user_valid[user] = [] self.user_test[user] = [] else: self.user_train[user] = self.User[user][:-2] self.user_valid[user] = [] self.user_valid[user].append(self.User[user][-2]) self.user_test[user] = [] self.user_test[user].append(self.User[user][-1])