Skip to content

Commit ae11f56

Browse files
committed
Implement MultipleInputRequirements workflow Requirement.
Provide an alternative implementation of "collection info" object to tool execution if non-collection map-over needs to happen. This is a bit sloppy still - I need to: - Rename "collection_info" everywhere - maybe map_over_info. - Build an interface that tool execution environment can consume. - Implement a blended approach that allows mapping over collections and individual inputs. - Rename "implicit_inputs" property on these to "implicit_input_collections".
1 parent d1ad64e commit ae11f56

File tree

8 files changed

+213
-91
lines changed

8 files changed

+213
-91
lines changed

lib/galaxy/dataset_collections/matching.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def structure( self ):
6969
effective_structure = effective_structure.multiply( linked_structure )
7070
return None if effective_structure.is_leaf else effective_structure
7171

72+
@property
73+
def implicit_inputs( self ):
74+
return list( self.collection_info.collections.items() )
75+
7276
@staticmethod
7377
def for_collections( collections_to_match, collection_type_descriptions ):
7478
if not collections_to_match.has_collections():

lib/galaxy/managers/collections.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def create( self, trans, parent, name, collection_type, element_identifiers=None
7171
name=name,
7272
)
7373
if implicit_collection_info:
74-
for input_name, input_collection in implicit_collection_info[ "implicit_inputs" ]:
75-
dataset_collection_instance.add_implicit_input_collection( input_name, input_collection )
74+
implicit_inputs = implicit_collection_info[ "implicit_inputs" ]
75+
if implicit_inputs:
76+
for input_name, input_collection in implicit_inputs:
77+
dataset_collection_instance.add_implicit_input_collection( input_name, input_collection )
7678
for output_dataset in implicit_collection_info.get( "outputs" ):
7779
if output_dataset not in trans.sa_session:
7880
output_dataset = trans.sa_session.query( type( output_dataset ) ).get( output_dataset.id )

lib/galaxy/tools/cwl/parser.py

Lines changed: 44 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import six
1616

1717
from galaxy.tools.hash import build_tool_hash
18-
from galaxy.util import safe_makedirs
18+
from galaxy.util import listify, safe_makedirs
1919
from galaxy.util.bunch import Bunch
2020
from galaxy.util.odict import odict
2121

@@ -41,6 +41,7 @@
4141
"InlineJavascriptRequirement",
4242
"ShellCommandRequirement",
4343
"ScatterFeatureRequirement",
44+
"MultipleInputFeatureRequirement",
4445
]
4546

4647

@@ -484,19 +485,24 @@ def input_connections_by_step(self, step_proxies):
484485
for cwl_input in cwl_inputs:
485486
cwl_input_id = cwl_input["id"]
486487
cwl_source_id = cwl_input["source"]
487-
step_name, input_name = split_step_reference(cwl_input_id)
488-
output_step_name, output_name = split_step_reference(cwl_source_id)
489-
output_step_id = self.cwl_id + "#" + output_step_name
490-
if output_step_id not in cwl_ids_to_index:
491-
template = "Output [%s] does not appear in ID-to-index map [%s]."
492-
msg = template % (output_step_id, cwl_ids_to_index)
493-
raise AssertionError(msg)
494-
495-
input_connections_step[input_name] = {
496-
"id": cwl_ids_to_index[output_step_id],
497-
"output_name": output_name,
498-
"input_type": "dataset"
499-
}
488+
step_name, input_name = split_step_references(cwl_input_id, multiple=False)
489+
# Consider only allow multiple if MultipleInputFeatureRequirement is enabled
490+
for (output_step_name, output_name) in split_step_references(cwl_source_id):
491+
output_step_id = self.cwl_id + "#" + output_step_name
492+
if output_step_id not in cwl_ids_to_index:
493+
template = "Output [%s] does not appear in ID-to-index map [%s]."
494+
msg = template % (output_step_id, cwl_ids_to_index)
495+
raise AssertionError(msg)
496+
497+
if input_name not in input_connections_step:
498+
input_connections_step[input_name] = []
499+
500+
input_connections_step[input_name].append({
501+
"id": cwl_ids_to_index[output_step_id],
502+
"output_name": output_name,
503+
"input_type": "dataset"
504+
})
505+
500506
input_connections_by_step.append(input_connections_step)
501507

502508
return input_connections_by_step
@@ -551,24 +557,34 @@ def cwl_object_to_annotation(self, cwl_obj):
551557
return cwl_obj.get("doc", None)
552558

