Skip to content

Add spark versions of walker classes #2256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 11, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
*/
public final class ReadsContext implements Iterable<GATKRead> {

private final ReadsDataSource dataSource;
private final GATKDataSource<GATKRead> dataSource;

private final SimpleInterval interval;

Expand All @@ -39,7 +39,7 @@ public ReadsContext() {
* @param dataSource backing source of reads data (may be null)
* @param interval interval over which to query (may be null)
*/
public ReadsContext( final ReadsDataSource dataSource, final SimpleInterval interval ) {
public ReadsContext( final GATKDataSource<GATKRead> dataSource, final SimpleInterval interval ) {
this.dataSource = dataSource;
this.interval = interval;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import htsjdk.variant.variantcontext.VariantContext;

import java.io.Serializable;
import java.util.function.Predicate;

@FunctionalInterface
public interface VariantFilter extends Predicate<VariantContext>{
public interface VariantFilter extends Predicate<VariantContext>, Serializable {
static final long serialVersionUID = 1L;

//HACK: These methods are a hack to get to get the type system to accept compositions of ReadFilters.
default VariantFilter and(VariantFilter filter ) { return Predicate.super.and(filter)::test; }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.broadinstitute.hellbender.engine.spark;

import org.broadinstitute.hellbender.engine.AssemblyRegion;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;

/**
* Encapsulates an {@link AssemblyRegion} with its {@link ReferenceContext} and {@link FeatureContext}.
*/
public final class AssemblyRegionWalkerContext {
private final AssemblyRegion assemblyRegion;
private final ReferenceContext referenceContext;
private final FeatureContext featureContext;

public AssemblyRegionWalkerContext(AssemblyRegion assemblyRegion, ReferenceContext referenceContext, FeatureContext featureContext) {
this.assemblyRegion = assemblyRegion;
this.referenceContext = referenceContext;
this.featureContext = featureContext;
}

public AssemblyRegion getAssemblyRegion() {
return assemblyRegion;
}

public ReferenceContext getReferenceContext() {
return referenceContext;
}

public FeatureContext getFeatureContext() {
return featureContext;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package org.broadinstitute.hellbender.engine.spark;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.broadcast.Broadcast;
import org.broadinstitute.barclay.argparser.Advanced;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.hellbender.engine.*;
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary;
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
import org.broadinstitute.hellbender.utils.IntervalUtils;
import org.broadinstitute.hellbender.utils.SimpleInterval;
import org.broadinstitute.hellbender.utils.read.GATKRead;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
* A Spark version of {@link AssemblyRegionWalker}. Subclasses should implement {@link #processAssemblyRegions(JavaRDD, JavaSparkContext)}
* and operate on the passed in RDD.
*/
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some explanation about you're expected to subclass this? Just an explanation of the entrypoint and necessary methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

public abstract class AssemblyRegionWalkerSpark extends GATKSparkTool {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might want to add an abstract method that makes it obvious how to extend this class. i.e. abstract void processAssemblyRegions(Tuple3<... > ...) It bring it more inline with the walkers that basically force you to implement the function you need.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

private static final long serialVersionUID = 1L;

@Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true)
protected int readShardSize = defaultReadShardSize();

@Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true)
protected int readShardPadding = defaultReadShardPadding();

@Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true)
protected int minAssemblyRegionSize = defaultMinAssemblyRegionSize();

@Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true)
protected int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize();

@Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true)
protected int assemblyRegionPadding = defaultAssemblyRegionPadding();

@Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true)
protected int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart();

@Advanced
@Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true)
protected double activeProbThreshold = defaultActiveProbThreshold();

@Advanced
@Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a lot of validation of these in AssemblyRegionWalker.onStartup that isn't ported here. I think we need to abstract this whole chunk of code because it's getting spread all over the place. My thought is to define a ShardingArgumentCollection interface and then users could subclass that to define the defaults they need. That would enable sharing defaults between tools in a reasonable way. We could also move the validation code into that so that it's not duplicated. Maybe that should be a separate PR though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Leaving for another PR.

protected int maxProbPropagationDistance = defaultMaxProbPropagationDistance();

/**
* @return Default value for the {@link #readShardSize} parameter, if none is provided on the command line
*/
protected abstract int defaultReadShardSize();

/**
* @return Default value for the {@link #readShardPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultReadShardPadding();

/**
* @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMinAssemblyRegionSize();

/**
* @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxAssemblyRegionSize();

/**
* @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line
*/
protected abstract int defaultAssemblyRegionPadding();

/**
* @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxReadsPerAlignmentStart();

/**
* @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line
*/
protected abstract double defaultActiveProbThreshold();

/**
* @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line
*/
protected abstract int defaultMaxProbPropagationDistance();

@Argument(doc = "whether to use the shuffle implementation or not", shortName = "shuffle", fullName = "shuffle", optional = true)
public boolean shuffle = false;

@Override
public final boolean requiresReads() { return true; }

@Override
public final boolean requiresReference() { return true; }

@Override
public List<ReadFilter> getDefaultReadFilters() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing @Override

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

final List<ReadFilter> defaultFilters = new ArrayList<>(2);
defaultFilters.add(new WellformedReadFilter());
defaultFilters.add(new ReadFilterLibrary.MappedReadFilter());
return defaultFilters;
}

/**
* @return The evaluator to be used to determine whether each locus is active or not. Must be implemented by tool authors.
* The results of this per-locus evaluator are used to determine the bounds of each active and inactive region.
*/
public abstract AssemblyRegionEvaluator assemblyRegionEvaluator();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might want to mention that it will be called once per shard, which may be expensive if this is an expensive operation. I found that practically, I had to either reuse the assembly region evaluator because initializing a HaplotypeCallerEngine is expensive. I ended up making the downstream call that used the assemblyRegionEvaluator be a mapPartitions in order to reduce the number of instantiations of the engine. Alternatively it could be serialized, but I ran into issues serializing the haplotypecallerengine since it does it's own file access.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As it stands it's being serialized for each task. I suggest leaving it, and addressing any problems when using this for the HaplotypeCaller.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds fine. We'll measure it when we get to it.


private List<ShardBoundary> intervalShards;

/**
* Note that this sets {@code intervalShards} as a side effect, in order to add padding to the intervals.
*/
@Override
protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment to this method calling out the fact that it sets intervalShards as a side effect? People might gloss over that since editIntervals isn't usually expected to have any side effects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
List<SimpleInterval> intervals = rawIntervals == null ? IntervalUtils.getAllIntervalsForReference(sequenceDictionary) : rawIntervals;
intervalShards = intervals.stream()
.flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream())
.collect(Collectors.toList());
List<SimpleInterval> paddedIntervalsForReads =
intervals.stream().map(interval -> interval.expandWithinContig(readShardPadding, sequenceDictionary)).collect(Collectors.toList());
return paddedIntervalsForReads;
}

