Skip to content

bin_df_cols leave input df unchanged #192

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymatviz/process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def count_elements(
values: ElemValues,
count_mode: ElemCountMode = ElemCountMode.composition,
exclude_elements: Sequence[str] = (),
fill_value: float | None = 0,
fill_value: float | None = None,
) -> pd.Series:
"""Count element occurrence in list of formula strings or dict-like compositions.
If passed values are already a map from element symbol to counts, ensure the
Expand Down
1 change: 1 addition & 0 deletions pymatviz/ptable/ptable_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ def ptable_heatmap_plotly(
font_size=font_size,
width=1000,
height=500,
title=dict(x=0.4, y=0.95),
)

if color_bar.get("orientation") == "h":
Expand Down
21 changes: 13 additions & 8 deletions pymatviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,28 +196,34 @@ def bin_df_cols(
verbose (bool): If True, report df length reduction. Defaults to True.

Returns:
pd.DataFrame: Binned DataFrame.
pd.DataFrame: Binned DataFrame with original index name and values.
"""
# Create a copy of the input DataFrame to avoid modifying the original
df_in = df_in.copy()

if isinstance(n_bins, int):
# broadcast integer n_bins to all bin_by_cols
n_bins = [n_bins] * len(bin_by_cols)

if len(bin_by_cols) != len(n_bins):
raise ValueError(f"{len(bin_by_cols)=} != {len(n_bins)=}")

index_name = df_in.index.name

cut_cols = [f"{col}_bins" for col in bin_by_cols]
for col, bins, cut_col in zip(bin_by_cols, n_bins, cut_cols):
df_in[cut_col] = pd.cut(df_in[col].values, bins=bins)

if df_in.index.name not in df_in:
# Preserve the original index
orig_index_name = df_in.index.name or "index"
# Reset index so it participates in groupby. If the index name is already in the
# columns, we it'll participate already and be set back to the index at the end.
if orig_index_name not in df_in:
df_in = df_in.reset_index()

group = df_in.groupby([*cut_cols, *group_by_cols], observed=True)
group = df_in.groupby(by=[*cut_cols, *group_by_cols], observed=True)

df_bin = group.first().dropna()
df_bin[bin_counts_col] = group.size()
df_bin = df_bin.reset_index()

if verbose:
print( # noqa: T201
Expand All @@ -234,9 +240,8 @@ def bin_df_cols(
density = gaussian_kde(xy_binned.astype(float))
df_bin[density_col] = density / density.sum() * len(values)

if index_name is None:
return df_bin
return df_bin.reset_index().set_index(index_name)
# Set the index back to the original index name
return df_bin.set_index(orig_index_name)


@contextmanager
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dependencies = [
"plotly>=5.23",
"pymatgen>=2024.7.18",
"scikit-learn>=1.5",
"scipy>=1.14",
"scipy>=1.13,<1.14",
]

[project.optional-dependencies]
Expand Down
16 changes: 7 additions & 9 deletions tests/test_process_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
def test_count_elements(count_mode: ElemCountMode, counts: dict[str, float]) -> None:
series = count_elements(["Fe2 O3"] * 5 + ["Fe4 P4 O16"] * 3, count_mode=count_mode)
expected = pd.Series(counts, index=df_ptable.index, name="count").fillna(0)
expected = pd.Series(counts, index=df_ptable.index, name="count")
pd.testing.assert_series_equal(series, expected, check_dtype=False)


Expand All @@ -47,7 +47,7 @@ def test_count_elements_composition_objects() -> None:
series = count_elements(compositions, count_mode=ElemCountMode.composition)
expected = pd.Series(
{"Fe": 22, "O": 63, "P": 12}, index=df_ptable.index, name="count"
).fillna(0)
)
pd.testing.assert_series_equal(series, expected, check_dtype=False)


Expand All @@ -58,7 +58,7 @@ def test_count_elements_composition_objects_fractional() -> None:
)
expected = pd.Series(
{"Fe": 2.5, "O": 5, "P": 0.5}, index=df_ptable.index, name="count"
).fillna(0)
)
pd.testing.assert_series_equal(series, expected, check_dtype=False)


Expand All @@ -67,16 +67,14 @@ def test_count_elements_composition_objects_reduced() -> None:
series = count_elements(compositions, count_mode=ElemCountMode.reduced_composition)
expected = pd.Series(
{"Fe": 13, "O": 27, "P": 3}, index=df_ptable.index, name="count"
).fillna(0)
)
pd.testing.assert_series_equal(series, expected, check_dtype=False)


