# Copyright (c) Recommenders contributors.
# Licensed under the MIT License.
# Original codes are from
# https://github.com/kang205/SASRec/blob/master/sampler.py
import numpy as np
from multiprocessing import Process, Queue
def random_neq(left, right, s):
t = np.random.randint(left, right)
while t in s:
t = np.random.randint(left, right)
return t
[docs]def sample_function(
user_train, usernum, itemnum, batch_size, maxlen, result_queue, seed
):
"""Batch sampler that creates a sequence of negative items based on the
original sequence of items (positive) that the user has interacted with.
Args:
user_train (dict): dictionary of training exampled for each user
usernum (int): number of users
itemnum (int): number of items
batch_size (int): batch size
maxlen (int): maximum input sequence length
result_queue (multiprocessing.Queue): queue for storing sample results
seed (int): seed for random generator
"""
def sample():
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1:
user = np.random.randint(1, usernum + 1)
seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen], dtype=np.int32)
nxt = user_train[user][-1]
idx = maxlen - 1
ts = set(user_train[user])
for i in reversed(user_train[user][:-1]):
seq[idx] = i
pos[idx] = nxt
if nxt != 0:
neg[idx] = random_neq(1, itemnum + 1, ts)
nxt = i
idx -= 1
if idx == -1:
break
return (user, seq, pos, neg)
np.random.seed(seed)
while True:
one_batch = []
for i in range(batch_size):
one_batch.append(sample())
result_queue.put(zip(*one_batch))
[docs]class WarpSampler(object):
"""Sampler object that creates an iterator for feeding batch data while training.
Attributes:
User: dict, all the users (keys) with items as values
usernum: integer, total number of users
itemnum: integer, total number of items
batch_size (int): batch size
maxlen (int): maximum input sequence length
n_workers (int): number of workers for parallel execution
"""
[docs] def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
self.result_queue = Queue(maxsize=n_workers * 10)
self.processors = []
for i in range(n_workers):
self.processors.append(
Process(
target=sample_function,
args=(
User,
usernum,
itemnum,
batch_size,
maxlen,
self.result_queue,
np.random.randint(2e9),
),
)
)
self.processors[-1].daemon = True
self.processors[-1].start()
def next_batch(self):
return self.result_queue.get()
def close(self):
for p in self.processors:
p.terminate()
p.join()