Skip to content

Commit 51a8ffa

Browse files
committed
Use sgkit.distarray for gwas_linear_regression
1 parent d522c0f commit 51a8ffa

File tree

5 files changed

+17
-11
lines changed

5 files changed

+17
-11
lines changed

.github/workflows/cubed.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,4 @@ jobs:
3030
3131
- name: Test with pytest
3232
run: |
33-
pytest -v sgkit/tests/test_{aggregation,hwe}.py -k 'test_count_call_alleles or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed
33+
pytest -v sgkit/tests/test_{aggregation,association,hwe}.py -k 'test_count_call_alleles or test_gwas_linear_regression or test_hwep or test_sample_stats or (test_count_variant_alleles and not test_count_variant_alleles__chunked[call_genotype]) or (test_variant_stats and not test_variant_stats__chunks[chunks2-False])' --use-cubed

sgkit/distarray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ def astype(x, dtype, /, *, copy=True): # pragma: no cover
1414
if not copy and dtype == x.dtype:
1515
return x
1616
return x.astype(dtype=dtype, copy=copy)
17+
18+
# dask doesn't have concat required by the array API
19+
concat = concatenate # noqa: F405

sgkit/stats/association.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from dataclasses import dataclass
22
from typing import Hashable, Optional, Sequence, Union
33

4-
import dask.array as da
54
import numpy as np
6-
from dask.array import Array, stats
5+
from scipy import stats
76
from xarray import Dataset, concat
87

8+
import sgkit.distarray as da
9+
from sgkit.distarray import Array
10+
911
from .. import variables
1012
from ..typing import ArrayLike
1113
from ..utils import conditional_merge_datasets, create_dataset
@@ -78,18 +80,18 @@ def linear_regression(
7880
# from projection require no extra terms in variance
7981
# estimate for loop covariates (columns of G), which is
8082
# only true when an intercept is present.
81-
XLPS = (XLP**2).sum(axis=0, keepdims=True).T
83+
XLPS = da.sum(XLP**2, axis=0, keepdims=True).T
8284
assert XLPS.shape == (n_loop_covar, 1)
8385
B = (XLP.T @ YP) / XLPS
8486
assert B.shape == (n_loop_covar, n_outcome)
8587

8688
# Compute residuals for each loop covariate and outcome separately
8789
YR = YP[:, np.newaxis, :] - XLP[..., np.newaxis] * B[np.newaxis, ...]
8890
assert YR.shape == (n_obs, n_loop_covar, n_outcome)
89-
RSS = (YR**2).sum(axis=0)
91+
RSS = da.sum(YR**2, axis=0)
9092
assert RSS.shape == (n_loop_covar, n_outcome)
9193
# Get t-statistics for coefficient estimates
92-
T = B / np.sqrt(RSS / dof / XLPS)
94+
T = B / da.sqrt(RSS / dof / XLPS)
9395
assert T.shape == (n_loop_covar, n_outcome)
9496

9597
# Match to p-values
@@ -102,7 +104,8 @@ def linear_regression(
102104
dtype="float64",
103105
)
104106
assert P.shape == (n_loop_covar, n_outcome)
105-
P = np.asarray(P, like=T)
107+
if hasattr(T, "__array_function__"):
108+
P = np.asarray(P, like=T)
106109
return LinearRegressionResult(beta=B, t_value=T, p_value=P)
107110

108111

@@ -216,7 +219,7 @@ def gwas_linear_regression(
216219
else:
217220
X = da.asarray(concat_2d(ds[list(covariates)], dims=("samples", "covariates")))
218221
if add_intercept:
219-
X = da.concatenate([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1)
222+
X = da.concat([da.ones((X.shape[0], 1), dtype=X.dtype), X], axis=1)
220223
# Note: dask qr decomp (used by lstsq) requires no chunking in one
221224
# dimension, and because dim 0 will be far greater than the number
222225
# of covariates for the large majority of use cases, chunking

sgkit/stats/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def assert_array_shape(x: ArrayLike, *args: int) -> None:
104104

105105

106106
def map_blocks_asnumpy(x: Array) -> Array:
107-
if da.utils.is_cupy_type(x._meta): # pragma: no cover
107+
if hasattr(x, "_meta") and da.utils.is_cupy_type(x._meta): # pragma: no cover
108108
import cupy as cp # type: ignore[import]
109109

110110
x = x.map_blocks(cp.asnumpy)

sgkit/tests/test_association.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from pathlib import Path
33
from typing import Any, Dict, List, Optional, Sequence, Tuple
44

5-
import dask.array as da
65
import numpy as np
76
import pandas as pd
87
import pytest
@@ -11,6 +10,7 @@
1110
from pandas import DataFrame
1211
from xarray import Dataset
1312

13+
import sgkit.distarray as da
1414
from sgkit.stats.association import (
1515
gwas_linear_regression,
1616
linear_regression,
@@ -263,7 +263,7 @@ def test_gwas_linear_regression__scalar_vars(ds: xr.Dataset) -> None:
263263
res_list = gwas_linear_regression(
264264
ds, dosage="dosage", covariates=["covar_0"], traits=["trait_0"]
265265
)
266-
xr.testing.assert_equal(res_scalar, res_list)
266+
xr.testing.assert_allclose(res_scalar, res_list)
267267

268268

269269
def test_gwas_linear_regression__raise_on_no_intercept_and_empty_covariates():

0 commit comments

Comments
 (0)