Skip to content

Add AVIF decoder (Part 1- this is not public or available yet) #8596

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ option(WITH_JPEG "Enable features requiring LibJPEG." ON)
# untested. Since building from cmake is very low pri anyway, this is OK. If
# you're a user and you need this, please open an issue (and a PR!).
option(WITH_WEBP "Enable features requiring LibWEBP." OFF)
# Same here
option(WITH_AVIF "Enable features requiring LibAVIF." OFF)

if(WITH_CUDA)
enable_language(CUDA)
Expand Down Expand Up @@ -41,6 +43,11 @@ if (WITH_WEBP)
find_package(WEBP REQUIRED)
endif()

if (WITH_AVIF)
add_definitions(-DAVIF_FOUND)
find_package(AVIF REQUIRED)
endif()

function(CUDA_CONVERT_FLAGS EXISTING_TARGET)
get_property(old_flags TARGET ${EXISTING_TARGET} PROPERTY INTERFACE_COMPILE_OPTIONS)
if(NOT "${old_flags}" STREQUAL "")
Expand Down Expand Up @@ -117,6 +124,10 @@ if (WITH_WEBP)
target_link_libraries(${PROJECT_NAME} PRIVATE ${WEBP_LIBRARIES})
endif()

if (WITH_AVIF)
target_link_libraries(${PROJECT_NAME} PRIVATE ${AVIF_LIBRARIES})
endif()

set_target_properties(${PROJECT_NAME} PROPERTIES
EXPORT_NAME TorchVision
INSTALL_RPATH ${TORCH_INSTALL_PREFIX}/lib)
Expand All @@ -135,6 +146,10 @@ if (WITH_WEBP)
include_directories(${WEBP_INCLUDE_DIRS})
endif()

if (WITH_AVIF)
include_directories(${AVIF_INCLUDE_DIRS})
endif()

