Skip to content

Commit 775e295

Browse files
committed
allow label GeoTIFFs to be un-georeferenced to unbreak potsdam example
1 parent 5ac4e6c commit 775e295

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

rastervision_core/rastervision/core/data/label_source/semantic_segmentation_label_source.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,10 @@ def __getitem__(self, key: Any) -> Any:
101101
return self.get_label_arr(key)
102102
else:
103103
return super().__getitem__(key)
104+
105+
def __repr__(self):
106+
arg_keys = ['raster_source', 'class_config', 'bbox']
107+
arg_vals = [getattr(self, k) for k in arg_keys]
108+
arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
109+
arg_str = ', '.join(arg_strs)
110+
return f'{type(self).__name__}({arg_str})'

rastervision_core/rastervision/core/data/raster_source/rasterio_source.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,13 @@ def __getitem__(self, key: Any) -> 'np.ndarray':
192192

193193
chip = self.get_chip(window, bands=c, out_shape=out_shape)
194194
return chip
195+
196+
def __repr__(self):
197+
arg_keys = [
198+
'uris', 'channel_order', 'bbox', 'raster_transformers',
199+
'allow_streaming', 'tmp_dir'
200+
]
201+
arg_vals = [getattr(self, k) for k in arg_keys]
202+
arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
203+
arg_str = ', '.join(arg_strs)
204+
return f'{type(self).__name__}({arg_str})'

rastervision_core/rastervision/core/data/utils/misc.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,17 +103,41 @@ def match_bboxes(raster_source: 'RasterSource',
103103
label_source (LabelSource | LabelStore): Source of labels for a
104104
scene. Can be a ``LabelStore``.
105105
"""
106+
from rastervision.core.data import (RasterioCRSTransformer,
107+
SemanticSegmentationLabelSource)
106108
crs_tf_img = raster_source.crs_transformer
107109
crs_tf_label = label_source.crs_transformer
108110
bbox_img_map = crs_tf_img.pixel_to_map(raster_source.bbox)
111+
112+
# For SS, if a label file is not georeferenced but is the same size as
113+
# the corresponding raster source (as is the case in the Potsdam dataset),
114+
# implicitly assume that they are aligned.
115+
if isinstance(label_source,
116+
SemanticSegmentationLabelSource) and isinstance(
117+
crs_tf_label, RasterioCRSTransformer):
118+
if crs_tf_label.image_crs is None:
119+
if raster_source.extent != label_source.extent:
120+
raise ValueError(
121+
f'Label source ({label_source}) is not georeferenced and '
122+
f'has a different extent ({label_source.extent}) than the '
123+
f"corresponding raster source's ({raster_source}) extent "
124+
f'({raster_source.extent}).')
125+
log.warning(
126+
f'Label source ({label_source}) is not georeferenced but has '
127+
f'the same extent ({label_source.extent}) as the '
128+
f'corresponding raster source ({raster_source}). '
129+
'Will assume they are aligned.')
130+
return
131+
109132
if label_source.bbox is not None:
110133
bbox_label_map = crs_tf_label.pixel_to_map(label_source.bbox)
111134
if not bbox_img_map.intersects(bbox_label_map):
112135
rs_cls = type(raster_source).__name__
113136
ls_cls = type(label_source).__name__
114-
log.warning(f'{rs_cls} bbox ({bbox_img_map}) does '
115-
f'not intersect with {ls_cls} bbox '
116-
f'({bbox_label_map}).')
137+
raise ValueError(f'{rs_cls} bbox ({bbox_img_map}) does '
138+
f'not intersect with {ls_cls} bbox '
139+
f'({bbox_label_map}).')
140+
117141
# set LabelStore bbox to RasterSource bbox
118142
bbox_label_pixel = crs_tf_label.map_to_pixel(bbox_img_map)
119143
label_source.set_bbox(bbox_label_pixel)

0 commit comments

Comments
 (0)