-
Notifications
You must be signed in to change notification settings - Fork 602
Add spark versions of walker classes #2256
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
package org.broadinstitute.hellbender.engine.spark; | ||
|
||
import org.broadinstitute.hellbender.engine.AssemblyRegion; | ||
import org.broadinstitute.hellbender.engine.FeatureContext; | ||
import org.broadinstitute.hellbender.engine.ReferenceContext; | ||
|
||
/** | ||
* Encapsulates an {@link AssemblyRegion} with its {@link ReferenceContext} and {@link FeatureContext}. | ||
*/ | ||
public final class AssemblyRegionWalkerContext { | ||
private final AssemblyRegion assemblyRegion; | ||
private final ReferenceContext referenceContext; | ||
private final FeatureContext featureContext; | ||
|
||
public AssemblyRegionWalkerContext(AssemblyRegion assemblyRegion, ReferenceContext referenceContext, FeatureContext featureContext) { | ||
this.assemblyRegion = assemblyRegion; | ||
this.referenceContext = referenceContext; | ||
this.featureContext = featureContext; | ||
} | ||
|
||
public AssemblyRegion getAssemblyRegion() { | ||
return assemblyRegion; | ||
} | ||
|
||
public ReferenceContext getReferenceContext() { | ||
return referenceContext; | ||
} | ||
|
||
public FeatureContext getFeatureContext() { | ||
return featureContext; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
package org.broadinstitute.hellbender.engine.spark; | ||
|
||
import htsjdk.samtools.SAMFileHeader; | ||
import htsjdk.samtools.SAMSequenceDictionary; | ||
import org.apache.spark.api.java.JavaRDD; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.api.java.function.FlatMapFunction; | ||
import org.apache.spark.broadcast.Broadcast; | ||
import org.broadinstitute.barclay.argparser.Advanced; | ||
import org.broadinstitute.barclay.argparser.Argument; | ||
import org.broadinstitute.hellbender.engine.*; | ||
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource; | ||
import org.broadinstitute.hellbender.engine.filters.ReadFilter; | ||
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary; | ||
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter; | ||
import org.broadinstitute.hellbender.utils.IntervalUtils; | ||
import org.broadinstitute.hellbender.utils.SimpleInterval; | ||
import org.broadinstitute.hellbender.utils.read.GATKRead; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.StreamSupport; | ||
|
||
/** | ||
* A Spark version of {@link AssemblyRegionWalker}. Subclasses should implement {@link #processAssemblyRegions(JavaRDD, JavaSparkContext)} | ||
* and operate on the passed in RDD. | ||
*/ | ||
public abstract class AssemblyRegionWalkerSpark extends GATKSparkTool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you might want to add an abstract method that makes it obvious how to extend this class. i.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
private static final long serialVersionUID = 1L; | ||
|
||
@Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true) | ||
protected int readShardSize = defaultReadShardSize(); | ||
|
||
@Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true) | ||
protected int readShardPadding = defaultReadShardPadding(); | ||
|
||
@Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true) | ||
protected int minAssemblyRegionSize = defaultMinAssemblyRegionSize(); | ||
|
||
@Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true) | ||
protected int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize(); | ||
|
||
@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true) | ||
protected int assemblyRegionPadding = defaultAssemblyRegionPadding(); | ||
|
||
@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true) | ||
protected int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart(); | ||
|
||
@Advanced | ||
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true) | ||
protected double activeProbThreshold = defaultActiveProbThreshold(); | ||
|
||
@Advanced | ||
@Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's a lot of validation of these in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed. Leaving for another PR. |
||
protected int maxProbPropagationDistance = defaultMaxProbPropagationDistance(); | ||
|
||
/** | ||
* @return Default value for the {@link #readShardSize} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract int defaultReadShardSize(); | ||
|
||
/** | ||
* @return Default value for the {@link #readShardPadding} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract int defaultReadShardPadding(); | ||
|
||
/** | ||
* @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract int defaultMinAssemblyRegionSize(); | ||
|
||
/** | ||
* @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract int defaultMaxAssemblyRegionSize(); | ||
|
||
/** | ||
* @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract int defaultAssemblyRegionPadding(); | ||
|
||
/** | ||
* @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract int defaultMaxReadsPerAlignmentStart(); | ||
|
||
/** | ||
* @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract double defaultActiveProbThreshold(); | ||
|
||
/** | ||
* @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line | ||
*/ | ||
protected abstract int defaultMaxProbPropagationDistance(); | ||
|
||
@Argument(doc = "whether to use the shuffle implementation or not", shortName = "shuffle", fullName = "shuffle", optional = true) | ||
public boolean shuffle = false; | ||
|
||
@Override | ||
public final boolean requiresReads() { return true; } | ||
|
||
@Override | ||
public final boolean requiresReference() { return true; } | ||
|
||
@Override | ||
public List<ReadFilter> getDefaultReadFilters() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
final List<ReadFilter> defaultFilters = new ArrayList<>(2); | ||
defaultFilters.add(new WellformedReadFilter()); | ||
defaultFilters.add(new ReadFilterLibrary.MappedReadFilter()); | ||
return defaultFilters; | ||
} | ||
|
||
/** | ||
* @return The evaluator to be used to determine whether each locus is active or not. Must be implemented by tool authors. | ||
* The results of this per-locus evaluator are used to determine the bounds of each active and inactive region. | ||
*/ | ||
public abstract AssemblyRegionEvaluator assemblyRegionEvaluator(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might want to mention that it will be called once per shard, which may be expensive if this is an expensive operation. I found that practically, I had to either reuse the assembly region evaluator because initializing a HaplotypeCallerEngine is expensive. I ended up making the downstream call that used the assemblyRegionEvaluator be a mapPartitions in order to reduce the number of instantiations of the engine. Alternatively it could be serialized, but I ran into issues serializing the haplotypecallerengine since it does it's own file access. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As it stands it's being serialized for each task. I suggest leaving it, and addressing any problems when using this for the HaplotypeCaller. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That sounds fine. We'll measure it when we get to it. |
||
|
||
private List<ShardBoundary> intervalShards; | ||
|
||
/** | ||
* Note that this sets {@code intervalShards} as a side effect, in order to add padding to the intervals. | ||
*/ | ||
@Override | ||
protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment to this method calling out the fact that it sets intervalShards as a side effect? People might gloss over that since editIntervals isn't usually expected to have any side effects. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary(); | ||
List<SimpleInterval> intervals = rawIntervals == null ? IntervalUtils.getAllIntervalsForReference(sequenceDictionary) : rawIntervals; | ||
intervalShards = intervals.stream() | ||
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream()) | ||
.collect(Collectors.toList()); | ||
List<SimpleInterval> paddedIntervalsForReads = | ||
intervals.stream().map(interval -> interval.expandWithinContig(readShardPadding, sequenceDictionary)).collect(Collectors.toList()); | ||
return paddedIntervalsForReads; | ||
} | ||
|
||
/** | ||
* Loads assembly regions and the corresponding reference and features into a {@link JavaRDD} for the intervals specified. | ||
* | ||
* If no intervals were specified, returns all the assembly regions. | ||
* | ||
* @return all assembly regions as a {@link JavaRDD}, bounded by intervals if specified. | ||
*/ | ||
protected JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegions(JavaSparkContext ctx) { | ||
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary(); | ||
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle); | ||
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null; | ||
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features); | ||
return shardedReads.flatMap(getAssemblyRegionsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(), | ||
assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance)); | ||
} | ||
|
||
private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction( | ||
final Broadcast<ReferenceMultiSource> bReferenceSource, | ||
final Broadcast<FeatureManager> bFeatureManager, | ||
final SAMSequenceDictionary sequenceDictionary, | ||
final SAMFileHeader header, | ||
final AssemblyRegionEvaluator evaluator, | ||
final int minAssemblyRegionSize, | ||
final int maxAssemblyRegionSize, | ||
final int assemblyRegionPadding, | ||
final double activeProbThreshold, | ||
final int maxProbPropagationDistance) { | ||
return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> { | ||
SimpleInterval paddedInterval = shardedRead.getPaddedInterval(); | ||
SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary); | ||
|
||
ReferenceDataSource reference = bReferenceSource == null ? null : | ||
new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary); | ||
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this really work to get features on a remote node? Aren't the features all loaded from local files in the current implementation? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will do with samtools/htsjdk#724. I've tested on a cluster with this change. Any chance of a review there? :) |
||
ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval); | ||
FeatureContext featureContext = new FeatureContext(features, paddedInterval); | ||
|
||
final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead, | ||
header, referenceContext, featureContext, evaluator, | ||
minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, | ||
maxProbPropagationDistance); | ||
return StreamSupport.stream(assemblyRegions.spliterator(), false).map(assemblyRegion -> | ||
new AssemblyRegionWalkerContext(assemblyRegion, | ||
new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), | ||
new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator(); | ||
}; | ||
} | ||
|
||
@Override | ||
protected void runTool(JavaSparkContext ctx) { | ||
processAssemblyRegions(getAssemblyRegions(ctx), ctx); | ||
} | ||
|
||
/** | ||
* Process the assembly regions and write output. Must be implemented by subclasses. | ||
* | ||
* @param rdd a distributed collection of {@link AssemblyRegionWalkerContext} | ||
* @param ctx our Spark context | ||
*/ | ||
protected abstract void processAssemblyRegions(JavaRDD<AssemblyRegionWalkerContext> rdd, JavaSparkContext ctx); | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package org.broadinstitute.hellbender.engine.spark; | ||
|
||
import org.broadinstitute.hellbender.engine.AssemblyRegion; | ||
import org.broadinstitute.hellbender.engine.FeatureContext; | ||
import org.broadinstitute.hellbender.engine.ReadsContext; | ||
import org.broadinstitute.hellbender.engine.ReferenceContext; | ||
import org.broadinstitute.hellbender.utils.SimpleInterval; | ||
|
||
/** | ||
* Encapsulates a {@link SimpleInterval} with the reads that overlap it (the {@link ReadsContext} and | ||
* its {@link ReferenceContext} and {@link FeatureContext}. | ||
*/ | ||
public class IntervalWalkerContext { | ||
private final SimpleInterval interval; | ||
private final ReadsContext readsContext; | ||
private final ReferenceContext referenceContext; | ||
private final FeatureContext featureContext; | ||
|
||
public IntervalWalkerContext(SimpleInterval interval, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) { | ||
this.interval = interval; | ||
this.readsContext = readsContext; | ||
this.referenceContext = referenceContext; | ||
this.featureContext = featureContext; | ||
} | ||
|
||
public SimpleInterval getInterval() { | ||
return interval; | ||
} | ||
|
||
public ReadsContext getReadsContext() { | ||
return readsContext; | ||
} | ||
|
||
public ReferenceContext getReferenceContext() { | ||
return referenceContext; | ||
} | ||
|
||
public FeatureContext getFeatureContext() { | ||
return featureContext; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add some explanation about you're expected to subclass this? Just an explanation of the entrypoint and necessary methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.