Skip to content

Fix circular mean by always storing and using the weighted one #142208

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 52 additions & 34 deletions homeassistant/components/recorder/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,13 @@ def query_circular_mean(table: type[StatisticsBase]) -> tuple[Label, Label]:
# in Python.
# https://en.wikipedia.org/wiki/Circular_mean
radians = func.radians(table.mean)
weighted_sum_sin = func.sum(func.sin(radians) * table.mean_weight)
weighted_sum_cos = func.sum(func.cos(radians) * table.mean_weight)
weight = func.sqrt(
func.power(func.sum(func.sin(radians) * table.mean_weight), 2)
+ func.power(func.sum(func.cos(radians) * table.mean_weight), 2)
func.power(weighted_sum_sin, 2) + func.power(weighted_sum_cos, 2)
)
return (
func.degrees(
func.atan2(func.sum(func.sin(radians)), func.sum(func.cos(radians)))
).label("mean"),
func.degrees(func.atan2(weighted_sum_sin, weighted_sum_cos)).label("mean"),
weight.label("mean_weight"),
)

Expand Down Expand Up @@ -240,18 +239,20 @@ def mean(values: list[float]) -> float | None:
RAD_TO_DEG = 180 / math.pi


def weighted_circular_mean(values: Iterable[tuple[float, float]]) -> float:
"""Return the weighted circular mean of the values."""
sin_sum = sum(math.sin(x * DEG_TO_RAD) * weight for x, weight in values)
cos_sum = sum(math.cos(x * DEG_TO_RAD) * weight for x, weight in values)
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
def weighted_circular_mean(
values: Iterable[tuple[float, float]],
) -> tuple[float, float]:
"""Return the weighted circular mean and the weight of the values."""
weighted_sin_sum, weighted_cos_sum = 0.0, 0.0
for x, weight in values:
rad_x = x * DEG_TO_RAD
weighted_sin_sum += math.sin(rad_x) * weight
weighted_cos_sum += math.cos(rad_x) * weight


def circular_mean(values: list[float]) -> float:
"""Return the circular mean of the values."""
sin_sum = sum(math.sin(x * DEG_TO_RAD) for x in values)
cos_sum = sum(math.cos(x * DEG_TO_RAD) for x in values)
return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
return (
(RAD_TO_DEG * math.atan2(weighted_sin_sum, weighted_cos_sum)) % 360,
math.sqrt(weighted_sin_sum**2 + weighted_cos_sum**2),
)


_LOGGER = logging.getLogger(__name__)
Expand Down Expand Up @@ -300,6 +301,7 @@ class StatisticsRow(BaseStatisticsRow, total=False):
min: float | None
max: float | None
mean: float | None
mean_weight: float | None
change: float | None


