|
| 1 | +package org.broadinstitute.hellbender.engine.spark; |
| 2 | + |
| 3 | +import com.google.common.base.Function; |
| 4 | +import com.google.common.collect.Iterators; |
| 5 | +import htsjdk.samtools.SAMFileHeader; |
| 6 | +import htsjdk.samtools.SAMSequenceDictionary; |
| 7 | +import org.apache.spark.api.java.JavaRDD; |
| 8 | +import org.apache.spark.api.java.JavaSparkContext; |
| 9 | +import org.apache.spark.api.java.function.FlatMapFunction; |
| 10 | +import org.apache.spark.broadcast.Broadcast; |
| 11 | +import org.broadinstitute.hellbender.cmdline.Advanced; |
| 12 | +import org.broadinstitute.hellbender.cmdline.Argument; |
| 13 | +import org.broadinstitute.hellbender.engine.*; |
| 14 | +import org.broadinstitute.hellbender.engine.datasources.ReferenceMultiSource; |
| 15 | +import org.broadinstitute.hellbender.engine.filters.ReadFilter; |
| 16 | +import org.broadinstitute.hellbender.engine.filters.ReadFilterLibrary; |
| 17 | +import org.broadinstitute.hellbender.engine.filters.WellformedReadFilter; |
| 18 | +import org.broadinstitute.hellbender.utils.IntervalUtils; |
| 19 | +import org.broadinstitute.hellbender.utils.SimpleInterval; |
| 20 | +import org.broadinstitute.hellbender.utils.read.GATKRead; |
| 21 | +import scala.Tuple3; |
| 22 | + |
| 23 | +import javax.annotation.Nullable; |
| 24 | +import java.util.ArrayList; |
| 25 | +import java.util.List; |
| 26 | +import java.util.stream.Collectors; |
| 27 | + |
| 28 | +/** |
| 29 | + * A Spark version of {@link AssemblyRegionWalker}. |
| 30 | + */ |
| 31 | +public abstract class AssemblyRegionWalkerSpark extends GATKSparkTool { |
| 32 | + private static final long serialVersionUID = 1L; |
| 33 | + |
| 34 | + @Argument(fullName="readShardSize", shortName="readShardSize", doc = "Maximum size of each read shard, in bases. For good performance, this should be much larger than the maximum assembly region size.", optional = true) |
| 35 | + protected int readShardSize = defaultReadShardSize(); |
| 36 | + |
| 37 | + @Argument(fullName="readShardPadding", shortName="readShardPadding", doc = "Each read shard has this many bases of extra context on each side. Read shards must have as much or more padding than assembly regions.", optional = true) |
| 38 | + protected int readShardPadding = defaultReadShardPadding(); |
| 39 | + |
| 40 | + @Argument(fullName = "minAssemblyRegionSize", shortName = "minAssemblyRegionSize", doc = "Minimum size of an assembly region", optional = true) |
| 41 | + protected int minAssemblyRegionSize = defaultMinAssemblyRegionSize(); |
| 42 | + |
| 43 | + @Argument(fullName = "maxAssemblyRegionSize", shortName = "maxAssemblyRegionSize", doc = "Maximum size of an assembly region", optional = true) |
| 44 | + protected int maxAssemblyRegionSize = defaultMaxAssemblyRegionSize(); |
| 45 | + |
| 46 | + @Argument(fullName = "assemblyRegionPadding", shortName = "assemblyRegionPadding", doc = "Number of additional bases of context to include around each assembly region", optional = true) |
| 47 | + protected int assemblyRegionPadding = defaultAssemblyRegionPadding(); |
| 48 | + |
| 49 | + @Argument(fullName = "maxReadsPerAlignmentStart", shortName = "maxReadsPerAlignmentStart", doc = "Maximum number of reads to retain per alignment start position. Reads above this threshold will be downsampled. Set to 0 to disable.", optional = true) |
| 50 | + protected int maxReadsPerAlignmentStart = defaultMaxReadsPerAlignmentStart(); |
| 51 | + |
| 52 | + @Advanced |
| 53 | + @Argument(fullName = "activeProbabilityThreshold", shortName = "activeProbabilityThreshold", doc="Minimum probability for a locus to be considered active.", optional = true) |
| 54 | + protected double activeProbThreshold = defaultActiveProbThreshold(); |
| 55 | + |
| 56 | + @Advanced |
| 57 | + @Argument(fullName = "maxProbPropagationDistance", shortName = "maxProbPropagationDistance", doc="Upper limit on how many bases away probability mass can be moved around when calculating the boundaries between active and inactive assembly regions", optional = true) |
| 58 | + protected int maxProbPropagationDistance = defaultMaxProbPropagationDistance(); |
| 59 | + |
| 60 | + /** |
| 61 | + * @return Default value for the {@link #readShardSize} parameter, if none is provided on the command line |
| 62 | + */ |
| 63 | + protected abstract int defaultReadShardSize(); |
| 64 | + |
| 65 | + /** |
| 66 | + * @return Default value for the {@link #readShardPadding} parameter, if none is provided on the command line |
| 67 | + */ |
| 68 | + protected abstract int defaultReadShardPadding(); |
| 69 | + |
| 70 | + /** |
| 71 | + * @return Default value for the {@link #minAssemblyRegionSize} parameter, if none is provided on the command line |
| 72 | + */ |
| 73 | + protected abstract int defaultMinAssemblyRegionSize(); |
| 74 | + |
| 75 | + /** |
| 76 | + * @return Default value for the {@link #maxAssemblyRegionSize} parameter, if none is provided on the command line |
| 77 | + */ |
| 78 | + protected abstract int defaultMaxAssemblyRegionSize(); |
| 79 | + |
| 80 | + /** |
| 81 | + * @return Default value for the {@link #assemblyRegionPadding} parameter, if none is provided on the command line |
| 82 | + */ |
| 83 | + protected abstract int defaultAssemblyRegionPadding(); |
| 84 | + |
| 85 | + /** |
| 86 | + * @return Default value for the {@link #maxReadsPerAlignmentStart} parameter, if none is provided on the command line |
| 87 | + */ |
| 88 | + protected abstract int defaultMaxReadsPerAlignmentStart(); |
| 89 | + |
| 90 | + /** |
| 91 | + * @return Default value for the {@link #activeProbThreshold} parameter, if none is provided on the command line |
| 92 | + */ |
| 93 | + protected abstract double defaultActiveProbThreshold(); |
| 94 | + |
| 95 | + /** |
| 96 | + * @return Default value for the {@link #maxProbPropagationDistance} parameter, if none is provided on the command line |
| 97 | + */ |
| 98 | + protected abstract int defaultMaxProbPropagationDistance(); |
| 99 | + |
| 100 | + @Argument(doc = "whether to use the shuffle implementation or not", shortName = "shuffle", fullName = "shuffle", optional = true) |
| 101 | + public boolean shuffle = false; |
| 102 | + |
| 103 | + @Override |
| 104 | + public final boolean requiresReads() { return true; } |
| 105 | + |
| 106 | + @Override |
| 107 | + public final boolean requiresReference() { return true; } |
| 108 | + |
| 109 | + public List<ReadFilter> getDefaultReadFilters() { |
| 110 | + final List<ReadFilter> defaultFilters = new ArrayList<>(2); |
| 111 | + defaultFilters.add(new WellformedReadFilter()); |
| 112 | + defaultFilters.add(new ReadFilterLibrary.MappedReadFilter()); |
| 113 | + return defaultFilters; |
| 114 | + } |
| 115 | + |
| 116 | + /** |
| 117 | + * @return The evaluator to be used to determine whether each locus is active or not. Must be implemented by tool authors. |
| 118 | + * The results of this per-locus evaluator are used to determine the bounds of each active and inactive region. |
| 119 | + */ |
| 120 | + public abstract AssemblyRegionEvaluator assemblyRegionEvaluator(); |
| 121 | + |
| 122 | + private List<ShardBoundary> intervalShards; |
| 123 | + |
| 124 | + @Override |
| 125 | + protected List<SimpleInterval> editIntervals(List<SimpleInterval> rawIntervals) { |
| 126 | + SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary(); |
| 127 | + List<SimpleInterval> intervals = rawIntervals == null ? IntervalUtils.getAllIntervalsForReference(sequenceDictionary) : rawIntervals; |
| 128 | + intervalShards = intervals.stream() |
| 129 | + .flatMap(interval -> Shard.divideIntervalIntoShards(interval, readShardSize, readShardPadding, sequenceDictionary).stream()) |
| 130 | + .collect(Collectors.toList()); |
| 131 | + List<SimpleInterval> paddedIntervalsForReads = |
| 132 | + intervals.stream().map(interval -> interval.expandWithinContig(readShardPadding, sequenceDictionary)).collect(Collectors.toList()); |
| 133 | + return paddedIntervalsForReads; |
| 134 | + } |
| 135 | + |
| 136 | + /** |
| 137 | + * Loads assembly regions and the corresponding reference and features into a {@link JavaRDD} for the intervals specified. |
| 138 | + * |
| 139 | + * If no intervals were specified, returns all the assembly regions. |
| 140 | + * |
| 141 | + * @return all assembly regions as a {@link JavaRDD}, bounded by intervals if specified. |
| 142 | + */ |
| 143 | + public JavaRDD<Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>> getAssemblyRegions(JavaSparkContext ctx) { |
| 144 | + SAMSequenceDictionary sequenceDictionary = getBestAvailableSequenceDictionary(); |
| 145 | + JavaRDD<Shard<GATKRead>> shardedReads = SparkSharder.shard(ctx, getReads(), GATKRead.class, sequenceDictionary, intervalShards, readShardSize, shuffle); |
| 146 | + Broadcast<ReferenceMultiSource> bReferenceSource = hasReference() ? ctx.broadcast(getReference()) : null; |
| 147 | + Broadcast<FeatureManager> bFeatureManager = features == null ? null : ctx.broadcast(features); |
| 148 | + return shardedReads.flatMap(getAssemblyRegionsFunction(bReferenceSource, bFeatureManager, sequenceDictionary, getHeaderForReads(), |
| 149 | + assemblyRegionEvaluator(), minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, maxProbPropagationDistance)); |
| 150 | + } |
| 151 | + |
| 152 | + private static FlatMapFunction<Shard<GATKRead>, Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>> getAssemblyRegionsFunction( |
| 153 | + final Broadcast<ReferenceMultiSource> bReferenceSource, |
| 154 | + final Broadcast<FeatureManager> bFeatureManager, |
| 155 | + final SAMSequenceDictionary sequenceDictionary, |
| 156 | + final SAMFileHeader header, |
| 157 | + final AssemblyRegionEvaluator evaluator, |
| 158 | + final int minAssemblyRegionSize, |
| 159 | + final int maxAssemblyRegionSize, |
| 160 | + final int assemblyRegionPadding, |
| 161 | + final double activeProbThreshold, |
| 162 | + final int maxProbPropagationDistance) { |
| 163 | + return (FlatMapFunction<Shard<GATKRead>, Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>>) shardedRead -> { |
| 164 | + SimpleInterval paddedInterval = shardedRead.getPaddedInterval(); |
| 165 | + SimpleInterval assemblyRegionPaddedInterval = paddedInterval.expandWithinContig(assemblyRegionPadding, sequenceDictionary); |
| 166 | + |
| 167 | + ReferenceDataSource reference = bReferenceSource == null ? null : |
| 168 | + new ReferenceMemorySource(bReferenceSource.getValue().getReferenceBases(null, assemblyRegionPaddedInterval), sequenceDictionary); |
| 169 | + FeatureManager features = bFeatureManager == null ? null : bFeatureManager.getValue(); |
| 170 | + ReferenceContext referenceContext = new ReferenceContext(reference, paddedInterval); |
| 171 | + FeatureContext featureContext = new FeatureContext(features, paddedInterval); |
| 172 | + |
| 173 | + final Iterable<AssemblyRegion> assemblyRegions = AssemblyRegion.createFromReadShard(shardedRead, |
| 174 | + header, referenceContext, featureContext, evaluator, |
| 175 | + minAssemblyRegionSize, maxAssemblyRegionSize, assemblyRegionPadding, activeProbThreshold, |
| 176 | + maxProbPropagationDistance); |
| 177 | + return Iterators.transform(assemblyRegions.iterator(), new Function<AssemblyRegion, Tuple3<AssemblyRegion, ReferenceContext, FeatureContext>>() { |
| 178 | + @Nullable |
| 179 | + @Override |
| 180 | + public Tuple3<AssemblyRegion, ReferenceContext, FeatureContext> apply(@Nullable AssemblyRegion assemblyRegion) { |
| 181 | + return new Tuple3<>(assemblyRegion, |
| 182 | + new ReferenceContext(reference, assemblyRegion.getExtendedSpan()), |
| 183 | + new FeatureContext(features, assemblyRegion.getExtendedSpan())); |
| 184 | + } |
| 185 | + }); |
| 186 | + }; |
| 187 | + } |
| 188 | + |
| 189 | +} |
0 commit comments