Skip to content

Commit 348ff17

Browse files
committed
Updated the normalize method and added tests for it.
1 parent d6cf88e commit 348ff17

File tree

2 files changed

+88
-6
lines changed

2 files changed

+88
-6
lines changed

geopyspark/geotrellis/layer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,20 +1198,33 @@ def get_min_max(self):
11981198
min_max = self.srdd.getMinMax()
11991199
return (min_max._1(), min_max._2())
12001200

1201-
def normalize(self, old_min, old_max, new_min, new_max):
1201+
def normalize(self, new_min, new_max, old_min=None, old_max=None):
12021202
"""Finds the min value that is contained within the given geometry.
12031203
1204+
Note:
1205+
If ``old_max - old_min <= 0`` or ``new_max - new_min <= 0``, then the normalization
1206+
will fail.
1207+
12041208
Args:
1205-
old_min (float): Old minimum.
1206-
old_max (float): Old maximum.
1207-
new_min (float): New minimum to normalize to.
1208-
new_max (float): New maximum to normalize to.
1209+
old_min (int or float, optional): Old minimum. If not given, then the minimum value
1210+
of this layer will be used.
1211+
old_max (int or float, optional): Old maximum. If not given, then the minimum value
1212+
of this layer will be used.
1213+
new_min (int or float): New minimum to normalize to.
1214+
new_max (int or float): New maximum to normalize to.
12091215
12101216
Returns:
12111217
:class:`~geopyspark.geotrellis.rdd.TiledRasterLayer`
12121218
"""
12131219

1214-
srdd = self.srdd.normalize(old_min, old_max, new_min, new_max)
1220+
if not old_min and not old_max:
1221+
old_min, old_max = self.get_min_max()
1222+
elif not old_min:
1223+
old_min = self.get_min_max()[0]
1224+
elif not old_max:
1225+
old_max = self.get_min_max()[1]
1226+
1227+
srdd = self.srdd.normalize(float(old_min), float(old_max), float(new_min), float(new_max))
12151228

12161229
return TiledRasterLayer(self.layer_type, srdd)
12171230

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import os
2+
import unittest
3+
import numpy as np
4+
5+
import pytest
6+
7+
from geopyspark.geotrellis import SpatialKey, Tile
8+
from geopyspark.tests.base_test_class import BaseTestClass
9+
from geopyspark.geotrellis.layer import TiledRasterLayer
10+
from geopyspark.geotrellis.constants import LayerType
11+
12+
13+
class NormalizeTest(BaseTestClass):
14+
cells = np.array([[
15+
[1.0, 1.0, 1.0, 1.0, 1.0],
16+
[1.0, 1.0, 1.0, 1.0, 1.0],
17+
[1.0, 1.0, 1.0, 1.0, 1.0],
18+
[1.0, 1.0, 1.0, 1.0, 1.0],
19+
[1.0, 1.0, 1.0, 1.0, 0.0]]])
20+
21+
layer = [(SpatialKey(0, 0), Tile(cells + 0, 'FLOAT', -1.0)),
22+
(SpatialKey(1, 0), Tile(cells + 1, 'FLOAT', -1.0,)),
23+
(SpatialKey(0, 1), Tile(cells + 2, 'FLOAT', -1.0,)),
24+
(SpatialKey(1, 1), Tile(cells + 3, 'FLOAT', -1.0,))]
25+
26+
rdd = BaseTestClass.pysc.parallelize(layer)
27+
28+
extent = {'xmin': 0.0, 'ymin': 0.0, 'xmax': 33.0, 'ymax': 33.0}
29+
layout = {'layoutCols': 2, 'layoutRows': 2, 'tileCols': 5, 'tileRows': 5}
30+
metadata = {'cellType': 'float32ud-1.0',
31+
'extent': extent,
32+
'crs': '+proj=longlat +datum=WGS84 +no_defs ',
33+
'bounds': {
34+
'minKey': {'col': 0, 'row': 0},
35+
'maxKey': {'col': 1, 'row': 1}},
36+
'layoutDefinition': {
37+
'extent': extent,
38+
'tileLayout': {'tileCols': 5, 'tileRows': 5, 'layoutCols': 2, 'layoutRows': 2}}}
39+
40+
raster_rdd = TiledRasterLayer.from_numpy_rdd(LayerType.SPATIAL, rdd, metadata)
41+
42+
@pytest.fixture(autouse=True)
43+
def tearDown(self):
44+
yield
45+
BaseTestClass.pysc._gateway.close()
46+
47+
def test_normalize_all_parameters(self):
48+
normalized = self.raster_rdd.normalize(old_min=0.0, old_max=4.0, new_min=5.0, new_max=10.0)
49+
50+
self.assertEqual(normalized.get_min_max(), (5.0, 10.0))
51+
52+
def test_normalize_no_optinal_parameters(self):
53+
normalized = self.raster_rdd.normalize(new_min=5.0, new_max=10.0)
54+
55+
self.assertEqual(normalized.get_min_max(), (5.0, 10.0))
56+
57+
def test_normalize_old_min(self):
58+
normalized = self.raster_rdd.normalize(old_min=-1, new_min=5.0, new_max=10.0)
59+
60+
self.assertEqual(normalized.get_min_max(), (6.0, 10.0))
61+
62+
def test_normalize_old_max(self):
63+
normalized = self.raster_rdd.normalize(old_max=5.0, new_min=5.0, new_max=10.0)
64+
65+
self.assertEqual(normalized.get_min_max(), (5.0, 9.0))
66+
67+
68+
if __name__ == "__main__":
69+
unittest.main()

0 commit comments

Comments
 (0)