Expand Down Expand Up @@ -1023,7 +1025,7 @@ def _reduce_statistics(
_want_sum = "sum" in types
for statistic_id, stat_list in stats.items():
max_values: list[float] = []
mean_values: list[float] = []
mean_values: list[tuple[float, float]] = []
min_values: list[float] = []
prev_stat: StatisticsRow = stat_list[0]
fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds}
Expand All @@ -1039,12 +1041,15 @@ def _reduce_statistics(
}
if _want_mean:
row["mean"] = None
row["mean_weight"] = None
if mean_values:
match metadata[statistic_id][1]["mean_type"]:
case StatisticMeanType.ARITHMETIC:
row["mean"] = mean(mean_values)
row["mean"] = mean([x[0] for x in mean_values])
case StatisticMeanType.CIRCULAR:
row["mean"] = circular_mean(mean_values)
row["mean"], row["mean_weight"] = (
weighted_circular_mean(mean_values)
)
mean_values.clear()
if _want_min:
row["min"] = min(min_values) if min_values else None
Expand All @@ -1063,7 +1068,8 @@ def _reduce_statistics(
max_values.append(_max)
if _want_mean:
if (_mean := statistic.get("mean")) is not None:
mean_values.append(_mean)
_mean_weight = statistic.get("mean_weight") or 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the get here needed to handle rows created before this patch was merged? If yes, I think it's better to just wipe the rows so we don't need to deal with this forever.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the column mean_weight is defined as float | None, we also need to handle the case where it's None. This case is when the mean_type is arithmetic.

An alternative could be:

diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py
index 80c0028ef7a..e168cbd92d6 100644
--- a/homeassistant/components/recorder/statistics.py
+++ b/homeassistant/components/recorder/statistics.py
@@ -1025,7 +1025,8 @@ def _reduce_statistics(
     _want_sum = "sum" in types
     for statistic_id, stat_list in stats.items():
         max_values: list[float] = []
-        mean_values: list[tuple[float, float]] = []
+        arithmetic_mean_values: list[float] = []
+        circular_mean_values: list[tuple[float, float]] = []
         min_values: list[float] = []
         prev_stat: StatisticsRow = stat_list[0]
         fake_entry: StatisticsRow = {"start": stat_list[-1]["start"] + period_seconds}
@@ -1042,15 +1043,14 @@ def _reduce_statistics(
                 if _want_mean:
                     row["mean"] = None
                     row["mean_weight"] = None
-                    if mean_values:
-                        match metadata[statistic_id][1]["mean_type"]:
-                            case StatisticMeanType.ARITHMETIC:
-                                row["mean"] = mean([x[0] for x in mean_values])
-                            case StatisticMeanType.CIRCULAR:
-                                row["mean"], row["mean_weight"] = (
-                                    weighted_circular_mean(mean_values)
-                                )
-                    mean_values.clear()
+                    if arithmetic_mean_values:
+                        row["mean"] = mean(arithmetic_mean_values)
+                    if circular_mean_values:
+                        row["mean"], row["mean_weight"] = weighted_circular_mean(
+                            circular_mean_values
+                        )
+                    arithmetic_mean_values.clear()
+                    circular_mean_values.clear()
                 if _want_min:
                     row["min"] = min(min_values) if min_values else None
                     min_values.clear()
@@ -1068,8 +1068,14 @@ def _reduce_statistics(
                 max_values.append(_max)
             if _want_mean:
                 if (_mean := statistic.get("mean")) is not None:
-                    _mean_weight = statistic.get("mean_weight") or 0.0
-                    mean_values.append((_mean, _mean_weight))
+                    match metadata[statistic_id][1]["mean_type"]:
+                        case StatisticMeanType.ARITHMETIC:
+                            arithmetic_mean_values.append(_mean)
+                        case StatisticMeanType.CIRCULAR:
+                            if (
+                                _mean_weight := statistic.get("mean_weight")
+                            ) is not None:
+                                circular_mean_values.append((_mean, _mean_weight))
             if _want_min and (_min := statistic.get("min")) is not None:
                 min_values.append(_min)
             prev_stat = statistic

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As the column mean_weight is defined as float | None

But we ourselves insert the rows, and for circular mean we know we always insert the mean and the weight.
But I guess it makes sense to not blow up if there's invalid data in the database?

I like the proposed new version better, then we don't convert an invalid None to a valid 0.0

mean_values.append((_mean, _mean_weight))
if _want_min and (_min := statistic.get("min")) is not None:
min_values.append(_min)
prev_stat = statistic
Expand Down Expand Up @@ -1385,7 +1391,7 @@ def _get_max_mean_min_statistic(
match metadata[1]["mean_type"]:
case StatisticMeanType.CIRCULAR:
if circular_means := max_mean_min["circular_means"]:
mean_value = weighted_circular_mean(circular_means)
mean_value = weighted_circular_mean(circular_means)[0]
case StatisticMeanType.ARITHMETIC:
if (mean_value := max_mean_min.get("mean_acc")) is not None and (
duration := max_mean_min.get("duration")
Expand Down Expand Up @@ -1739,12 +1745,12 @@ def statistic_during_period(


_type_column_mapping = {
"last_reset": "last_reset_ts",
"max": "max",
"mean": "mean",
"min": "min",
"state": "state",
"sum": "sum",
"last_reset": ("last_reset_ts",),
"max": ("max",),
"mean": ("mean", "mean_weight"),
"min": ("min",),
"state": ("state",),
"sum": ("sum",),
}


Expand All @@ -1756,12 +1762,13 @@ def _generate_select_columns_for_types_stmt(
track_on: list[str | None] = [
table.__tablename__, # type: ignore[attr-defined]
]
for key, column in _type_column_mapping.items():
if key in types:
columns = columns.add_columns(getattr(table, column))
track_on.append(column)
else:
track_on.append(None)
for key, type_columns in _type_column_mapping.items():
for column in type_columns:
if key in types:
columns = columns.add_columns(getattr(table, column))
track_on.append(column)
else:
track_on.append(None)
return lambda_stmt(lambda: columns, track_on=track_on)


Expand Down Expand Up @@ -1944,6 +1951,12 @@ def _statistics_during_period_with_session(
hass, session, start_time, units, _types, table, metadata, result
)

# filter out mean_weight as it is only needed to reduce statistics
# and not needed in the result
for stats_rows in result.values():
for row in stats_rows:
row.pop("mean_weight", None)

# Return statistics combined with metadata
return result

Expand Down Expand Up @@ -2391,7 +2404,12 @@ def _sorted_statistics_to_dict(
field_map["last_reset"] = field_map.pop("last_reset_ts")
sum_idx = field_map["sum"] if "sum" in types else None
sum_only = len(types) == 1 and sum_idx is not None
row_mapping = tuple((key, field_map[key]) for key in types if key in field_map)
row_mapping = tuple(
(column, field_map[column])
for key in types
for column in ({key, *_type_column_mapping.get(key, ())})
if column in field_map
)
# Append all statistic entries, and optionally do unit conversion
table_duration_seconds = table.duration.total_seconds()
for meta_id, db_rows in stats_by_meta_id.items():
Expand Down
4 changes: 2 additions & 2 deletions homeassistant/components/sensor/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _time_weighted_arithmetic_mean(

def _time_weighted_circular_mean(
fstates: list[tuple[float, State]], start: datetime.datetime, end: datetime.datetime
) -> float:
) -> tuple[float, float]:
"""Calculate a time weighted circular mean.

The circular mean is calculated by weighting the states by duration in seconds between
Expand Down Expand Up @@ -623,7 +623,7 @@ def compile_statistics( # noqa: C901
valid_float_states, start, end
)
case StatisticMeanType.CIRCULAR:
stat["mean"] = _time_weighted_circular_mean(
stat["mean"], stat["mean_weight"] = _time_weighted_circular_mean(
valid_float_states, start, end
)

Expand Down
46 changes: 22 additions & 24 deletions tests/components/sensor/test_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4508,23 +4508,19 @@ def _weighted_average(seq, i, last_state):
duration += dur
return total / duration

def _time_weighted_circular_mean(values: list[tuple[float, int]]):
def _weighted_circular_mean(
values: Iterable[tuple[float, float]],
) -> tuple[float, float]:
sin_sum = 0
cos_sum = 0
for x, dur in values:
sin_sum += math.sin(x * DEG_TO_RAD) * dur
cos_sum += math.cos(x * DEG_TO_RAD) * dur
for x, weight in values:
sin_sum += math.sin(x * DEG_TO_RAD) * weight
cos_sum += math.cos(x * DEG_TO_RAD) * weight

return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360

def _circular_mean(values: list[float]) -> float:
sin_sum = 0
cos_sum = 0
for x in values:
sin_sum += math.sin(x * DEG_TO_RAD)
cos_sum += math.cos(x * DEG_TO_RAD)

return (RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360
return (
(RAD_TO_DEG * math.atan2(sin_sum, cos_sum)) % 360,
math.sqrt(sin_sum**2 + cos_sum**2),
)

def _min(seq, last_state):
if last_state is None:
Expand Down Expand Up @@ -4631,7 +4627,7 @@ def _sum(seq, last_state, last_sum):
values = [(seq, durations[j]) for j, seq in enumerate(seq)]
if (state := last_states["sensor.test5"]) is not None:
values.append((state, 5))
expected_means["sensor.test5"].append(_time_weighted_circular_mean(values))
expected_means["sensor.test5"].append(_weighted_circular_mean(values))
last_states["sensor.test5"] = seq[-1]

start += timedelta(minutes=5)
Expand Down Expand Up @@ -4733,15 +4729,17 @@ def _sum(seq, last_state, last_sum):
start = zero
end = zero + timedelta(minutes=5)
for i in range(24):
for entity_id in (
"sensor.test1",
"sensor.test2",
"sensor.test3",
"sensor.test4",
"sensor.test5",
for entity_id, mean_extractor in (
("sensor.test1", lambda x: x),
("sensor.test2", lambda x: x),
("sensor.test3", lambda x: x),
("sensor.test4", lambda x: x),
("sensor.test5", lambda x: x[0]),
):
expected_average = (
expected_means[entity_id][i] if entity_id in expected_means else None
mean_extractor(expected_means[entity_id][i])
if entity_id in expected_means
else None
)
expected_minimum = (
expected_minima[entity_id][i] if entity_id in expected_minima else None
Expand Down Expand Up @@ -4772,7 +4770,7 @@ def _sum(seq, last_state, last_sum):
assert stats == expected_stats

def verify_stats(
period: Literal["5minute", "day", "hour", "week", "month"],
period: Literal["hour", "day", "week", "month"],
start: datetime,
next_datetime: Callable[[datetime], datetime],
) -> None:
Expand All @@ -4791,7 +4789,7 @@ def verify_stats(
("sensor.test2", mean),
("sensor.test3", mean),
("sensor.test4", mean),
("sensor.test5", _circular_mean),
("sensor.test5", lambda x: _weighted_circular_mean(x)[0]),
):
expected_average = (
mean_fn(expected_means[entity_id][i * 12 : (i + 1) * 12])
Expand Down
Loading