553559

554-
def split_step_reference(step_reference):
560+
def split_step_references(step_references, multiple=True):
555561
"""Split a CWL step input or output reference into step id and name."""
556562
# Trim off the workflow id part of the reference.
557-
assert "#" in step_reference
558-
cwl_workflow_id, step_reference = step_reference.split("#", 1)
563+
step_references = listify(step_references)
564+
split_references = []
565+
566+
for step_reference in step_references:
567+
assert "#" in step_reference
568+
cwl_workflow_id, step_reference = step_reference.split("#", 1)
559569

560-
# Now just grab the step name and input/output name.
561-
assert "#" not in step_reference
562-
if "/" in step_reference:
563-
step_name, io_name = step_reference.split("/", 1)
570+
# Now just grab the step name and input/output name.
571+
assert "#" not in step_reference
572+
if "/" in step_reference:
573+
step_name, io_name = step_reference.split("/", 1)
574+
else:
575+
# Referencing an input, not a step.
576+
# In Galaxy workflows input steps have an implicit output named
577+
# "output" for consistency with tools - in cwl land
578+
# just the input name is referenced.
579+
step_name = step_reference
580+
io_name = "output"
581+
split_references.append((step_name, io_name))
582+
583+
if multiple:
584+
return split_references
564585
else:
565-
# Referencing an input, not a step.
566-
# In Galaxy workflows input steps have an implicit output named
567-
# "output" for consistency with tools - in cwl land
568-
# just the input name is referenced.
569-
step_name = step_reference
570-
io_name = "output"
571-
return (step_name, io_name)
586+
assert len(split_references) == 1
587+
return split_references[0]
572588

573589

574590
class StepProxy(object):

lib/galaxy/tools/execute.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,18 +138,21 @@ def create_output_collections( self, trans, history, params ):
138138

139139
structure = self.collection_info.structure
140140

141-
# params is just one sample tool param execution with parallelized
142-
# collection replaced with a specific dataset. Need to replace this
143-
# with the collection and wrap everything up so can evaluate output
144-
# label.
145-
params.update( self.collection_info.collections ) # Replace datasets with source collections for labelling outputs.
146-
147-
collection_names = ["collection %d" % c.hid for c in self.collection_info.collections.values()]
148-
on_text = on_text_for_names( collection_names )
141+
if hasattr( self.collection_info, "collections" ):
142+
# params is just one sample tool param execution with parallelized
143+
# collection replaced with a specific dataset. Need to replace this
144+
# with the collection and wrap everything up so can evaluate output
145+
# label.
146+
params.update( self.collection_info.collections ) # Replace datasets with source collections for labelling outputs.
147+
148+
collection_names = ["collection %d" % c.hid for c in self.collection_info.collections.values()]
149+
on_text = on_text_for_names( collection_names )
150+
else:
151+
on_text = "implicitly create collection for inputs"
149152

150153
collections = {}
151154

152-
implicit_inputs = list(self.collection_info.collections.items())
155+
implicit_inputs = self.collection_info.implicit_inputs
153156
for output_name, outputs in self.outputs_by_output_name.items():
154157
if not len( structure ) == len( outputs ):
155158
# Output does not have the same structure, if all jobs were

lib/galaxy/workflow/modules.py

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
web
1515
)
1616
from galaxy.dataset_collections import matching
17+
from galaxy.dataset_collections.structure import leaf, Tree
1718
from galaxy.exceptions import ToolMissingException
1819
from galaxy.jobs.actions.post import ActionBox
1920
from galaxy.model import PostJobAction
@@ -58,6 +59,7 @@ class WorkflowModule( object ):
5859

5960
def __init__( self, trans, content_id=None, **kwds ):
6061
self.trans = trans
62+
self.app = trans.app
6163
self.content_id = content_id
6264
self.state = DefaultToolState()
6365

@@ -859,6 +861,25 @@ def decode_runtime_state( self, runtime_state ):
859861
else:
860862
raise ToolMissingException( "Tool %s missing. Cannot recover runtime state." % self.tool_id )
861863

