Skip to content

Commit 6fcb9d5

Browse files
authored
Merge pull request #71 from matthewsedam/bug-fixes
Bug fixes
2 parents 2cb7821 + b611247 commit 6fcb9d5

File tree

2 files changed

+42
-85
lines changed

2 files changed

+42
-85
lines changed

data_reader/dataset.py

Lines changed: 41 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
import os
22
from sklearn.feature_extraction.text import TfidfVectorizer
3-
from sklearn import preprocessing
4-
import sklearn.utils
53
import scipy
64
from scipy.sparse import csr_matrix
75
import sklearn
@@ -10,7 +8,6 @@
108
import pickle
119
from collections import namedtuple
1210
from copy import deepcopy
13-
from typing import Dict
1411

1512

1613
class Dataset(object):
@@ -50,10 +47,9 @@ class EmailDataset(Dataset):
5047
binary (boolean, optional): Feature type, continuous (False) by default
5148
"""
5249

53-
def __init__(self, path=None, raw=True, features=None, labels=None, norm='l1',
54-
binary=False, strip_accents_=None, ngram_range_=(1, 1),
55-
max_df_=1.0, min_df_=1, max_features_=1000, num_instances=0,
56-
standardization=False):
50+
def __init__(self, path=None, raw=True, features=None, labels=None,
51+
binary=True, strip_accents_=None, ngram_range_=(1, 1),
52+
max_df_=1.0, min_df_=1, max_features_=1000, num_instances=0):
5753
super(EmailDataset, self).__init__()
5854
self.num_instances = num_instances
5955
self.binary = binary
@@ -64,21 +60,15 @@ def __init__(self, path=None, raw=True, features=None, labels=None, norm='l1',
6460
self.labels, self.corpus = self._create_corpus(path)
6561
# Sklearn module to fit/transform data and resulting feature matrix
6662
# Maybe optionally pass this in as a parameter instead.
67-
# stop words?
6863
self.vectorizer = \
6964
TfidfVectorizer(analyzer='word',
7065
strip_accents=strip_accents_,
7166
ngram_range=ngram_range_, max_df=max_df_,
7267
min_df=min_df_, max_features=max_features_,
73-
binary=self.binary, stop_words='english',
74-
use_idf=True, norm=norm)
68+
binary=False, stop_words='english',
69+
use_idf=True, norm=None)
7570
self.vectorizer = self.vectorizer.fit(self.corpus)
7671
self.features = self.vectorizer.transform(self.corpus)
77-
if standardization:
78-
dense_feature = self.features.todense()
79-
scaler = preprocessing.StandardScaler(with_mean=False)
80-
scaler.fit(dense_feature)
81-
self.features = csr_matrix(scaler.transform(dense_feature))
8272
else:
8373
self.labels, self.features = \
8474
self._load(path, os.path.splitext(path)[1][1:])
@@ -239,26 +229,18 @@ def _csv(self, outfile, save=True):
239229
serialize.writerow(instance)
240230
else:
241231
# TODO: throw exception if FileNotFoundError
242-
if self.binary:
243-
data = np.genfromtxt(outfile, delimiter=',')
244-
self.num_instances = data.shape[0]
245-
labels = data[:, :1]
246-
feats = data[:, 1:]
247-
mask = ~np.isnan(feats)
248-
col = feats[mask]
249-
row = np.concatenate([np.ones_like(x) * i
250-
for i, x in enumerate(feats)])[mask.flatten()]
251-
features = csr_matrix((np.ones_like(col), (row, col)),
252-
shape=(feats.shape[0],
253-
int(np.max(feats[mask])) + 1))
254-
return np.squeeze(labels), features
255-
else:
256-
data = np.genfromtxt(outfile, delimiter=',')
257-
self.num_instances = data.shape[0]
258-
labels = data[:, :1]
259-
feats = data[:, 1:]
260-
features = csr_matrix(feats)
261-
return np.squeeze(labels), features
232+
data = np.genfromtxt(outfile, delimiter=',')
233+
self.num_instances = data.shape[0]
234+
labels = data[:, :1]
235+
feats = data[:, 1:]
236+
mask = ~np.isnan(feats)
237+
col = feats[mask]
238+
row = np.concatenate([np.ones_like(x) * i
239+
for i, x in enumerate(feats)])[mask.flatten()]
240+
features = csr_matrix((np.ones_like(col), (row, col)),
241+
shape=(feats.shape[0],
242+
int(np.max(feats[mask])) + 1))
243+
return np.squeeze(labels), features
262244

263245
def _pickle(self, outfile, save=True):
264246
"""A fast method for saving and loading datasets as python objects.
@@ -320,62 +302,37 @@ def _load(self, path, format='pkl', binary=False):
320302
raise AttributeError('The given load format is not currently \
321303
supported.')
322304

323-
def split(self, fraction=0.5, seed=None, random=True):
305+
def split(self, split={'test': 50, 'train': 50}):
324306
"""Split the dataset into test and train sets using
325307
`sklearn.utils.shuffle()`.
326308
327309
Args:
328-
fraction (float/int, optional): fraction of training data in split
310+
split (Dict, optional): A dictionary specifying the splits between
311+
test and trainset. The values can be floats or ints.
312+
329313
Returns:
330314
trainset, testset (namedtuple, namedtuple): Split tuples containing
331315
share of shuffled data instances.
332316
333317
"""
334-
335-
if isinstance(fraction, Dict):
336-
fraction = fraction['train'] / 100
337-
if fraction < 0:
338-
raise ValueError('Split percentages must be positive values')
339-
if fraction > 1.0:
340-
fraction /= 100
341-
pivot = int(self.__len__() * fraction)
342-
if random:
343-
if seed:
344-
s_feats, s_labels = sklearn.utils.shuffle(self.features, self.labels,
345-
random_state=seed)
346-
else:
347-
s_feats, s_labels = sklearn.utils.shuffle(self.features, self.labels,
348-
random_state=scipy.random.seed())
349-
350-
return (self.__class__(raw=False, features=s_feats[:pivot, :],
351-
labels=s_labels[:pivot], num_instances=pivot,
352-
binary=self.binary),
353-
self.__class__(raw=False, features=s_feats[pivot:, :],
354-
labels=s_labels[pivot:],
355-
num_instances=self.num_instances - pivot,
356-
binary=self.binary))
318+
splits = list(split.values())
319+
for s in splits:
320+
if s < 0:
321+
raise ValueError('Split percentages must be positive values')
322+
# data = self.features.toarray()
323+
frac = 0
324+
if splits[0] < 1.0:
325+
frac = splits[0]
357326
else:
358-
return (self.__class__(raw=False, features=self.features[:pivot, :],
359-
labels=self.labels[:pivot], num_instances=pivot,
360-
binary=self.binary),
361-
self.__class__(raw=False, features=self.features[pivot:, :],
362-
labels=self.labels[pivot:],
363-
num_instances=self.num_instances - pivot,
364-
binary=self.binary))
365-
366-
def report(self):
367-
"""
368-
return a string that contains information about data set shape, size
369-
and positive/negative instance count and percentage
370-
"""
371-
s = "number of instances: {0}\n".format(self.num_instances)
372-
s += "instance feature length: {0}\n".format(self.shape[1])
373-
pos_cnt = self.labels.tolist().count(1)
374-
neg_cnt = self.labels.tolist().count(-1)
375-
s += "positive instance count and percentage: {0}, {1}%\n".format(pos_cnt,
376-
100 * pos_cnt / len(
377-
self.labels))
378-
s += "negative instance count and percentage: {0}, {1}%\n".format(neg_cnt,
379-
100 * neg_cnt / len(
380-
self.labels))
381-
return s
327+
frac = splits[0] / 100
328+
pivot = int(self.__len__() * frac)
329+
s_feats, s_labels = sklearn.utils.shuffle(self.features, self.labels)
330+
return (self.__class__(raw=False, features=s_feats[:pivot, :],
331+
labels=s_labels[:pivot], num_instances=pivot,
332+
binary=self.binary),
333+
self.__class__(raw=False, features=s_feats[pivot:, :],
334+
labels=s_labels[pivot:],
335+
num_instances=self.num_instances - pivot,
336+
binary=self.binary))
337+
# return (self.Data(s_feats[:pivot, :], s_labels[:pivot]),
338+
# self.Data(s_feats[pivot:, :], s_labels[pivot:]))

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def get_requirements():
1111

1212
setup(
1313
name='adlib',
14-
version='1.0.4',
14+
version='1.1.0',
1515
description='Game-theoretic adversarial machine learning library providing '
1616
'a set of learner and adversary modules.',
1717
url='https://github.com/vu-aml/adlib',

0 commit comments

Comments
 (0)