Skip to content

Unbreak ISPRS Potsdam example #2319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,10 @@ def __getitem__(self, key: Any) -> Any:
return self.get_label_arr(key)
else:
return super().__getitem__(key)

def __repr__(self):
arg_keys = ['raster_source', 'class_config', 'bbox']
arg_vals = [getattr(self, k) for k in arg_keys]
arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
arg_str = ', '.join(arg_strs)
return f'{type(self).__name__}({arg_str})'
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,13 @@ def __getitem__(self, key: Any) -> 'np.ndarray':

chip = self.get_chip(window, bands=c, out_shape=out_shape)
return chip

def __repr__(self):
arg_keys = [
'uris', 'channel_order', 'bbox', 'raster_transformers',
'allow_streaming', 'tmp_dir'
]
arg_vals = [getattr(self, k) for k in arg_keys]
arg_strs = [f'{k}={v!r}' for k, v in zip(arg_keys, arg_vals)]
arg_str = ', '.join(arg_strs)
return f'{type(self).__name__}({arg_str})'
30 changes: 27 additions & 3 deletions rastervision_core/rastervision/core/data/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,41 @@ def match_bboxes(raster_source: 'RasterSource',
label_source (LabelSource | LabelStore): Source of labels for a
scene. Can be a ``LabelStore``.
"""
from rastervision.core.data import (RasterioCRSTransformer,
SemanticSegmentationLabelSource)
crs_tf_img = raster_source.crs_transformer
crs_tf_label = label_source.crs_transformer
bbox_img_map = crs_tf_img.pixel_to_map(raster_source.bbox)

# For SS, if a label file is not georeferenced but is the same size as
# the corresponding raster source (as is the case in the Potsdam dataset),
# implicitly assume that they are aligned.
if isinstance(label_source,
SemanticSegmentationLabelSource) and isinstance(
crs_tf_label, RasterioCRSTransformer):
if crs_tf_label.image_crs is None:
if raster_source.extent != label_source.extent:
raise ValueError(
f'Label source ({label_source}) is not georeferenced and '
f'has a different extent ({label_source.extent}) than the '
f"corresponding raster source's ({raster_source}) extent "
f'({raster_source.extent}).')
log.warning(
f'Label source ({label_source}) is not georeferenced but has '
f'the same extent ({label_source.extent}) as the '
f'corresponding raster source ({raster_source}). '
'Will assume they are aligned.')
return

if label_source.bbox is not None:
bbox_label_map = crs_tf_label.pixel_to_map(label_source.bbox)
if not bbox_img_map.intersects(bbox_label_map):
rs_cls = type(raster_source).__name__
ls_cls = type(label_source).__name__
log.warning(f'{rs_cls} bbox ({bbox_img_map}) does '
f'not intersect with {ls_cls} bbox '
f'({bbox_label_map}).')
raise ValueError(f'{rs_cls} bbox ({bbox_img_map}) does '
f'not intersect with {ls_cls} bbox '
f'({bbox_label_map}).')

# set LabelStore bbox to RasterSource bbox
bbox_label_pixel = crs_tf_label.map_to_pixel(bbox_img_map)
label_source.set_bbox(bbox_label_pixel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_config(runner,
external_model: bool = True,
augment: bool = False,
nochip: bool = True,
allow_streaming: bool = False,
num_epochs: int = 10,
batch_sz: int = 8,
test: bool = False) -> SemanticSegmentationConfig:
Expand All @@ -69,6 +70,8 @@ def get_config(runner,
training instead of from pre-generated chips. The analyze and chip
commands should not be run, if this is set to True. Defaults to
True.
allow_streaming (bool): If True, read directly from remote files
instead of downloading them. Defaults to False.
num_epochs (int): Number of epochs to train for.
batch_sz (int): Batch size.
test (bool): If True, does the following simplifications:
Expand Down Expand Up @@ -142,7 +145,9 @@ def make_scene(id) -> SceneConfig:
label_uri = label_crop_uri

raster_source = RasterioSourceConfig(
uris=[raster_uri], channel_order=channel_order)
uris=[raster_uri],
channel_order=channel_order,
allow_streaming=allow_streaming)

# Using with_rgb_class_map because label TIFFs have classes encoded as
# RGB colors.
Expand All @@ -151,7 +156,8 @@ def make_scene(id) -> SceneConfig:
uris=[label_uri],
transformers=[
RGBClassTransformerConfig(class_config=class_config)
]))
],
allow_streaming=allow_streaming))

# URI will be injected by scene config.
# Using rgb=True because we want prediction TIFFs to be in
Expand Down