Skip to content

Commit 5f2653e

Browse files
committed
Add spark versions of walker classes (ReadWalkerSpark, AssemblyRegionWalkerSpark, IntervalWalkerSpark, VariantWalkerSpark) and examples.
1 parent 37ca5bb commit 5f2653e

16 files changed

+960
-15
lines changed

src/main/java/org/broadinstitute/hellbender/engine/ReadsContext.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
public final class ReadsContext implements Iterable<GATKRead> {
2121

2222
private final ReadsDataSource dataSource;
23+
private final Iterable<GATKRead> iterable;
2324

2425
private final SimpleInterval interval;
2526

@@ -41,16 +42,23 @@ public ReadsContext() {
4142
*/
4243
public ReadsContext( final ReadsDataSource dataSource, final SimpleInterval interval ) {
4344
this.dataSource = dataSource;
45+
this.iterable = null;
4446
this.interval = interval;
4547
}
4648

49+
public ReadsContext( Shard<GATKRead> shard ) {
50+
this.dataSource = null;
51+
this.iterable = shard;
52+
this.interval = shard.getInterval();
53+
}
54+
4755
/**
4856
* Does this context have a backing source of reads data?
4957
*
5058
* @return true if there is a backing ReadsDataSource, otherwise false
5159
*/
5260
public boolean hasBackingDataSource() {
53-
return dataSource != null;
61+
return dataSource != null || iterable != null;
5462
}
5563

5664
/**
@@ -71,6 +79,9 @@ public SimpleInterval getInterval() {
7179
*/
7280
@Override
7381
public Iterator<GATKRead> iterator() {
82+
if (iterable != null && interval != null) {
83+
return iterable.iterator();
84+
}
7485
// We can't perform a query if we lack either a dataSource or an interval to query on
7586
if ( dataSource == null || interval == null ) {
7687
return Collections.<GATKRead>emptyList().iterator();
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package org.broadinstitute.hellbender.engine.filters;
22

3+
import java.io.Serializable;
4+
35
/**
46
* Collects common variant filters.
57
*/
68
public final class VariantFilterLibrary {
7-
public static VariantFilter ALLOW_ALL_VARIANTS = variant -> true;
9+
public static VariantFilter ALLOW_ALL_VARIANTS = (VariantFilter & Serializable) variant -> true;
810
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
package org.broadinstitute.hellbender.engine.spark;
2+
3+
import com.google.common.base.Function;
4+
import com.google.common.collect.Iterables;
5+
import htsjdk.samtools.SAMFileHeader;
6+
import htsjdk.samtools.SAMSequenceDictionary;
7+
import org.apache.spark.api.java.JavaRDD;
8+
import org.apache.spark.api.java.JavaSparkContext;
9+
import org.apache.spark.api.java.function.FlatMapFunction;
10+
import org.apache.spark.broadcast.Broadcast;
11+
import org.broadinstitute.hellbender.cmdline.Advanced;
12+
import org.broadinstitute.hellbender.cmdline.Argument;
13+
import org.broadinstitute.hellbender.engine.*;
14+
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
15+
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
16+
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
17+
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
18+
import org.broadinstitute.hellbender.utils.IntervalUtils;
19+
import org.broadinstitute.hellbender.utils.SimpleInterval;
20+
import org.broadinstitute.hellbender.utils.read.GATKRead;
21+
import scala.Tuple3;
22+
23+
import javax.annotation.Nullable;
24+
import java.util.ArrayList;
25+
import java.util.List;
26+
import java.util.stream.Collectors;
27+
28+
/**
29+
* A Spark version of {@link AssemblyRegionWalker}.
30+
*/
31+
public abstract class AssemblyRegionWalkerSpark extends GATKSparkTool {
32+
private static final long serialVersionUID = 1L;
33+
34+
@Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true)
35+
protected int readShardSize = defaultReadShardSize();
36+
37+
@Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true)
38+
protected int readShardPadding = defaultReadShardPadding();
39+
40+
@Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true)
41+
protected int minAssemblyRegionSize = defaultMinAssemblyRegionSize();
42+
43+
@Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true)
44+
protected int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize();
45+
46+
@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
47+
protected int assemblyRegionPadding = defaultAssemblyRegionPadding();
48+
49+
@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true)
50+
protected int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart();
51+
52+
@Advanced
53+
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
54+
protected double activeProbThreshold = defaultActiveProbThreshold();
55+
56+
@Advanced
57+
@Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true)
58+
protected int maxProbPropagationDistance = defaultMaxProbPropagationDistance();
59+
60+
/**
61+
* @return Default value for the {@link #readShardSize} parameter, if none is provided on the command line
62+
*/
63+
protected abstract int defaultReadShardSize();
64+
65+
/**
66+
* @return Default value for the {@link #readShardPadding} parameter, if none is provided on the command line
67+
*/
68+
protected abstract int defaultReadShardPadding();
69+
70+
/**
71+
* @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line
72+
*/
73+
protected abstract int defaultMinAssemblyRegionSize();
74+
75+
/**
76+
* @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line
77+
*/
78+
protected abstract int defaultMaxAssemblyRegionSize();
79+
80+
/**
81+
* @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line
82+
*/
83+
protected abstract int defaultAssemblyRegionPadding();
84+
85+
/**
86+
* @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line
87+
*/
88+
protected abstract int defaultMaxReadsPerAlignmentStart();
89+
90+
/**
91+
* @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line
92+
*/
93+
protected abstract double defaultActiveProbThreshold();
94+
95+
/**
96+
* @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line
97+
*/
98+
protected abstract int defaultMaxProbPropagationDistance();
99+
100+
@Argument(doc = "whether to use the shuffle implementation or not", shortName = "shuffle", fullName = "shuffle", optional = true)
101+
public boolean shuffle = false;
102+
103+
@Override
104+
public final boolean requiresReads() { return true; }
105+
106+
@Override
107+
public final boolean requiresReference() { return true; }
108+
109+
public List<ReadFilter> getDefaultReadFilters() {
110+
final List<ReadFilter> defaultFilters = new ArrayList<>(2);
111+
defaultFilters.add(new WellformedReadFilter());
112+
defaultFilters.add(new ReadFilterLibrary.MappedReadFilter());
113+
return defaultFilters;
114+
}
115+
116+
/**
117+
* @return The evaluator to be used to determine whether each locus is active or not. Must be implemented by tool authors.
118+
* The results of this per-locus evaluator are used to determine the bounds of each active and inactive region.
119+
*/
120+
public abstract AssemblyRegionEvaluator assemblyRegionEvaluator();
121+
122+
private List<ShardBoundary> intervalShards;
123+
124+
@Override
125+
protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals) {
126+
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
127+
List<SimpleInterval> intervals = rawIntervals == null ? IntervalUtils.getAllIntervalsForReference(sequenceDictionary) : rawIntervals;
128+
intervalShards = intervals.stream()
129+
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream())
130+
.collect(Collectors.toList());
131+
List<SimpleInterval> paddedIntervalsForReads =
132+
intervals.stream().map(interval -> interval.expandWithinContig(readShardPadding, sequenceDictionary)).collect(Collectors.toList());
133+
return paddedIntervalsForReads;
134+
}
135+
136+
/**
137+
* Loads assembly regions and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
138+
*
139+
* If no intervals were specified, returns all the assembly regions.
140+
*
141+
* @return all assembly regions as a {@link JavaRDD}, bounded by intervals if specified.
142+
*/
143+
public JavaRDD<Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>> getAssemblyRegions(JavaSparkContext ctx) {
144+
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
145+
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle);
146+
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
147+
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
148+
return shardedReads.flatMap(getAssemblyRegionsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(),
149+
assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance));
150+
}
151+
152+
private static FlatMapFunction<Shard<GATKRead>, Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>> getAssemblyRegionsFunction(
153+
final Broadcast<ReferenceMultiSource> bReferenceSource,
154+
final Broadcast<FeatureManager> bFeatureManager,
155+
final SAMSequenceDictionary sequenceDictionary,
156+
final SAMFileHeader header,
157+
final AssemblyRegionEvaluator evaluator,
158+
final int minAssemblyRegionSize,
159+
final int maxAssemblyRegionSize,
160+
final int assemblyRegionPadding,
161+
final double activeProbThreshold,
162+
final int maxProbPropagationDistance) {
163+
return (FlatMapFunction<Shard<GATKRead>, Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>>) shardedRead -> {
164+
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
165+
SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary);
166+
167+
ReferenceDataSource reference = bReferenceSource == null ? null :
168+
new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary);
169+
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
170+
ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval);
171+
FeatureContext featureContext = new FeatureContext(features, paddedInterval);
172+
173+
final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead,
174+
header, referenceContext, featureContext, evaluator,
175+
minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold,
176+
maxProbPropagationDistance);
177+
return Iterables.transform(assemblyRegions, new Function<AssemblyRegion, Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>>() {
178+
@Nullable
179+
@Override
180+
public Tuple3<AssemblyRegion, ReferenceContext, FeatureContext> apply(@Nullable AssemblyRegion assemblyRegion) {
181+
return new Tuple3<>(assemblyRegion,
182+
new ReferenceContext(reference, assemblyRegion.getExtendedSpan()),
183+
new FeatureContext(features, assemblyRegion.getExtendedSpan()));
184+
}
185+
});
186+
};
187+
}
188+
189+
}

