Skip to content

Commit b1be106

Browse files
committed
Feat:TM2LGCN_modularization #62
1 parent f878072 commit b1be106

File tree

6 files changed

+545
-0
lines changed

6 files changed

+545
-0
lines changed

model/TM2LGCN/args.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import argparse
2+
3+
def parse_args():
4+
parser = argparse.ArgumentParser()
5+
parser.add_argument("--device", default="cuda", type=str, help="cpu or gpu")
6+
7+
# TM args
8+
parser.add_argument("--num_topics", default=24, type=int, help="number of topics")
9+
parser.add_argument("--random_state", default=42, type=int, help="LDAmodel_random_state")
10+
parser.add_argument("--passes", default=20, type=int, help="LDAmodel_passes")
11+
12+
# model
13+
parser.add_argument("--emb_dim", default=24, type=int, help="hidden dimension size")
14+
parser.add_argument("--reg", default=1e-5, type=int, help="regularization")
15+
parser.add_argument("--n_layers", default=2, type=int, help="number of layers")
16+
parser.add_argument("--node_dropout", default=0.2, type=float, help="drop out rate")
17+
parser.add_argument("--valid_samples", default=2, type=int, help="valid samples")
18+
19+
# train
20+
parser.add_argument("--seed", default=22, type=int, help="seed")
21+
parser.add_argument("--num_epochs", default=150, type=int, help="number of epochs")
22+
parser.add_argument("--batch_size", default=64, type=int, help="batch size")
23+
parser.add_argument("--lr", default=0.0001, type=float, help="learning rate")
24+
parser.add_argument("--n_batch", default=10, type=int, help="n_batch")
25+
26+
args = parser.parse_args()
27+
28+
return args

model/TM2LGCN/dataloader.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from google.oauth2 import service_account
2+
from google.cloud import storage
3+
import pickle
4+
5+
from datetime import datetime
6+
import pandas as pd
7+
import numpy as np
8+
from collections import defaultdict
9+
10+
11+
def preprocess(df):
12+
df = df[df['uri_first']==1]
13+
df['timestamp']=pd.to_datetime(df['local_time']).astype(int)//10**9
14+
df = df[['hashed_ip', 'products', 'timestamp']]
15+
16+
df['user']=df['hashed_ip']
17+
df['item']=df['products']
18+
df['time']=df['timestamp']
19+
20+
df.sort_values(['user', 'timestamp'])
21+
22+
del df['hashed_ip'], df['products'], df['timestamp']
23+
user_interaction_counts = df['user'].value_counts()
24+
selected_users = user_interaction_counts[user_interaction_counts >= 5].index
25+
df = df[df['user'].isin(selected_users)]
26+
27+
return df
28+
29+
30+
def load_data():
31+
# LOAD ITEM2IDX PICKLE
32+
33+
SERVICE_ACCOUNT_FILE = "/home/user/TM2LGCN/storage/level3-416207-893f91c9529e_api.json"
34+
credentials = service_account.Credentials.from_service_account_file(SERVICE_ACCOUNT_FILE)
35+
project_id = "level3-416207"
36+
storage_client = storage.Client(credentials=credentials, project=project_id)
37+
bucket_name = 'crwalnoti'
38+
bucket = storage_client.bucket(bucket_name)
39+
40+
item2idx_name = '240320/item_to_idx.pickle'
41+
inter_name = '240320/inter_240129.csv'
42+
43+
# prepare item2idx
44+
blob_item2idx = bucket.blob(item2idx_name)
45+
with blob_item2idx.open(mode='rb') as f:
46+
item2idx = pickle.load(f)
47+
48+
# prepare interaction_df
49+
blob_inter = bucket.blob(inter_name)
50+
with blob_inter.open(mode='rb') as f:
51+
interaction_df = pd.read_csv(f)
52+
53+
interaction_df = preprocess(interaction_df)
54+
55+
return item2idx, interaction_df
56+