/**
* Loads assembly regions and the corresponding reference and features into a {@link JavaRDD} for the intervals specified.
*
* If no intervals were specified, returns all the assembly regions.
*
* @return all assembly regions as a {@link JavaRDD}, bounded by intervals if specified.
*/
protected JavaRDD<AssemblyRegionWalkerContext> getAssemblyRegions(JavaSparkContext ctx) {
SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary();
JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle);
Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null;
Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features);
return shardedReads.flatMap(getAssemblyRegionsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(),
assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance));
}

private static FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext> getAssemblyRegionsFunction(
final Broadcast<ReferenceMultiSource> bReferenceSource,
final Broadcast<FeatureManager> bFeatureManager,
final SAMSequenceDictionary sequenceDictionary,
final SAMFileHeader header,
final AssemblyRegionEvaluator evaluator,
final int minAssemblyRegionSize,
final int maxAssemblyRegionSize,
final int assemblyRegionPadding,
final double activeProbThreshold,
final int maxProbPropagationDistance) {
return (FlatMapFunction<Shard<GATKRead>, AssemblyRegionWalkerContext>) shardedRead -> {
SimpleInterval paddedInterval = shardedRead.getPaddedInterval();
SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary);

ReferenceDataSource reference = bReferenceSource == null ? null :
new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary);
FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really work to get features on a remote node? Aren't the features all loaded from local files in the current implementation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will do with samtools/htsjdk#724. I've tested on a cluster with this change. Any chance of a review there? :)

ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval);
FeatureContext featureContext = new FeatureContext(features, paddedInterval);

