Skip to content

issue: keras-cv MixedPrecision in KPL #1784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
innat opened this issue May 14, 2023 · 8 comments · Fixed by #1860
Closed

issue: keras-cv MixedPrecision in KPL #1784

innat opened this issue May 14, 2023 · 8 comments · Fixed by #1860

Comments

@innat
Copy link
Contributor

innat commented May 14, 2023

Describe

TF 2.12
KerasCV: 0.5

I've tried to use mixed precision in object detection pipelines but it gives error,

keras.mixed_precision.set_global_policy('mixed_float16')
augmenter = keras.Sequential(
    layers=[
        preprocessor,

        keras_cv.layers.RandomFlip(mode="horizontal", bounding_box_format="xywh"),
        keras_cv.layers.RandomChannelShift(value_range=(0, 255), factor=0.2),
        keras_cv.layers.MixUp(alpha=0.4),
        keras_cv.layers.Mosaic(bounding_box_format="xywh")
    ]
)
ValueError: Error in map_fn:
      Expected `fn` to return a:
        RaggedTensorSpec(TensorShape([None, 4]), tf.float16, 1, tf.int64)
      But it returned a:
        RaggedTensorSpec(TensorShape([None, 4]), tf.float32, 1, tf.int64)

Gist.

@jbischof
Copy link
Contributor

@innat is this an object detection issue? I don't see any modeling code in your gist.

If this can be localized to compatibility of the augmentations themselves, let's center the issue there.

@innat innat changed the title issue: keras-cv MixedPrecision in object detection issue: keras-cv MixedPrecision in KPL May 14, 2023
@barrypitman
Copy link
Contributor

barrypitman commented May 30, 2023

I'm running into the same issue. It seems like several augmentation layers don't support mixed precision.

import tensorflow as tf
from tensorflow import keras

import keras_cv

keras.mixed_precision.set_global_policy("mixed_float16")

img = tf.random.uniform(
    shape=(3, 512, 512, 3), minval=0, maxval=255, dtype=tf.float32
)

layer = keras_cv.layers.Resizing(bounding_box_format="xywh", 
                                 width=640, 
                                 height=640, 
                                 pad_to_aspect_ratio=True)

bounding_boxes = {
    "classes": tf.ones((3, 5), dtype=tf.float32),
    "boxes": tf.random.uniform(shape=(3, 5, 4), minval=0, maxval=512, dtype=tf.float32),
}
layer({"images": img, "bounding_boxes": bounding_boxes})

Running this gives the following error:

Error in map_fn:
  Expected `fn` to return a:
    RaggedTensorSpec(TensorShape([None, 4]), tf.float16, 1, tf.int64)
  But it returned a:
    RaggedTensorSpec(TensorShape([5, 4]), tf.float32, 1, tf.int64)
    (value=<tf.RaggedTensor [[419.0625, 513.75, 33.75, 126.25],
 [528.125, 376.25, 111.875, 263.75],
 [126.640625, 639.375, 513.3594, 0.625],
 [515.625, 131.25, 124.375, 18.34961],
 [445.0, 495.0, 73.24219, 145.0]]>)
  To fix, update the `fn_output_signature` (or `dtype`) argument to `map_fn`.

Call arguments received by layer 'resizing' (type Resizing):
  • inputs={'images': 'tf.Tensor(shape=(640, 640, 3), dtype=float32)', 'bounding_boxes': {'classes': 'tf.Tensor(shape=(5,), dtype=float32)', 'boxes': '<tf.RaggedTensor [[419.0625, 513.75, 33.75, 126.25],\n [528.125, 376.25, 111.875, 263.75],\n [126.640625, 639.375, 513.3594, 0.625],\n [515.625, 131.25, 124.375, 18.34961],\n [445.0, 495.0, 73.24219, 145.0]]>'}}

It seems like RandomRotation, Resizing, JitteredResize etc have similar issues. I think its an object detection issue because you don't run into the same issues for classification inputs. I see that there is a WithMixedPrecisionTest, but its only testing those layers using classification labels as input.

@ianstenbit
Copy link
Contributor

In general, these types of issues are caused by the usage of dtype constants in layers.
For Resizing, this is a likely culprit because the underlying implementation of smart_resize seems to always use FP32.