model/TM2LGCN/dataset.py

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
from dataloader import load_data
2+
3+
import gensim
4+
from gensim.corpora import Dictionary
5+
from collections import defaultdict
6+
7+
import pickle
8+
import scipy.sparse as sp
9+
import numpy as np
10+
import random
11+
import torch
12+
13+
14+
class MakeTMDataSet():
15+
def __init__(self):
16+
self.item2idx , self.df = load_data()
17+
18+
self.df["item_idx"] = self.df["item"].map(self.item2idx)
19+
self.df['item_name'] = self.df['item'].map(self.item2name())
20+
21+
# inter_dict & df user 순서 주의
22+
self.inter_dict = self.df.groupby('user', sort=False)['item_name'].apply(set).apply(list).to_dict()
23+
self.user_ids = list(self.inter_dict.keys())
24+
self.user2idx = {user_id: index for index, user_id in enumerate(self.user_ids)}
25+
26+
self.df["user_idx"] = self.df["user"].map(self.user2idx)
27+
28+
self.num_item, self.num_user = len(self.item2idx), len(self.user2idx)
29+
30+
self.dictionary, self.corpus = self.TM_traindata()
31+
32+
def item2name(self):
33+
with open('/home/user/pickle/product_info_df.pickle', 'rb') as fr:
34+
product_info = pickle.load(fr)
35+
36+
product_data = product_info.copy()
37+
product_data['title'] = product_data['title'].map(lambda x: x.replace("'",'').replace(',','').replace('(', ' ').replace(')', ' '))
38+
product_data['title'] = product_data['title'].map(lambda x: x.lower())
39+
product_data['title'] = product_data['title'].map(lambda x: x.split(' '))
40+
product_data['title'] = product_data['title'].map(lambda x: ' '.join(x).split())
41+
product_data['title'] = product_data['title'].map(lambda x: ' '.join(x))
42+
43+
dict_products = product_data[['id','title']].set_index('id').to_dict()['title']
44+
45+
return dict_products
46+
47+
def TM_traindata(self):
48+
documents = list(self.inter_dict.values())
49+
dictionary = Dictionary(documents)
50+
corpus = [dictionary.doc2bow(document) for document in documents]
51+
return dictionary, corpus
52+
53+
def get_dictionary(self):
54+
return self.dictionary
55+
56+
def get_corpus(self):
57+
return self.corpus
58+
59+
60+
class MakeLightGCNDataSet():
61+
def __init__(self, TM_dataset, lda_model, args):
62+
self.args = args
63+
self.TM_dataset = TM_dataset
64+
self.lda_model = lda_model
65+
66+
self.df = self.TM_dataset.df
67+
self.user2idx = self.TM_dataset.user2idx
68+
self.item2idx = self.TM_dataset.item2idx
69+
self.num_user, self.num_item = self.TM_dataset.num_user, self.TM_dataset.num_item
70+
71+
self.exist_users = [i for i in range(self.num_user)]
72+
self.exist_items = [i for i in range(self.num_item)]
73+
74+
self.user_train, self.user_valid = self.generate_sequence_data()
75+
self.R_train, self.R_valid, self.R_total = self.generate_dok_matrix()
76+
self.ngcf_adj_matrix = self.generate_ngcf_adj_matrix()
77+
78+
self.user_topic_tensor = self.get_TM_user_vector()
79+
80+
self.n_train = len(self.R_train)
81+
self.batch_size = self.args.batch_size
82+
83+
def generate_sequence_data(self) -> dict:
84+
"""
85+
split train/valid
86+
중복 허용
87+
"""
88+
users = defaultdict(list)
89+
user_train = {}
90+
user_valid = {}
91+
for user, item, time in zip(self.df['user_idx'], self.df['item_idx'], self.df['time']):
92+
users[user].append(item)
93+
94+
for user in users:
95+
np.random.seed(self.args.seed)
96+
user_total = users[user]
97+
valid_indices = random.sample(range(len(user_total)), 2)
98+
valid = [user_total[idx] for idx in valid_indices]
99+
train = [user_total[idx] for idx in range(len(user_total)) if idx not in valid_indices]
100+
user_train[user] = train
101+
user_valid[user] = valid
102+
103+
return user_train, user_valid
104+
105+
def generate_dok_matrix(self):
106+
R_train = sp.dok_matrix((self.num_user, self.num_item), dtype=np.float32)
107+
R_valid = sp.dok_matrix((self.num_user, self.num_item), dtype=np.float32)
108+
R_total = sp.dok_matrix((self.num_user, self.num_item), dtype=np.float32)
109+
user_list = self.exist_users # user2idx에 있는 value값
110+
for user in user_list:
111+
train_items = self.user_train[user]
112+
valid_items = self.user_valid[user]
113+
114+
for train_item in train_items:
115+
R_train[user, train_item] = 1.0
116+
R_total[user, train_item] = 1.0
117+
118+
for valid_item in valid_items:
119+
R_valid[user, valid_item] = 1.0
120+
R_total[user, valid_item] = 1.0
121+
122+
return R_train, R_valid, R_total
123+
124+
def generate_ngcf_adj_matrix(self):
125+
adj_mat = sp.dok_matrix((self.num_user + self.num_item, self.num_user + self.num_item), dtype=np.float32)
126+
adj_mat = adj_mat.tolil() # to_list
127+
R = self.R_train.tolil()
128+
129+
adj_mat[:self.num_user, self.num_user:] = R
130+
adj_mat[self.num_user:, :self.num_user] = R.T
131+
adj_mat = adj_mat.todok() # to_dok_matrix
132+
133+
def normalized_adj_single(adj):
134+
rowsum = np.array(adj.sum(1))
135+
d_inv = np.power(rowsum, -.5).flatten()
136+
d_inv[np.isinf(d_inv)] = 0.
137+
d_mat_inv = sp.diags(d_inv)
138+
norm_adj = d_mat_inv.dot(adj).dot(d_mat_inv)
139+
140+
return norm_adj.tocoo()
141+
142+
ngcf_adj_matrix = normalized_adj_single(adj_mat)
143+
return ngcf_adj_matrix.tocsr()
144+
145+
def get_TM_user_vector(self):
146+
user_topic_matrix = np.zeros((self.num_user, self.args.num_topics))
147+
corpus = self.TM_dataset.get_corpus()
148+
149+
user_topic_vectors = [self.lda_model.get_document_topics(bow, minimum_probability=0.0)
150+
for bow in corpus]
151+
for i, user_vec in enumerate(user_topic_vectors):
152+
"""
153+
i: user idx
154+
user_vec: (topic, prob)
155+
"""
156+
for topic, prob in user_vec:
157+
user_topic_matrix[i, topic] = prob
158+
159+
# numpy array --> torch tensor
160+
user_topic_tensor = torch.tensor(user_topic_matrix, dtype=torch.float32)
161+
162+
return user_topic_tensor
163+
164+
def sampling(self):
165+
users = random.sample(self.exist_users, self.args.batch_size)
166+
167+
def sample_pos_items_for_u(u, num):
168+
pos_items = self.user_train[u]
169+
pos_batch = random.sample(pos_items, num)
170+
return pos_batch
171+
172+
def sample_neg_items_for_u(u, num):
173+
neg_items = list(set(self.exist_items) - set(self.user_train[u]))
174+
neg_batch = random.sample(neg_items, num)
175+
return neg_batch
176+
177+
pos_items, neg_items = [], []
178+
for user in users:
179+
pos_items += sample_pos_items_for_u(user, 1)
180+
neg_items += sample_neg_items_for_u(user, 1)
181+
182+
return users, pos_items, neg_items
183+
184+
def get_train_valid_data(self):
185+
return self.user_train, self.user_valid
186+
187+
def get_R_data(self):
188+
return self.R_train, self.R_valid, self.R_total
189+
190+
def get_ngcf_adj_matrix_data(self):
191+
return self.ngcf_adj_matrix

