Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions scallops/features/normalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,9 @@ def normalize_features(
"""

mad_scale = _convert_scale(mad_scale)
x_data = _anndata_to_xr(data)
xdata = _anndata_to_xr(data)
if normalize_groups is not None:
group_result = x_data.groupby(normalize_groups).map(
group_result = xdata.groupby(normalize_groups).map(
lambda x: _normalize_group(
x,
reference_query=reference_query,
Expand All @@ -238,7 +238,7 @@ def normalize_features(
)

result = _normalize_group(
x_data,
xdata,
reference_query=reference_query,
normalize=normalize,
n_neighbors=n_neighbors,
Expand Down
82 changes: 82 additions & 0 deletions scallops/features/preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from collections.abc import Sequence

import anndata
import dask
import dask.array as da
import numpy as np
from array_api_compat import get_namespace
from sklearn.preprocessing import PowerTransformer

from scallops.features.util import _anndata_to_xr, _slice_anndata


def transform_features_yj(
adata: anndata.AnnData, by: str | Sequence | None = None
) -> anndata.AnnData:
"""Transform features using yeo-johnson transform

:param adata: AnnData object
:param by: Column(s) in `adata.obs` to stratify by.
:return: Transformed AnnData object
"""

def _transform_block(x):
return PowerTransformer(method="yeo-johnson").fit_transform(x)

def _transform_feature_group(x):
d = x.data
if isinstance(d, da.Array):
chunks = list(d.chunksize)
if chunks[0] != d.shape[0]:
chunks[0] = -1
d = d.rechunk(tuple(chunks))
d = da.map_blocks(_transform_block, d, meta=np.array((), dtype=np.float64))
else:
d = _transform_block(d)
return x.copy(data=d, deep=False)

xdata = _anndata_to_xr(adata, by)
if by is not None:
result = xdata.groupby(by).map(_transform_feature_group)
return anndata.AnnData(
X=result.data,
obs=adata.obs.loc[result.coords["obs"].values],
var=adata.var.copy(),
)

return anndata.AnnData(
X=_transform_feature_group(xdata).data,
obs=adata.obs.copy(),
var=adata.var.copy(),
)


def filter_data(
adata: anndata.AnnData,
max_fraction_nans: float | None = 0.25,
min_variance: float | None = 0.1,
) -> anndata.AnnData:
"""Filter cells using `max_fraction_nans` then filter features using `min_variance`

:param adata: AnnData object
:param max_fraction_nans: Keep cells with <= `max_fraction_nans` missing values
:param min_variance: Keep features with variance >= `min_variance`
:return: Filtered AnnData object
"""
xp = get_namespace(adata.X)
keep_cells = None
keep_features = None
if max_fraction_nans is not None:
nan_counts_per_cell = xp.isnan(adata.X).sum(axis=1)
max_nans = int(adata.shape[1] * max_fraction_nans)
keep_cells = nan_counts_per_cell <= max_nans
if min_variance is not None:
variance = (
xp.var(adata.X[keep_cells], axis=0)
if keep_cells is not None
else xp.var(adata.X, axis=0)
)
keep_features = variance >= min_variance
if isinstance(adata.X, da.Array):
keep_features, keep_cells = dask.compute(keep_features, keep_cells)
return _slice_anndata(adata, keep_cells, keep_features)
104 changes: 104 additions & 0 deletions scallops/tests/test_features_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import anndata
import dask.array as da
import numpy as np
import pandas as pd
import pytest
from sklearn.preprocessing import PowerTransformer

from scallops.features.preprocessing import filter_data, transform_features_yj


@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.features
def test_filter_data(use_dask):
adata = anndata.AnnData(
X=da.arange(8, chunks=(1,)).reshape((4, 2))
if use_dask
else np.arange(8).reshape((4, 2)),
obs=pd.DataFrame(
data=dict(
pert=["pert1", "pert2", "pert1", "pert2"],
well=["well1", "well2", "well1", "well2"],
)
),
var=pd.DataFrame(index=["gene1", "gene2"]),
)
adata.X = adata.X.astype(np.float32)
adata.X[1, 0] = 100
adata.X[0, 0] = np.nan
# np.var(adata.X, axis=0) array([nan, 5.], dtype=float32)
d = filter_data(adata, max_fraction_nans=0, min_variance=None)
# d.X.var(axis=0) # array([2006.2222 , 2.6666667]
assert d.shape == (3, 2)
assert filter_data(adata, max_fraction_nans=None, min_variance=0).shape == (4, 1)
assert filter_data(adata, max_fraction_nans=0, min_variance=5).shape == (3, 1)


@pytest.mark.parametrize("by", [None, ["pert", "well"], ["well"]])
@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.features
def test_transform_features_yj(by, use_dask):
adata = anndata.AnnData(
X=da.arange(8, chunks=(1,)).reshape((4, 2))
if use_dask
else np.arange(8).reshape((4, 2)),
obs=pd.DataFrame(
data=dict(
pert=["pert1", "pert2", "pert1", "pert2"],
well=["well1", "well2", "well1", "well2"],
)
),
var=pd.DataFrame(index=["gene1", "gene2"]),
)
adata2 = adata.copy()
if isinstance(adata2.X, da.Array):
adata2.X = adata2.X.compute()
df = adata2.to_df().join(adata2.obs)

if by is not None:
grouped = df.groupby(by)

def single_group(x):
x = x.copy()
x["gene1"] = (
PowerTransformer(method="yeo-johnson")
.fit_transform(x["gene1"].values.reshape(-1, 1))
.squeeze()
)
x["gene2"] = (
PowerTransformer(method="yeo-johnson")
.fit_transform(x["gene2"].values.reshape(-1, 1))
.squeeze()
)
return x

df = grouped.apply(single_group, include_groups=False).reset_index()

else:
df["gene1"] = (
PowerTransformer(method="yeo-johnson")
.fit_transform(df["gene1"].values.reshape(-1, 1))
.squeeze()
)
df["gene2"] = (
PowerTransformer(method="yeo-johnson")
.fit_transform(df["gene2"].values.reshape(-1, 1))
.squeeze()
)
df = df.reset_index(drop=True)
columns_drop = df.columns[df.columns.str.startswith("level_")]
if len(columns_drop) > 0:
df = df.drop(columns_drop, axis=1)

adata_transformed = transform_features_yj(adata, by=by)

if isinstance(adata_transformed.X, da.Array):
adata_transformed.X = adata_transformed.X.compute()
df_test = (
adata_transformed.to_df()
.join(adata_transformed.obs)
.sort_values(["pert", "well"])
)
df_test = df_test.sort_values(["pert", "well"]).reset_index(drop=True)
df = df.sort_values(["pert", "well"]).reset_index(drop=True)
pd.testing.assert_frame_equal(df_test[df.columns], df)