1
+ from dataloader import load_data
2
+
3
+ import gensim
4
+ from gensim .corpora import Dictionary
5
+ from collections import defaultdict
6
+
7
+ import pickle
8
+ import scipy .sparse as sp
9
+ import numpy as np
10
+ import random
11
+ import torch
12
+
13
+
14
+ class MakeTMDataSet ():
15
+ def __init__ (self ):
16
+ self .item2idx , self .df = load_data ()
17
+
18
+ self .df ["item_idx" ] = self .df ["item" ].map (self .item2idx )
19
+ self .df ['item_name' ] = self .df ['item' ].map (self .item2name ())
20
+
21
+ # inter_dict & df user 순서 주의
22
+ self .inter_dict = self .df .groupby ('user' , sort = False )['item_name' ].apply (set ).apply (list ).to_dict ()
23
+ self .user_ids = list (self .inter_dict .keys ())
24
+ self .user2idx = {user_id : index for index , user_id in enumerate (self .user_ids )}
25
+
26
+ self .df ["user_idx" ] = self .df ["user" ].map (self .user2idx )
27
+
28
+ self .num_item , self .num_user = len (self .item2idx ), len (self .user2idx )
29
+
30
+ self .dictionary , self .corpus = self .TM_traindata ()
31
+
32
+ def item2name (self ):
33
+ with open ('/home/user/pickle/product_info_df.pickle' , 'rb' ) as fr :
34
+ product_info = pickle .load (fr )
35
+
36
+ product_data = product_info .copy ()
37
+ product_data ['title' ] = product_data ['title' ].map (lambda x : x .replace ("'" ,'' ).replace (',' ,'' ).replace ('(' , ' ' ).replace (')' , ' ' ))
38
+ product_data ['title' ] = product_data ['title' ].map (lambda x : x .lower ())
39
+ product_data ['title' ] = product_data ['title' ].map (lambda x : x .split (' ' ))
40
+ product_data ['title' ] = product_data ['title' ].map (lambda x : ' ' .join (x ).split ())
41
+ product_data ['title' ] = product_data ['title' ].map (lambda x : ' ' .join (x ))
42
+
43
+ dict_products = product_data [['id' ,'title' ]].set_index ('id' ).to_dict ()['title' ]
44
+
45
+ return dict_products
46
+
47
+ def TM_traindata (self ):
48
+ documents = list (self .inter_dict .values ())
49
+ dictionary = Dictionary (documents )
50
+ corpus = [dictionary .doc2bow (document ) for document in documents ]
51
+ return dictionary , corpus
52
+
53
+ def get_dictionary (self ):
54
+ return self .dictionary
55
+
56
+ def get_corpus (self ):
57
+ return self .corpus
58
+
59
+
60
+ class MakeLightGCNDataSet ():
61
+ def __init__ (self , TM_dataset , lda_model , args ):
62
+ self .args = args
63
+ self .TM_dataset = TM_dataset
64
+ self .lda_model = lda_model
65
+
66
+ self .df = self .TM_dataset .df
67
+ self .user2idx = self .TM_dataset .user2idx
68
+ self .item2idx = self .TM_dataset .item2idx
69
+ self .num_user , self .num_item = self .TM_dataset .num_user , self .TM_dataset .num_item
70
+
71
+ self .exist_users = [i for i in range (self .num_user )]
72
+ self .exist_items = [i for i in range (self .num_item )]
73
+
74
+ self .user_train , self .user_valid = self .generate_sequence_data ()
75
+ self .R_train , self .R_valid , self .R_total = self .generate_dok_matrix ()
76
+ self .ngcf_adj_matrix = self .generate_ngcf_adj_matrix ()
77
+
78
+ self .user_topic_tensor = self .get_TM_user_vector ()
79
+
80
+ self .n_train = len (self .R_train )
81
+ self .batch_size = self .args .batch_size
82
+
83
+ def generate_sequence_data (self ) -> dict :
84
+ """
85
+ split train/valid
86
+ 중복 허용
87
+ """
88
+ users = defaultdict (list )
89
+ user_train = {}
90
+ user_valid = {}
91
+ for user , item , time in zip (self .df ['user_idx' ], self .df ['item_idx' ], self .df ['time' ]):
92
+ users [user ].append (item )
93
+
94
+ for user in users :
95
+ np .random .seed (self .args .seed )
96
+ user_total = users [user ]
97
+ valid_indices = random .sample (range (len (user_total )), 2 )
98
+ valid = [user_total [idx ] for idx in valid_indices ]
99
+ train = [user_total [idx ] for idx in range (len (user_total )) if idx not in valid_indices ]
100
+ user_train [user ] = train
101
+ user_valid [user ] = valid
102
+
103
+ return user_train , user_valid
104
+
105
+ def generate_dok_matrix (self ):
106
+ R_train = sp .dok_matrix ((self .num_user , self .num_item ), dtype = np .float32 )
107
+ R_valid = sp .dok_matrix ((self .num_user , self .num_item ), dtype = np .float32 )
108
+ R_total = sp .dok_matrix ((self .num_user , self .num_item ), dtype = np .float32 )
109
+ user_list = self .exist_users # user2idx에 있는 value값
110
+ for user in user_list :
111
+ train_items = self .user_train [user ]
112
+ valid_items = self .user_valid [user ]
113
+
114
+ for train_item in train_items :
115
+ R_train [user , train_item ] = 1.0
116
+ R_total [user , train_item ] = 1.0
117
+
118
+ for valid_item in valid_items :
119
+ R_valid [user , valid_item ] = 1.0
120
+ R_total [user , valid_item ] = 1.0
121
+
122
+ return R_train , R_valid , R_total
123
+
124
+ def generate_ngcf_adj_matrix (self ):
125
+ adj_mat = sp .dok_matrix ((self .num_user + self .num_item , self .num_user + self .num_item ), dtype = np .float32 )
126
+ adj_mat = adj_mat .tolil () # to_list
127
+ R = self .R_train .tolil ()
128
+
129
+ adj_mat [:self .num_user , self .num_user :] = R
130
+ adj_mat [self .num_user :, :self .num_user ] = R .T
131
+ adj_mat = adj_mat .todok () # to_dok_matrix
132
+
133
+ def normalized_adj_single (adj ):
134
+ rowsum = np .array (adj .sum (1 ))
135
+ d_inv = np .power (rowsum , - .5 ).flatten ()
136
+ d_inv [np .isinf (d_inv )] = 0.
137
+ d_mat_inv = sp .diags (d_inv )
138
+ norm_adj = d_mat_inv .dot (adj ).dot (d_mat_inv )
139
+
140
+ return norm_adj .tocoo ()
141
+
142
+ ngcf_adj_matrix = normalized_adj_single (adj_mat )
143
+ return ngcf_adj_matrix .tocsr ()
144
+
145
+ def get_TM_user_vector (self ):
146
+ user_topic_matrix = np .zeros ((self .num_user , self .args .num_topics ))
147
+ corpus = self .TM_dataset .get_corpus ()
148
+
149
+ user_topic_vectors = [self .lda_model .get_document_topics (bow , minimum_probability = 0.0 )
150
+ for bow in corpus ]
151
+ for i , user_vec in enumerate (user_topic_vectors ):
152
+ """
153
+ i: user idx
154
+ user_vec: (topic, prob)
155
+ """
156
+ for topic , prob in user_vec :
157
+ user_topic_matrix [i , topic ] = prob
158
+
159
+ # numpy array --> torch tensor
160
+ user_topic_tensor = torch .tensor (user_topic_matrix , dtype = torch .float32 )
161
+
162
+ return user_topic_tensor
163
+
164
+ def sampling (self ):
165
+ users = random .sample (self .exist_users , self .args .batch_size )
166
+
167
+ def sample_pos_items_for_u (u , num ):
168
+ pos_items = self .user_train [u ]
169
+ pos_batch = random .sample (pos_items , num )
170
+ return pos_batch
171
+
172
+ def sample_neg_items_for_u (u , num ):
173
+ neg_items = list (set (self .exist_items ) - set (self .user_train [u ]))
174
+ neg_batch = random .sample (neg_items , num )
175
+ return neg_batch
176
+
177
+ pos_items , neg_items = [], []
178
+ for user in users :
179
+ pos_items += sample_pos_items_for_u (user , 1 )
180
+ neg_items += sample_neg_items_for_u (user , 1 )
181
+
182
+ return users , pos_items , neg_items
183
+
184
+ def get_train_valid_data (self ):
185
+ return self .user_train , self .user_valid
186
+
187
+ def get_R_data (self ):
188
+ return self .R_train , self .R_valid , self .R_total
189
+
190
+ def get_ngcf_adj_matrix_data (self ):
191
+ return self .ngcf_adj_matrix
0 commit comments