Skip to content

Commit 5172832

Browse files
authored
Merge pull request #352 from ecmwf/develop
v1.0.36
2 parents 00c2610 + a40541a commit 5172832

File tree

8 files changed

+243
-26
lines changed

8 files changed

+243
-26
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ newest-polytope-venv
2727
serializedTree
2828
new_polytope_venv
2929
*.json
30-
venv_python3_11
30+
venv_python3_11
31+
tests/data

polytope_feature/datacube/backends/fdb.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -227,17 +227,22 @@ def nearest_lat_lon_search(self, requests):
227227
first_ax_name = requests.children[0].axis.name
228228
second_ax_name = requests.children[0].children[0].axis.name
229229

230-
if first_ax_name not in self.nearest_search.keys() or second_ax_name not in self.nearest_search.keys():
230+
axes_in_nearest_search = [
231+
first_ax_name not in self.nearest_search.keys(),
232+
second_ax_name not in self.nearest_search.keys(),
233+
]
234+
235+
if all(not item for item in axes_in_nearest_search):
231236
raise Exception("nearest point search axes are wrong")
232237

233238
second_ax = requests.children[0].children[0].axis
239+
nearest_pts = self.nearest_search.get(first_ax_name, None)
240+
if nearest_pts is None:
241+
nearest_pts = self.nearest_search[second_ax_name]
234242

235-
nearest_pts = [
236-
[lat_val, second_ax._remap_val_to_axis_range(lon_val)]
237-
for (lat_val, lon_val) in zip(
238-
self.nearest_search[first_ax_name][0], self.nearest_search[second_ax_name][0]
239-
)
240-
]
243+
transformed_nearest_pts = []
244+
for point in nearest_pts:
245+
transformed_nearest_pts.append([point[0], second_ax._remap_val_to_axis_range(point[1])])
241246

242247
found_latlon_pts = []
243248
for lat_child in requests.children:
@@ -246,7 +251,7 @@ def nearest_lat_lon_search(self, requests):
246251

247252
# now find the nearest lat lon to the points requested
248253
nearest_latlons = []
249-
for pt in nearest_pts:
254+
for pt in transformed_nearest_pts:
250255
nearest_latlon = nearest_pt(found_latlon_pts, pt)
251256
nearest_latlons.append(nearest_latlon)
252257

polytope_feature/engine/hullslicer.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..datacube.backends.datacube import Datacube
99
from ..datacube.datacube_axis import UnsliceableDatacubeAxis
1010
from ..datacube.tensor_index_tree import TensorIndexTree
11-
from ..shapes import ConvexPolytope
11+
from ..shapes import ConvexPolytope, Product
1212
from ..utility.combinatorics import group, tensor_product
1313
from ..utility.exceptions import UnsliceableShapeError
1414
from ..utility.geometry import lerp
@@ -76,8 +76,6 @@ def find_values_between(self, polytope, ax, node, datacube, lower, upper):
7676
upper = ax.from_float(upper + tol)
7777
flattened = node.flatten()
7878
method = polytope.method
79-
if method == "nearest":
80-
datacube.nearest_search[ax.name] = polytope.points
8179

8280
# NOTE: caching
8381
# Create a coupled_axes list inside of datacube and add to it during axis formation, then here
@@ -214,7 +212,11 @@ def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]):
214212

215213
# Convert the polytope points to float type to support triangulation and interpolation
216214
for p in polytopes:
217-
self._unique_continuous_points(p, datacube)
215+
if isinstance(p, Product):
216+
for poly in p.polytope():
217+
self._unique_continuous_points(poly, datacube)
218+
else:
219+
self._unique_continuous_points(p, datacube)
218220

