Skip to content

Commit c984552

Browse files
committed
split each file also on pairs
1 parent 39565e9 commit c984552

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

abnet3/dataloader.py

+35-2
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,16 @@ class PairsDataLoader(OriginalDataLoader):
357357
This dataloader takes a pair file as argument (instead of a cluster
358358
file like the other dataloaders)
359359
"""
360+
SPLIT_FILES = "files"
361+
SPLIT_EACH_FILE = "split_each_file"
362+
SPLIT_METHODS = [SPLIT_FILES, SPLIT_EACH_FILE]
360363

361364
def __init__(self, pairs_path, features_path, id_to_file,
362365
ratio_split_train_test=0.7,
363366
batch_size=8, train_iterations=10000, test_iterations=500,
364367
proportion_positive_pairs=0.5,
365-
align_different_words=True):
368+
align_different_words=True,
369+
split_method=SPLIT_EACH_FILE):
366370
self.pairs_path = pairs_path
367371
self.features_path = features_path
368372
self.features = None # type: Features_Accessor
@@ -373,6 +377,8 @@ def __init__(self, pairs_path, features_path, id_to_file,
373377
self.align_different_words = align_different_words
374378
self.iterations = {'train': train_iterations, 'test': test_iterations}
375379
self.proportion_positive_pairs = proportion_positive_pairs
380+
self.split_method = split_method
381+
assert split_method in self.SPLIT_METHODS
376382
self.tokens = {'train': [], 'test': []}
377383
self.statistics_training = defaultdict(int)
378384
self.files = set()
@@ -444,7 +450,10 @@ def load_pairs(self):
444450
self.files.add(file2)
445451
pairs.append(
446452
[file1, begin1, end1, file2, begin2, end2])
447-
self.pairs['train'], self.pairs['test'] = self.split_train_test(pairs)
453+
if self.split_method == self.SPLIT_FILES:
454+
self.pairs['train'], self.pairs['test'] = self.split_train_test(pairs)
455+
elif self.split_method == self.SPLIT_EACH_FILE:
456+
self.pairs['train'], self.pairs['test'] = self.split_train_test_each_file(pairs)
448457
tokens = {'train': set(), 'test': set()}
449458
for mode in ('train', 'test'):
450459
for file1, begin1, end1, file2, begin2, end2 in self.pairs[mode]:
@@ -471,6 +480,30 @@ def split_train_test(self, pairs):
471480

472481
return train_pairs, dev_pairs
473482

483+
def split_train_test_each_file(self, pairs):
484+
# fill len of each file
485+
len_files = defaultdict(int)
486+
for p in pairs:
487+
file1, s1, e1, file2, s2, e2 = p
488+
len_files[file1] = max(len_files[file1], e1)
489+
len_files[file2] = max(len_files[file2], e2)
490+
print(len_files)
491+
492+
# split on length
493+
train_threshold = {}
494+
for file in len_files:
495+
train_threshold[file] = len_files[file] * self.ratio_split_train_test
496+
print(train_threshold)
497+
# split clusters
498+
train_pairs, dev_pairs = [], []
499+
for p in pairs:
500+
file1, s1, e1, file2, s2, e2 = p
501+
if s1 > train_threshold[file1] and s2 > train_threshold[file2]:
502+
dev_pairs.append(p)
503+
elif s1 < train_threshold[file1] and s2 <= train_threshold[file2]:
504+
train_pairs.append(p)
505+
return train_pairs, dev_pairs
506+
474507
def batch_iterator(self, train_mode=True):
475508
print("constructing batches")
476509
mode = 'train' if train_mode else 'test'

0 commit comments

Comments
 (0)