final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead,
header, referenceContext, featureContext, evaluator,
minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold,
maxProbPropagationDistance);
return StreamSupport.stream(assemblyRegions.spliterator(), false).map(assemblyRegion ->
new AssemblyRegionWalkerContext(assemblyRegion,
new ReferenceContext(reference, assemblyRegion.getExtendedSpan()),
new FeatureContext(features, assemblyRegion.getExtendedSpan()))).iterator();
};
}

@Override
protected void runTool(JavaSparkContext ctx) {
processAssemblyRegions(getAssemblyRegions(ctx), ctx);
}

/**
* Process the assembly regions and write output. Must be implemented by subclasses.
*
* @param rdd a distributed collection of {@link AssemblyRegionWalkerContext}
* @param ctx our Spark context
*/
protected abstract void processAssemblyRegions(JavaRDD<AssemblyRegionWalkerContext> rdd, JavaSparkContext ctx);

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import org.broadinstitute.hellbender.cmdline.argumentcollections.*;
import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource;
import org.broadinstitute.hellbender.engine.datasources.ReferenceWindowFunctions;
import org.broadinstitute.hellbender.engine.FeatureDataSource;
import org.broadinstitute.hellbender.engine.FeatureManager;
import org.broadinstitute.hellbender.engine.filters.ReadFilter;
import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter;
import org.broadinstitute.hellbender.engine.spark.datasources.ReadsSparkSink;
Expand Down Expand Up @@ -93,6 +95,7 @@ public abstract class GATKSparkTool extends SparkCommandLineProgram {
private ReferenceMultiSource referenceSource;
private SAMSequenceDictionary referenceDictionary;
private List<SimpleInterval> intervals;
protected FeatureManager features;

/**
* Return the list of GATKCommandLinePluginDescriptor objects to be used for this CLP.
Expand Down Expand Up @@ -355,6 +358,7 @@ protected void runPipeline( JavaSparkContext sparkContext ) {
private void initializeToolInputs(final JavaSparkContext sparkContext) {
initializeReference();
initializeReads(sparkContext); // reference must be initialized before reads
initializeFeatures();
initializeIntervals();
}

Expand Down Expand Up @@ -394,6 +398,21 @@ private void initializeReference() {
}
}

/**
* Initialize our source of Feature data (or set it to null if no Feature argument(s) were provided).
*
* Package-private so that engine classes can access it, but concrete tool child classes cannot.
* May be overridden by traversals that require custom initialization of Feature data sources.
*
* By default, this method initializes the FeatureManager to use the lookahead cache of {@link FeatureDataSource#DEFAULT_QUERY_LOOKAHEAD_BASES} bases.
*/
void initializeFeatures() {
features = new FeatureManager(this);
if ( features.isEmpty() ) { // No available sources of Features discovered for this tool
features = null;
}
}

/**
* Loads our intervals using the best available sequence dictionary (as returned by {@link #getBestAvailableSequenceDictionary})
* to parse/verify them. Does nothing if no intervals were specified.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package org.broadinstitute.hellbender.engine.spark;

import org.broadinstitute.hellbender.engine.AssemblyRegion;
import org.broadinstitute.hellbender.engine.FeatureContext;
import org.broadinstitute.hellbender.engine.ReadsContext;
import org.broadinstitute.hellbender.engine.ReferenceContext;
import org.broadinstitute.hellbender.utils.SimpleInterval;

/**
* Encapsulates a {@link SimpleInterval} with the reads that overlap it (the {@link ReadsContext} and
* its {@link ReferenceContext} and {@link FeatureContext}.
*/
public class IntervalWalkerContext {
private final SimpleInterval interval;
private final ReadsContext readsContext;
private final ReferenceContext referenceContext;
private final FeatureContext featureContext;

public IntervalWalkerContext(SimpleInterval interval, ReadsContext readsContext, ReferenceContext referenceContext, FeatureContext featureContext) {
this.interval = interval;
this.readsContext = readsContext;
this.referenceContext = referenceContext;
this.featureContext = featureContext;
}

public SimpleInterval getInterval() {
return interval;
}

public ReadsContext getReadsContext() {
return readsContext;
}

public ReferenceContext getReferenceContext() {
return referenceContext;
}

public FeatureContext getFeatureContext() {
return featureContext;
}
}
Loading