219221
groups, input_axes = group(polytopes)
220222
datacube.validate(input_axes)
@@ -233,7 +235,16 @@ def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]):
233235
new_c.extend(combi)
234236
else:
235237
new_c.append(combi)
236-
r["unsliced_polytopes"] = set(new_c)
238+
# NOTE TODO: here some of the polys in new_c can be a Product shape instead of a ConvexPolytope
239+
# -> need to go through the polytopes in new_c and replace the Products with their sub-ConvexPolytopes
240+
final_polys = []
241+
for poly in new_c:
242+
if isinstance(poly, Product):
243+
final_polys.extend(poly.polytope())
244+
else:
245+
final_polys.append(poly)
246+
# r["unsliced_polytopes"] = set(new_c)
247+
r["unsliced_polytopes"] = set(final_polys)
237248
current_nodes = [r]
238249
for ax in datacube.axes.values():
239250
next_nodes = []

polytope_feature/polytope.py

+7
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def retrieve(self, request: Request, method="standard"):
6464
"""Higher-level API which takes a request and uses it to slice the datacube"""
6565
logging.info("Starting request for %s ", self.context)
6666
self.datacube.check_branching_axes(request)
67+
for polytope in request.polytopes():
68+
method = polytope.method
69+
if method == "nearest":
70+
if self.datacube.nearest_search.get(polytope.axes()[0], None) is None:
71+
self.datacube.nearest_search[polytope.axes()[0]] = polytope.values
72+
else:
73+
self.datacube.nearest_search[polytope.axes()[0]].append(polytope.values[0])
6774
request_tree = self.engine.extract(self.datacube, request.polytopes())
6875
logging.info("Created request tree for %s ", self.context)
6976
self.datacube.get(request_tree, self.context)

polytope_feature/shapes.py

+47-8
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,42 @@ def polytope(self):
6161
return [self]
6262

6363

64+
class Product(Shape):
65+
"""Shape that takes two polytopes and 'multiplies' them together to obtain higher-dimensional shape"""
66+
67+
def __init__(self, *polytopes, method, value):
68+
# TODO
69+
all_axes = []
70+
for poly in polytopes:
71+
all_axes.extend(poly.axes())
72+
self._axes = list(set(all_axes))
73+
# Check there weren't any duplicates in the polytopes' axes
74+
assert len(self._axes) == len(all_axes)
75+
76+
self._polytopes = []
77+
for poly in polytopes:
78+
self._polytopes.extend(poly.polytope())
79+
80+
self.is_in_union = False
81+
self.method = method
82+
self.values = value
83+
84+
self.is_orthogonal = False
85+
86+
polys_orthogonal = [poly.is_orthogonal for poly in polytopes]
87+
if all(polys_orthogonal):
88+
self.is_orthogonal = True
89+
90+
def add_to_union(self):
91+
self.is_in_union = True
92+
93+
def axes(self):
94+
return self._axes
95+
96+
def polytope(self):
97+
return self._polytopes
98+
99+
64100
# This is the only shape which can slice on axes without a discretizer or interpolator
65101
class Select(Shape):
66102
"""Matches several discrete values"""
@@ -89,19 +125,22 @@ def __init__(self, axes, values, method=None):
89125
self._axes = axes
90126
self.values = values
91127
self.method = method
92-
self.polytopes = []
93-
if method == "nearest":
94-
assert len(self.values) == 1
95-
for i in range(len(axes)):
96-
polytope_points = [v[i] for v in self.values]
97-
self.polytopes.extend(
98-
[ConvexPolytope([axes[i]], [[point]], self.method, is_orthogonal=True) for point in polytope_points]
99-
)
128+
assert len(values) == 1
100129

101130
def axes(self):
102131
return self._axes
103132

104133
def polytope(self):
134+
# TODO: change this to use the Product instead and return a Product here of the two 1D selects
135+
136+
polytopes = []
137+
for point in self.values:
138+
poly_to_mult = []
139+
for i in range(len(self._axes)):
140+
poly_to_mult.append(ConvexPolytope([self._axes[i]], [[point[i]]], self.method, is_orthogonal=True))
141+
polytopes.append(Product(*poly_to_mult, method=self.method, value=[point]))
142+
self.polytopes = polytopes
143+
105144
return self.polytopes
106145