src/main/java/org/broadinstitute/hellbender/engine/spark/GATKSparkTool.java

+19
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKCommandLinePluginDescriptor;
44
import org.broadinstitute.hellbender.cmdline.GATKPlugin.GATKReadFilterPluginDescriptor;
5+
import org.broadinstitute.hellbender.engine.FeatureDataSource;
6+
import org.broadinstitute.hellbender.engine.FeatureManager;
57
import org.broadinstitute.hellbender.utils.SerializableFunction;
68
import com.google.cloud.genomics.dataflow.utils.GCSOptions;
79
import htsjdk.samtools.SAMFileHeader;
@@ -92,6 +94,7 @@ public abstract class GATKSparkTool extends SparkCommandLineProgram {
9294
private ReferenceMultiSource referenceSource;
9395
private SAMSequenceDictionary referenceDictionary;
9496
private List<SimpleInterval> intervals;
97+
protected FeatureManager features;
9598

9699
/**
97100
* Return the list of GATKCommandLinePluginDescriptor objects to be used for this CLP.
@@ -354,6 +357,7 @@ protected void runPipeline( JavaSparkContext sparkContext ) {
354357
private void initializeToolInputs(final JavaSparkContext sparkContext) {
355358
initializeReference();
356359
initializeReads(sparkContext); // reference must be initialized before reads
360+
initializeFeatures();
357361
initializeIntervals();
358362
}
359363

@@ -393,6 +397,21 @@ private void initializeReference() {
393397
}
394398
}
395399

400+
/**
401+
* Initialize our source of Feature data (or set it to null if no Feature argument(s) were provided).
402+
*
403+
* Package-private so that engine classes can access it, but concrete tool child classes cannot.
404+
* May be overridden by traversals that require custom initialization of Feature data sources.
405+
*
406+
* By default, this method initializes the FeatureManager to use the lookahead cache of {@link FeatureDataSource#DEFAULT_QUERY_LOOKAHEAD_BASES} bases.
407+
*/
408+
void initializeFeatures() {
409+
features = new FeatureManager(this);
410+
if ( features.isEmpty() ) { // No available sources of Features discovered for this tool
411+
features = null;
412+
}
413+
}
414+
396415
/**
397416
* Loads our intervals using the best available sequence dictionary (as returned by {@link #getBestAvailableSequenceDictionary})
398417
* to parse/verify them. Does nothing if no intervals were specified.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
package org.broadinstitute.hellbender.engine.spark;
2+
3+
import htsjdk.samtools.SAMSequenceDictionary;
4+
import org.apache.spark.api.java.JavaRDD;
5+
import org.apache.spark.api.java.JavaSparkContext;
6+
import org.apache.spark.broadcast.Broadcast;
7+
import org.broadinstitute.hellbender.cmdline.Argument;
8+
import org.broadinstitute.hellbender.engine.*;
9+
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
10+
import org.broadinstitute.hellbender.utils.SimpleInterval;
11+
import org.broadinstitute.hellbender.utils.read.GATKRead;
12+
import scala.Tuple4;
13+
14+
import java.util.List;
15+
import java.util.stream.Collectors;
16+
17+
/**
18+
* A Spark version of {@link IntervalWalker}.
19+
*/
20+
public abstract class IntervalWalkerSpark extends GATKSparkTool {
21+
private static final long serialVersionUID = 1L;
22+
23+
@Override
24+
public boolean requiresIntervals() {
25+
return true;
26+
}
27+
28+
@Argument(doc = "whether to use the shuffle implementation or not", shortName = "shuffle", fullName = "shuffle", optional = true)
29+
public boolean shuffle = false;
30+
31+
@Argument(fullName="intervalShardPadding", shortName="intervalShardPadding", doc = "Each interval shard has this many bases of extra context on each side.", optional = true)
32+
public int intervalShardPadding = 1000;
33+
34+
/**
35+
* Customize initialization of the Feature data source for this traversal type to disable query lookahead.
36+
*/
37+
void initializeFeatures() {
38+
// Disable query lookahead in our FeatureManager for this traversal type. Query lookahead helps
39+
// when our query intervals are overlapping and gradually increasing in position (as they are
40+
// with ReadWalkers, typically), but with IntervalWalkers our query intervals are guaranteed
41+
// to be non-overlapping, since our interval parsing code always merges overlapping intervals.
42+
features = new FeatureManager(this, 0);
43+
if ( features.isEmpty() ) { // No available sources of Features for this tool
44+
features = null;
45+
}
46+
}
47+
48+
/**
49+
* Loads intervals and the corresponding reads, reference and features into a {@link JavaRDD}.
50+
*
51+
* @return all intervals as a {@link JavaRDD}.
52+
*/
53+
public JavaRDD<Tuple4<SimpleInterval, ReadsContext, ReferenceContext, FeatureContext>> getIntervals(JavaSparkContext ctx) {
54+
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
55+
// don't shard the intervals themselves, since we want each interval to be processed by a single task
56+
final List<ShardBoundary> intervalShardBoundaries = getIntervals().stream()
57+
.map(i -> new ShardBoundary(i, i)).collect(Collectors.toList());
58+
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShardBoundaries, Integer.MAX_VALUE, shuffle);
59+
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
60+
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
61+
return shardedReads.map(getIntervalsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, intervalShardPadding));
62+
}
63+
64+
private static org.apache.spark.api.java.function.Function<Shard<GATKRead>, Tuple4<SimpleInterval, ReadsContext, ReferenceContext, FeatureContext>> getIntervalsFunction(
65+
Broadcast<ReferenceMultiSource> bReferenceSource, Broadcast<FeatureManager> bFeatureManager,
66+
SAMSequenceDictionary sequenceDictionary, int intervalShardPadding) {
67+
return (org.apache.spark.api.java.function.Function<Shard<GATKRead>, Tuple4<SimpleInterval, ReadsContext, ReferenceContext, FeatureContext>>) shard -> {
68+
// get reference bases for this shard (padded)
69+
SimpleInterval interval = shard.getInterval();
70+
SimpleInterval paddedInterval = shard.getInterval().expandWithinContig(intervalShardPadding, sequenceDictionary);
71+
ReadsContext readsContext = new ReadsContext(shard);
72+
ReferenceDataSource reference = bReferenceSource == null ? null :
73+
new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, paddedInterval), sequenceDictionary);
74+
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
75+
return new Tuple4<>(interval, readsContext, new ReferenceContext(reference, interval), new FeatureContext(features, interval));
76+
};
77+
}
78+
}

0 commit comments

Comments
 (0)