864+
def _check_for_scatters( self, step, tool, progress, tool_state ):
865+
scatter_collector = ScatterOverCollector(
866+
self.app
867+
)
868+
869+
def callback( input, prefixed_name, **kwargs ):
870+
replacement = progress.replacement_for_tool_input( step, input, prefixed_name )
871+
log.info("replacement for %s is %s" % (prefixed_name, replacement))
872+
if replacement:
873+
if isinstance(replacement, ScatterOver):
874+
scatter_collector.add_scatter(replacement)
875+
876+
return NO_REPLACEMENT
877+
878+
visit_input_values( tool.inputs, tool_state, callback, no_replacement_value=NO_REPLACEMENT )
879+
880+
# TODO: num slices is bad - what about empty arrays.
881+
return None if scatter_collector.num_slices == 0 else scatter_collector
882+
862883
def execute( self, trans, progress, invocation, step ):
863884
tool = trans.app.toolbox.get_tool( step.tool_id, tool_version=step.tool_version, tool_hash=step.tool_hash )
864885
tool_state = step.state
@@ -869,10 +890,10 @@ def execute( self, trans, progress, invocation, step ):
869890
collections_to_match = self._find_collections_to_match( tool, progress, step )
870891
# Have implicit collections...
871892
if collections_to_match.has_collections():
893+
# Is a MatchingCollections
872894
collection_info = self.trans.app.dataset_collections_service.match_collections( collections_to_match )
873895
else:
874-
collection_info = None
875-
896+
collection_info = self._check_for_scatters( step, tool, progress, make_dict_copy( tool_state.inputs ) )
876897
param_combinations = []
877898
if collection_info:
878899
iteration_elements_iter = collection_info.slice_collections()
@@ -1121,6 +1142,59 @@ def load_module_sections( trans ):
11211142
return module_sections
11221143

11231144

1145+
class ScatterOverCollector(object):
1146+
1147+
def __init__(self, app):
1148+
self.inputs_per_name = {}
1149+
self.num_slices = 0
1150+
self.app = app
1151+
1152+
def add_scatter(self, scatter_over):
1153+
inputs = scatter_over.inputs
1154+
self.inputs_per_name[scatter_over.prefixed_name] = inputs
1155+
if self.num_slices > 0:
1156+
assert len(inputs) == self.num_slices
1157+
else:
1158+
self.num_slices = len(inputs)
1159+
1160+
def slice_collections(self):
1161+
slices = []
1162+
for i in range(self.num_slices):
1163+
this_slice = {}
1164+
for prefixed_name, inputs in self.inputs_per_name.items():
1165+
this_slice[prefixed_name] = SliceElement(inputs[i], str(i))
1166+
slices.append(this_slice)
1167+
return slices
1168+
1169+
@property
1170+
def structure(self):
1171+
collection_type_descriptions = self.app.dataset_collections_service.collection_type_descriptions
1172+
collection_type_description = collection_type_descriptions.for_collection_type("list")
1173+
children = []
1174+
for input in self.inputs_per_name.values()[0]:
1175+
children.append((input.element_identifier, leaf))
1176+
1177+
return Tree(children, collection_type_description)
1178+
1179+
@property
1180+
def implicit_inputs(self):
1181+
return None
1182+
1183+
1184+
class SliceElement(object):
1185+
1186+
def __init__(self, dataset_instance, element_identifier):
1187+
self.dataset_instance = dataset_instance
1188+
self.element_identifier = element_identifier
1189+
1190+
1191+
class ScatterOver(object):
1192+
1193+
def __init__(self, prefixed_name, inputs):
1194+
self.prefixed_name = prefixed_name
1195+
self.inputs = inputs
1196+
1197+
11241198
class DelayedWorkflowEvaluation(Exception):
11251199

11261200
def __init__(self, why=None):

lib/galaxy/workflow/run.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,21 @@ def replacement_for_tool_input( self, step, input, prefixed_name ):
295295
replacement = replacement[ 0 ]
296296
else:
297297
is_data = input.type in ["data", "data_collection"]
298-
replacement = self.replacement_for_connection( connection[ 0 ], is_data=is_data )
298+
if len( connection ) == 1:
299+
replacement = self.replacement_for_connection( connection[ 0 ], is_data=is_data )
300+
else:
301+
# We've mapped multiple individual inputs to a single parameter,
302+
# promote output to a collection.
303+
inputs = []
304+
for c in connection:
305+
input_from_connection = self.replacement_for_connection( c, is_data=is_data )
306+
inputs.append(input_from_connection)
307+
308+
replacement = modules.ScatterOver(
309+
prefixed_name,
310+
inputs,
311+
)
312+
299313
return replacement
300314

301315
def replacement_for_connection( self, connection, is_data=True ):

0 commit comments

Comments
 (0)