diff --git a/movement/io/load_poses.py b/movement/io/load_poses.py index fd38055fe..8eb83b35a 100644 --- a/movement/io/load_poses.py +++ b/movement/io/load_poses.py @@ -1,7 +1,7 @@ """Load pose tracking data from various frameworks into ``movement``.""" from pathlib import Path -from typing import Literal +from typing import Any, Literal, cast import h5py import numpy as np @@ -40,7 +40,7 @@ def from_numpy( Array of shape (n_frames, n_keypoints, n_individuals) containing the point-wise confidence scores. It will be converted to a :class:`xarray.DataArray` object named "confidence". - If None (default), the scores will be set to an array of NaNs. + If None (default), no confidence data variable is included. individual_names : list of str, optional List of unique names for the individuals in the video. If None (default), the individuals will be named "individual_0", @@ -59,8 +59,8 @@ def from_numpy( Returns ------- xarray.Dataset - ``movement`` dataset containing the pose tracks, confidence scores, - and associated metadata. + ``movement`` dataset containing the pose tracks, confidence scores + (if provided), and associated metadata. Examples -------- @@ -94,7 +94,11 @@ def from_numpy( def from_file( file_path: Path | str, source_software: Literal[ - "DeepLabCut", "SLEAP", "LightningPose", "Anipose" + "DeepLabCut", + "SLEAP", + "LightningPose", + "Anipose", + "animovement", ], fps: float | None = None, **kwargs, @@ -106,10 +110,17 @@ def from_file( file_path : pathlib.Path or str Path to the file containing predicted poses. The file format must be among those supported by the ``from_dlc_file()``, - ``from_slp_file()`` or ``from_lp_file()`` functions. One of these - these functions will be called internally, based on + ``from_slp_file()``, ``from_lp_file()``, ``from_anipose_file()``, + or ``from_animovement_file()`` functions. One of these + functions will be called internally, based on the value of ``source_software``. - source_software : "DeepLabCut", "SLEAP", "LightningPose", or "Anipose" + source_software : Literal[ + "DeepLabCut", + "SLEAP", + "LightningPose", + "Anipose", + "animovement", + ] The source software of the file. fps : float, optional The number of frames per second in the video. If None (default), @@ -130,6 +141,7 @@ def from_file( movement.io.load_poses.from_sleap_file movement.io.load_poses.from_lp_file movement.io.load_poses.from_anipose_file + movement.io.load_poses.from_animovement_file Examples -------- @@ -147,6 +159,8 @@ def from_file( return from_lp_file(file_path, fps) elif source_software == "Anipose": return from_anipose_file(file_path, fps, **kwargs) + elif source_software == "animovement": + return from_animovement_file(file_path, fps) else: raise logger.error( ValueError(f"Unsupported source software: {source_software}") @@ -289,6 +303,7 @@ def from_sleap_file( # Add metadata as attrs ds.attrs["source_file"] = file.path.as_posix() logger.info(f"Loaded pose tracks from {file.path}:\n{ds}") + logger.info(ds) return ds @@ -506,18 +521,30 @@ def _ds_from_sleap_labels_file( file = ValidHDF5(file_path, expected_datasets=["pred_points", "metadata"]) labels = read_labels(file.path.as_posix()) tracks_with_scores = _sleap_labels_to_numpy(labels) - individual_names = [track.name for track in labels.tracks] or None - if individual_names is None: + + individual_names: list[str] = ( + [track.name for track in labels.tracks] + if labels.tracks + else ["individual_0"] + ) + if not labels.tracks: logger.warning( f"Could not find SLEAP Track in {file.path}. " "Assuming single-individual dataset and assigning " "default individual name." ) + + keypoint_names: list[str] = [kp.name for kp in labels.skeletons[0].nodes] + + # Explicit type assertions for mypy + individual_names = cast(list[str], individual_names) + keypoint_names = cast(list[str], keypoint_names) + return from_numpy( position_array=tracks_with_scores[:, :-1, :, :], confidence_array=tracks_with_scores[:, -1, :, :], individual_names=individual_names, - keypoint_names=[kp.name for kp in labels.skeletons[0].nodes], + keypoint_names=keypoint_names, fps=fps, source_software="SLEAP", ) @@ -559,15 +586,19 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray: lfs = [lf for lf in labels.labeled_frames if lf.video == labels.videos[0]] # Figure out frame index range frame_idxs = [lf.frame_idx for lf in lfs] - first_frame = min(0, min(frame_idxs)) - last_frame = max(0, max(frame_idxs)) + first_frame = min(0, min(frame_idxs)) if frame_idxs else 0 + last_frame = max(0, max(frame_idxs)) if frame_idxs else 0 n_tracks = len(labels.tracks) or 1 # If no tracks, assume 1 individual individuals = labels.tracks or [None] skeleton = labels.skeletons[-1] # Assume project only uses last skeleton n_nodes = len(skeleton.nodes) n_frames = int(last_frame - first_frame + 1) - tracks = np.full((n_frames, 3, n_nodes, n_tracks), np.nan, dtype="float32") + + # Initialize tracks array with explicit type + tracks: np.ndarray = np.full( + (n_frames, 3, n_nodes, n_tracks), np.nan, dtype=np.float32 + ) for lf in lfs: i = int(lf.frame_idx - first_frame) @@ -583,12 +614,18 @@ def _sleap_labels_to_numpy(labels: Labels) -> np.ndarray: # Use user-labelled instance if available if user_track_instances: inst = user_track_instances[-1] - tracks[i, ..., j] = np.hstack( - (inst.numpy(), np.full((n_nodes, 1), np.nan)) - ).T + points = inst.numpy() + for k in range(n_nodes): + tracks[i, 0, k, j] = points[k, 0] # x-coordinate + tracks[i, 1, k, j] = points[k, 1] # y-coordinate + tracks[i, 2, k, j] = np.nan # No scores for user instances elif predicted_track_instances: inst = predicted_track_instances[-1] - tracks[i, ..., j] = inst.numpy(scores=True).T + points = inst.numpy(scores=True) + for k in range(n_nodes): + tracks[i, 0, k, j] = points[k, 0] # x-coordinate + tracks[i, 1, k, j] = points[k, 1] # y-coordinate + tracks[i, 2, k, j] = points[k, 2] # confidence score return tracks @@ -670,8 +707,8 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: Returns ------- xarray.Dataset - ``movement`` dataset containing the pose tracks, confidence scores, - and associated metadata. + ``movement`` dataset containing the pose tracks, confidence scores + (if provided), and associated metadata. """ n_frames = data.position_array.shape[0] @@ -693,14 +730,18 @@ def _ds_from_valid_data(data: ValidPosesDataset) -> xr.Dataset: dataset_attrs["time_unit"] = time_unit DIM_NAMES = ValidPosesDataset.DIM_NAMES + # Initialize data_vars dictionary with position + data_vars = { + "position": xr.DataArray(data.position_array, dims=DIM_NAMES), + } + # Add confidence only if confidence_array is provided + if data.confidence_array is not None: + data_vars["confidence"] = xr.DataArray( + data.confidence_array, dims=DIM_NAMES[:1] + DIM_NAMES[2:] + ) # Convert data to an xarray.Dataset return xr.Dataset( - data_vars={ - "position": xr.DataArray(data.position_array, dims=DIM_NAMES), - "confidence": xr.DataArray( - data.confidence_array, dims=DIM_NAMES[:1] + DIM_NAMES[2:] - ), - }, + data_vars=data_vars, coords={ DIM_NAMES[0]: time_coords, DIM_NAMES[1]: ["x", "y", "z"][:n_space], @@ -716,42 +757,13 @@ def from_anipose_style_df( fps: float | None = None, individual_name: str = "individual_0", ) -> xr.Dataset: - """Create a ``movement`` poses dataset from an Anipose 3D dataframe. - - Parameters - ---------- - df : pd.DataFrame - Anipose triangulation dataframe - fps : float, optional - The number of frames per second in the video. If None (default), - the ``time`` coordinates will be in frame units. - individual_name : str, optional - Name of the individual, by default "individual_0" - - Returns - ------- - xarray.Dataset - ``movement`` dataset containing the pose tracks, confidence scores, - and associated metadata. - - - Notes - ----- - Reshape dataframe with columns keypoint1_x, keypoint1_y, keypoint1_z, - keypoint1_score,keypoint2_x, keypoint2_y, keypoint2_z, - keypoint2_score...to array of positions with dimensions - time, space, keypoints, individuals, and array of confidence (from scores) - with dimensions time, keypoints, individuals. - - """ - keypoint_names = sorted( + """Create a ``movement`` poses dataset from an Anipose 3D dataframe.""" + keypoint_names: list[str] = sorted( list( set( - [ - col.rsplit("_", 1)[0] - for col in df.columns - if any(col.endswith(f"_{s}") for s in ["x", "y", "z"]) - ] + col.rsplit("_", 1)[0] + for col in df.columns + if any(col.endswith(f"_{s}") for s in ["x", "y", "z"]) ) ) ) @@ -769,7 +781,7 @@ def from_anipose_style_df( position_array[:, j, i, 0] = df[f"{kp}_{coord}"] confidence_array[:, i, 0] = df[f"{kp}_score"] - individual_names = [individual_name] + individual_names: list[str] = [individual_name] return from_numpy( position_array=position_array, @@ -822,3 +834,156 @@ def from_anipose_file( return from_anipose_style_df( anipose_df, fps=fps, individual_name=individual_name ) + + +def from_tidy_df( + df: pd.DataFrame, + fps: float | None = None, + source_software: str = "animovement", +) -> xr.Dataset: + """Create a ``movement`` poses dataset from a tidy DataFrame. + + Parameters + ---------- + df : pandas.DataFrame + Tidy DataFrame containing pose tracks and confidence scores. + Expected columns: 'frame', 'track_id', 'keypoint', 'x', 'y', + and optionally 'confidence'. + fps : float, optional + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame numbers. + source_software : str, optional + Name of the pose estimation software or package from which the + data originate. Defaults to "animovement". + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + Notes + ----- + The DataFrame must have at least the following columns: + - 'frame': integer, the frame number (time index) + - 'track_id': string or integer, the individual ID + - 'keypoint': string, the keypoint name + - 'x': float, x-coordinate + - 'y': float, y-coordinate + - 'confidence': float, optional, point-wise confidence scores + + Examples + -------- + >>> import pandas as pd + >>> from movement.io import load_poses + >>> df = pd.DataFrame( + ... { + ... "frame": [0, 0, 1, 1], + ... "track_id": ["ind1", "ind1", "ind1", "ind1"], + ... "keypoint": ["nose", "tail", "nose", "tail"], + ... "x": [100.0, 150.0, 101.0, 151.0], + ... "y": [200.0, 250.0, 201.0, 251.0], + ... "confidence": [0.9, 0.8, 0.85, 0.75], + ... } + ... ) + >>> ds = load_poses.from_tidy_df(df, fps=30) + + """ + # Validate DataFrame columns + required_columns = {"frame", "track_id", "keypoint", "x", "y"} + if not required_columns.issubset(df.columns): + missing = required_columns - set(df.columns) + raise ValueError(f"DataFrame missing required columns: {missing}") + + # Ensure correct data types + df = df.astype( + { + "frame": int, + "track_id": str, + "keypoint": str, + "x": float, + "y": float, + } + ) + + # Get unique values for coordinates + time: np.ndarray[Any, np.dtype[np.int_]] = np.sort(df["frame"].unique()) + individuals: np.ndarray[Any, np.dtype[np.str_]] = df["track_id"].unique() + keypoints: np.ndarray[Any, np.dtype[np.str_]] = df["keypoint"].unique() + n_frames = len(time) + n_individuals = len(individuals) + n_keypoints = len(keypoints) + + # Initialize position and confidence arrays + position_array = np.full( + (n_frames, 2, n_keypoints, n_individuals), np.nan, dtype=float + ) + confidence_array = ( + np.full((n_frames, n_keypoints, n_individuals), np.nan, dtype=float) + if "confidence" in df.columns + else None + ) + + # Pivot data to fill arrays + for _idx, row in df.iterrows(): + t_idx = np.nonzero(time == row["frame"])[0][0] + i_idx = np.nonzero(individuals == row["track_id"])[0][0] + k_idx = np.nonzero(keypoints == row["keypoint"])[0][0] + position_array[t_idx, 0, k_idx, i_idx] = row["x"] + position_array[t_idx, 1, k_idx, i_idx] = row["y"] + if confidence_array is not None and "confidence" in row: + confidence_array[t_idx, k_idx, i_idx] = row["confidence"] + + # Explicitly convert to lists to ensure mypy recognizes list[str] + individual_names: list[str] = list(individuals) + keypoint_names: list[str] = list(keypoints) + + return from_numpy( + position_array=position_array, + confidence_array=confidence_array, + individual_names=individual_names, + keypoint_names=keypoint_names, + fps=fps, + source_software=source_software, + ) + + +def from_animovement_file( + file_path: Path | str, + fps: float | None = None, +) -> xr.Dataset: + """Create a ``movement`` poses dataset from an animovement Parquet file. + + Parameters + ---------- + file_path : pathlib.Path or str + Path to the Parquet file containing pose tracks in tidy format. + fps : float, optional + The number of frames per second in the video. If None (default), + the ``time`` coordinates will be in frame numbers. + + Returns + ------- + xarray.Dataset + ``movement`` dataset containing the pose tracks, confidence scores, + and associated metadata. + + Examples + -------- + >>> from movement.io import load_poses + >>> ds = load_poses.from_animovement_file("path/to/file.parquet", fps=30) + + """ + file = ValidFile( + file_path, + expected_permission="r", + expected_suffix=[".parquet"], + ) + # Load Parquet file into DataFrame + df = pd.read_parquet(file.path) + # Convert to xarray Dataset + ds = from_tidy_df(df, fps=fps, source_software="animovement") + # Add metadata + ds.attrs["source_file"] = file.path.as_posix() + logger.info(f"Loaded pose tracks from {file.path}:\n{ds}") + return ds diff --git a/movement/io/save_poses.py b/movement/io/save_poses.py index e65bd481e..b5ea96131 100644 --- a/movement/io/save_poses.py +++ b/movement/io/save_poses.py @@ -232,8 +232,8 @@ def to_lp_file( Parameters ---------- ds : xarray.Dataset - ``movement`` dataset containing pose tracks, confidence scores, - and associated metadata. + ``movement`` dataset containing pose tracks, coordinates, + confidence scores, and associated metadata. file_path : pathlib.Path or str Path to the file to save the poses to. File extension must be .csv. @@ -241,7 +241,7 @@ def to_lp_file( ----- LightningPose saves pose estimation outputs as .csv files, using the same format as single-animal DeepLabCut projects. Therefore, under the hood, - this function calls :func:`movement.io.save_poses.to_dlc_file` + this function calls :func:`to_dlc_file` with ``split_individuals=True``. This setting means that each individual is saved to a separate file, with the individual's name appended to the file path, just before the file extension, @@ -361,6 +361,116 @@ def to_sleap_analysis_file(ds: xr.Dataset, file_path: str | Path) -> None: logger.info(f"Saved poses dataset to {file.path}.") +def to_tidy_df(ds: xr.Dataset) -> pd.DataFrame: + """Convert a ``movement`` dataset to a tidy pandas DataFrame. + + Parameters + ---------- + ds : xarray.Dataset + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. + + Returns + ------- + pandas.DataFrame + Tidy DataFrame with columns: 'frame', 'track_id', 'keypoint', 'x', 'y', + and 'confidence' (if available in the dataset). + + Notes + ----- + The output DataFrame is in a tidy format where each row represents a + single observation (one keypoint for one individual at one frame). + The columns are: + - 'frame': integer, the frame number (time index) + - 'track_id': string, the individual ID + - 'keypoint': string, the keypoint name + - 'x': float, x-coordinate + - 'y': float, y-coordinate + - 'confidence': float, point-wise confidence scores (if present) + + Examples + -------- + >>> from movement.io import save_poses, load_poses + >>> ds = load_poses.from_sleap_file("path/to/file_sleap.analysis.h5") + >>> df = save_poses.to_tidy_df(ds) + + """ + _validate_dataset(ds) + + # Compute frame indices + fps = getattr(ds, "fps", None) + if fps is not None: + frame_idxs = np.rint(ds.time.values * fps).astype(int) + else: + frame_idxs = ds.time.values.astype(int) + + # Stack data to create tidy format + position = ( + ds["position"] + .stack(obs=["time", "individuals", "keypoints"]) + .transpose("obs", "space") + ) + + # Create frame indices for each observation + time_indices = position.indexes["obs"].get_level_values("time") + frame_values = frame_idxs[np.searchsorted(ds.time.values, time_indices)] + + # Create DataFrame + df = pd.DataFrame( + { + "frame": frame_values, + "track_id": position.individuals.values, + "keypoint": position.keypoints.values, + "x": position.sel(space="x").values, + "y": position.sel(space="y").values, + } + ) + + # Add confidence only if present and not all NaN + if "confidence" in ds and not ds["confidence"].isnull().all(): + confidence = ds["confidence"].stack( + obs=["time", "individuals", "keypoints"] + ) + df["confidence"] = confidence.values + + logger.info("Converted poses dataset to tidy DataFrame.") + return df.reset_index(drop=True) + + +def to_animovement_file(ds: xr.Dataset, file_path: str | Path) -> None: + """Save a ``movement`` dataset to an animovement Parquet file. + + Parameters + ---------- + ds : xarray.Dataset + ``movement`` dataset containing pose tracks, confidence scores, + and associated metadata. + file_path : pathlib.Path or str + Path to the file to save the poses to. File extension must be .parquet. + + Notes + ----- + The dataset is first converted to a tidy DataFrame using + :func:`to_tidy_df`, then saved as a Parquet file using pandas' + `to_parquet` method. + + Examples + -------- + >>> from movement.io import save_poses, load_poses + >>> ds = load_poses.from_sleap_file("path/to/file_sleap.analysis.h5") + >>> save_poses.to_animovement_file(ds, "path/to/file.parquet") + + """ + file = _validate_file_path(file_path, expected_suffix=[".parquet"]) + _validate_dataset(ds) + + # Convert to tidy DataFrame + df = to_tidy_df(ds) + # Save to Parquet + df.to_parquet(file.path, index=False) + logger.info(f"Saved poses dataset to {file.path}.") + + def _remove_unoccupied_tracks(ds: xr.Dataset): """Remove tracks that are completely unoccupied from the dataset. diff --git a/pyproject.toml b/pyproject.toml index 1692131e7..b91c30e39 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "numpy", "pandas", "h5py", + "pyarrow", "attrs", "pooch", "tqdm", diff --git a/tests/test_parquet_io.py b/tests/test_parquet_io.py new file mode 100644 index 000000000..003105a7f --- /dev/null +++ b/tests/test_parquet_io.py @@ -0,0 +1,334 @@ +"""Integration tests for Parquet I/O in movement.io.""" + +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from movement.io import load_poses, save_poses + + +@pytest.fixture +def sample_dataset(): + """Create a sample xarray Dataset for testing.""" + rng = np.random.default_rng(seed=10) + position_array = rng.random( + (10, 2, 3, 2) + ) # 10 frames, 2D, 3 keypoints, 2 individuals + confidence_array = np.ones((10, 3, 2)) * 0.9 + return load_poses.from_numpy( + position_array=position_array, + confidence_array=confidence_array, + individual_names=["ind1", "ind2"], + keypoint_names=["nose", "tail", "spine"], + fps=30, + source_software="test", + ) + + +@pytest.fixture +def sample_tidy_df(): + """Create a sample tidy DataFrame for testing.""" + rng = np.random.default_rng(seed=12) + data = { + "frame": [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1], + "track_id": ["ind1", "ind1", "ind1", "ind2", "ind2", "ind2"] * 2, + "keypoint": ["nose", "tail", "spine"] * 4, + "x": rng.random(12), + "y": rng.random(12), + "confidence": np.ones(12) * 0.9, + } + return pd.DataFrame(data) + + +def test_to_tidy_df(sample_dataset): + """Test conversion of xarray Dataset to tidy DataFrame.""" + df = save_poses.to_tidy_df(sample_dataset) + + # Check columns + expected_columns = { + "frame", + "track_id", + "keypoint", + "x", + "y", + "confidence", + } + assert set(df.columns) == expected_columns, ( + "Unexpected columns in tidy DataFrame" + ) + + # Check data types + assert df["frame"].dtype == int, "Frame column should be integer" + assert df["track_id"].dtype == object, ( + "Track_id column should be string/object" + ) + assert df["keypoint"].dtype == object, ( + "Keypoint column should be string/object" + ) + assert df["x"].dtype == float, "X column should be float" + assert df["y"].dtype == float, "Y column should be float" + assert df["confidence"].dtype == float, "Confidence column should be float" + + # Check shape + expected_rows = ( + sample_dataset.sizes["time"] + * sample_dataset.sizes["individuals"] + * sample_dataset.sizes["keypoints"] + ) + assert len(df) == expected_rows, ( + f"Expected {expected_rows} rows, got {len(df)}" + ) + + # Verify frame indices are integers starting from 0 + expected_frames = np.repeat( + np.arange(sample_dataset.sizes["time"]), + sample_dataset.sizes["individuals"] + * sample_dataset.sizes["keypoints"], + ) + assert np.array_equal(df["frame"].values, expected_frames), ( + "Frame indices should be integers starting from 0" + ) + + +def test_from_tidy_df(sample_tidy_df): + """Test conversion of tidy DataFrame to xarray Dataset.""" + ds = load_poses.from_tidy_df(sample_tidy_df, fps=30) + + # Check dataset structure + assert isinstance(ds, xr.Dataset), "Output should be an xarray Dataset" + assert set(ds.dims) == {"time", "space", "keypoints", "individuals"}, ( + "Unexpected dimensions" + ) + assert set(ds.data_vars) == {"position", "confidence"}, ( + "Unexpected data variables" + ) + + # Check coordinates + assert ds.sizes["time"] == 2, "Expected 2 frames" + assert ds.sizes["individuals"] == 2, "Expected 2 individuals" + assert ds.sizes["keypoints"] == 3, "Expected 3 keypoints" + assert ds.sizes["space"] == 2, "Expected 2D space" + + +def test_round_trip_dataframe(sample_dataset): + """Test round-trip conversion: Dataset -> tidy DataFrame -> Dataset.""" + df = save_poses.to_tidy_df(sample_dataset) + ds_roundtrip = load_poses.from_tidy_df( + df, fps=sample_dataset.attrs.get("fps") + ) + + # Compare datasets + xr.testing.assert_allclose( + ds_roundtrip["position"], sample_dataset["position"] + ) + xr.testing.assert_allclose( + ds_roundtrip["confidence"], sample_dataset["confidence"] + ) + assert ds_roundtrip.attrs["fps"] == sample_dataset.attrs["fps"], ( + "FPS metadata mismatch" + ) + assert set(ds_roundtrip.coords["individuals"].values) == set( + sample_dataset.coords["individuals"].values + ) + assert set(ds_roundtrip.coords["keypoints"].values) == set( + sample_dataset.coords["keypoints"].values + ) + + +def test_round_trip_parquet(sample_dataset, tmp_path): + """Test round-trip conversion: Dataset -> Parquet -> Dataset.""" + file_path = tmp_path / "test.parquet" + save_poses.to_animovement_file(sample_dataset, file_path) + ds_roundtrip = load_poses.from_animovement_file( + file_path, fps=sample_dataset.attrs.get("fps") + ) + + # Compare datasets + xr.testing.assert_allclose( + ds_roundtrip["position"], sample_dataset["position"] + ) + xr.testing.assert_allclose( + ds_roundtrip["confidence"], sample_dataset["confidence"] + ) + assert ds_roundtrip.attrs["fps"] == sample_dataset.attrs["fps"], ( + "FPS metadata mismatch" + ) + assert set(ds_roundtrip.coords["individuals"].values) == set( + sample_dataset.coords["individuals"].values + ) + assert set(ds_roundtrip.coords["keypoints"].values) == set( + sample_dataset.coords["keypoints"].values + ) + + +def test_to_tidy_df_no_confidence(): + """Test to_tidy_df with a dataset lacking confidence scores.""" + rng = np.random.default_rng(seed=5) + position_array = rng.random((5, 2, 2, 1)) + ds = load_poses.from_numpy( + position_array=position_array, + individual_names=["ind1"], + keypoint_names=["nose", "tail"], + fps=25, + ) + df = save_poses.to_tidy_df(ds) + + # Check columns (no confidence) + expected_columns = {"frame", "track_id", "keypoint", "x", "y"} + assert set(df.columns) == expected_columns, ( + "Unexpected columns in tidy DataFrame" + ) + assert len(df) == 5 * 1 * 2, "Incorrect number of rows" + + +def test_from_tidy_df_missing_columns(sample_tidy_df): + """Test from_tidy_df with missing required columns.""" + invalid_df = sample_tidy_df.drop(columns=["x"]) + with pytest.raises( + ValueError, match="DataFrame missing required columns: {'x'}" + ): + load_poses.from_tidy_df(invalid_df) + + +def test_from_animovement_file_invalid_extension(tmp_path): + """Test from_animovement_file with incorrect file extension.""" + invalid_file = tmp_path / "test.csv" + invalid_file.write_text("dummy") + with pytest.raises( + ValueError, + match=r"Expected file with suffix\(es\) \['.parquet'\] " + r"but got suffix .csv", + ): + load_poses.from_animovement_file(invalid_file) + + +def test_to_animovement_file_invalid_extension(tmp_path, sample_dataset): + """Test to_animovement_file with incorrect file extension.""" + invalid_file = tmp_path / "test.csv" + with pytest.raises( + ValueError, + match=r"Expected file with suffix\(es\) \['.parquet'\] " + r"but got suffix .csv", + ): + save_poses.to_animovement_file(sample_dataset, invalid_file) + + +def test_empty_dataset(): + """Test handling of empty dataset.""" + empty_ds = load_poses.from_numpy( + position_array=np.empty((0, 2, 0, 0)), + confidence_array=np.empty((0, 0, 0)), + individual_names=[], + keypoint_names=[], + fps=30, + ) + df = save_poses.to_tidy_df(empty_ds) + assert df.empty, "Tidy DataFrame should be empty for empty dataset" + + ds_roundtrip = load_poses.from_tidy_df(df, fps=30) + assert ds_roundtrip.sizes["time"] == 0, ( + "Round-trip dataset should have zero frames" + ) + + +def test_from_file_animovement(sample_dataset, tmp_path): + """Test from_file with source_software='animovement'.""" + file_path = tmp_path / "test.parquet" + save_poses.to_animovement_file(sample_dataset, file_path) + ds = load_poses.from_file( + file_path, + source_software="animovement", + fps=sample_dataset.attrs.get("fps"), + ) + + # Verify the loaded dataset + xr.testing.assert_allclose(ds["position"], sample_dataset["position"]) + xr.testing.assert_allclose(ds["confidence"], sample_dataset["confidence"]) + assert ds.attrs["fps"] == sample_dataset.attrs["fps"], ( + "FPS metadata mismatch" + ) + assert set(ds.coords["individuals"].values) == set( + sample_dataset.coords["individuals"].values + ) + assert set(ds.coords["keypoints"].values) == set( + sample_dataset.coords["keypoints"].values + ) + + +def test_to_tidy_df_float_time(): + """Test to_tidy_df with non-integer float time coordinates.""" + rng = np.random.default_rng(seed=15) + position_array = rng.random( + (5, 2, 2, 1) + ) # 5 frames, 2D, 2 keypoints, 1 individual + ds = load_poses.from_numpy( + position_array=position_array, + confidence_array=np.ones((5, 2, 1)) * 0.8, + individual_names=["ind1"], + keypoint_names=["nose", "tail"], + fps=10, + source_software="test", + ) + # Explicitly set time coordinates to non-integer floats + ds = ds.assign_coords(time=np.array([1.5, 2.5, 3.5, 4.5, 5.5])) + df = save_poses.to_tidy_df(ds) + + # Check columns + expected_columns = { + "frame", + "track_id", + "keypoint", + "x", + "y", + "confidence", + } + assert set(df.columns) == expected_columns, ( + "Unexpected columns in tidy DataFrame" + ) + assert df["frame"].dtype == int, "Frame column should be integer" + assert len(df) == 5 * 1 * 2, "Incorrect number of rows" + # Verify frame indices match scaled time values + expected_frames = np.repeat(np.array([15, 25, 35, 45, 55]), 2) + assert np.array_equal(df["frame"].values, expected_frames), ( + "Frame indices should match scaled time values" + ) + + +def test_to_tidy_df_no_fps(): + """Test to_tidy_df with fps=None and integer time coordinates.""" + rng = np.random.default_rng(seed=15) + position_array = rng.random( + (5, 2, 2, 1) + ) # 5 frames, 2D, 2 keypoints, 1 individual + ds = load_poses.from_numpy( + position_array=position_array, + confidence_array=np.ones((5, 2, 1)) * 0.8, + individual_names=["ind1"], + keypoint_names=["nose", "tail"], + fps=None, + source_software="test", + ) + # Explicitly set time coordinates to integers + ds = ds.assign_coords(time=np.array([0, 1, 2, 3, 4])) + df = save_poses.to_tidy_df(ds) + + # Check columns + expected_columns = { + "frame", + "track_id", + "keypoint", + "x", + "y", + "confidence", + } + assert set(df.columns) == expected_columns, ( + "Unexpected columns in tidy DataFrame" + ) + assert df["frame"].dtype == int, "Frame column should be integer" + assert len(df) == 5 * 1 * 2, "Incorrect number of rows" + # Verify frame indices match time values + expected_frames = np.repeat(np.array([0, 1, 2, 3, 4]), 2) + assert np.array_equal(df["frame"].values, expected_frames), ( + "Frame indices should match time values" + )