diff --git a/scallops/features/normalize.py b/scallops/features/normalize.py index 0163b48..6594a2e 100644 --- a/scallops/features/normalize.py +++ b/scallops/features/normalize.py @@ -213,9 +213,9 @@ def normalize_features( """ mad_scale = _convert_scale(mad_scale) - x_data = _anndata_to_xr(data) + xdata = _anndata_to_xr(data) if normalize_groups is not None: - group_result = x_data.groupby(normalize_groups).map( + group_result = xdata.groupby(normalize_groups).map( lambda x: _normalize_group( x, reference_query=reference_query, @@ -238,7 +238,7 @@ def normalize_features( ) result = _normalize_group( - x_data, + xdata, reference_query=reference_query, normalize=normalize, n_neighbors=n_neighbors, diff --git a/scallops/features/preprocessing.py b/scallops/features/preprocessing.py new file mode 100644 index 0000000..d76cc90 --- /dev/null +++ b/scallops/features/preprocessing.py @@ -0,0 +1,82 @@ +from collections.abc import Sequence + +import anndata +import dask +import dask.array as da +import numpy as np +from array_api_compat import get_namespace +from sklearn.preprocessing import PowerTransformer + +from scallops.features.util import _anndata_to_xr, _slice_anndata + + +def transform_features_yj( + adata: anndata.AnnData, by: str | Sequence | None = None +) -> anndata.AnnData: + """Transform features using yeo-johnson transform + + :param adata: AnnData object + :param by: Column(s) in `adata.obs` to stratify by. + :return: Transformed AnnData object + """ + + def _transform_block(x): + return PowerTransformer(method="yeo-johnson").fit_transform(x) + + def _transform_feature_group(x): + d = x.data + if isinstance(d, da.Array): + chunks = list(d.chunksize) + if chunks[0] != d.shape[0]: + chunks[0] = -1 + d = d.rechunk(tuple(chunks)) + d = da.map_blocks(_transform_block, d, meta=np.array((), dtype=np.float64)) + else: + d = _transform_block(d) + return x.copy(data=d, deep=False) + + xdata = _anndata_to_xr(adata, by) + if by is not None: + result = xdata.groupby(by).map(_transform_feature_group) + return anndata.AnnData( + X=result.data, + obs=adata.obs.loc[result.coords["obs"].values], + var=adata.var.copy(), + ) + + return anndata.AnnData( + X=_transform_feature_group(xdata).data, + obs=adata.obs.copy(), + var=adata.var.copy(), + ) + + +def filter_data( + adata: anndata.AnnData, + max_fraction_nans: float | None = 0.25, + min_variance: float | None = 0.1, +) -> anndata.AnnData: + """Filter cells using `max_fraction_nans` then filter features using `min_variance` + + :param adata: AnnData object + :param max_fraction_nans: Keep cells with <= `max_fraction_nans` missing values + :param min_variance: Keep features with variance >= `min_variance` + :return: Filtered AnnData object + """ + xp = get_namespace(adata.X) + keep_cells = None + keep_features = None + if max_fraction_nans is not None: + nan_counts_per_cell = xp.isnan(adata.X).sum(axis=1) + max_nans = int(adata.shape[1] * max_fraction_nans) + keep_cells = nan_counts_per_cell <= max_nans + if min_variance is not None: + variance = ( + xp.var(adata.X[keep_cells], axis=0) + if keep_cells is not None + else xp.var(adata.X, axis=0) + ) + keep_features = variance >= min_variance + if isinstance(adata.X, da.Array): + keep_features, keep_cells = dask.compute(keep_features, keep_cells) + return _slice_anndata(adata, keep_cells, keep_features) diff --git a/scallops/tests/test_features_preprocessing.py b/scallops/tests/test_features_preprocessing.py new file mode 100644 index 0000000..a85acdb --- /dev/null +++ b/scallops/tests/test_features_preprocessing.py @@ -0,0 +1,104 @@ +import anndata +import dask.array as da +import numpy as np +import pandas as pd +import pytest +from sklearn.preprocessing import PowerTransformer + +from scallops.features.preprocessing import filter_data, transform_features_yj + + +@pytest.mark.parametrize("use_dask", [True, False]) +@pytest.mark.features +def test_filter_data(use_dask): + adata = anndata.AnnData( + X=da.arange(8, chunks=(1,)).reshape((4, 2)) + if use_dask + else np.arange(8).reshape((4, 2)), + obs=pd.DataFrame( + data=dict( + pert=["pert1", "pert2", "pert1", "pert2"], + well=["well1", "well2", "well1", "well2"], + ) + ), + var=pd.DataFrame(index=["gene1", "gene2"]), + ) + adata.X = adata.X.astype(np.float32) + adata.X[1, 0] = 100 + adata.X[0, 0] = np.nan + # np.var(adata.X, axis=0) array([nan, 5.], dtype=float32) + d = filter_data(adata, max_fraction_nans=0, min_variance=None) + # d.X.var(axis=0) # array([2006.2222 , 2.6666667] + assert d.shape == (3, 2) + assert filter_data(adata, max_fraction_nans=None, min_variance=0).shape == (4, 1) + assert filter_data(adata, max_fraction_nans=0, min_variance=5).shape == (3, 1) + + +@pytest.mark.parametrize("by", [None, ["pert", "well"], ["well"]]) +@pytest.mark.parametrize("use_dask", [True, False]) +@pytest.mark.features +def test_transform_features_yj(by, use_dask): + adata = anndata.AnnData( + X=da.arange(8, chunks=(1,)).reshape((4, 2)) + if use_dask + else np.arange(8).reshape((4, 2)), + obs=pd.DataFrame( + data=dict( + pert=["pert1", "pert2", "pert1", "pert2"], + well=["well1", "well2", "well1", "well2"], + ) + ), + var=pd.DataFrame(index=["gene1", "gene2"]), + ) + adata2 = adata.copy() + if isinstance(adata2.X, da.Array): + adata2.X = adata2.X.compute() + df = adata2.to_df().join(adata2.obs) + + if by is not None: + grouped = df.groupby(by) + + def single_group(x): + x = x.copy() + x["gene1"] = ( + PowerTransformer(method="yeo-johnson") + .fit_transform(x["gene1"].values.reshape(-1, 1)) + .squeeze() + ) + x["gene2"] = ( + PowerTransformer(method="yeo-johnson") + .fit_transform(x["gene2"].values.reshape(-1, 1)) + .squeeze() + ) + return x + + df = grouped.apply(single_group, include_groups=False).reset_index() + + else: + df["gene1"] = ( + PowerTransformer(method="yeo-johnson") + .fit_transform(df["gene1"].values.reshape(-1, 1)) + .squeeze() + ) + df["gene2"] = ( + PowerTransformer(method="yeo-johnson") + .fit_transform(df["gene2"].values.reshape(-1, 1)) + .squeeze() + ) + df = df.reset_index(drop=True) + columns_drop = df.columns[df.columns.str.startswith("level_")] + if len(columns_drop) > 0: + df = df.drop(columns_drop, axis=1) + + adata_transformed = transform_features_yj(adata, by=by) + + if isinstance(adata_transformed.X, da.Array): + adata_transformed.X = adata_transformed.X.compute() + df_test = ( + adata_transformed.to_df() + .join(adata_transformed.obs) + .sort_values(["pert", "well"]) + ) + df_test = df_test.sort_values(["pert", "well"]).reset_index(drop=True) + df = df.sort_values(["pert", "well"]).reset_index(drop=True) + pd.testing.assert_frame_equal(df_test[df.columns], df)