@@ -357,12 +357,16 @@ class PairsDataLoader(OriginalDataLoader):
357
357
This dataloader takes a pair file as argument (instead of a cluster
358
358
file like the other dataloaders)
359
359
"""
360
+ SPLIT_FILES = "files"
361
+ SPLIT_EACH_FILE = "split_each_file"
362
+ SPLIT_METHODS = [SPLIT_FILES , SPLIT_EACH_FILE ]
360
363
361
364
def __init__ (self , pairs_path , features_path , id_to_file ,
362
365
ratio_split_train_test = 0.7 ,
363
366
batch_size = 8 , train_iterations = 10000 , test_iterations = 500 ,
364
367
proportion_positive_pairs = 0.5 ,
365
- align_different_words = True ):
368
+ align_different_words = True ,
369
+ split_method = SPLIT_EACH_FILE ):
366
370
self .pairs_path = pairs_path
367
371
self .features_path = features_path
368
372
self .features = None # type: Features_Accessor
@@ -373,6 +377,8 @@ def __init__(self, pairs_path, features_path, id_to_file,
373
377
self .align_different_words = align_different_words
374
378
self .iterations = {'train' : train_iterations , 'test' : test_iterations }
375
379
self .proportion_positive_pairs = proportion_positive_pairs
380
+ self .split_method = split_method
381
+ assert split_method in self .SPLIT_METHODS
376
382
self .tokens = {'train' : [], 'test' : []}
377
383
self .statistics_training = defaultdict (int )
378
384
self .files = set ()
@@ -444,7 +450,10 @@ def load_pairs(self):
444
450
self .files .add (file2 )
445
451
pairs .append (
446
452
[file1 , begin1 , end1 , file2 , begin2 , end2 ])
447
- self .pairs ['train' ], self .pairs ['test' ] = self .split_train_test (pairs )
453
+ if self .split_method == self .SPLIT_FILES :
454
+ self .pairs ['train' ], self .pairs ['test' ] = self .split_train_test (pairs )
455
+ elif self .split_method == self .SPLIT_EACH_FILE :
456
+ self .pairs ['train' ], self .pairs ['test' ] = self .split_train_test_each_file (pairs )
448
457
tokens = {'train' : set (), 'test' : set ()}
449
458
for mode in ('train' , 'test' ):
450
459
for file1 , begin1 , end1 , file2 , begin2 , end2 in self .pairs [mode ]:
@@ -471,6 +480,30 @@ def split_train_test(self, pairs):
471
480
472
481
return train_pairs , dev_pairs
473
482
483
+ def split_train_test_each_file (self , pairs ):
484
+ # fill len of each file
485
+ len_files = defaultdict (int )
486
+ for p in pairs :
487
+ file1 , s1 , e1 , file2 , s2 , e2 = p
488
+ len_files [file1 ] = max (len_files [file1 ], e1 )
489
+ len_files [file2 ] = max (len_files [file2 ], e2 )
490
+ print (len_files )
491
+
492
+ # split on length
493
+ train_threshold = {}
494
+ for file in len_files :
495
+ train_threshold [file ] = len_files [file ] * self .ratio_split_train_test
496
+ print (train_threshold )
497
+ # split clusters
498
+ train_pairs , dev_pairs = [], []
499
+ for p in pairs :
500
+ file1 , s1 , e1 , file2 , s2 , e2 = p
501
+ if s1 > train_threshold [file1 ] and s2 > train_threshold [file2 ]:
502
+ dev_pairs .append (p )
503
+ elif s1 < train_threshold [file1 ] and s2 <= train_threshold [file2 ]:
504
+ train_pairs .append (p )
505
+ return train_pairs , dev_pairs
506
+
474
507
def batch_iterator (self , train_mode = True ):
475
508
print ("constructing batches" )
476
509
mode = 'train' if train_mode else 'test'
0 commit comments