Skip to content

Commit 468750d

Browse files
committed
implemented ParallelBatch in another attempt to fix issue #4
1 parent 5838d5c commit 468750d

File tree

3 files changed

+61
-20
lines changed

3 files changed

+61
-20
lines changed

luiginlp/engine.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def inherit_parameters(Class, *ChildClasses):
433433
if isinstance(attr,luigi.Parameter) and not hasattr(Class, key):
434434
setattr(Class,key, attr)
435435

436-
def outputfrominput(self, inputformat, stripextension, addextension, outputdirparam='outputdir'):
436+
def outputfrominput(self, inputformat, stripextension, addextension, replaceinputdirparam='replaceinputdir', outputdirparam='outputdir'):
437437
"""Derives the output filename from the input filename, removing the input extension and adding the output extension. Supports outputdir parameter."""
438438

439439
if not hasattr(self,'in_' + inputformat):
@@ -448,15 +448,23 @@ def outputfrominput(self, inputformat, stripextension, addextension, outputdirpa
448448
if hasattr(self,outputdirparam):
449449
outputdir = getattr(self,outputdirparam)
450450
if outputdir and outputdir != '.':
451-
return TargetInfo(self, os.path.join(outputdir, os.path.basename(replaceextension(inputfilename, stripextension,addextension))))
452-
return TargetInfo(self, replaceextension(inputfilename, stripextension,addextension))
451+
if hasattr(self, replaceinputdirparam):
452+
replaceinputdir = getattr(self,replaceinputdirparam)
453+
if replaceinputdir:
454+
if inputfilename.startswith(replaceinputdir):
455+
return TargetInfo(self, os.path.join(outputdir, os.path.basename(replaceextension(inputfilename[len(replaceinputdir):], stripextension,addextension))))
456+
else:
457+
return TargetInfo(self, os.path.join(outputdir, os.path.basename(replaceextension(inputfilename, stripextension,addextension))))
458+
else:
459+
return TargetInfo(self, replaceextension(inputfilename, stripextension,addextension))
453460

454461

455462
class StandardWorkflowComponent(WorkflowComponent):
456463
"""A workflow component that takes one inputfile"""
457464

458465
inputfile = luigi.Parameter()
459466
outputdir = luigi.Parameter(default="")
467+
replaceinputdir = luigi.Parameter(default="")
460468

461469
class TargetInfo(sciluigi.TargetInfo):
462470
pass
@@ -479,6 +487,36 @@ def __init__(self, *args, **kwargs):
479487
def __hash__(self):
480488
return hash(tuple(sorted(self.items())))
481489

490+
class ParallelBatch(luigi.Task):
491+
"""Meta workflow"""
492+
inputfiles = luigi.Parameter()
493+
component = luigi.Parameter()
494+
passparameters = luigi.Parameter(default=PassParameters())
495+
496+
def requires(self):
497+
if isinstance(self.passparameters, str):
498+
self.passparameters = PassParameters(json.loads(self.passparameters.replace("'",'"')))
499+
elif isinstance(self.passparameters, dict):
500+
self.passparameters = PassParameters(self.passparameters)
501+
elif not isinstance(self.passparameters, PassParameters):
502+
raise TypeError("Keywork argument passparameters must be instance of PassParameters, got " + repr(self.passparameters))
503+
tasks = []
504+
ComponentClass = getcomponentclass(self.component)
505+
if isinstance(self.inputfiles, str):
506+
self.inputfiles = self.inputfiles.split(',')
507+
for inputfile in self.inputfiles:
508+
tasks.append( ComponentClass(inputfile=inputfile,**self.passparameters))
509+
return tasks
510+
511+
def run(self):
512+
if isinstance(self.inputfiles, str):
513+
self.inputfiles = self.inputfiles.split(',')
514+
with self.output().open('w') as f:
515+
f.write("\n".join(self.inputfiles))
516+
517+
def output(self):
518+
return luigi.LocalTarget('.parallelbatch-' + self.component + '-' + str(hash(self)) + '.done')
519+
482520
class Parallel(sciluigi.WorkflowTask):
483521
"""Meta workflow"""
484522
inputfiles = luigi.Parameter()

luiginlp/modules/folia.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import natsort
44
import subprocess
55
import pickle
6-
from luiginlp.engine import Task, TargetInfo, InputFormat, StandardWorkflowComponent, registercomponent, InputSlot, Parameter, BoolParameter, IntParameter
6+
from luiginlp.engine import Task, TargetInfo, InputFormat, StandardWorkflowComponent, registercomponent, InputSlot, Parameter, BoolParameter, IntParameter, PassParameters, ParallelBatch
77
from luiginlp.util import getlog, recursive_glob, waitforslot, waitforcompletion, replaceextension, chunk
88
from luiginlp.modules.openconvert import OpenConvert_folia
99

@@ -167,28 +167,27 @@ def on_failure(self, exception):
167167

168168
def run(self):
169169
#gather input files
170-
batchsize = 1000
171170
if self.outputdir and not os.path.exists(self.outputdir): os.makedirs(self.outputdir)
172171

173172
if os.path.exists(self.out_state().path):
174-
log.info("Loading index...")
173+
log.info("Collecting input files from saved state...")
175174
with open(self.out_state().path,'rb') as f:
176175
inputfiles = pickle.load(f)
177176
else:
178177
log.info("Collecting input files...")
179178
inputfiles = recursive_glob(self.in_foliadir().path, '*.' + self.folia_extension)
180179
log.info("Collected " + str(len(inputfiles)) + " input files")
180+
with open(self.out_state().path,'wb') as f:
181+
pickle.dump(inputfiles,f)
181182

182-
with open(self.out_state().path,'wb') as f:
183-
pickle.dump(inputfiles[batchsize:],f)
184-
185-
log.info("Scheduling validators, " + str(len(inputfiles)) + " left...")
186-
for taskbatch in chunk(inputfiles,batchsize): #schedule in batches of 1000 so we don't overload the scheduler
187-
if self.outputdir:
188-
yield [ FoliaValidator(inputfile=inputfile,folia_extension=self.folia_extension,outputdir=os.path.dirname(inputfile).replace(self.in_foliadir().path,self.outputdir)) for inputfile in taskbatch ]
189-
else:
190-
yield [ FoliaValidator(inputfile=inputfile,folia_extension=self.folia_extension) for inputfile in taskbatch ]
183+
log.info("Scheduling validators")
184+
if self.outputdir:
185+
passparameters = PassParameters(folia_extension=self.folia_extension,replaceinputdir=self.in_foliadir().path, outputdir=self.outputdir)
186+
else:
187+
passparameters = PassParameters(folia_extension=self.folia_extension)
191188

189+
for inputfiles_batch in chunk(inputfiles,1000): #schedule in batches of 1000 so we don't overload the scheduler
190+
yield ParallelBatch(component='FoliaValidator',inputfiles=inputfiles_batch,passparameters=passparameters)
192191

193192
log.info("Collecting output files...")
194193
#Gather all output files

test/scaletest.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import glob
55
import shutil
66
import luiginlp
7-
from luiginlp.engine import Task, StandardWorkflowComponent, PassParameters, InputFormat, InputComponent, InputSlot, Parameter, IntParameter, registercomponent, Parallel
7+
import luigi
8+
import json
9+
from luiginlp.engine import Task, StandardWorkflowComponent, PassParameters, InputFormat, InputComponent, InputSlot, Parameter, IntParameter, registercomponent, ParallelBatch
810
from luiginlp.util import getlog, chunk
911

1012
log = getlog()
@@ -49,6 +51,8 @@ def accepts(self):
4951

5052

5153

54+
55+
5256
class ScaleTestTask(Task):
5357

5458
in_txtdir = InputSlot()
@@ -67,11 +71,11 @@ def run(self):
6771
log.info("Collected " + str(len(inputfiles)) + " input files")
6872

6973
#inception aka dynamic dependencies: we yield a list of tasks to perform which could not have been predicted statically
70-
#in this case we run the OCR_singlepage component for each input file in the directory
74+
for inputfiles_chunk in chunk(inputfiles, 1000):
75+
yield ParallelBatch(component='Voweleater',inputfiles=','.join(inputfiles_chunk),passparameters=PassParameters(outputdir=self.out_txtdir().path))
7176

72-
chunks = [ Parallel(component='Voweleater',inputfiles=','.join(inputfiles_chunk),passparameters=PassParameters(outputdir=self.out_txtdir().path)) for inputfiles_chunk in chunk(inputfiles, 1000) ]
73-
log.info("Scheduling chunks: " + str(len(chunks)))
74-
yield chunks
77+
#log.info("Scheduling chunks: " + str(len(chunks)))
78+
#yield chunks
7579

7680
#yield [ Voweleater(inputfile=inputfile,outputdir=self.out_txtdir().path) for inputfile in inputfiles ]
7781

0 commit comments

Comments
 (0)