Skip to content

Commit b63f960

Browse files
holtskinnercopybara-github
authored andcommitted
fix: LVM - Update Video.load_from_file() to support storage.googleapis.com links
PiperOrigin-RevId: 649149724
1 parent a6f68df commit b63f960

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

tests/unit/aiplatform/test_vision_models.py

+72
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,12 @@ def generate_video_from_gcs_uri(
184184
return ga_vision_models.Video.load_from_file(gcs_uri)
185185

186186

187+
def generate_video_from_storage_url(
188+
gcs_uri: str = "https://storage.googleapis.com/cloud-samples-data/vertex-ai-vision/highway_vehicles.mp4",
189+
) -> ga_vision_models.Video:
190+
return ga_vision_models.Video.load_from_file(gcs_uri)
191+
192+
187193
@pytest.mark.usefixtures("google_auth_mock")
188194
class TestImageGenerationModels:
189195
"""Unit tests for the image generation models."""
@@ -1215,6 +1221,72 @@ def test_video_embedding_model_with_only_video(self):
12151221
assert not embedding_response.text_embedding
12161222
assert not embedding_response.image_embedding
12171223

1224+
def test_video_embedding_model_with_storage_url(self):
1225+
aiplatform.init(
1226+
project=_TEST_PROJECT,
1227+
location=_TEST_LOCATION,
1228+
)
1229+
with mock.patch.object(
1230+
target=model_garden_service_client.ModelGardenServiceClient,
1231+
attribute="get_publisher_model",
1232+
return_value=gca_publisher_model.PublisherModel(
1233+
_IMAGE_EMBEDDING_PUBLISHER_MODEL_DICT
1234+
),
1235+
) as mock_get_publisher_model:
1236+
model = preview_vision_models.MultiModalEmbeddingModel.from_pretrained(
1237+
"multimodalembedding@001"
1238+
)
1239+
1240+
mock_get_publisher_model.assert_called_once_with(
1241+
name="publishers/google/models/multimodalembedding@001",
1242+
retry=base._DEFAULT_RETRY,
1243+
)
1244+
1245+
test_video_embeddings = [
1246+
ga_vision_models.VideoEmbedding(
1247+
start_offset_sec=0,
1248+
end_offset_sec=7,
1249+
embedding=[0, 7],
1250+
)
1251+
]
1252+
1253+
gca_predict_response = gca_prediction_service.PredictResponse()
1254+
gca_predict_response.predictions.append(
1255+
{
1256+
"videoEmbeddings": [
1257+
{
1258+
"startOffsetSec": test_video_embeddings[0].start_offset_sec,
1259+
"endOffsetSec": test_video_embeddings[0].end_offset_sec,
1260+
"embedding": test_video_embeddings[0].embedding,
1261+
}
1262+
]
1263+
}
1264+
)
1265+
1266+
video = generate_video_from_storage_url()
1267+
1268+
with mock.patch.object(
1269+
target=prediction_service_client.PredictionServiceClient,
1270+
attribute="predict",
1271+
return_value=gca_predict_response,
1272+
):
1273+
embedding_response = model.get_embeddings(video=video)
1274+
1275+
assert (
1276+
embedding_response.video_embeddings[0].embedding
1277+
== test_video_embeddings[0].embedding
1278+
)
1279+
assert (
1280+
embedding_response.video_embeddings[0].start_offset_sec
1281+
== test_video_embeddings[0].start_offset_sec
1282+
)
1283+
assert (
1284+
embedding_response.video_embeddings[0].end_offset_sec
1285+
== test_video_embeddings[0].end_offset_sec
1286+
)
1287+
assert not embedding_response.text_embedding
1288+
assert not embedding_response.image_embedding
1289+
12181290
def test_video_embedding_model_with_video_and_text(self):
12191291
aiplatform.init(
12201292
project=_TEST_PROJECT,

vertexai/vision_models/_vision_models.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def __init__(
187187
video_bytes: Optional[bytes] = None,
188188
gcs_uri: Optional[str] = None,
189189
):
190-
"""Creates an `Image` object.
190+
"""Creates a `Video` object.
191191
192192
Args:
193193
video_bytes: Video file bytes. Video can be in AVI, FLV, MKV, MOV,
@@ -211,9 +211,20 @@ def load_from_file(location: str) -> "Video":
211211
Returns:
212212
Loaded video as an `Video` object.
213213
"""
214-
if location.startswith("gs://"):
214+
parsed_url = urllib.parse.urlparse(location)
215+
if (
216+
parsed_url.scheme == "https"
217+
and parsed_url.netloc == "storage.googleapis.com"
218+
):
219+
parsed_url = parsed_url._replace(
220+
scheme="gs", netloc="", path=f"/{urllib.parse.unquote(parsed_url.path)}"
221+
)
222+
location = urllib.parse.urlunparse(parsed_url)
223+
224+
if parsed_url.scheme == "gs":
215225
return Video(gcs_uri=location)
216226

227+
# Load video from local path
217228
video_bytes = pathlib.Path(location).read_bytes()
218229
video = Video(video_bytes=video_bytes)
219230
return video

0 commit comments

Comments
 (0)