Skip to content

Commit 9da2c8f

Browse files
authored
BUG: Require sample weights to sum to less than 1 when replace = True (#61582)
1 parent b876c67 commit 9da2c8f

File tree

6 files changed

+47
-7
lines changed

6 files changed

+47
-7
lines changed

doc/source/user_guide/indexing.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ to have different probabilities, you can pass the ``sample`` function sampling w
700700
701701
s = pd.Series([0, 1, 2, 3, 4, 5])
702702
example_weights = [0, 0, 0.2, 0.2, 0.2, 0.4]
703-
s.sample(n=3, weights=example_weights)
703+
s.sample(n=2, weights=example_weights)
704704
705705
# Weights will be re-normalized automatically
706706
example_weights2 = [0.5, 0, 0, 0, 0, 0]
@@ -714,7 +714,7 @@ as a string.
714714
715715
df2 = pd.DataFrame({'col1': [9, 8, 7, 6],
716716
'weight_column': [0.5, 0.4, 0.1, 0]})
717-
df2.sample(n=3, weights='weight_column')
717+
df2.sample(n=2, weights='weight_column')
718718
719719
``sample`` also allows users to sample columns instead of rows using the ``axis`` argument.
720720

doc/source/whatsnew/v0.16.1.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ facilitate replication. (:issue:`2419`)
196196
197197
# weights are accepted.
198198
example_weights = [0, 0, 0.2, 0.2, 0.2, 0.4]
199-
example_series.sample(n=3, weights=example_weights)
199+
example_series.sample(n=2, weights=example_weights)
200200
201201
# weights will also be normalized if they do not sum to one,
202202
# and missing values will be treated as zeros.
@@ -210,7 +210,7 @@ when sampling from rows.
210210
.. ipython:: python
211211
212212
df = pd.DataFrame({"col1": [9, 8, 7, 6], "weight_column": [0.5, 0.4, 0.1, 0]})
213-
df.sample(n=3, weights="weight_column")
213+
df.sample(n=2, weights="weight_column")
214214
215215
216216
.. _whatsnew_0161.enhancements.string:

doc/source/whatsnew/v3.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ Other
911911
- Bug in :meth:`DataFrame.query` where using duplicate column names led to a ``TypeError``. (:issue:`59950`)
912912
- Bug in :meth:`DataFrame.query` which raised an exception or produced incorrect results when expressions contained backtick-quoted column names containing the hash character ``#``, backticks, or characters that fall outside the ASCII range (U+0001..U+007F). (:issue:`59285`) (:issue:`49633`)
913913
- Bug in :meth:`DataFrame.query` which raised an exception when querying integer column names using backticks. (:issue:`60494`)
914+
- Bug in :meth:`DataFrame.sample` with ``replace=False`` and ``(n * max(weights) / sum(weights)) > 1``, the method would return biased results. Now raises ``ValueError``. (:issue:`61516`)
914915
- Bug in :meth:`DataFrame.shift` where passing a ``freq`` on a DataFrame with no columns did not shift the index correctly. (:issue:`60102`)
915916
- Bug in :meth:`DataFrame.sort_index` when passing ``axis="columns"`` and ``ignore_index=True`` and ``ascending=False`` not returning a :class:`RangeIndex` columns (:issue:`57293`)
916917
- Bug in :meth:`DataFrame.sort_values` where sorting by a column explicitly named ``None`` raised a ``KeyError`` instead of sorting by the column as expected. (:issue:`61512`)

pandas/core/generic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5814,6 +5814,8 @@ def sample(
58145814
If weights do not sum to 1, they will be normalized to sum to 1.
58155815
Missing values in the weights column will be treated as zero.
58165816
Infinite values not allowed.
5817+
When replace = False will not allow ``(n * max(weights) / sum(weights)) > 1``
5818+
in order to avoid biased results. See the Notes below for more details.
58175819
random_state : int, array-like, BitGenerator, np.random.RandomState, np.random.Generator, optional
58185820
If int, array-like, or BitGenerator, seed for random number generator.
58195821
If np.random.RandomState or np.random.Generator, use as given.
@@ -5850,6 +5852,11 @@ def sample(
58505852
-----
58515853
If `frac` > 1, `replacement` should be set to `True`.
58525854
5855+
When replace = False will not allow ``(n * max(weights) / sum(weights)) > 1``,
5856+
since that would cause results to be biased. E.g. sampling 2 items without replacement
5857+
with weights [100, 1, 1] would yield two last items in 1/2 of cases, instead of 1/102.
5858+
This is similar to specifying `n=4` without replacement on a Series with 3 elements.
5859+
58535860
Examples
58545861
--------
58555862
>>> df = pd.DataFrame(

pandas/core/sample.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ def sample(
150150
else:
151151
raise ValueError("Invalid weights: weights sum to zero")
152152

153+
assert weights is not None # for mypy
154+
if not replace and size * weights.max() > 1:
155+
raise ValueError(
156+
"Weighted sampling cannot be achieved with replace=False. Either "
157+
"set replace=True or use smaller weights. See the docstring of "
158+
"sample for details."
159+
)
160+
153161
return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
154162
np.intp, copy=False
155163
)

pandas/tests/frame/methods/test_sample.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ def test_sample_invalid_weight_lengths(self, obj):
113113
with pytest.raises(ValueError, match=msg):
114114
obj.sample(n=3, weights=[0.5] * 11)
115115

116-
with pytest.raises(ValueError, match="Fewer non-zero entries in p than size"):
117-
obj.sample(n=4, weights=Series([0, 0, 0.2]))
118-
119116
def test_sample_negative_weights(self, obj):
120117
# Check won't accept negative weights
121118
bad_weights = [-0.1] * 10
@@ -137,6 +134,33 @@ def test_sample_inf_weights(self, obj):
137134
with pytest.raises(ValueError, match=msg):
138135
obj.sample(n=3, weights=weights_with_ninf)
139136

137+
def test_sample_unit_probabilities_raises(self, obj):
138+
# GH#61516
139+
high_variance_weights = [1] * 10
140+
high_variance_weights[0] = 100
141+
msg = (
142+
"Weighted sampling cannot be achieved with replace=False. Either "
143+
"set replace=True or use smaller weights. See the docstring of "
144+
"sample for details."
145+
)
146+
with pytest.raises(ValueError, match=msg):
147+
obj.sample(n=2, weights=high_variance_weights, replace=False)
148+
149+
def test_sample_unit_probabilities_edge_case_do_not_raise(self, obj):
150+
# GH#61516
151+
# edge case, n*max(weights)/sum(weights) == 1
152+
edge_variance_weights = [1] * 10
153+
edge_variance_weights[0] = 9
154+
# should not raise
155+
obj.sample(n=2, weights=edge_variance_weights, replace=False)
156+
157+
def test_sample_unit_normal_probabilities_do_not_raise(self, obj):
158+
# GH#61516
159+
low_variance_weights = [1] * 10
160+
low_variance_weights[0] = 8
161+
# should not raise
162+
obj.sample(n=2, weights=low_variance_weights, replace=False)
163+
140164
def test_sample_zero_weights(self, obj):
141165
# All zeros raises errors
142166

0 commit comments

Comments
 (0)