set(TORCHVISION_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchVision" CACHE STRING "install path for TorchVisionConfig.cmake")

configure_package_config_file(cmake/TorchVisionConfig.cmake.in
Expand Down
17 changes: 17 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
USE_PNG = os.getenv("TORCHVISION_USE_PNG", "1") == "1"
USE_JPEG = os.getenv("TORCHVISION_USE_JPEG", "1") == "1"
USE_WEBP = os.getenv("TORCHVISION_USE_WEBP", "1") == "1"
USE_AVIF = os.getenv("TORCHVISION_USE_AVIF", "0") == "1" # TODO enable by default!
USE_NVJPEG = os.getenv("TORCHVISION_USE_NVJPEG", "1") == "1"
NVCC_FLAGS = os.getenv("NVCC_FLAGS", None)
# Note: the GPU video decoding stuff used to be called "video codec", which
Expand Down Expand Up @@ -49,6 +50,7 @@
print(f"{USE_PNG = }")
print(f"{USE_JPEG = }")
print(f"{USE_WEBP = }")
print(f"{USE_AVIF = }")
print(f"{USE_NVJPEG = }")
print(f"{NVCC_FLAGS = }")
print(f"{USE_CPU_VIDEO_DECODER = }")
Expand Down Expand Up @@ -332,6 +334,21 @@ def make_image_extension():
else:
warnings.warn("Building torchvision without WEBP support")

if USE_AVIF:
avif_found, avif_include_dir, avif_library_dir = find_library(header="avif/avif.h")
if avif_found:
print("Building torchvision with AVIF support")
print(f"{avif_include_dir = }")
print(f"{avif_library_dir = }")
if avif_include_dir is not None and avif_library_dir is not None:
# if those are None it means they come from standard paths that are already in the search paths, which we don't need to re-add.
include_dirs.append(avif_include_dir)
library_dirs.append(avif_library_dir)
libraries.append("avif")
define_macros += [("AVIF_FOUND", 1)]
else:
warnings.warn("Building torchvision without AVIF support")

if USE_NVJPEG and torch.cuda.is_available():
nvjpeg_found = CUDA_HOME is not None and (Path(CUDA_HOME) / "include/nvjpeg.h").exists()

Expand Down
Binary file added test/assets/fakedata/logos/rgb_pytorch.avif
Binary file not shown.
15 changes: 14 additions & 1 deletion test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_decode_avif,
decode_gif,
decode_image,
decode_jpeg,
Expand Down Expand Up @@ -873,7 +874,7 @@ def test_decode_gif_webp_errors(decode_fun):
decode_fun(encoded_data[::2])
if decode_fun is decode_gif:
expected_match = re.escape("DGifOpenFileName() failed - 103")
else:
elif decode_fun is decode_webp:
expected_match = "WebPDecodeRGB failed."
with pytest.raises(RuntimeError, match=expected_match):
decode_fun(encoded_data)
Expand All @@ -890,5 +891,17 @@ def test_decode_webp(decode_fun, scripted):
assert img[None].is_contiguous(memory_format=torch.channels_last)


@pytest.mark.xfail(reason="AVIF support not enabled yet.")
@pytest.mark.parametrize("decode_fun", (_decode_avif, decode_image))
@pytest.mark.parametrize("scripted", (False, True))
def test_decode_avif(decode_fun, scripted):
encoded_bytes = read_file(next(get_images(FAKEDATA_DIR, ".avif")))
if scripted:
decode_fun = torch.jit.script(decode_fun)
img = decode_fun(encoded_bytes)
assert img.shape == (3, 100, 100)
assert img[None].is_contiguous(memory_format=torch.channels_last)


if __name__ == "__main__":
pytest.main([__file__])
92 changes: 92 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_avif.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include "decode_avif.h"

#if AVIF_FOUND
#include "avif/avif.h"
#endif // AVIF_FOUND

namespace vision {
namespace image {

#if !AVIF_FOUND
torch::Tensor decode_avif(const torch::Tensor& data) {
TORCH_CHECK(
false, "decode_avif: torchvision not compiled with libavif support");
}
#else

// This normally comes from avif_cxx.h, but it's not always present when
// installing libavif. So we just copy/paste it here.
struct UniquePtrDeleter {
void operator()(avifDecoder* decoder) const {
avifDecoderDestroy(decoder);
}
};
using DecoderPtr = std::unique_ptr<avifDecoder, UniquePtrDeleter>;

torch::Tensor decode_avif(const torch::Tensor& encoded_data) {
// This is based on
// https://github.com/AOMediaCodec/libavif/blob/main/examples/avif_example_decode_memory.c
// Refer there for more detail about what each function does, and which
// structure/data is available after which call.

TORCH_CHECK(encoded_data.is_contiguous(), "Input tensor must be contiguous.");
TORCH_CHECK(
encoded_data.dtype() == torch::kU8,
"Input tensor must have uint8 data type, got ",
encoded_data.dtype());
TORCH_CHECK(
encoded_data.dim() == 1,
"Input tensor must be 1-dimensional, got ",
encoded_data.dim(),
" dims.");

DecoderPtr decoder(avifDecoderCreate());
TORCH_CHECK(decoder != nullptr, "Failed to create avif decoder.");

auto result = AVIF_RESULT_UNKNOWN_ERROR;
result = avifDecoderSetIOMemory(
decoder.get(), encoded_data.data_ptr<uint8_t>(), encoded_data.numel());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderSetIOMemory failed:",
avifResultToString(result));

result = avifDecoderParse(decoder.get());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderParse failed: ",
avifResultToString(result));
TORCH_CHECK(
decoder->imageCount == 1, "Avif file contains more than one image");
TORCH_CHECK(
decoder->image->depth <= 8,
"avif images with bitdepth > 8 are not supported");

result = avifDecoderNextImage(decoder.get());
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifDecoderNextImage failed:",
avifResultToString(result));

auto out = torch::empty(
{decoder->image->height, decoder->image->width, 3}, torch::kUInt8);

avifRGBImage rgb;
memset(&rgb, 0, sizeof(rgb));
avifRGBImageSetDefaults(&rgb, decoder->image);
rgb.format = AVIF_RGB_FORMAT_RGB;
rgb.pixels = out.data_ptr<uint8_t>();
rgb.rowBytes = rgb.width * avifRGBImagePixelSize(&rgb);

result = avifImageYUVToRGB(decoder->image, &rgb);
TORCH_CHECK(
result == AVIF_RESULT_OK,
"avifImageYUVToRGB failed: ",
avifResultToString(result));

return out.permute({2, 0, 1}); // return CHW, channels-last
}
#endif // AVIF_FOUND

} // namespace image
} // namespace vision
11 changes: 11 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_avif.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <torch/types.h>

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_avif(const torch::Tensor& data);

} // namespace image
} // namespace vision
13 changes: 13 additions & 0 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "decode_image.h"

#include "decode_avif.h"
#include "decode_gif.h"
#include "decode_jpeg.h"
#include "decode_png.h"
Expand Down Expand Up @@ -48,6 +49,18 @@ torch::Tensor decode_image(
return decode_gif(data);
}

// We assume the signature of an avif file is
// 0000 0020 6674 7970 6176 6966
// xxxx xxxx f t y p a v i f
// We only check for the "ftyp avif" part.
// This is probably not perfect, but hopefully this should cover most files.
const uint8_t avif_signature[8] = {
0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66}; // == "ftypavif"
TORCH_CHECK(data.numel() >= 12, err_msg);
if ((memcmp(avif_signature, datap + 4, 8) == 0)) {
return decode_avif(data);
}

const uint8_t webp_signature_begin[4] = {0x52, 0x49, 0x46, 0x46}; // == "RIFF"
const uint8_t webp_signature_end[7] = {
0x57, 0x45, 0x42, 0x50, 0x56, 0x50, 0x38}; // == "WEBPVP8"
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ static auto registry =
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_jpeg)
.op("image::decode_webp", &decode_webp)
.op("image::decode_avif", &decode_avif)
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
Expand Down
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "cpu/decode_avif.h"
#include "cpu/decode_gif.h"
#include "cpu/decode_image.h"
#include "cpu/decode_jpeg.h"
Expand Down
2 changes: 2 additions & 0 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
"decode_image",
"decode_jpeg",
"decode_png",
"decode_webp",
"decode_gif",
"encode_jpeg",
"encode_png",
"read_file",
Expand Down
8 changes: 8 additions & 0 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,11 @@ def decode_webp(
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_webp(input)


def _decode_avif(
input: torch.Tensor,
) -> torch.Tensor:
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_webp)
return torch.ops.image.decode_avif(input)
Loading