model/TM2LGCN/main.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from args import parse_args
2+
from Dataset import MakeTMDataSet, MakeLightGCNDataSet
3+
from gensim.models import LdaModel
4+
from model import LightGCN
5+
from trainer import train, evaluate
6+
from tqdm import tqdm
7+
8+
import torch
9+
import os
10+
import mlflow
11+
import mlflow.pytorch
12+
13+
def main(args):
14+
15+
print(f'----------------------Load TM Data & Make TM Dataset----------------------')
16+
TM_dataset = MakeTMDataSet()
17+
dictionary, corpus = TM_dataset.get_dictionary(), TM_dataset.get_corpus()
18+
print(f'Done.')
19+
20+
print(f'----------------------Load & Train TM Model----------------------')
21+
print(f'...')
22+
lda_model = LdaModel(corpus=corpus, id2word=dictionary,
23+
num_topics=args.num_topics,
24+
random_state=args.random_state,
25+
passes=args.passes)
26+
print(f'Done.')
27+
28+
print(f'----------------------Make LGCN_dataset & LGCN_model----------------------')
29+
lightgcn_dataset = MakeLightGCNDataSet(TM_dataset, lda_model, args)
30+
ngcf_adj_matrix = lightgcn_dataset.get_ngcf_adj_matrix_data()
31+
R_train, R_valid, R_total = lightgcn_dataset.get_R_data()
32+
33+
args.device = "cuda" if torch.cuda.is_available() else "cpu"
34+
35+
model = LightGCN(
36+
n_users = lightgcn_dataset.num_user,
37+
n_items = lightgcn_dataset.num_item,
38+
args = args,
39+
adj_mtx = ngcf_adj_matrix,
40+
user_topic_tensor = lightgcn_dataset.user_topic_tensor,
41+
).to(args.device)
42+
43+
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
44+
print(f'Done.')
45+
46+
print(f'----------------------Training----------------------')
47+
best_hit = 0
48+
for epoch in range(1, args.num_epochs + 1):
49+
tbar = tqdm(range(1))
50+
for _ in tbar:
51+
train_loss = train(
52+
model = model,
53+
make_graph_data_set = lightgcn_dataset,
54+
optimizer = optimizer,
55+
n_batch = args.n_batch,
56+
)
57+
with torch.no_grad():
58+
ndcg, hit = evaluate(
59+
u_emb = model.u_emb.detach(),
60+
i_emb = model.i_emb.detach(),
61+
Rtr = R_train,
62+
Rte = R_valid,
63+
args = args,
64+
k = 10,
65+
)
66+
# if best_hit < hit:
67+
# best_hit = hit
68+
# torch.save(model.state_dict(), os.path.join(args.model_path, args.model_name))
69+
tbar.set_description(f'Epoch: {epoch:3d}| Train loss: {train_loss:.5f}| NDCG@10: {ndcg:.5f}| HIT@10: {hit:.5f}')
70+
71+
if __name__ == "__main__":
72+
args = parse_args()
73+
main(args)

0 commit comments

Comments
 (0)