Skip to content

Commit ed8307f

Browse files
committed
Refactor multiple input handling.
Fixes common-workflow-lab#54. Implements count-lines7-wf.cwl conformance test.
1 parent 9933c3c commit ed8307f

File tree

10 files changed

+115
-82
lines changed

10 files changed

+115
-82
lines changed

lib/galaxy/dataset_collections/matching.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,10 @@ class CollectionsToMatch( object ):
1616

1717
def __init__( self ):
1818
self.collections = {}
19+
self.uses_ephemeral_collections = False
1920

2021
def add( self, input_name, hdca, subcollection_type=None, linked=True ):
22+
self.uses_ephemeral_collections = self.uses_ephemeral_collections or not hasattr( hdca, "hid" )
2123
self.collections[ input_name ] = bunch.Bunch(
2224
hdca=hdca,
2325
subcollection_type=subcollection_type,
@@ -45,6 +47,7 @@ def __init__( self ):
4547
self.linked_structure = None
4648
self.unlinked_structures = []
4749
self.collections = {}
50+
self.uses_ephemeral_collections = False
4851

4952
def __attempt_add_to_linked_match( self, input_name, hdca, collection_type_description, subcollection_type ):
5053
structure = get_structure( hdca, collection_type_description, leaf_subcollection_type=subcollection_type )
@@ -71,14 +74,19 @@ def structure( self ):
7174

7275
@property
7376
def implicit_inputs( self ):
74-
return list( self.collections.items() )
77+
if not self.uses_ephemeral_collections:
78+
# Consider doing something smarter here.
79+
return list( self.collections.items() )
80+
else:
81+
return []
7582

7683
@staticmethod
7784
def for_collections( collections_to_match, collection_type_descriptions ):
7885
if not collections_to_match.has_collections():
7986
return None
8087

8188
matching_collections = MatchingCollections()
89+
matching_collections.uses_ephemeral_collections = collections_to_match.uses_ephemeral_collections
8290
for input_key, to_match in collections_to_match.items():
8391
hdca = to_match.hdca
8492
collection_type_description = collection_type_descriptions.for_collection_type( hdca.collection.collection_type )

lib/galaxy/tools/actions/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,9 @@ def _record_inputs( self, trans, tool, job, incoming, inp_data, inp_dataset_coll
560560
reductions[name] = []
561561
reductions[name].append(dataset_collection)
562562

563+
if getattr( dataset_collection, "ephemeral", False ):
564+
dataset_collection = dataset_collection.persistent_object
565+
563566
# TODO: verify can have multiple with same name, don't want to loose tracability
564567
job.add_input_dataset_collection( name, dataset_collection )
565568

lib/galaxy/tools/cwl/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -963,7 +963,7 @@ def to_dict(self, itemwise=True):
963963
if self.input_type == INPUT_TYPE.FLOAT:
964964
as_dict["value"] = "0.0"
965965
elif self.input_type == INPUT_TYPE.DATA_COLLECTON:
966-
as_dict["collection_type"] = "record"
966+
as_dict["collection_type"] = self.collection_type
967967

968968
return as_dict
969969

lib/galaxy/tools/cwl/representation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ def dataset_wrapper_to_file_json(dataset_wrapper):
130130
elif type_representation.name == "json":
131131
raw_value = param_dict_value.value
132132
return json.loads(raw_value)
133+
elif type_representation.name == "array":
134+
# TODO: generalize to lists of lists and lists of non-files...
135+
rval = []
136+
for value in param_dict_value:
137+
rval.append(dataset_wrapper_to_file_json(value))
138+
return rval
133139
elif type_representation.name == "record":
134140
rval = dict() # TODO: THIS NEEDS TO BE ORDERED BUT odict not json serializable!
135141
for key, value in param_dict_value.items():

lib/galaxy/tools/execute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def create_output_collections( self, trans, history, params ):
138138

139139
structure = self.collection_info.structure
140140

141-
if hasattr( self.collection_info, "collections" ):
141+
if not self.collection_info.uses_ephemeral_collections:
142142
# params is just one sample tool param execution with parallelized
143143
# collection replaced with a specific dataset. Need to replace this
144144
# with the collection and wrap everything up so can evaluate output
@@ -148,7 +148,7 @@ def create_output_collections( self, trans, history, params ):
148148
collection_names = ["collection %d" % c.hid for c in self.collection_info.collections.values()]
149149
on_text = on_text_for_names( collection_names )
150150
else:
151-
on_text = "implicitly create collection for inputs"
151+
on_text = "implicitly created collection from inputs"
152152

153153
collections = {}
154154

lib/galaxy/tools/parameters/basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,9 @@ def get_initial_value( self, trans, other_values ):
14721472
return hdca
14731473

14741474
def to_json( self, value, app, use_security ):
1475+
if getattr( value, "ephemeral", False ):
1476+
value = value.persistent_object
1477+
14751478
def single_to_json( value ):
14761479
src = None
14771480
if isinstance( value, dict ) and 'src' in value and 'id' in value:

lib/galaxy/tools/wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def __init__( self, job_working_directory, has_collection, dataset_paths=[], **k
384384
else:
385385
self.__input_supplied = True
386386

387-
if hasattr( has_collection, "name" ):
387+
if hasattr( has_collection, "history_content_type" ):
388388
# It is a HistoryDatasetCollectionAssociation
389389
collection = has_collection.collection
390390
self.name = has_collection.name

lib/galaxy/workflow/modules.py

Lines changed: 25 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
web
1515
)
1616
from galaxy.dataset_collections import matching
17-
from galaxy.dataset_collections.structure import leaf, Tree
1817
from galaxy.exceptions import ToolMissingException
1918
from galaxy.jobs.actions.post import ActionBox
2019
from galaxy.model import PostJobAction
@@ -861,25 +860,6 @@ def decode_runtime_state( self, runtime_state ):
861860
else:
862861
raise ToolMissingException( "Tool %s missing. Cannot recover runtime state." % self.tool_id )
863862

864-
def _check_for_scatters( self, step, tool, progress, tool_state ):
865-
scatter_collector = ScatterOverCollector(
866-
self.app
867-
)
868-
869-
def callback( input, prefixed_name, **kwargs ):
870-
replacement = progress.replacement_for_tool_input( step, input, prefixed_name )
871-
log.info("replacement for %s is %s" % (prefixed_name, replacement))
872-
if replacement:
873-
if isinstance(replacement, ScatterOver):
874-
scatter_collector.add_scatter(replacement)
875-
876-
return NO_REPLACEMENT
877-
878-
visit_input_values( tool.inputs, tool_state, callback, no_replacement_value=NO_REPLACEMENT )
879-
880-
# TODO: num slices is bad - what about empty arrays.
881-
return None if scatter_collector.num_slices == 0 else scatter_collector
882-
883863
def execute( self, trans, progress, invocation, step ):
884864
tool = trans.app.toolbox.get_tool( step.tool_id, tool_version=step.tool_version, tool_hash=step.tool_hash )
885865
tool_state = step.state
@@ -890,10 +870,10 @@ def execute( self, trans, progress, invocation, step ):
890870
collections_to_match = self._find_collections_to_match( tool, progress, step )
891871
# Have implicit collections...
892872
if collections_to_match.has_collections():
893-
# Is a MatchingCollections
894873
collection_info = self.trans.app.dataset_collections_service.match_collections( collections_to_match )
895874
else:
896-
collection_info = self._check_for_scatters( step, tool, progress, make_dict_copy( tool_state.inputs ) )
875+
collection_info = None
876+
897877
param_combinations = []
898878
if collection_info:
899879
iteration_elements_iter = collection_info.slice_collections()
@@ -1005,16 +985,17 @@ def callback( input, prefixed_name, **kwargs ):
1005985
is_data_param = isinstance( input, DataToolParameter )
1006986
if is_data_param and not input.multiple:
1007987
data = progress.replacement_for_tool_input( step, input, prefixed_name )
1008-
if isinstance( data, model.HistoryDatasetCollectionAssociation ):
988+
if hasattr( data, "collection" ):
1009989
collections_to_match.add( prefixed_name, data )
1010990

1011991
is_data_collection_param = isinstance( input, DataCollectionToolParameter )
1012992
if is_data_collection_param and not input.multiple:
1013993
data = progress.replacement_for_tool_input( step, input, prefixed_name )
1014994
history_query = input._history_query( self.trans )
1015-
subcollection_type_description = history_query.can_map_over( data )
1016-
if subcollection_type_description:
1017-
collections_to_match.add( prefixed_name, data, subcollection_type=subcollection_type_description.collection_type )
995+
if hasattr( data, "collection" ):
996+
subcollection_type_description = history_query.can_map_over( data )
997+
if subcollection_type_description:
998+
collections_to_match.add( prefixed_name, data, subcollection_type=subcollection_type_description.collection_type )
1018999

10191000
visit_input_values( tool.inputs, step.state.inputs, callback )
10201001
return collections_to_match
@@ -1142,57 +1123,28 @@ def load_module_sections( trans ):
11421123
return module_sections
11431124

11441125

1145-
class ScatterOverCollector(object):
1126+
class EphemeralCollection(object):
1127+
"""Interface for collecting datasets together in workflows and treating as collections.
11461128
1147-
def __init__(self, app):
1148-
self.inputs_per_name = {}
1149-
self.num_slices = 0
1150-
self.app = app
1151-
1152-
def add_scatter(self, scatter_over):
1153-
inputs = scatter_over.inputs
1154-
self.inputs_per_name[scatter_over.prefixed_name] = inputs
1155-
if self.num_slices > 0:
1156-
assert len(inputs) == self.num_slices
1157-
else:
1158-
self.num_slices = len(inputs)
1159-
1160-
def slice_collections(self):
1161-
slices = []
1162-
for i in range(self.num_slices):
1163-
this_slice = {}
1164-
for prefixed_name, inputs in self.inputs_per_name.items():
1165-
this_slice[prefixed_name] = SliceElement(inputs[i], str(i))
1166-
slices.append(this_slice)
1167-
return slices
1168-
1169-
@property
1170-
def structure(self):
1171-
collection_type_descriptions = self.app.dataset_collections_service.collection_type_descriptions
1172-
collection_type_description = collection_type_descriptions.for_collection_type("list")
1173-
children = []
1174-
for input in self.inputs_per_name.values()[0]:
1175-
children.append((input.element_identifier, leaf))
1176-
1177-
return Tree(children, collection_type_description)
1178-
1179-
@property
1180-
def implicit_inputs(self):
1181-
return []
1182-
1183-
1184-
class SliceElement(object):
1185-
1186-
def __init__(self, dataset_instance, element_identifier):
1187-
self.dataset_instance = dataset_instance
1188-
self.element_identifier = element_identifier
1129+
These aren't real collections in the database - just datasets groupped together
1130+
in someway by workflows for passing data around as collections.
1131+
"""
11891132

1133+
# Used to distinguish between datasets and collections frequently.
1134+
ephemeral = True
1135+
history_content_type = "dataset_collection"
1136+
name = "Dynamically generated collection"
11901137

1191-
class ScatterOver(object):
1138+
def __init__(self, collection, history):
1139+
self.collection = collection
1140+
self.history = history
11921141

1193-
def __init__(self, prefixed_name, inputs):
1194-
self.prefixed_name = prefixed_name
1195-
self.inputs = inputs
1142+
hdca = model.HistoryDatasetCollectionAssociation(
1143+
collection=collection,
1144+
history=history,
1145+
)
1146+
hdca.history.add_dataset_collection( hdca )
1147+
self.persistent_object = hdca
11961148

11971149

11981150
class DelayedWorkflowEvaluation(Exception):

lib/galaxy/workflow/run.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,67 @@ def replacement_for_tool_input( self, step, input, prefixed_name ):
301301
# We've mapped multiple individual inputs to a single parameter,
302302
# promote output to a collection.
303303
inputs = []
304-
for c in connection:
304+
input_history_content_type = None
305+
input_collection_type = None
306+
for i, c in enumerate(connection):
305307
input_from_connection = self.replacement_for_connection( c, is_data=is_data )
308+
input_history_content_type = input_from_connection.history_content_type
309+
if i == 0:
310+
if input_history_content_type == "dataset_collection":
311+
input_collection_type = input_from_connection.collection.collection_type
312+
else:
313+
input_collection_type = None
314+
else:
315+
if input_collection_type is None:
316+
if input_history_content_type != "dataset":
317+
raise Exception("Cannot map over a combination of datasets and collections.")
318+
else:
319+
if input_history_content_type != "dataset_collection":
320+
raise Exception("Cannot merge over combinations of datasets and collections.")
321+
elif input_from_connection.collection.collection_type != input_collection_type:
322+
raise Exception("Cannot merge collections of different collection types.")
323+
306324
inputs.append(input_from_connection)
307325

308-
replacement = modules.ScatterOver(
309-
prefixed_name,
310-
inputs,
326+
327+
if input.type == "data_collection":
328+
# TODO: Implement more nested types here...
329+
assert input.collection_types == ["list"], input.collection_types
330+
331+
collection = model.DatasetCollection()
332+
# If individual datasets provided (type is None) - premote to a list.
333+
collection.collection_type = input_collection_type or "list"
334+
elements = []
335+
336+
next_index = 0
337+
for input in inputs:
338+
if input_collection_type is None:
339+
element = model.DatasetCollectionElement(
340+
element=input,
341+
element_index=next_index,
342+
element_identifier=str(next_index),
343+
)
344+
elements.append(element)
345+
next_index += 1
346+
elif input_collection_type == "list":
347+
for dataset_instance in input.dataset_instances:
348+
element = model.DatasetCollectionElement(
349+
element=dataset_instance,
350+
element_index=next_index,
351+
element_identifier=str(next_index),
352+
)
353+
elements.append(element)
354+
next_index += 1
355+
else:
356+
raise NotImplementedError()
357+
358+
collection.elements = elements
359+
360+
ephemeral_collection = modules.EphemeralCollection(
361+
collection=collection,
362+
history=self.workflow_invocation.history,
311363
)
364+
return ephemeral_collection
312365

313366
return replacement
314367

test/unit/tools/test_cwl.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@ def test_workflow_scatter_multiple_input():
184184
assert len(galaxy_workflow_dict["steps"]) == 3
185185

186186

187+
def test_workflow_multiple_input_merge_flattened():
188+
version = "v1.0"
189+
proxy = workflow_proxy(_cwl_tool_path("%s/count-lines7-wf.cwl" % version))
190+
191+
galaxy_workflow_dict = proxy.to_dict()
192+
assert len(galaxy_workflow_dict["steps"]) == 3
193+
194+
187195
def test_load_proxy_simple():
188196
cat3 = _cwl_tool_path("draft3/cat3-tool.cwl")
189197
tool_source = get_tool_source(cat3)

0 commit comments

Comments
 (0)