107146
def __repr__(self):

polytope_feature/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.35"
1+
__version__ = "1.0.36"

tests/test_point_shape.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from polytope_feature.engine.hullslicer import HullSlicer
66
from polytope_feature.polytope import Polytope, Request
7-
from polytope_feature.shapes import Point, Select
7+
from polytope_feature.shapes import Point, Select, Union
88

99

1010
class TestSlicing3DXarrayDatacube:
@@ -30,20 +30,26 @@ def test_point(self):
3030
assert result.leaves[0].axis.name == "level"
3131

3232
def test_multiple_points(self):
33-
request = Request(Point(["step", "level"], [[3, 10], [3, 12]]), Select("date", ["2000-01-01"]))
33+
# request = Request(Point(["step", "level"], [[3, 10], [3, 12]]), Select("date", ["2000-01-01"]))
34+
request = Request(
35+
Union(["step", "level"], Point(["step", "level"], [[3, 10]]), Point(["step", "level"], [[3, 12]])),
36+
Select("date", ["2000-01-01"]),
37+
)
3438
result = self.API.retrieve(request)
3539
result.pprint()
36-
assert len(result.leaves) == 1
40+
assert len(result.leaves) == 2
3741
assert result.leaves[0].axis.name == "level"
3842

3943
def test_point_surrounding_step(self):
4044
request = Request(Point(["step", "level"], [[2, 10]], method="surrounding"), Select("date", ["2000-01-01"]))
4145
result = self.API.retrieve(request)
46+
result.pprint()
4247
assert len(result.leaves) == 1
4348
assert np.shape(result.leaves[0].result[1]) == (1, 2, 3)
4449

4550
def test_point_surrounding_exact_step(self):
4651
request = Request(Point(["step", "level"], [[3, 10]], method="surrounding"), Select("date", ["2000-01-01"]))
4752
result = self.API.retrieve(request)
53+
result.pprint()
4854
assert len(result.leaves) == 1
4955
assert np.shape(result.leaves[0].result[1]) == (1, 3, 3)

tests/test_point_union.py