def test_count_elements_composition_objects_occurrence() -> None:
compositions = [Composition("Fe2O3")] * 5 + [Composition("Fe4P4O16")] * 3
series = count_elements(compositions, count_mode=ElemCountMode.occurrence)
expected = pd.Series(
{"Fe": 8, "O": 8, "P": 3}, index=df_ptable.index, name="count"
).fillna(0)
expected = pd.Series({"Fe": 8, "O": 8, "P": 3}, index=df_ptable.index, name="count")
pd.testing.assert_series_equal(series, expected, check_dtype=False)


Expand All @@ -87,7 +85,7 @@ def test_count_elements_mixed_input() -> None:
{"Fe": 6, "O": 21, "P": 4, "Li": 1, "Co": 1, "Na": 1, "Cl": 1},
index=df_ptable.index,
name="count",
).fillna(0)
)
pd.testing.assert_series_equal(series, expected, check_dtype=False)


Expand All @@ -98,7 +96,7 @@ def test_count_elements_exclude_elements() -> None:
)
expected = pd.Series(
{"O": 63}, index=df_ptable.index.drop(["Fe", "P"]), name="count"
).fillna(0)
)
pd.testing.assert_series_equal(series, expected, check_dtype=False)


Expand Down
49 changes: 34 additions & 15 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import copy
import re
from copy import deepcopy
from typing import Any, Literal
Expand Down Expand Up @@ -119,8 +120,8 @@ def test_df_to_arrays_strict() -> None:
"verbose, density_col, expected_n_rows",
[
(["A"], [], 2, [2], True, "", 2),
(["A", "B"], [], 2, [2, 2], True, "kde", 4),
(["A", "B"], [], [2, 3], [2, 3], False, "kde", 6),
(["A", "B"], [], 2, [2, 2], True, "kde_bin_counts", 4),
(["A", "B"], [], [2, 3], [2, 3], False, "kde_bin_counts", 6),
(["A"], ["B"], 2, [2], False, "", 30),
],
)
Expand All @@ -135,7 +136,14 @@ def test_bin_df_cols(
df_float: pd.DataFrame,
) -> None:
idx_col = "index"
# don't move this below df_float.copy() line
df_float.index.name = idx_col

# keep copy of original DataFrame to assert it is not modified
# not using df.copy(deep=True) here for extra sensitivity, doc str says
# not as deep as deepcopy
df_float_orig = copy.deepcopy(df_float)

bin_counts_col = "bin_counts"
df_binned = bin_df_cols(
df_float,
Expand All @@ -147,24 +155,35 @@ def test_bin_df_cols(
density_col=density_col,
)

assert len(df_binned) == expected_n_rows
assert len(df_binned) <= len(df_float)
assert df_binned.index.name == idx_col

# ensure binned DataFrame has a minimum set of expected columns
expected_cols = {bin_counts_col, *df_float, *(f"{col}_bins" for col in bin_by_cols)}
assert {*df_binned} >= expected_cols
assert len(df_binned) == expected_n_rows
assert (
{*df_binned} >= expected_cols
), f"{set(df_binned)=}\n{expected_cols=},\n{bin_by_cols=}\n{group_by_cols=}"

# validate the number of unique bins for each binned column
df_grouped = (
df_float.reset_index(names=idx_col)
.groupby([*[f"{c}_bins" for c in bin_by_cols], *group_by_cols])
.first()
.dropna()
)
for col, expected in zip(bin_by_cols, expected_n_bins):
binned_col = f"{col}_bins"
assert binned_col in df_grouped.index.names
for col, n_bins_expec in zip(bin_by_cols, expected_n_bins):
assert df_binned[f"{col}_bins"].nunique() == n_bins_expec

# ensure original DataFrame is not modified
pd.testing.assert_frame_equal(df_float, df_float_orig)

# Check that the index values of df_binned are a subset of df_float
assert set(df_binned.index).issubset(set(df_float.index))

uniq_bins = df_grouped.index.get_level_values(binned_col).nunique()
assert uniq_bins == expected
# Check that bin_counts column exists and contains only integers
assert bin_counts_col in df_binned
assert df_binned[bin_counts_col].dtype in [int, "int64"]

# If density column is specified, check if it exists
if density_col:
assert density_col in df_binned
else:
assert density_col not in df_binned


def test_bin_df_cols_raises() -> None:
Expand Down
Loading