For the others, one place to start is to specify compute_dtype as the dtype whenever we call bounding_box.convert_format, which uses an implicit default of FP32.

We definitely should be testing bounding box augmentation in the WithMixedPrecisionTest

@jaygala223
Copy link
Contributor

Hi @ianstenbit, I would like to work on this. I understand that the issue is regarding conflicting dtypes.

How do I do this?

We definitely should be testing bounding box augmentation in the WithMixedPrecisionTest

Also, I checked the implementation of smart_resize and it does use FP32 (link)

@ianstenbit
Copy link
Contributor

Hi @ianstenbit, I would like to work on this. I understand that the issue is regarding conflicting dtypes.

How do I do this?

We definitely should be testing bounding box augmentation in the WithMixedPrecisionTest

Also, I checked the implementation of smart_resize and it does use FP32 (link)

For testing boxes in the WithMixedPrecisionTest, we'll need to add bounding boxes as inputs here

Here is an example of a similar test which includes bboxes.

@innat
Copy link
Contributor Author

innat commented Feb 19, 2024

@divyashreepathihalli Could you please take a look at this issue. The reported issue didn't solve with the above fix.

ValueError: Exception encountered when calling layer 'resizing' (type Resizing).

in user code:

    File "/usr/local/lib/python3.10/dist-packages/keras_cv/layers/preprocessing/base_image_augmentation_layer.py", line 435, in call  *
        outputs = self._format_output(
    File "/usr/local/lib/python3.10/dist-packages/keras_cv/layers/preprocessing/resizing.py", line 382, in _batch_augment  *
        return self._resize_with_pad(inputs)
    File "/usr/local/lib/python3.10/dist-packages/keras_cv/layers/preprocessing/resizing.py", line 285, in _resize_with_pad  *
        fn_output_signature=fn_output_signature,

    ValueError: Error in map_fn:
      Expected `fn` to return a:
        RaggedTensorSpec(TensorShape([None, 4]), tf.float16, 1, tf.int64)
      But it returned a:
        RaggedTensorSpec(TensorShape([None, None]), tf.float32, 1, tf.int64)
        (value=tf.RaggedTensor(values=Tensor("sequential_1/sequential/resizing/map/while/RaggedFromTensor/Reshape:0", shape=(None,), dtype=float32), row_splits=Tensor("sequential_1/sequential/resizing/map/while/RaggedFromTensor/RaggedFromUniformRowLength/RowPartitionFromUniformRowLength/mul:0", shape=(None,), dtype=int64)))
      To fix, update the `fn_output_signature` (or `dtype`) argument to `map_fn`.


Call arguments received by layer 'resizing' (type Resizing):
  • inputs={'images': 'tf.RaggedTensor(values=tf.RaggedTensor(values=Tensor("sequential_1/sequential/resizing/Cast_2:0", shape=(None, 3), dtype=float16), row_splits=Tensor("RaggedFromVariant_2/RaggedTensorFromVariant:1", shape=(None,), dtype=int64)), row_splits=Tensor("RaggedFromVariant_2/RaggedTensorFromVariant:0", shape=(9,), dtype=int64))', 'bounding_boxes': {'classes': 'tf.RaggedTensor(values=Tensor("sequential_1/sequential/resizing/Cast_1:0", shape=(None,), dtype=float16), row_splits=Tensor("RaggedFromVariant_1/RaggedTensorFromVariant:0", shape=(9,), dtype=int64))', 'boxes': 'tf.RaggedTensor(values=Tensor("sequential_1/sequential/resizing/Cast:0", shape=(None, 4), dtype=float16), row_splits=Tensor("RaggedFromVariant/RaggedTensorFromVariant:0", shape=(9,), dtype=int64))'}}

@innat
Copy link
Contributor Author

innat commented Feb 19, 2024

@james77777778 Any plan to progress this ticket?

@james77777778
Copy link
Contributor

@james77777778 Any plan to progress this ticket?

@innat

I'm afraid I don't have the bandwidth for the ticket.
The development of the preprocessing layer has changed after the migration to Keras3.

BTW, it is a bit weird for me that preprocessing layers still depend on tf.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants