Skip to content

[SEDONA-724] Fix RS_ZonalStats and RS_ZonalStatsAll edge case bug #1871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ public static long getCount(GridCoverage2D raster, int band) {
* @return An array with all the stats for the region
* @throws FactoryException
*/
public static double[] getZonalStatsAll(
public static Double[] getZonalStatsAll(
GridCoverage2D raster,
Geometry roi,
int band,
Expand All @@ -114,18 +114,35 @@ public static double[] getZonalStatsAll(
DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0);
double[] pixelData = (double[]) objects.get(1);

// Shortcut for an edge case where ROI barely intersects with raster's extent, but it doesn't
// intersect with the centroid of the pixel.
// This happens when allTouched parameter is false.
if (pixelData.length == 0) {
return new Double[] {0.0, null, null, null, null, null, null, null, null};
}

// order of stats
// count, sum, mean, median, mode, stddev, variance, min, max
double[] result = new double[9];
result[0] = stats.getN();
result[1] = stats.getSum();
result[2] = stats.getMean();
result[3] = stats.getPercentile(50);
Double[] result = new Double[9];
result[0] = (double) stats.getN();
if (stats.getN() == 0) {
result[1] = null;
} else {
result[1] = stats.getSum();
}
double mean = stats.getMean();
result[2] = Double.isNaN(mean) ? null : mean;
double median = stats.getPercentile(50);
result[3] = Double.isNaN(median) ? null : median;
result[4] = zonalMode(pixelData);
result[5] = stats.getStandardDeviation();
result[6] = stats.getVariance();
result[7] = stats.getMin();
result[8] = stats.getMax();
double stdDev = stats.getStandardDeviation();
result[5] = Double.isNaN(stdDev) ? null : stats.getStandardDeviation();
double variance = stats.getVariance();
result[6] = Double.isNaN(variance) ? null : variance;
double min = stats.getMin();
result[7] = Double.isNaN(min) ? null : min;
double max = stats.getMax();
result[8] = Double.isNaN(max) ? null : max;

return result;
}
Expand All @@ -139,7 +156,7 @@ public static double[] getZonalStatsAll(
* @return An array with all the stats for the region
* @throws FactoryException
*/
public static double[] getZonalStatsAll(
public static Double[] getZonalStatsAll(
GridCoverage2D raster, Geometry roi, int band, boolean allTouched, boolean excludeNoData)
throws FactoryException {
return getZonalStatsAll(raster, roi, band, allTouched, excludeNoData, true);
Expand All @@ -153,7 +170,7 @@ public static double[] getZonalStatsAll(
* @return An array with all the stats for the region, excludeNoData is set to true
* @throws FactoryException
*/
public static double[] getZonalStatsAll(
public static Double[] getZonalStatsAll(
GridCoverage2D raster, Geometry roi, int band, boolean allTouched) throws FactoryException {
return getZonalStatsAll(raster, roi, band, allTouched, true);
}
Expand All @@ -165,7 +182,7 @@ public static double[] getZonalStatsAll(
* @return An array with all the stats for the region, excludeNoData is set to true
* @throws FactoryException
*/
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band)
public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band)
throws FactoryException {
return getZonalStatsAll(raster, roi, band, false);
}
Expand All @@ -177,7 +194,7 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int
* set to 1
* @throws FactoryException
*/
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi)
public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi)
throws FactoryException {
return getZonalStatsAll(raster, roi, 1);
}
Expand Down Expand Up @@ -213,26 +230,36 @@ public static Double getZonalStats(

switch (statType.toLowerCase()) {
case "sum":
return stats.getSum();
if (pixelData.length == 0) {
return null;
} else {
return stats.getSum();
}
case "average":
case "avg":
case "mean":
return stats.getMean();
double mean = stats.getMean();
return Double.isNaN(mean) ? null : mean;
case "count":
return (double) stats.getN();
case "max":
return stats.getMax();
double max = stats.getMax();
return Double.isNaN(max) ? null : max;
case "min":
return stats.getMin();
double min = stats.getMin();
return Double.isNaN(min) ? null : min;
case "stddev":
case "sd":
return stats.getStandardDeviation();
double stdDev = stats.getStandardDeviation();
return Double.isNaN(stdDev) ? null : stdDev;
case "median":
return stats.getPercentile(50);
double median = stats.getPercentile(50);
return Double.isNaN(median) ? null : median;
case "mode":
return zonalMode(pixelData);
case "variance":
return stats.getVariance();
double variance = stats.getVariance();
return Double.isNaN(variance) ? null : variance;
default:
throw new IllegalArgumentException(
"Please select from the accepted options. Some of the valid options are sum, mean, stddev, etc.");
Expand Down Expand Up @@ -310,8 +337,13 @@ public static Double getZonalStats(GridCoverage2D raster, Geometry roi, String s
* @return Mode of the pixel values. If there is multiple with same occurrence, then the largest
* value will be returned.
*/
private static double zonalMode(double[] pixelData) {
private static Double zonalMode(double[] pixelData) {
double[] modes = StatUtils.mode(pixelData);
// Return NaN when ROI and raster's extent overlap, but there's no pixel data.
// This behavior only happens when allTouched parameter is false.
if (modes.length == 0) {
return null;
}
return modes[modes.length - 1];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.io.ParseException;
import org.opengis.referencing.FactoryException;
import org.opengis.referencing.operation.TransformException;

public class RasterBandAccessorsTest extends RasterTestBase {

Expand Down Expand Up @@ -84,6 +83,31 @@ public void testBandNoDataValueIllegalBand() throws FactoryException, IOExceptio
assertEquals("Provided band index 2 is not present in the raster", exception.getMessage());
}

@Test
public void testZonalStatsIntersectingNoPixelData() throws FactoryException, ParseException {
double[][] pixelsValues =
new double[][] {
new double[] {
3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 53, 54, 86, 28, 45, 24, 99, 22, 18,
98, 10
}
};
GridCoverage2D raster =
RasterConstructors.makeNonEmptyRaster(1, "", 5, 5, 1, -1, 1, -1, 0, 0, 0, pixelsValues);
Geometry extent =
Constructors.geomFromWKT(
"POLYGON ((5.822754 -6.620957, 6.965332 -6.620957, 6.965332 -5.834616, 5.822754 -5.834616, 5.822754 -6.620957))",
0);

Double actualZonalStats = RasterBandAccessors.getZonalStats(raster, extent, "mode");
assertNull(actualZonalStats);

String actualZonalStatsAll =
Arrays.toString(RasterBandAccessors.getZonalStatsAll(raster, extent));
String expectedZonalStatsAll = "[0.0, null, null, null, null, null, null, null, null]";
assertEquals(expectedZonalStatsAll, actualZonalStatsAll);
}

@Test
public void testZonalStats() throws FactoryException, ParseException, IOException {
GridCoverage2D raster =
Expand Down Expand Up @@ -182,15 +206,17 @@ public void testZonalStatsWithNoData() throws IOException, FactoryException, Par
}

@Test
public void testZonalStatsAll()
throws IOException, FactoryException, ParseException, TransformException {
public void testZonalStatsAll() throws IOException, FactoryException, ParseException {
GridCoverage2D raster =
rasterFromGeoTiff(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif");
String polygon =
"POLYGON ((-8673439.6642 4572993.5327, -8673155.5737 4563873.2099, -8701890.3259 4562931.7093, -8682522.8735 4572703.8908, -8673439.6642 4572993.5327))";
Geometry geom = Constructors.geomFromWKT(polygon, 3857);

double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false);
double[] actual =
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false))
.mapToDouble(Double::doubleValue)
.toArray();
double[] expected =
new double[] {
185953.0,
Expand All @@ -209,16 +235,19 @@ public void testZonalStatsAll()
Constructors.geomFromWKT(
"POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))",
0);
actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false);
actual =
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false))
.mapToDouble(Double::doubleValue)
.toArray();
assertNotNull(actual);

Geometry nonIntersectingGeom =
Constructors.geomFromWKT(
"POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))",
0);
actual =
Double[] actualNull =
RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, false, true);
assertNull(actual);
assertNull(actualNull);
assertThrows(
IllegalArgumentException.class,
() ->
Expand All @@ -227,15 +256,17 @@ public void testZonalStatsAll()
}

@Test
public void testZonalStatsAllWithNoData()
throws IOException, FactoryException, ParseException, TransformException {
public void testZonalStatsAllWithNoData() throws IOException, FactoryException, ParseException {
GridCoverage2D raster =
rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff");
String polygon =
"POLYGON((-167.750000 87.750000, -155.250000 87.750000, -155.250000 40.250000, -180.250000 40.250000, -167.750000 87.750000))";
Geometry geom = Constructors.geomFromWKT(polygon, RasterAccessors.srid(raster));

double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true);
double[] actual =
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true))
.mapToDouble(Double::doubleValue)
.toArray();
double[] expected =
new double[] {
14249.0,
Expand Down Expand Up @@ -265,7 +296,10 @@ public void testZonalStatsAllWithEmptyRaster() throws FactoryException, ParseExc
// Testing implicit CRS transformation
Geometry geom = Constructors.geomFromWKT("POLYGON((2 -2, 2 -6, 6 -6, 6 -2, 2 -2))", 0);

double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true);
double[] actual =
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true))
.mapToDouble(Double::doubleValue)
.toArray();
double[] expected = new double[] {13.0, 114.0, 8.7692, 9.0, 11.0, 4.7285, 22.3589, 1.0, 16.0};
assertArrayEquals(expected, actual, FP_TOLERANCE);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,24 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
assertTrue(expectedSummary4.equals(actualSummary4))
}

it("Passed RS_ZonalStats edge case") {
val df = sparkSession.sql("""
|with data as (
| SELECT array(3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 53, 54, 86, 28, 45, 24, 99, 22, 18, 98, 10) as pixels,
| ST_GeomFromWKT('POLYGON ((5.822754 -6.620957, 6.965332 -6.620957, 6.965332 -5.834616, 5.822754 -5.834616, 5.822754 -6.620957))', 4326) as geom
|)
|
|SELECT RS_SetSRID(RS_AddBandFromArray(RS_MakeEmptyRaster(1, "D", 5, 5, 1, -1, 1), pixels, 1), 4326) as raster, geom FROM data
|""".stripMargin)

val actual = df.selectExpr("RS_ZonalStats(raster, geom, 1, 'mode')").first().get(0)
assertNull(actual)

val statsDf = df.selectExpr("RS_ZonalStatsAll(raster, geom) as stats")
val actualBoolean = statsDf.selectExpr("isNull(stats.mode)").first().getAs[Boolean](0)
assertTrue(actualBoolean)
}

it("Passed RS_ZonalStats") {
var df = sparkSession.read
.format("binaryFile")
Expand Down
Loading