Skip to content

Commit 1e8b70f

Browse files
furqaankhanjiayuasu
authored andcommitted
[SEDONA-724] Fix RS_ZonalStats and RS_ZonalStatsAll edge case bug (#1871)
* fix: RS_ZonalStats and RS_ZonalStatsAll edge case * fix: spotless * change NaN to null * fix spotless * change all NaNs to nulls
1 parent f6ef773 commit 1e8b70f

File tree

3 files changed

+117
-33
lines changed

3 files changed

+117
-33
lines changed

common/src/main/java/org/apache/sedona/common/raster/RasterBandAccessors.java

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ public static long getCount(GridCoverage2D raster, int band) {
9999
* @return An array with all the stats for the region
100100
* @throws FactoryException
101101
*/
102-
public static double[] getZonalStatsAll(
102+
public static Double[] getZonalStatsAll(
103103
GridCoverage2D raster,
104104
Geometry roi,
105105
int band,
@@ -114,18 +114,35 @@ public static double[] getZonalStatsAll(
114114
DescriptiveStatistics stats = (DescriptiveStatistics) objects.get(0);
115115
double[] pixelData = (double[]) objects.get(1);
116116

117+
// Shortcut for an edge case where ROI barely intersects with raster's extent, but it doesn't
118+
// intersect with the centroid of the pixel.
119+
// This happens when allTouched parameter is false.
120+
if (pixelData.length == 0) {
121+
return new Double[] {0.0, null, null, null, null, null, null, null, null};
122+
}
123+
117124
// order of stats
118125
// count, sum, mean, median, mode, stddev, variance, min, max
119-
double[] result = new double[9];
120-
result[0] = stats.getN();
121-
result[1] = stats.getSum();
122-
result[2] = stats.getMean();
123-
result[3] = stats.getPercentile(50);
126+
Double[] result = new Double[9];
127+
result[0] = (double) stats.getN();
128+
if (stats.getN() == 0) {
129+
result[1] = null;
130+
} else {
131+
result[1] = stats.getSum();
132+
}
133+
double mean = stats.getMean();
134+
result[2] = Double.isNaN(mean) ? null : mean;
135+
double median = stats.getPercentile(50);
136+
result[3] = Double.isNaN(median) ? null : median;
124137
result[4] = zonalMode(pixelData);
125-
result[5] = stats.getStandardDeviation();
126-
result[6] = stats.getVariance();
127-
result[7] = stats.getMin();
128-
result[8] = stats.getMax();
138+
double stdDev = stats.getStandardDeviation();
139+
result[5] = Double.isNaN(stdDev) ? null : stats.getStandardDeviation();
140+
double variance = stats.getVariance();
141+
result[6] = Double.isNaN(variance) ? null : variance;
142+
double min = stats.getMin();
143+
result[7] = Double.isNaN(min) ? null : min;
144+
double max = stats.getMax();
145+
result[8] = Double.isNaN(max) ? null : max;
129146

130147
return result;
131148
}
@@ -139,7 +156,7 @@ public static double[] getZonalStatsAll(
139156
* @return An array with all the stats for the region
140157
* @throws FactoryException
141158
*/
142-
public static double[] getZonalStatsAll(
159+
public static Double[] getZonalStatsAll(
143160
GridCoverage2D raster, Geometry roi, int band, boolean allTouched, boolean excludeNoData)
144161
throws FactoryException {
145162
return getZonalStatsAll(raster, roi, band, allTouched, excludeNoData, true);
@@ -153,7 +170,7 @@ public static double[] getZonalStatsAll(
153170
* @return An array with all the stats for the region, excludeNoData is set to true
154171
* @throws FactoryException
155172
*/
156-
public static double[] getZonalStatsAll(
173+
public static Double[] getZonalStatsAll(
157174
GridCoverage2D raster, Geometry roi, int band, boolean allTouched) throws FactoryException {
158175
return getZonalStatsAll(raster, roi, band, allTouched, true);
159176
}
@@ -165,7 +182,7 @@ public static double[] getZonalStatsAll(
165182
* @return An array with all the stats for the region, excludeNoData is set to true
166183
* @throws FactoryException
167184
*/
168-
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band)
185+
public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int band)
169186
throws FactoryException {
170187
return getZonalStatsAll(raster, roi, band, false);
171188
}
@@ -177,7 +194,7 @@ public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi, int
177194
* set to 1
178195
* @throws FactoryException
179196
*/
180-
public static double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi)
197+
public static Double[] getZonalStatsAll(GridCoverage2D raster, Geometry roi)
181198
throws FactoryException {
182199
return getZonalStatsAll(raster, roi, 1);
183200
}
@@ -213,26 +230,36 @@ public static Double getZonalStats(
213230

214231
switch (statType.toLowerCase()) {
215232
case "sum":
216-
return stats.getSum();
233+
if (pixelData.length == 0) {
234+
return null;
235+
} else {
236+
return stats.getSum();
237+
}
217238
case "average":
218239
case "avg":
219240
case "mean":
220-
return stats.getMean();
241+
double mean = stats.getMean();
242+
return Double.isNaN(mean) ? null : mean;
221243
case "count":
222244
return (double) stats.getN();
223245
case "max":
224-
return stats.getMax();
246+
double max = stats.getMax();
247+
return Double.isNaN(max) ? null : max;
225248
case "min":
226-
return stats.getMin();
249+
double min = stats.getMin();
250+
return Double.isNaN(min) ? null : min;
227251
case "stddev":
228252
case "sd":
229-
return stats.getStandardDeviation();
253+
double stdDev = stats.getStandardDeviation();
254+
return Double.isNaN(stdDev) ? null : stdDev;
230255
case "median":
231-
return stats.getPercentile(50);
256+
double median = stats.getPercentile(50);
257+
return Double.isNaN(median) ? null : median;
232258
case "mode":
233259
return zonalMode(pixelData);
234260
case "variance":
235-
return stats.getVariance();
261+
double variance = stats.getVariance();
262+
return Double.isNaN(variance) ? null : variance;
236263
default:
237264
throw new IllegalArgumentException(
238265
"Please select from the accepted options. Some of the valid options are sum, mean, stddev, etc.");
@@ -310,8 +337,13 @@ public static Double getZonalStats(GridCoverage2D raster, Geometry roi, String s
310337
* @return Mode of the pixel values. If there is multiple with same occurrence, then the largest
311338
* value will be returned.
312339
*/
313-
private static double zonalMode(double[] pixelData) {
340+
private static Double zonalMode(double[] pixelData) {
314341
double[] modes = StatUtils.mode(pixelData);
342+
// Return NaN when ROI and raster's extent overlap, but there's no pixel data.
343+
// This behavior only happens when allTouched parameter is false.
344+
if (modes.length == 0) {
345+
return null;
346+
}
315347
return modes[modes.length - 1];
316348
}
317349

common/src/test/java/org/apache/sedona/common/raster/RasterBandAccessorsTest.java

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import org.locationtech.jts.geom.Geometry;
2929
import org.locationtech.jts.io.ParseException;
3030
import org.opengis.referencing.FactoryException;
31-
import org.opengis.referencing.operation.TransformException;
3231

3332
public class RasterBandAccessorsTest extends RasterTestBase {
3433

@@ -84,6 +83,31 @@ public void testBandNoDataValueIllegalBand() throws FactoryException, IOExceptio
8483
assertEquals("Provided band index 2 is not present in the raster", exception.getMessage());
8584
}
8685

86+
@Test
87+
public void testZonalStatsIntersectingNoPixelData() throws FactoryException, ParseException {
88+
double[][] pixelsValues =
89+
new double[][] {
90+
new double[] {
91+
3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 53, 54, 86, 28, 45, 24, 99, 22, 18,
92+
98, 10
93+
}
94+
};
95+
GridCoverage2D raster =
96+
RasterConstructors.makeNonEmptyRaster(1, "", 5, 5, 1, -1, 1, -1, 0, 0, 0, pixelsValues);
97+
Geometry extent =
98+
Constructors.geomFromWKT(
99+
"POLYGON ((5.822754 -6.620957, 6.965332 -6.620957, 6.965332 -5.834616, 5.822754 -5.834616, 5.822754 -6.620957))",
100+
0);
101+
102+
Double actualZonalStats = RasterBandAccessors.getZonalStats(raster, extent, "mode");
103+
assertNull(actualZonalStats);
104+
105+
String actualZonalStatsAll =
106+
Arrays.toString(RasterBandAccessors.getZonalStatsAll(raster, extent));
107+
String expectedZonalStatsAll = "[0.0, null, null, null, null, null, null, null, null]";
108+
assertEquals(expectedZonalStatsAll, actualZonalStatsAll);
109+
}
110+
87111
@Test
88112
public void testZonalStats() throws FactoryException, ParseException, IOException {
89113
GridCoverage2D raster =
@@ -182,15 +206,17 @@ public void testZonalStatsWithNoData() throws IOException, FactoryException, Par
182206
}
183207

184208
@Test
185-
public void testZonalStatsAll()
186-
throws IOException, FactoryException, ParseException, TransformException {
209+
public void testZonalStatsAll() throws IOException, FactoryException, ParseException {
187210
GridCoverage2D raster =
188211
rasterFromGeoTiff(resourceFolder + "raster_geotiff_color/FAA_UTM18N_NAD83.tif");
189212
String polygon =
190213
"POLYGON ((-8673439.6642 4572993.5327, -8673155.5737 4563873.2099, -8701890.3259 4562931.7093, -8682522.8735 4572703.8908, -8673439.6642 4572993.5327))";
191214
Geometry geom = Constructors.geomFromWKT(polygon, 3857);
192215

193-
double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false);
216+
double[] actual =
217+
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false))
218+
.mapToDouble(Double::doubleValue)
219+
.toArray();
194220
double[] expected =
195221
new double[] {
196222
185953.0,
@@ -209,16 +235,19 @@ public void testZonalStatsAll()
209235
Constructors.geomFromWKT(
210236
"POLYGON ((-77.96672569800863073 37.91971182746296876, -77.9688630154902711 37.89620133516485367, -77.93936803424354309 37.90517806858776595, -77.96672569800863073 37.91971182746296876))",
211237
0);
212-
actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false);
238+
actual =
239+
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, false, false))
240+
.mapToDouble(Double::doubleValue)
241+
.toArray();
213242
assertNotNull(actual);
214243

215244
Geometry nonIntersectingGeom =
216245
Constructors.geomFromWKT(
217246
"POLYGON ((-78.22106647832458748 37.76411511479908967, -78.20183062098976734 37.72863564460374874, -78.18088490966962922 37.76753482276972562, -78.22106647832458748 37.76411511479908967))",
218247
0);
219-
actual =
248+
Double[] actualNull =
220249
RasterBandAccessors.getZonalStatsAll(raster, nonIntersectingGeom, 1, false, false, true);
221-
assertNull(actual);
250+
assertNull(actualNull);
222251
assertThrows(
223252
IllegalArgumentException.class,
224253
() ->
@@ -227,15 +256,17 @@ public void testZonalStatsAll()
227256
}
228257

229258
@Test
230-
public void testZonalStatsAllWithNoData()
231-
throws IOException, FactoryException, ParseException, TransformException {
259+
public void testZonalStatsAllWithNoData() throws IOException, FactoryException, ParseException {
232260
GridCoverage2D raster =
233261
rasterFromGeoTiff(resourceFolder + "raster/raster_with_no_data/test5.tiff");
234262
String polygon =
235263
"POLYGON((-167.750000 87.750000, -155.250000 87.750000, -155.250000 40.250000, -180.250000 40.250000, -167.750000 87.750000))";
236264
Geometry geom = Constructors.geomFromWKT(polygon, RasterAccessors.srid(raster));
237265

238-
double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true);
266+
double[] actual =
267+
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true))
268+
.mapToDouble(Double::doubleValue)
269+
.toArray();
239270
double[] expected =
240271
new double[] {
241272
14249.0,
@@ -265,7 +296,10 @@ public void testZonalStatsAllWithEmptyRaster() throws FactoryException, ParseExc
265296
// Testing implicit CRS transformation
266297
Geometry geom = Constructors.geomFromWKT("POLYGON((2 -2, 2 -6, 6 -6, 6 -2, 2 -2))", 0);
267298

268-
double[] actual = RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true);
299+
double[] actual =
300+
Arrays.stream(RasterBandAccessors.getZonalStatsAll(raster, geom, 1, false, true))
301+
.mapToDouble(Double::doubleValue)
302+
.toArray();
269303
double[] expected = new double[] {13.0, 114.0, 8.7692, 9.0, 11.0, 4.7285, 22.3589, 1.0, 16.0};
270304
assertArrayEquals(expected, actual, FP_TOLERANCE);
271305
}

spark/common/src/test/scala/org/apache/sedona/sql/rasteralgebraTest.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1623,6 +1623,24 @@ class rasteralgebraTest extends TestBaseScala with BeforeAndAfter with GivenWhen
16231623
assertTrue(expectedSummary4.equals(actualSummary4))
16241624
}
16251625

1626+
it("Passed RS_ZonalStats edge case") {
1627+
val df = sparkSession.sql("""
1628+
|with data as (
1629+
| SELECT array(3, 7, 5, 40, 61, 70, 60, 80, 27, 55, 35, 44, 21, 36, 53, 54, 86, 28, 45, 24, 99, 22, 18, 98, 10) as pixels,
1630+
| ST_GeomFromWKT('POLYGON ((5.822754 -6.620957, 6.965332 -6.620957, 6.965332 -5.834616, 5.822754 -5.834616, 5.822754 -6.620957))', 4326) as geom
1631+
|)
1632+
|
1633+
|SELECT RS_SetSRID(RS_AddBandFromArray(RS_MakeEmptyRaster(1, "D", 5, 5, 1, -1, 1), pixels, 1), 4326) as raster, geom FROM data
1634+
|""".stripMargin)
1635+
1636+
val actual = df.selectExpr("RS_ZonalStats(raster, geom, 1, 'mode')").first().get(0)
1637+
assertNull(actual)
1638+
1639+
val statsDf = df.selectExpr("RS_ZonalStatsAll(raster, geom) as stats")
1640+
val actualBoolean = statsDf.selectExpr("isNull(stats.mode)").first().getAs[Boolean](0)
1641+
assertTrue(actualBoolean)
1642+
}
1643+
16261644
it("Passed RS_ZonalStats") {
16271645
var df = sparkSession.read
16281646
.format("binaryFile")

0 commit comments

Comments
 (0)