+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import pandas as pd
2+
import pytest
3+
4+
from polytope_feature.engine.hullslicer import HullSlicer
5+
from polytope_feature.polytope import Polytope, Request
6+
from polytope_feature.shapes import Point, Select, Span, Union
7+
8+
9+
class TestSlicingFDBDatacube:
10+
def setup_method(self, method):
11+
# Create a dataarray with 3 labelled axes using different index types
12+
self.options = {
13+
"axis_config": [
14+
{"axis_name": "number", "transformations": [{"name": "type_change", "type": "int"}]},
15+
{"axis_name": "step", "transformations": [{"name": "type_change", "type": "int"}]},
16+
{
17+
"axis_name": "date",
18+
"transformations": [{"name": "merge", "other_axis": "time", "linkers": ["T", "00"]}],
19+
},
20+
{
21+
"axis_name": "values",
22+
"transformations": [
23+
{"name": "mapper", "type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}
24+
],
25+
},
26+
{"axis_name": "latitude", "transformations": [{"name": "reverse", "is_reverse": True}]},
27+
{"axis_name": "longitude", "transformations": [{"name": "cyclic", "range": [0, 360]}]},
28+
],
29+
"pre_path": {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper"},
30+
"compressed_axes_config": [
31+
"longitude",
32+
"latitude",
33+
"levtype",
34+
"step",
35+
"date",
36+
"domain",
37+
"expver",
38+
"param",
39+
"class",
40+
"stream",
41+
"type",
42+
],
43+
}
44+
45+
# Testing different shapes
46+
@pytest.mark.fdb
47+
def test_fdb_datacube(self):
48+
import pygribjump as gj
49+
50+
request = Request(
51+
Select("step", [0]),
52+
Select("levtype", ["sfc"]),
53+
Span("date", pd.Timestamp("20230625T120000"), pd.Timestamp("20230626T120000")),
54+
Select("domain", ["g"]),
55+
Select("expver", ["0001"]),
56+
Select("param", ["167"]),
57+
Select("class", ["od"]),
58+
Select("stream", ["oper"]),
59+
Select("type", ["an"]),
60+
Union(
61+
["latitude", "longitude"],
62+
Point(["latitude", "longitude"], [[20, 20]], method="nearest"),
63+
Point(["latitude", "longitude"], [[0, 0]], method="nearest"),
64+
Point(["latitude", "longitude"], [[0, 20]], method="nearest"),
65+
Point(["latitude", "longitude"], [[25, 30]], method="nearest"),
66+
Point(["latitude", "longitude"], [[-30, 90]], method="nearest"),
67+
Point(["latitude", "longitude"], [[-60, -30]], method="nearest"),
68+
Point(["latitude", "longitude"], [[-15, -45]], method="nearest"),
69+
Point(["latitude", "longitude"], [[20, 0]], method="nearest"),
70+
),
71+
)
72+
73+
self.fdbdatacube = gj.GribJump()
74+
self.slicer = HullSlicer()
75+
self.API = Polytope(
76+
datacube=self.fdbdatacube,
77+
engine=self.slicer,
78+
options=self.options,
79+
)
80+
result = self.API.retrieve(request)
81+
result.pprint()
82+
assert len(result.leaves) == 8
83+
84+
@pytest.mark.fdb
85+
def test_fdb_datacube_surrounding(self):
86+
import pygribjump as gj
87+
88+
request = Request(
89+
Select("step", [0]),
90+
Select("levtype", ["sfc"]),
91+
Span("date", pd.Timestamp("20230625T120000"), pd.Timestamp("20230626T120000")),
92+
Select("domain", ["g"]),
93+
Select("expver", ["0001"]),
94+
Select("param", ["167"]),
95+
Select("class", ["od"]),
96+
Select("stream", ["oper"]),
97+
Select("type", ["an"]),
98+
Union(
99+
["latitude", "longitude"],
100+
Point(["latitude", "longitude"], [[25, 30]], method="surrounding"),
101+
Point(["latitude", "longitude"], [[-15, -45]], method="surrounding"),
102+
),
103+
)
104+
105+
self.fdbdatacube = gj.GribJump()
106+
self.slicer = HullSlicer()
107+
self.API = Polytope(
108+
datacube=self.fdbdatacube,
109+
engine=self.slicer,
110+
options=self.options,
111+
)
112+
result = self.API.retrieve(request)
113+
result.pprint()
114+
assert len(result.leaves) == 4
115+
tot_leaves = 0
116+
for leaf in result.leaves:
117+
tot_leaves += len(leaf.result)
118+
assert tot_leaves == 9
119+
120+
# @pytest.mark.fdb
121+
# def test_fdb_datacube_mix_methods(self):
122+
# import pygribjump as gj
123+
124+
# request = Request(
125+
# Select("step", [0]),
126+
# Select("levtype", ["sfc"]),
127+
# Span("date", pd.Timestamp("20230625T120000"), pd.Timestamp("20230626T120000")),
128+
# Select("domain", ["g"]),
129+
# Select("expver", ["0001"]),
130+
# Select("param", ["167"]),
131+
# Select("class", ["od"]),
132+
# Select("stream", ["oper"]),
133+
# Select("type", ["an"]),
134+
# Union(["latitude", "longitude"],
135+
# Point(["latitude", "longitude"], [[25, 30]], method="nearest"),
136+
# Point(["latitude", "longitude"], [[-15, -45]], method="surrounding"))
137+
# )
138+
139+
# self.fdbdatacube = gj.GribJump()
140+
# self.slicer = HullSlicer()
141+
# self.API = Polytope(
142+
# datacube=self.fdbdatacube,
143+
# engine=self.slicer,
144+
# options=self.options,
145+
# )
146+
# result = self.API.retrieve(request)
147+
# result.pprint()
148+
# assert len(result.leaves) == 6

0 commit comments

Comments
 (0)