From 150888455bebe660a18c152a196d97af6fe2f243 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 26 Oct 2025 18:57:49 -0500 Subject: [PATCH 01/33] initial stab at a general matrix (no normalisation) --- c/tskit/trees.c | 66 ++++++++------ c/tskit/trees.h | 7 ++ python/_tskitmodule.c | 202 ++++++++++++++++++++++++++++++++++++++++++ python/tskit/trees.py | 65 +++++++++++--- 4 files changed, 299 insertions(+), 41 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1aa06e5b03..f7805857ca 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2411,8 +2411,8 @@ static int compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim, - tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, - norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) + tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f, + bool polarised, two_locus_work_t *restrict work, double *result) { int ret = 0; // Sample sets and b sites are rows, a sites are columns @@ -2463,9 +2463,8 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, static int compute_general_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, - tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; tsk_size_t k; @@ -2653,9 +2652,8 @@ static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, - const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + void *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites, + tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { int ret = 0; tsk_bitset_t allele_samples, allele_sample_sets; @@ -3089,9 +3087,8 @@ advance_collect_edges(iter_state *s, tsk_id_t index) static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, - tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t result_dim, int sign, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; double a_len, b_len; @@ -3141,8 +3138,8 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, static int compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, - iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, - tsk_size_t result_dim, tsk_size_t state_dim, double *result) + iter_state *r_state, general_stat_func_t *f, void *f_params, tsk_size_t result_dim, + tsk_size_t state_dim, double *result) { int ret = 0; tsk_id_t e, c, ec, p, *updated_nodes = NULL; @@ -3243,9 +3240,9 @@ static int tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), - tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, - const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) + void *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows, + const double *row_positions, tsk_size_t n_cols, const double *col_positions, + tsk_flags_t TSK_UNUSED(options), double *result) { int ret = 0; int r, c; @@ -3385,10 +3382,10 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s } int -tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, +tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { @@ -3398,10 +3395,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); tsk_size_t state_dim = num_sample_sets; - sample_count_stat_params_t f_params = { .sample_sets = sample_sets, - .num_sample_sets = num_sample_sets, - .sample_set_sizes = sample_set_sizes, - .set_indexes = set_indexes }; // We do not support two-locus node stats if (!!(options & TSK_STAT_NODE)) { @@ -3441,7 +3434,7 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); } else if (stat_branch) { ret = check_positions( @@ -3455,13 +3448,30 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_positions, out_cols, col_positions, options, result); } out: return ret; } +int +tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) +{ + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; + return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_sites, row_positions, out_cols, col_sites, col_positions, options, result); +} + /*********************************** * Allele frequency spectrum ***********************************/ @@ -8697,8 +8707,8 @@ update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, for (k = offsets[b]; k < offsets[b + 1]; k++) { u = A[j]; v = A[k]; - /* Only increment the upper triangle to (hopefully) improve memory - * access patterns */ + /* Only increment the upper triangle to (hopefully) improve + * memory access patterns */ if (u > v) { u = A[k]; v = A[j]; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 84480ed96e..acc15c9aac 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1120,6 +1120,13 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 0e0c1c5ed5..afde032847 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7946,6 +7946,203 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) return array; } +typedef struct { + PyArrayObject *sample_set_sizes; + PyObject *callable; +} two_locus_general_stat_params; + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + two_locus_general_stat_params *tl_params = params; + PyObject *callable = tl_params->callable; + PyArrayObject *sample_set_sizes = tl_params->sample_set_sizes; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + npy_intp X_dims[2] = { K, 3 }; + // Convert "n" to a column array + PyArray_Dims n_dims = { (npy_intp[2]){ PyArray_DIMS(sample_set_sizes)[0], 1 }, 2 }; + npy_intp *Y_dims; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { + goto out; + } + sample_set_sizes + = (PyArrayObject *) PyArray_Newshape(sample_set_sizes, &n_dims, NPY_CORDER); + + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + arglist = Py_BuildValue("OO", X_array, sample_set_sizes); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(callable, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by general_stat callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + Y_dims = PyArray_DIMS(Y_array); + if (Y_dims[0] != (npy_intp) M) { + PyErr_Format(PyExc_ValueError, + "Array returned by general_stat callback is of length %d; " + "must be %d", + Y_dims[0], M); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), M * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + return ret; +} + +static PyObject * +TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", + "output_dim", "polarised", "row_sites", "col_sites", "row_positions", + "column_positions", "mode", NULL }; + two_locus_general_stat_params *params; + PyObject *summary_func = NULL; + unsigned int output_dim; + PyObject *sample_set_sizes = NULL; + PyObject *sample_sets = NULL; + PyObject *row_sites = NULL; + PyObject *col_sites = NULL; + PyObject *row_positions = NULL; + PyObject *col_positions = NULL; + char *mode = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *row_sites_array = NULL; + PyArrayObject *col_sites_array = NULL; + PyArrayObject *row_positions_array = NULL; + PyArrayObject *col_positions_array = NULL; + PyArrayObject *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL; + tsk_id_t *col_sites_parsed = NULL; + double *row_positions_parsed = NULL; + double *col_positions_parsed = NULL; + npy_intp result_dim[3] = { 0, 0, 0 }; + tsk_size_t num_sample_sets; + tsk_flags_t options = 0; + int polarised = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, + &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + Py_XINCREF(summary_func); + goto out; + } + Py_INCREF(summary_func); + if (!PyCallable_Check(summary_func)) { + PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + PyArray_CLEARFLAGS(sample_set_sizes_array, NPY_ARRAY_WRITEABLE); + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); + } + + result_dim[2] = num_sample_sets; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + + params = &(two_locus_general_stat_params){ + .sample_set_sizes = sample_set_sizes_array, + .callable = summary_func, + }; + // TODO: deal with null norm func, need general stat. + err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + output_dim, general_two_locus_count_stat_func, params, NULL, result_dim[0], + row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, + col_positions_parsed, options, PyArray_DATA(result_matrix)); + + if (err == TSK_PYTHON_CALLBACK_ERROR) { + goto out; + } else if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(summary_func); + Py_XDECREF(row_sites_array); + Py_XDECREF(col_sites_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_matrix); + return ret; +} + static PyObject * TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_count_stat_method *method) @@ -8831,6 +9028,11 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_general_stat, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Runs the general stats algorithm for a given summary function." }, + { .ml_name = "two_locus_count_stat", + .ml_meth = (PyCFunction) TreeSequence_two_locus_count_stat, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc + = "Runs the general two locus stats algorithm for a given summary function." }, { .ml_name = "diversity", .ml_meth = (PyCFunction) TreeSequence_diversity, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 45d2da59e0..a09cf8e741 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8249,19 +8249,7 @@ def parse_positions(self, positions): ) return row_positions, col_positions - def __two_locus_sample_set_stat( - self, - ll_method, - sample_sets, - sites=None, - positions=None, - mode=None, - ): - if sample_sets is None: - sample_sets = self.samples() - row_sites, col_sites = self.parse_sites(sites) - row_positions, col_positions = self.parse_positions(positions) - + def __convert_sample_sets(self, sample_sets): # First try to convert to a 1D numpy array. If we succeed, then we strip off # the corresponding dimension from the output. drop_dimension = False @@ -8283,7 +8271,23 @@ def __two_locus_sample_set_stat( raise ValueError("Sample sets must contain at least one element") flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + return drop_dimension, flattened, sample_set_sizes + def __two_locus_sample_set_stat( + self, + ll_method, + sample_sets, + sites=None, + positions=None, + mode=None, + ): + if sample_sets is None: + sample_sets = self.samples() + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) result = ll_method( sample_set_sizes, flattened, @@ -10927,6 +10931,41 @@ def impute_unknown_mutations_time( mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] return mutations_time + def two_locus_count_stat( + self, + sample_sets, + f, + result_dim, + polarised=False, + sites=None, + positions=None, + mode="site", + ): + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) + result = self._ll_tree_sequence.two_locus_count_stat( + sample_set_sizes, + sample_sets, + f, + result_dim, + polarised, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) + if drop_dimension: + result = result.reshape(result.shape[:2]) + else: + # Orient the data so that the first dimension is the sample set. + # With this orientation, we get one LD matrix per sample set. + result = result.swapaxes(0, 2).swapaxes(1, 2) + return result + def ld_matrix( self, sample_sets=None, From b144f8d55ba3bd8e5f13ee34a4020aa8197efeca Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 26 Oct 2025 18:59:39 -0500 Subject: [PATCH 02/33] added dimension dropping, but I think transposing is better -- we don't have to add a dimension at the end for scalar operations --- python/_tskitmodule.c | 37 +++++++++++++++++++++++-------------- python/tskit/trees.py | 4 +++- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index afde032847..ed50577ee8 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7947,6 +7947,7 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) } typedef struct { + bool drop_dimensions; PyArrayObject *sample_set_sizes; PyObject *callable; } two_locus_general_stat_params; @@ -7956,29 +7957,33 @@ general_two_locus_count_stat_func( tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; - two_locus_general_stat_params *tl_params = params; - PyObject *callable = tl_params->callable; - PyArrayObject *sample_set_sizes = tl_params->sample_set_sizes; PyObject *arglist = NULL; PyObject *result = NULL; PyArrayObject *X_array = NULL; PyArrayObject *Y_array = NULL; - npy_intp X_dims[2] = { K, 3 }; - // Convert "n" to a column array - PyArray_Dims n_dims = { (npy_intp[2]){ PyArray_DIMS(sample_set_sizes)[0], 1 }, 2 }; + two_locus_general_stat_params *tl_params = params; + PyObject *callable = tl_params->callable; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + bool drop = (K == 1 && tl_params->drop_dimensions); + // Convert "n" to a column array -- reshape(-1, K) or a scalar if K=1 and drop=True + PyArray_Dims ss_sizes_dims = (drop ? (PyArray_Dims){ (npy_intp[1]){ 1 }, 0 } + : (PyArray_Dims){ (npy_intp[2]){ K, 1 }, 2 }); + int X_ndims = drop ? 1 : 2; + npy_intp *X_dims = drop ? (npy_intp[1]){ 3 } : (npy_intp[2]){ K, 3 }; npy_intp *Y_dims; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - 2, X_dims, NPY_FLOAT64, (void *) X); + X_ndims, X_dims, NPY_FLOAT64, (void *) X); if (X_array == NULL) { goto out; } - sample_set_sizes - = (PyArrayObject *) PyArray_Newshape(sample_set_sizes, &n_dims, NPY_CORDER); - + ss_sizes = (PyArrayObject *) PyArray_Newshape(ss_sizes, &ss_sizes_dims, NPY_CORDER); + if (ss_sizes == NULL) { + goto out; + } PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); - arglist = Py_BuildValue("OO", X_array, sample_set_sizes); + arglist = Py_BuildValue("OO", X_array, ss_sizes); if (arglist == NULL) { goto out; } @@ -8014,6 +8019,7 @@ general_two_locus_count_stat_func( Py_XDECREF(arglist); Py_XDECREF(result); Py_XDECREF(Y_array); + Py_XDECREF(ss_sizes); return ret; } @@ -8023,7 +8029,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", "output_dim", "polarised", "row_sites", "col_sites", "row_positions", - "column_positions", "mode", NULL }; + "column_positions", "mode", "drop_dimensions", NULL }; two_locus_general_stat_params *params; PyObject *summary_func = NULL; unsigned int output_dim; @@ -8048,15 +8054,17 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * npy_intp result_dim[3] = { 0, 0, 0 }; tsk_size_t num_sample_sets; tsk_flags_t options = 0; + int drop_dimensions = 0; int polarised = 0; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|s", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|si", kwlist, &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, - &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + &row_sites, &col_sites, &row_positions, &col_positions, &mode, + &drop_dimensions)) { Py_XINCREF(summary_func); goto out; } @@ -8115,6 +8123,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * params = &(two_locus_general_stat_params){ .sample_set_sizes = sample_set_sizes_array, .callable = summary_func, + .drop_dimensions = drop_dimensions, }; // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index a09cf8e741..a858f7ff0e 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10940,6 +10940,7 @@ def two_locus_count_stat( sites=None, positions=None, mode="site", + drop_dimensions=True, ): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) @@ -10948,7 +10949,7 @@ def two_locus_count_stat( ) result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, - sample_sets, + flattened, f, result_dim, polarised, @@ -10957,6 +10958,7 @@ def two_locus_count_stat( row_positions, col_positions, mode, + drop_dimensions, ) if drop_dimension: result = result.reshape(result.shape[:2]) From 92b422aaeaaab7abb3241538058ad94dd0b927ec Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 12:31:28 -0600 Subject: [PATCH 03/33] finalize and add tests for single and multipop --- c/tskit/trees.c | 4 +- python/_tskitmodule.c | 153 +++++++++++++++----- python/tests/test_ld_matrix.py | 255 ++++++++++++++++++++++++++++++++- python/tskit/trees.py | 33 ++--- 4 files changed, 387 insertions(+), 58 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index f7805857ca..a8f6e168f9 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -8707,8 +8707,8 @@ update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, for (k = offsets[b]; k < offsets[b + 1]; k++) { u = A[j]; v = A[k]; - /* Only increment the upper triangle to (hopefully) improve - * memory access patterns */ + /* Only increment the upper triangle to (hopefully) improve memory + * access patterns */ if (u > v) { u = A[k]; v = A[j]; diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index ed50577ee8..c6bca34ca9 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7947,47 +7947,123 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) } typedef struct { - bool drop_dimensions; PyArrayObject *sample_set_sizes; - PyObject *callable; + PyObject *summary_func; + PyObject *norm_func; } two_locus_general_stat_params; static int -general_two_locus_count_stat_func( - tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) +general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n_a, + tsk_size_t n_b, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; PyObject *arglist = NULL; PyObject *result = NULL; + PyArrayObject *n_a_scalar = NULL; + PyArrayObject *n_b_scalar = NULL; PyArrayObject *X_array = NULL; PyArrayObject *Y_array = NULL; two_locus_general_stat_params *tl_params = params; - PyObject *callable = tl_params->callable; + PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - bool drop = (K == 1 && tl_params->drop_dimensions); - // Convert "n" to a column array -- reshape(-1, K) or a scalar if K=1 and drop=True - PyArray_Dims ss_sizes_dims = (drop ? (PyArray_Dims){ (npy_intp[1]){ 1 }, 0 } - : (PyArray_Dims){ (npy_intp[2]){ K, 1 }, 2 }); - int X_ndims = drop ? 1 : 2; - npy_intp *X_dims = drop ? (npy_intp[1]){ 3 } : (npy_intp[2]){ K, 3 }; - npy_intp *Y_dims; + npy_intp X_dims[2] = { result_dim, 3 }; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - X_ndims, X_dims, NPY_FLOAT64, (void *) X); + 2, X_dims, NPY_FLOAT64, (void *) X); if (X_array == NULL) { goto out; } - ss_sizes = (PyArrayObject *) PyArray_Newshape(ss_sizes, &ss_sizes_dims, NPY_CORDER); - if (ss_sizes == NULL) { + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } + n_a_scalar + = (PyArrayObject *) PyArray_Scalar(&n_a, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_a_scalar == NULL) { + goto out; + } + n_b_scalar + = (PyArrayObject *) PyArray_Scalar(&n_b, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_b_scalar == NULL) { + goto out; + } + arglist = Py_BuildValue("OOOO", X_array, ss_sizes, n_a_scalar, n_b_scalar); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(summary_func, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + Py_XDECREF(n_a_scalar); + Py_XDECREF(n_b_scalar); + return ret; +} + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t result_dim, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + two_locus_general_stat_params *tl_params = params; + PyObject *summary_func = tl_params->summary_func; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + npy_intp X_dims[2] = { K, 3 }; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { goto out; } PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + // For example: pAB, pAb, paB = X / n + // which works with K>1. In addition, the data is not reordered, meaning + // that the data is still oriented where samples are rows, meaning that + // we'll preserve data locality in ops over samples. + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } arglist = Py_BuildValue("OO", X_array, ss_sizes); if (arglist == NULL) { goto out; } - result = PyObject_CallObject(callable, arglist); + result = PyObject_CallObject(summary_func, arglist); if (result == NULL) { goto out; } @@ -7998,28 +8074,25 @@ general_two_locus_count_stat_func( } if (PyArray_NDIM(Y_array) != 1) { PyErr_Format(PyExc_ValueError, - "Array returned by general_stat callback is %d dimensional; " + "Array returned by summary function callback is %d dimensional; " "must be 1D", (int) PyArray_NDIM(Y_array)); goto out; } - Y_dims = PyArray_DIMS(Y_array); - if (Y_dims[0] != (npy_intp) M) { + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { PyErr_Format(PyExc_ValueError, - "Array returned by general_stat callback is of length %d; " - "must be %d", - Y_dims[0], M); + "Array returned by summary function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); goto out; } /* Copy the contents of the return Y array into Y */ - memcpy(Y, PyArray_DATA(Y_array), M * sizeof(*Y)); + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); ret = 0; out: Py_XDECREF(X_array); Py_XDECREF(arglist); Py_XDECREF(result); Py_XDECREF(Y_array); - Py_XDECREF(ss_sizes); return ret; } @@ -8028,10 +8101,11 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * { PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", - "output_dim", "polarised", "row_sites", "col_sites", "row_positions", - "column_positions", "mode", "drop_dimensions", NULL }; + "norm_func", "output_dim", "polarised", "row_sites", "col_sites", + "row_positions", "column_positions", "mode", NULL }; two_locus_general_stat_params *params; PyObject *summary_func = NULL; + PyObject *norm_func = NULL; unsigned int output_dim; PyObject *sample_set_sizes = NULL; PyObject *sample_sets = NULL; @@ -8054,25 +8128,29 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * npy_intp result_dim[3] = { 0, 0, 0 }; tsk_size_t num_sample_sets; tsk_flags_t options = 0; - int drop_dimensions = 0; int polarised = 0; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|si", kwlist, - &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, - &row_sites, &col_sites, &row_positions, &col_positions, &mode, - &drop_dimensions)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &norm_func, &output_dim, + &polarised, &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { Py_XINCREF(summary_func); + Py_XINCREF(norm_func); goto out; } Py_INCREF(summary_func); + Py_INCREF(norm_func); if (!PyCallable_Check(summary_func)) { PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); goto out; } + if (!PyCallable_Check(norm_func)) { + PyErr_SetString(PyExc_TypeError, "norm_func must be callable"); + goto out; + } if (parse_stats_mode(mode, &options) != 0) { goto out; } @@ -8113,7 +8191,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * col_positions_parsed = PyArray_DATA(col_positions_array); } - result_dim[2] = num_sample_sets; + result_dim[2] = output_dim; result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); if (result_matrix == NULL) { PyErr_NoMemory(); @@ -8122,15 +8200,16 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * params = &(two_locus_general_stat_params){ .sample_set_sizes = sample_set_sizes_array, - .callable = summary_func, - .drop_dimensions = drop_dimensions, + .summary_func = summary_func, + .norm_func = norm_func, }; // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), - output_dim, general_two_locus_count_stat_func, params, NULL, result_dim[0], - row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, - col_positions_parsed, options, PyArray_DATA(result_matrix)); + output_dim, general_two_locus_count_stat_func, params, + general_two_locus_norm_func, result_dim[0], row_sites_parsed, + row_positions_parsed, result_dim[1], col_sites_parsed, col_positions_parsed, + options, PyArray_DATA(result_matrix)); if (err == TSK_PYTHON_CALLBACK_ERROR) { goto out; diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 4d6e47ddcc..e784a9628b 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,7 +22,6 @@ """ Test cases for two-locus statistics """ - import contextlib import io from collections.abc import Callable, Generator @@ -2398,3 +2397,257 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex norm_hap_weighted_ij(1, state, max(a) + 1, max(b) + 1, norm[i, j], params) np.testing.assert_allclose((result * norm).sum(), expected) + + +class GeneralStatFuncs: + """ + functions take X, n as parameters where + + X: shape=(3, #ss) + sample sets + count AB [[ ] + count Ab [ ] + count aB [ ]] + + n: shape=(#ss, ) + [ ] + """ + + @staticmethod + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + @staticmethod + def D2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return (pAB - (pA * pB)) ** 2 + + @staticmethod + def r2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D**2 / denom + + @staticmethod + def r(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D / np.sqrt(denom) + + @staticmethod + def D_prime(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = np.vstack( + [ + np.min([pA * (1 - pB), (1 - pA) * pB], axis=0), + np.min([pA * pB, (1 - pA) * (1 - pB)], axis=0), + ] + ) + with suppress_overflow_div0_warning(): + return D / denom[(D < 0).astype(int), range(len(D))] + + @staticmethod + def Dz(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return D * (1 - 2 * pA) * (1 - 2 * pB) + + @staticmethod + def pi2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pA * (1 - pA) * pB * (1 - pB) + + @staticmethod + def D2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((aB**2) * (Ab - 1) * Ab) + + ((ab - 1) * ab * (AB - 1) * AB) + - (aB * Ab * (Ab + (2 * ab * AB) - 1)) + ) + + @staticmethod + def Dz_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + (((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB)) + - ((AB * ab) * (AB + ab - Ab - aB - 2)) + - ((Ab * aB) * (Ab + aB - AB - ab - 2)) + ) + + @staticmethod + def pi2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab)) + - ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1)) + - ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1)) + ) + + @staticmethod + def r2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = np.prod(pAB - (pA * pB)) + denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) + with suppress_overflow_div0_warning(): + return np.expand_dims(D / denom, axis=0) + + @staticmethod + def D2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return np.expand_dims(np.prod(D), axis=0) + + @staticmethod + def D2_ij_unbiased(X, n): + """ + NB: the two sample sets must be disjoint + we have no way for testing equality + """ + AB, Ab, aB = X + ab = n - X.sum(0) + return np.expand_dims( + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / n[0] + / (n[0] - 1) + / n[1] + / (n[1] - 1), + axis=0, + ) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "D", + ), + (ts, "D2"), + (ts, "r2"), + (ts, "r"), + (ts, "D_prime"), + (ts, "Dz"), + (ts, "pi2"), + (ts, "D2_unbiased"), + (ts, "Dz_unbiased"), + (ts, "pi2_unbiased"), + ], +) +def test_general_two_locus_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) + ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) + np.testing.assert_equal(ldg, ld) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "r2_ij", + ), + (ts, "D2_ij"), + (ts, "D2_ij_unbiased"), + ], +) +def test_general_two_locus_two_way_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) + ld = ts.ld_matrix( + sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) + ) + np.testing.assert_allclose(ldg, ld) + + +@pytest.mark.parametrize( + "stat", + [ + "D", + "D2", + "r2", + "r", + "D_prime", + "Dz", + "pi2", + "D2_unbiased", + "Dz_unbiased", + "pi2_unbiased", + ], +) +def test_general_one_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2": + result = ts.two_locus_count_stat( + [ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n + ) + elif stat in {"D", "r", "D_prime"}: + result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples()], func, 1) + np.testing.assert_allclose(ts.ld_matrix(stat=stat), result) + + +@pytest.mark.parametrize( + "stat", + [ + "r2_ij", + "D2_ij", + "D2_ij_unbiased", + ], +) +def test_general_two_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2_ij": + result = ts.two_locus_count_stat( + [ts.samples(), ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n + ) + elif stat in {"D", "r", "D_prime"}: + result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) + np.testing.assert_allclose( + ts.ld_matrix( + stat=stat.replace("_ij", ""), + indexes=(0, 1), + sample_sets=[ts.samples(), ts.samples()], + ), + result, + ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index a858f7ff0e..c23d98b9cc 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -696,7 +696,8 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 and is ignored", + "The sample_counts option is not supported since 0.2.4 " + "and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6945,7 +6946,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" ) return "\n".join(output) + "\n" @@ -9391,9 +9392,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert time_windows[0] < time_windows[1], ( - "The second argument should be larger." - ) + assert ( + time_windows[0] < time_windows[1] + ), "The second argument should be larger." tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), @@ -10936,21 +10937,20 @@ def two_locus_count_stat( sample_sets, f, result_dim, + norm_f=lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0), polarised=False, sites=None, positions=None, mode="site", - drop_dimensions=True, ): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) - drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( - sample_sets - ) + _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, - flattened, + sample_sets, f, + norm_f, result_dim, polarised, row_sites, @@ -10958,15 +10958,12 @@ def two_locus_count_stat( row_positions, col_positions, mode, - drop_dimensions, ) - if drop_dimension: - result = result.reshape(result.shape[:2]) - else: - # Orient the data so that the first dimension is the sample set. - # With this orientation, we get one LD matrix per sample set. - result = result.swapaxes(0, 2).swapaxes(1, 2) - return result + if result_dim == 1: # drop dimension + return result.reshape(result.shape[:2]) + # Orient the data so that the first dimension is the sample set so that + # we get one LD matrix per sample set. + return result.swapaxes(0, 2).swapaxes(1, 2) def ld_matrix( self, From 8018550dfecd28dd35dfa665719446ac4307e448 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:16:12 -0600 Subject: [PATCH 04/33] turns out, the general norm function needs to know the state_dims --- c/tskit/trees.c | 17 ++++++++++------- c/tskit/trees.h | 4 ++-- python/_tskitmodule.c | 6 +++--- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index a8f6e168f9..cccf56a8be 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2298,8 +2298,9 @@ get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset, } static int -norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2315,8 +2316,9 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, } static int -norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted_ij(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2341,8 +2343,9 @@ norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, } static int -norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), - tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) +norm_total_weighted(tsk_size_t TSK_UNUSED(state_dim), + const double *TSK_UNUSED(hap_weights), tsk_size_t result_dim, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; double norm = 1 / (double) (n_a * n_b); @@ -2445,7 +2448,7 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, if (ret != 0) { goto out; } - ret = norm_f(result_dim, weights, num_a_alleles - is_polarised, + ret = norm_f(state_dim, weights, result_dim, num_a_alleles - is_polarised, num_b_alleles - is_polarised, norm, f_params); if (ret != 0) { goto out; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index acc15c9aac..2bf1a26cc9 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1036,8 +1036,8 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, - tsk_size_t n_b, double *result, void *params); +typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, + tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params); int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c6bca34ca9..3ad4229186 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7953,8 +7953,8 @@ typedef struct { } two_locus_general_stat_params; static int -general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n_a, - tsk_size_t n_b, double *Y, void *params) +general_two_locus_norm_func(tsk_size_t K, const double *X, tsk_size_t result_dim, + tsk_size_t n_a, tsk_size_t n_b, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; PyObject *arglist = NULL; @@ -7966,7 +7966,7 @@ general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n two_locus_general_stat_params *tl_params = params; PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - npy_intp X_dims[2] = { result_dim, 3 }; + npy_intp X_dims[2] = { K, 3 }; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( From 03511419f9bf52fac663273a4e2a8a54ddd7c695 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:18:40 -0600 Subject: [PATCH 05/33] fix up a bit of naming in general test funcs, remove unneeded branch, fix norm func for r2_ij --- python/tests/test_ld_matrix.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index e784a9628b..6ba04ceb41 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2512,18 +2512,17 @@ def r2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D = np.prod(pAB - (pA * pB)) + D2_ij = np.prod(pAB - (pA * pB)) denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) with suppress_overflow_div0_warning(): - return np.expand_dims(D / denom, axis=0) + return np.expand_dims(D2_ij / denom, axis=0) @staticmethod def D2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D = pAB - (pA * pB) - return np.expand_dims(np.prod(D), axis=0) + return np.expand_dims(np.prod(pAB - (pA * pB)), axis=0) @staticmethod def D2_ij_unbiased(X, n): @@ -2635,11 +2634,8 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - result = ts.two_locus_count_stat( - [ts.samples(), ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n - ) - elif stat in {"D", "r", "D_prime"}: - result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + norm_f = lambda X, n, nA, nB: np.expand_dims(X[0].sum() / n.sum(), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) From ad920d0e8ecdc20e6f085ab14bc2895dc6e937e8 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:21:12 -0600 Subject: [PATCH 06/33] flake8 does not like assigning lambdas to variables --- python/tests/test_ld_matrix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 6ba04ceb41..e230bce601 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2634,7 +2634,8 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - norm_f = lambda X, n, nA, nB: np.expand_dims(X[0].sum() / n.sum(), axis=0) + def norm_f(X, n, nA, nB): + return np.expand_dims(X[0].sum() / n.sum(), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) From 8d22e8ea301cc4020d921877f5017d92d711a596 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:23:53 -0600 Subject: [PATCH 07/33] and black doesn't like that --- python/tests/test_ld_matrix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index e230bce601..524287e9be 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2634,8 +2634,10 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": + def norm_f(X, n, nA, nB): return np.expand_dims(X[0].sum() / n.sum(), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) From e5fcd0e9f2b11cd7167dd33fa83b8d09ef9fff45 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sat, 6 Dec 2025 18:29:12 -0600 Subject: [PATCH 08/33] do not test equality, this was useful on my local machine but is problematic in practice --- python/tests/test_ld_matrix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 524287e9be..b1b2cd29b6 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,6 +22,7 @@ """ Test cases for two-locus statistics """ + import contextlib import io from collections.abc import Callable, Generator @@ -2567,7 +2568,7 @@ def test_general_two_locus_site_stat(ts, stat): sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) - np.testing.assert_equal(ldg, ld) + np.testing.assert_allclose(ldg, ld) @pytest.mark.parametrize( From 6bb867f4dcf848a6eafe5cad8b168eff5f863129 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 16:33:19 -0500 Subject: [PATCH 09/33] lowlevel tests --- python/tests/test_python_c.py | 165 ++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 15f9967f3f..80861f83d5 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -1987,6 +1987,171 @@ def test_ld_matrix_multipop(self, stat_method_name): with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") + def test_two_locus_count_stat(self): + ts = self.get_example_tree_sequence(10) + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + row_pos = ts.get_breakpoints()[:-1] + col_pos = row_pos + row_sites_list = list(range(ts.get_num_sites())) + col_sites_list = row_sites_list + row_pos_list = list(map(float, ts.get_breakpoints()[:-1])) + col_pos_list = row_pos_list + + def stat_func(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + def norm_func(X, n, nA, nB): + return np.expand_dims(X[0].sum() / n.sum(), axis=0) + + method = ts.two_locus_count_stat + + site_args = row_sites, col_sites, None, None, "site" + branch_args = None, None, row_pos, col_pos, "branch" + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) + assert a.shape == (10, 10, 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_args) + assert a.shape == (2, 2, 1) + site_list_args = row_sites_list, col_sites_list, None, None, "site" + branch_list_args = None, None, row_pos_list, col_pos_list, "branch" + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) + assert a.shape == (10, 10, 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) + assert a.shape == (2, 2, 1) + # CPython API errors + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + bad_ss = np.array([], dtype=np.int32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="cast array data"): + bad_ss = np.array(ts.get_samples(), dtype=np.uint32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="Unrecognised stats mode"): + bad_args = row_sites, col_sites, None, None, "bla" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_args) + with pytest.raises(TypeError, match="at most"): + method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args, "extraarg") + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError, match="Cannot specify positions in site mode"): + bad_site_args = None, None, row_pos, col_pos, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="Cannot specify sites in branch mode"): + bad_branch_args = row_sites, col_sites, None, None, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError, match="summary_func must be callable"): + method(ss_sizes, ss, "uncallable", norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="norm_func must be callable"): + method(ss_sizes, ss, stat_func, "uncallable", 1, True, *site_args) + with pytest.raises(ValueError, match="summary function.*must be 1D"): + method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="length 2; must be 1"): + method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args) + # TODO: Cannot test without multiallelic sites + # with pytest.raises(ValueError, match="summary function.*must be 1D"): + # method(ss_sizes, ss, stat_func, lambda a, b, c, d: 1, 1, True, *site_args) + # with pytest.raises(ValueError, match="length 2; must be 1"): + # method(ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args) + # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): From 26bebf5b70f09cea6884feef768e1e68e29c90bc Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 16:47:41 -0500 Subject: [PATCH 10/33] relax diff requirements (macos failure) --- python/tests/test_ld_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index b1b2cd29b6..28c4ad3ff5 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2568,7 +2568,7 @@ def test_general_two_locus_site_stat(ts, stat): sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) - np.testing.assert_allclose(ldg, ld) + np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize( From 667dc653c1924eeda72b1202c5f322515a95780a Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 17:08:05 -0500 Subject: [PATCH 11/33] relax diff requirements (macos failure) -- previous commit fixed one --- python/tests/test_ld_matrix.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 28c4ad3ff5..953f542f30 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2590,7 +2590,7 @@ def test_general_two_locus_two_way_site_stat(ts, stat): ld = ts.ld_matrix( sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) ) - np.testing.assert_allclose(ldg, ld) + np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize( @@ -2620,7 +2620,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat): else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) result = ts.two_locus_count_stat([ts.samples()], func, 1) - np.testing.assert_allclose(ts.ld_matrix(stat=stat), result) + np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result) @pytest.mark.parametrize( @@ -2643,7 +2643,7 @@ def norm_f(X, n, nA, nB): else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) - np.testing.assert_allclose( + np.testing.assert_array_almost_equal( ts.ld_matrix( stat=stat.replace("_ij", ""), indexes=(0, 1), From 8ea085289880c91f6eb489ac1033c70f62eed8b7 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 9 Mar 2026 17:09:14 -0500 Subject: [PATCH 12/33] new formatting tools, fix lint --- python/_tskitmodule.c | 2 +- python/tests/test_python_c.py | 4 +++- python/tskit/trees.py | 11 +++++------ 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 3ad4229186..09f29e6de3 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8198,7 +8198,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * goto out; } - params = &(two_locus_general_stat_params){ + params = &(two_locus_general_stat_params) { .sample_set_sizes = sample_set_sizes_array, .summary_func = summary_func, .norm_func = norm_func, diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 80861f83d5..7181964509 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2101,7 +2101,9 @@ def norm_func(X, n, nA, nB): # with pytest.raises(ValueError, match="summary function.*must be 1D"): # method(ss_sizes, ss, stat_func, lambda a, b, c, d: 1, 1, True, *site_args) # with pytest.raises(ValueError, match="length 2; must be 1"): - # method(ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args) + # method( + # ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args + # ) # C API errors with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index c23d98b9cc..6fc2bc0b4d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -696,8 +696,7 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 " - "and is ignored", + "The sample_counts option is not supported since 0.2.4 and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6946,7 +6945,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" ) return "\n".join(output) + "\n" @@ -9392,9 +9391,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert ( - time_windows[0] < time_windows[1] - ), "The second argument should be larger." + assert time_windows[0] < time_windows[1], ( + "The second argument should be larger." + ) tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), From 2ddf6d03c70b274a472cd10cf05a5a977660ccfc Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 10 Mar 2026 10:16:26 -0500 Subject: [PATCH 13/33] remove TODOs, old comment and tested elsewhere --- python/_tskitmodule.c | 1 - python/tests/test_python_c.py | 7 ------- 2 files changed, 8 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 09f29e6de3..c772270de6 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -8203,7 +8203,6 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * .summary_func = summary_func, .norm_func = norm_func, }; - // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), output_dim, general_two_locus_count_stat_func, params, diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 7181964509..625d8f9bcb 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2097,13 +2097,6 @@ def norm_func(X, n, nA, nB): method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args) with pytest.raises(ValueError, match="length 2; must be 1"): method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args) - # TODO: Cannot test without multiallelic sites - # with pytest.raises(ValueError, match="summary function.*must be 1D"): - # method(ss_sizes, ss, stat_func, lambda a, b, c, d: 1, 1, True, *site_args) - # with pytest.raises(ValueError, match="length 2; must be 1"): - # method( - # ss_sizes, ss, stat_func, lambda a, b, c, d: [1, 2], 1, True, *site_args - # ) # C API errors with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) From 680f9c59c4b133682b9a4bf5ea570ac1c6fc90d3 Mon Sep 17 00:00:00 2001 From: peter Date: Sun, 15 Mar 2026 08:05:10 -0700 Subject: [PATCH 14/33] make testing more clear --- python/tests/test_ld_matrix.py | 84 ++++++++++++++++++---------------- python/tests/tsutil.py | 42 +++++++++++------ 2 files changed, 73 insertions(+), 53 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 953f542f30..2160a360c0 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2063,38 +2063,36 @@ def compute_branch_stat( ts for ts in get_example_tree_sequences() if ts.id - not in { - "no_samples", - "empty_ts", - # We must skip these cases so that tests run in a reasonable - # amount of time. To get more complete testing, these filters - # can be commented out. (runtime ~1hr) - "gap_0", - "gap_0.1", - "gap_0.5", - "gap_0.75", - "n=2_m=32_rho=0", - "n=10_m=1_rho=0", - "n=10_m=1_rho=0.1", - "n=10_m=2_rho=0", - "n=10_m=2_rho=0.1", - "n=10_m=32_rho=0", - "n=10_m=32_rho=0.1", - "n=10_m=32_rho=0.5", + in { + # We run only these cases so that tests run in a reasonable + # amount of time. All examples takes ~1hr. + "decapitate_recomb", + "gap_at_end", + "all_nodes_samples", + "internal_nodes_samples", + "mixed_internal_leaf_samples", + "bottleneck_n=3_mutated", + "bottleneck_n=10_mutated", + "rev_node_order", + "empty_tree", + "n=3_m=2_rho=0.5", + "n=3_m=32_rho=0", + "n=3_m=32_rho=0.1", + "n=2_m=1_rho=0", + "n=2_m=1_rho=0.1", + "n=2_m=1_rho=0.5", + "n=2_m=2_rho=0", + "n=2_m=2_rho=0.1", + "n=2_m=2_rho=0.5", + "n=2_m=32_rho=0.1", + "n=2_m=32_rho=0.5", + "n=3_m=1_rho=0", + "n=3_m=1_rho=0.5", + "n=3_m=2_rho=0", + "n=10_m=1_rho=0.5", + "n=10_m=2_rho=0.5", # we keep one n=100 case to ensure bit arrays are working - "n=100_m=1_rho=0.1", - "n=100_m=1_rho=0.5", - "n=100_m=2_rho=0", - "n=100_m=2_rho=0.1", - "n=100_m=2_rho=0.5", - "n=100_m=32_rho=0", - "n=100_m=32_rho=0.1", - "n=100_m=32_rho=0.5", - "all_fields", - "back_mutations", - "multichar", - "multichar_no_metadata", - "bottleneck_n=100_mutated", + "n=100_m=1_rho=0", } ], ) @@ -2548,9 +2546,13 @@ def D2_ij_unbiased(X, n): "ts,stat", [ ( - ts := [ - p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" - ][0].values[0], + ts := tsutil.get_sim_example( + sample_size=100, + sequence_length=32, + recombination_rate=0.5, + mutation_rate=0.1, + seed=123, + ), "D", ), (ts, "D2"), @@ -2575,9 +2577,13 @@ def test_general_two_locus_site_stat(ts, stat): "ts,stat", [ ( - ts := [ - p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" - ][0].values[0], + ts := tsutil.get_sim_example( + sample_size=100, + sequence_length=32, + recombination_rate=0.5, + mutation_rate=0.1, + seed=123, + ), "r2_ij", ), (ts, "D2_ij"), @@ -2609,7 +2615,7 @@ def test_general_two_locus_two_way_site_stat(ts, stat): ], ) def test_general_one_way_two_locus_stat_multiallelic(stat): - (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + ts = tsutil.all_fields_ts() func = getattr(GeneralStatFuncs, stat) if stat == "r2": result = ts.two_locus_count_stat( @@ -2632,7 +2638,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat): ], ) def test_general_two_way_two_locus_stat_multiallelic(stat): - (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + ts = tsutil.all_fields_ts() func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 0037e06391..48ff72d0d7 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -2463,6 +2463,28 @@ def get_back_mutation_examples(): yield insert_branch_mutations(ts) +@functools.lru_cache +def get_sim_example( + sample_size, sequence_length, recombination_rate, mutation_rate, seed +): + recomb_map = msprime.RecombinationMap.uniform_map( + sequence_length, recombination_rate + ) + ts = msprime.simulate( + recombination_map=recomb_map, + mutation_rate=mutation_rate, + random_seed=seed, + population_configurations=[ + msprime.PopulationConfiguration(sample_size), + msprime.PopulationConfiguration(0), + ], + migration_matrix=[[0, 1], [1, 0]], + ) + ts = insert_random_ploidy_individuals(ts, 4, seed=seed) + ts = add_random_metadata(ts, seed=seed) + return ts + + def make_example_tree_sequences(custom_max=None): yield from get_decapitated_examples(custom_max=custom_max) yield from get_gap_examples(custom_max=custom_max) @@ -2475,22 +2497,14 @@ def make_example_tree_sequences(custom_max=None): for n in n_list: for m in [1, 2, 32]: for rho in [0, 0.1, 0.5]: - recomb_map = msprime.RecombinationMap.uniform_map(m, rho, num_loci=m) - ts = msprime.simulate( - recombination_map=recomb_map, + ts = get_sim_example( + sample_size=n, + sequence_length=m, + recombination_rate=rho, mutation_rate=0.1, - random_seed=seed, - population_configurations=[ - msprime.PopulationConfiguration(n), - msprime.PopulationConfiguration(0), - ], - migration_matrix=[[0, 1], [1, 0]], - ) - ts = insert_random_ploidy_individuals(ts, 4, seed=seed) - yield ( - f"n={n}_m={m}_rho={rho}", - add_random_metadata(ts, seed=seed), + seed=seed, ) + yield (f"n={n}_m={m}_rho={rho}", ts) seed += 1 for name, ts in get_bottleneck_examples(custom_max=custom_max): yield ( From 5e984bec3aa04a499f3af37036dea3ede8c87463 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 15 Mar 2026 21:45:28 -0500 Subject: [PATCH 15/33] preserve native dimensions instead of expanding at the end --- python/tests/test_ld_matrix.py | 43 ++++++++++++++++------------------ python/tskit/trees.py | 2 +- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 2160a360c0..e7133af0b9 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2511,34 +2511,30 @@ def r2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D2_ij = np.prod(pAB - (pA * pB)) - denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) + D2_ij = np.prod(pAB - (pA * pB), keepdims=True) + denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)), keepdims=True) with suppress_overflow_div0_warning(): - return np.expand_dims(D2_ij / denom, axis=0) + return D2_ij / denom @staticmethod def D2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - return np.expand_dims(np.prod(pAB - (pA * pB)), axis=0) + return np.prod(pAB - (pA * pB), keepdims=True) @staticmethod def D2_ij_unbiased(X, n): - """ - NB: the two sample sets must be disjoint - we have no way for testing equality - """ + """NB: We use double brackets here to preserve the output shape of (1,)""" AB, Ab, aB = X ab = n - X.sum(0) - return np.expand_dims( - (Ab[0] * aB[0] - AB[0] * ab[0]) - * (Ab[1] * aB[1] - AB[1] * ab[1]) - / n[0] - / (n[0] - 1) - / n[1] - / (n[1] - 1), - axis=0, + return ( + (Ab[[0]] * aB[[0]] - AB[[0]] * ab[[0]]) + * (Ab[[1]] * aB[[1]] - AB[[1]] * ab[[1]]) + / n[[0]] + / (n[[0]] - 1) + / n[[1]] + / (n[[1]] - 1) ) @@ -2624,7 +2620,7 @@ def test_general_one_way_two_locus_stat_multiallelic(stat): elif stat in {"D", "r", "D_prime"}: result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) else: - # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` result = ts.two_locus_count_stat([ts.samples()], func, 1) np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result) @@ -2641,13 +2637,14 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): ts = tsutil.all_fields_ts() func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - - def norm_f(X, n, nA, nB): - return np.expand_dims(X[0].sum() / n.sum(), axis=0) - - result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) + result = ts.two_locus_count_stat( + [ts.samples(), ts.samples()], + func, + 1, + lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum(), + ) else: - # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) np.testing.assert_array_almost_equal( ts.ld_matrix( diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 6fc2bc0b4d..23e38c8e6e 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10936,7 +10936,7 @@ def two_locus_count_stat( sample_sets, f, result_dim, - norm_f=lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0), + norm_f=lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,], polarised=False, sites=None, positions=None, From e08abf42b1c0254de8cbe4d580d26172cf434e33 Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 16 Mar 2026 17:33:00 -0500 Subject: [PATCH 16/33] Update tests according to Peters's feedback *Python tests* Overhaul python testing of the general stat functions. Remove the dependence on the example tree sequences, opting instead to simulate a couple of examples directly. Use these simulated trees in test fixtures, scoped at the module level. This streamlines the test parameterization a lot. Use the single stat site names from the summary function definitions. *CPython tests* Add a multiallelic tree sequence to test normalisation function validation and errors. Remove one more occurrence of `np.expand_dims`. *trees.c* Remove the unnecessary branch in tsk_treeseq_two_locus_count_general_stat, improving the code coverage. *trees.py* Default normalisation function can be None, applying default at runtime. Simplifies calling code and is more in line with the rest of the API. --- c/tskit/trees.c | 29 +++--- python/tests/test_ld_matrix.py | 163 +++++++++++++-------------------- python/tests/test_python_c.py | 93 +++++++++++++++++-- python/tskit/trees.py | 4 +- 4 files changed, 165 insertions(+), 124 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index cccf56a8be..0f8a10c182 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3439,21 +3439,22 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); - } else if (stat_branch) { - ret = check_positions( - row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = check_positions( - col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, - row_positions, out_cols, col_positions, options, result); + goto out; + } + tsk_bug_assert(stat_branch); + ret = check_positions( + row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = check_positions( + col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; } + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, + row_positions, out_cols, col_positions, options, result); out: return ret; } diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index e7133af0b9..bb459ee7ca 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2538,55 +2538,49 @@ def D2_ij_unbiased(X, n): ) -@pytest.mark.parametrize( - "ts,stat", - [ - ( - ts := tsutil.get_sim_example( - sample_size=100, - sequence_length=32, - recombination_rate=0.5, - mutation_rate=0.1, - seed=123, - ), - "D", +@pytest.fixture(scope="module") +def ts_100_samp_with_sites_fixture(): + ts = tsutil.get_sim_example( + sample_size=100, + sequence_length=32, + recombination_rate=0.5, + mutation_rate=0.1, + seed=123, + ) + assert ts.num_sites > 0, "sites are required" + assert ts.num_samples == 100, "100 samples are required" + return ts + + +@pytest.fixture(scope="module") +def ts_multiallelic_fixture(): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + 2, recombination_rate=0.1, sequence_length=100, random_seed=123 ), - (ts, "D2"), - (ts, "r2"), - (ts, "r"), - (ts, "D_prime"), - (ts, "Dz"), - (ts, "pi2"), - (ts, "D2_unbiased"), - (ts, "Dz_unbiased"), - (ts, "pi2_unbiased"), - ], -) -def test_general_two_locus_site_stat(ts, stat): + rate=0.1, + random_seed=123, + ) + # Need at least 4 samples to test unbiased statistics + assert ts.num_samples >= 4, "At least 4 samples required" + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + return ts + + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): + ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) np.testing.assert_array_almost_equal(ldg, ld) -@pytest.mark.parametrize( - "ts,stat", - [ - ( - ts := tsutil.get_sim_example( - sample_size=100, - sequence_length=32, - recombination_rate=0.5, - mutation_rate=0.1, - seed=123, - ), - "r2_ij", - ), - (ts, "D2_ij"), - (ts, "D2_ij_unbiased"), - ], -) -def test_general_two_locus_two_way_site_stat(ts, stat): +@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) +def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixture): + ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) ld = ts.ld_matrix( @@ -2595,62 +2589,31 @@ def test_general_two_locus_two_way_site_stat(ts, stat): np.testing.assert_array_almost_equal(ldg, ld) -@pytest.mark.parametrize( - "stat", - [ - "D", - "D2", - "r2", - "r", - "D_prime", - "Dz", - "pi2", - "D2_unbiased", - "Dz_unbiased", - "pi2_unbiased", - ], -) -def test_general_one_way_two_locus_stat_multiallelic(stat): - ts = tsutil.all_fields_ts() - func = getattr(GeneralStatFuncs, stat) - if stat == "r2": - result = ts.two_locus_count_stat( - [ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n - ) - elif stat in {"D", "r", "D_prime"}: - result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) - else: - # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` - result = ts.two_locus_count_stat([ts.samples()], func, 1) - np.testing.assert_array_almost_equal(ts.ld_matrix(stat=stat), result) - - -@pytest.mark.parametrize( - "stat", - [ - "r2_ij", - "D2_ij", - "D2_ij_unbiased", - ], -) -def test_general_two_way_two_locus_stat_multiallelic(stat): - ts = tsutil.all_fields_ts() - func = getattr(GeneralStatFuncs, stat) - if stat == "r2_ij": - result = ts.two_locus_count_stat( - [ts.samples(), ts.samples()], - func, - 1, - lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum(), - ) - else: - # default norm func is `lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]` - result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) - np.testing.assert_array_almost_equal( - ts.ld_matrix( - stat=stat.replace("_ij", ""), - indexes=(0, 1), - sample_sets=[ts.samples(), ts.samples()], - ), - result, +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + polarised = POLARIZATION[SUMMARY_FUNCS[stat]] + ldg = ts.two_locus_count_stat( + [ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised + ) + ld = ts.ld_matrix(stat=stat) + np.testing.assert_array_almost_equal(ld, ldg) + + +@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) +def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_func = ( + (lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum()) + if stat == "r2_ij" + else None + ) + sample_sets = [ts.samples(), ts.samples()] + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_func) + ld = ts.ld_matrix( + stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets ) + np.testing.assert_array_almost_equal(ld, ldg) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 625d8f9bcb..d9a5f79be7 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -138,6 +138,23 @@ def get_example_migration_tree_sequence(self): ) return ts.ll_tree_sequence + def get_example_tree_sequence_multiallelic(self, sample_size=10): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + sample_size, + recombination_rate=0.1, + sequence_length=100, + ploidy=1, + random_seed=123, + ), + rate=0.1, + random_seed=123, + ) + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + return ts.ll_tree_sequence + def verify_iterator(self, iterator): """ Checks that the specified non-empty iterator implements the @@ -1989,6 +2006,12 @@ def test_ld_matrix_multipop(self, stat_method_name): def test_two_locus_count_stat(self): ts = self.get_example_tree_sequence(10) + # Multiallelic test case to test norm function + ts_multi = self.get_example_tree_sequence_multiallelic() + assert (ts.get_samples() == ts_multi.get_samples()).all(), ( + "biallelic and multiallelic test case are expected " + "to have the same sample nodes" + ) ss = ts.get_samples() # sample sets ss_sizes = np.array([len(ss)], dtype=np.uint32) row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) @@ -2007,10 +2030,9 @@ def stat_func(X, n): return pAB - (pA * pB) def norm_func(X, n, nA, nB): - return np.expand_dims(X[0].sum() / n.sum(), axis=0) - - method = ts.two_locus_count_stat + return X[0].sum(keepdims=True) / n.sum() + method = ts.two_locus_count_stat # most tests on biallelic site_args = row_sites, col_sites, None, None, "site" branch_args = None, None, row_pos, col_pos, "branch" a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) @@ -2019,10 +2041,20 @@ def norm_func(X, n, nA, nB): assert a.shape == (2, 2, 1) site_list_args = row_sites_list, col_sites_list, None, None, "site" branch_list_args = None, None, row_pos_list, col_pos_list, "branch" + + # happy path a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) - assert a.shape == (10, 10, 1) + assert a.shape == (10, 10, 1) # ts has 10 sites a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) - assert a.shape == (2, 2, 1) + assert a.shape == (2, 2, 1) # ts has 2 trees + a = ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" + ) + assert a.shape == (56, 56, 1) # ts has 56 sites + a = ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch" + ) + assert a.shape == (48, 48, 1) # ts has 48 trees # CPython API errors with pytest.raises(ValueError, match="Sum of sample_set_sizes"): bad_ss = np.array([], dtype=np.int32) @@ -2094,10 +2126,55 @@ def norm_func(X, n, nA, nB): with pytest.raises(TypeError, match="norm_func must be callable"): method(ss_sizes, ss, stat_func, "uncallable", 1, True, *site_args) with pytest.raises(ValueError, match="summary function.*must be 1D"): - method(ss_sizes, ss, lambda a, b: 1, norm_func, 1, True, *site_args) - with pytest.raises(ValueError, match="length 2; must be 1"): - method(ss_sizes, ss, lambda a, b: [1, 2], norm_func, 1, True, *site_args) + method(ss_sizes, ss, lambda *_: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="summary function.*length 2; must be 1"): + method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="could not convert string to float"): + method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="norm function.*must be 1D"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args + ) + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + ts_multi.two_locus_count_stat( + ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args + ) + with pytest.raises(ValueError, match="norm function.*length 2; must be 1"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args + ) + with pytest.raises( + TypeError, match="takes 1 positional argument but 4 were given" + ): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args + ) + with pytest.raises(ValueError, match="could not convert string to float"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args + ) + # Exceptions within stat_func and norm_func are correctly raised. + for exception in [ValueError, TypeError]: + + def stat_func_except(*_): + raise exception("test") + + def norm_func_except(*_): + raise exception("test") + + with pytest.raises(exception, match="test"): + method( + ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_list_args + ) + with pytest.raises(exception, match="test"): + ts_multi.two_locus_count_stat( + ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_list_args + ) # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_RESULT_DIMS"): + method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_list_args) with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) bad_site_args = bad_sites, col_sites, None, None, "site" diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 23e38c8e6e..ee9d4e1114 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10936,7 +10936,7 @@ def two_locus_count_stat( sample_sets, f, result_dim, - norm_f=lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,], + norm_f=None, polarised=False, sites=None, positions=None, @@ -10949,7 +10949,7 @@ def two_locus_count_stat( sample_set_sizes, sample_sets, f, - norm_f, + norm_f or (lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]), result_dim, polarised, row_sites, From bd0a1a5683c4d8505a0ec26c8fcbf89da32c3dad Mon Sep 17 00:00:00 2001 From: lkirk Date: Mon, 16 Mar 2026 17:52:04 -0500 Subject: [PATCH 17/33] msprime produces different trees on macos (same seed) --- python/tests/test_python_c.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index d9a5f79be7..ccd73ac3d4 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2050,7 +2050,12 @@ def norm_func(X, n, nA, nB): a = ts_multi.two_locus_count_stat( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" ) - assert a.shape == (56, 56, 1) # ts has 56 sites + import platform + + if platform.system() == "Darwin": + assert a.shape == (54, 54, 1) # ts has 54 sites on macos? + else: + assert a.shape == (56, 56, 1) # ts has 56 sites a = ts_multi.two_locus_count_stat( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch" ) @@ -2159,10 +2164,10 @@ def norm_func(X, n, nA, nB): for exception in [ValueError, TypeError]: def stat_func_except(*_): - raise exception("test") + raise exception("test") # noqa: B023 def norm_func_except(*_): - raise exception("test") + raise exception("test") # noqa: B023 with pytest.raises(exception, match="test"): method( From 4c04ff3291c75d4189cd0af83b60de9977a5114c Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 13:07:12 -0500 Subject: [PATCH 18/33] Clean up python C tests Use the number of sites and trees reported by the tree sequence instead of hard coded values. This has the benefit of being more readable, communicating intent (review comment from Peter). Split the multiallelic and biallelic test cases, they're getting messy. Now I can explicitly assert that the norm_func is not run for biallelic sites and for branch stats. Also gets rid of awkward assertions about sample sets. --- python/tests/test_python_c.py | 128 ++++++++++++++++++---------------- 1 file changed, 68 insertions(+), 60 deletions(-) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index ccd73ac3d4..9f8822b2f6 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2005,13 +2005,8 @@ def test_ld_matrix_multipop(self, stat_method_name): stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") def test_two_locus_count_stat(self): + """Test two_locus_count_stat on biallelic data (no norm function)""" ts = self.get_example_tree_sequence(10) - # Multiallelic test case to test norm function - ts_multi = self.get_example_tree_sequence_multiallelic() - assert (ts.get_samples() == ts_multi.get_samples()).all(), ( - "biallelic and multiallelic test case are expected " - "to have the same sample nodes" - ) ss = ts.get_samples() # sample sets ss_sizes = np.array([len(ss)], dtype=np.uint32) row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) @@ -2029,37 +2024,33 @@ def stat_func(X, n): pB = paB + pAB return pAB - (pA * pB) - def norm_func(X, n, nA, nB): - return X[0].sum(keepdims=True) / n.sum() + def norm_func(*_): + raise Exception # norm function will not be used - method = ts.two_locus_count_stat # most tests on biallelic + method = ts.two_locus_count_stat site_args = row_sites, col_sites, None, None, "site" branch_args = None, None, row_pos, col_pos, "branch" + # happy path a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) - assert a.shape == (10, 10, 1) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_args) - assert a.shape == (2, 2, 1) + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) + # happy path - sample sets as lists are also valid site_list_args = row_sites_list, col_sites_list, None, None, "site" branch_list_args = None, None, row_pos_list, col_pos_list, "branch" - - # happy path a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) - assert a.shape == (10, 10, 1) # ts has 10 sites + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) - assert a.shape == (2, 2, 1) # ts has 2 trees - a = ts_multi.two_locus_count_stat( + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) + # happy path - default array filling + a = method( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" ) - import platform - - if platform.system() == "Darwin": - assert a.shape == (54, 54, 1) # ts has 54 sites on macos? - else: - assert a.shape == (56, 56, 1) # ts has 56 sites - a = ts_multi.two_locus_count_stat( + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + a = method( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch" ) - assert a.shape == (48, 48, 1) # ts has 48 trees + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) # CPython API errors with pytest.raises(ValueError, match="Sum of sample_set_sizes"): bad_ss = np.array([], dtype=np.int32) @@ -2136,50 +2127,17 @@ def norm_func(X, n, nA, nB): method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args) with pytest.raises(ValueError, match="could not convert string to float"): method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args) - with pytest.raises(ValueError, match="norm function.*must be 1D"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args - ) - with pytest.raises( - TypeError, match="takes 1 positional argument but 2 were given" - ): - ts_multi.two_locus_count_stat( - ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args - ) - with pytest.raises(ValueError, match="norm function.*length 2; must be 1"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args - ) - with pytest.raises( - TypeError, match="takes 1 positional argument but 4 were given" - ): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args - ) - with pytest.raises(ValueError, match="could not convert string to float"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args - ) - # Exceptions within stat_func and norm_func are correctly raised. + # Exceptions within stat_func are correctly raised. for exception in [ValueError, TypeError]: def stat_func_except(*_): raise exception("test") # noqa: B023 - def norm_func_except(*_): - raise exception("test") # noqa: B023 - - with pytest.raises(exception, match="test"): - method( - ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_list_args - ) with pytest.raises(exception, match="test"): - ts_multi.two_locus_count_stat( - ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_list_args - ) + method(ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_args) # C API errors with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_RESULT_DIMS"): - method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_list_args) + method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_args) with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): bad_sites = np.array([1, 0, 2], dtype=np.int32) bad_site_args = bad_sites, col_sites, None, None, "site" @@ -2229,6 +2187,56 @@ def norm_func_except(*_): bad_branch_args = None, None, row_pos, bad_pos, "branch" method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + def test_two_locus_count_stat_multialleliic(self): + """ + Test two_locus_count_stat on multiallelic sites to test the behavior of + the norm function. + """ + ts = self.get_example_tree_sequence_multiallelic() + + def stat_func(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + def norm_func(X, n, nA, nB): + return X[0].sum(keepdims=True) / n.sum() + + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + method = ts.two_locus_count_stat + site_args = row_sites, col_sites, None, None, "site" + + # happy path + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + # CPython API errors + with pytest.raises(ValueError, match="norm function.*must be 1D"): + method(ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args) + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + method(ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="norm function.*length 2; must be 1"): + method(ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args) + with pytest.raises( + TypeError, match="takes 1 positional argument but 4 were given" + ): + method(ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args) + with pytest.raises(ValueError, match="could not convert string to float"): + method(ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args) + # Exceptions within stat_func are correctly raised. + for exception in [ValueError, TypeError]: + + def norm_func_except(*_): + raise exception("test") # noqa: B023 + + with pytest.raises(exception, match="test"): + method(ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_args) + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): From 3f4c0ed1e49357abb303dfd97e7af1e18c7fe623 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 17:42:26 -0500 Subject: [PATCH 19/33] Add/refine tests, draft docstring Clean up dimension handling around summary functions and normalisation. There is a slight speed advantage (according to a microbenchmark) and a huge readability advantage to simply returning [value]. I keep all computations specifying `keepdims`, but remove list indexing (i.e. `AB[[0]]`) in favor of returning a list with a single scalar. It turns out that vectorised numpy functions are actually slower in some cases because the data we're operating on is so small. Finally, fix the default normalisation function so that it works both on one-way and two-way statistics. Users will still need to specify `hap_norm` when appropriate (and a special case of `hap_norm` for two-way stats). Per Peter's comment, I investigated dimension dropping and indeed, general stats don't drop dimensions so I removed the dimension dropping code. However, we return a matrix of `(m, m, k)` and we want `(k, m, m)`, so `np.moveaxis` is still needed. Added tests: * Multiallelic multi sample-set. This tests operations on two sample sets for multiallelic data (which excercises the norm function with multiple sample sets). This test highlighted the slight changes needed to the default normalisation function. * Multi outputs. This test mimics a two-way stat called on multiple indexes. It shows and tests the ability to compute multiple statistics from the same haplotype counts matrix (which is especially useful with the explosion of possible summary functions in three-way, four-way stats). In our biallelic test case, I also assert that the normalisation function is never called and add a note about polarisation. Finally, I add a draft docstring, but to complete this I think that the two-locus docs are required. Also, I'd like to add some general documentation. --- python/tests/test_ld_matrix.py | 123 ++++++++++++++++++++++++++------- python/tskit/trees.py | 76 ++++++++++++++++++-- 2 files changed, 170 insertions(+), 29 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index bb459ee7ca..411149aa35 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2400,16 +2400,19 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex class GeneralStatFuncs: """ - functions take X, n as parameters where + Summary functions take X, n as parameters where X is a matrix of haplotype + counts per sample set and n is a vector of sample set sizes. X has shape (3, k) + and n has shape (k, ), where k is the number of sample sets. The rows of X + contain haplotype counts for AB, Ab, aB (capitalized == derived). - X: shape=(3, #ss) + X: shape=(3, k) sample sets - count AB [[ ] - count Ab [ ] - count aB [ ]] + count AB [[ #ss1, #ss2, ... ] + count Ab [ #ss1, #ss2, ... ] + count aB [ #ss1, #ss2, ... ]] - n: shape=(#ss, ) - [ ] + n: shape=(k, ) + [ #ss1, #ss2, ... ] """ @staticmethod @@ -2480,37 +2483,39 @@ def pi2(X, n): def D2_unbiased(X, n): AB, Ab, aB = X ab = n - X.sum(0) - return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + return ( ((aB**2) * (Ab - 1) * Ab) + ((ab - 1) * ab * (AB - 1) * AB) - (aB * Ab * (Ab + (2 * ab * AB) - 1)) - ) + ) / (n * (n - 1) * (n - 2) * (n - 3)) @staticmethod def Dz_unbiased(X, n): AB, Ab, aB = X ab = n - X.sum(0) - return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + return ( (((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB)) - ((AB * ab) * (AB + ab - Ab - aB - 2)) - ((Ab * aB) * (Ab + aB - AB - ab - 2)) - ) + ) / (n * (n - 1) * (n - 2) * (n - 3)) @staticmethod def pi2_unbiased(X, n): AB, Ab, aB = X ab = n - X.sum(0) - return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + return ( ((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab)) - ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1)) - ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1)) - ) + ) / (n * (n - 1) * (n - 2) * (n - 3)) + # Two-way statistics have the _ij suffix. @staticmethod def r2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB + # keepdims preserves the output shape of (1, ) D2_ij = np.prod(pAB - (pA * pB), keepdims=True) denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)), keepdims=True) with suppress_overflow_div0_warning(): @@ -2525,17 +2530,37 @@ def D2_ij(X, n): @staticmethod def D2_ij_unbiased(X, n): - """NB: We use double brackets here to preserve the output shape of (1,)""" + """The identity of the sample sets is up to the user.""" AB, Ab, aB = X ab = n - X.sum(0) - return ( - (Ab[[0]] * aB[[0]] - AB[[0]] * ab[[0]]) - * (Ab[[1]] * aB[[1]] - AB[[1]] * ab[[1]]) - / n[[0]] - / (n[[0]] - 1) - / n[[1]] - / (n[[1]] - 1) + return [ + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / (n[0] * (n[0] - 1) * n[1] * (n[1] - 1)) + ] + + @staticmethod + def D2_ii_ij_jj_unbiased(X, n): + """ + Multiple stats can be computed from the same data. The identity of the + sample sets is up to the user. This function assumes two sample sets. + """ + AB, Ab, aB = X + ab = n - X.sum(0) + + # unbiased estimator for equal sample sets + ii, jj = ( + AB * (AB - 1) * ab * (ab - 1) + + Ab * (Ab - 1) * aB * (aB - 1) + - 2 * AB * Ab * aB * ab + ) / (n * (n - 1) * (n - 2) * (n - 3)) + # unbiased estimator for disjoint sample sets + ij = ( + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / (n[0] * (n[0] - 1) * n[1] * (n[1] - 1)) ) + return [ii, ij, jj] @pytest.fixture(scope="module") @@ -2573,7 +2598,17 @@ def ts_multiallelic_fixture(): def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] - ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) + + # In addition to not needing a normalisation function, normalisation is also + # not required because these sites are biallelic. + def assert_no_norm_func(*_): + raise Exception( + "Normalisation function should not be called for biallelic sites" + ) + + ldg = ts.two_locus_count_stat( + sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func + ) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) np.testing.assert_array_almost_equal(ldg, ld) @@ -2584,7 +2619,7 @@ def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixtur sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) ld = ts.ld_matrix( - sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) + sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=[(0, 1)] ) np.testing.assert_array_almost_equal(ldg, ld) @@ -2599,7 +2634,24 @@ def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu [ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised ) ld = ts.ld_matrix(stat=stat) - np.testing.assert_array_almost_equal(ld, ldg) + # ld_matrix drops dims, expand for comparison + np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) + + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_two_locus_stat_multiallelic_multi_sample_set( + stat, ts_multiallelic_fixture +): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + polarised = POLARIZATION[SUMMARY_FUNCS[stat]] + sample_sets = [ts.samples(), ts.samples()] + ldg = ts.two_locus_count_stat( + sample_sets, general_func, 2, norm_f=norm_func, polarised=polarised + ) + ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets) + np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) @@ -2616,4 +2668,25 @@ def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu ld = ts.ld_matrix( stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets ) - np.testing.assert_array_almost_equal(ld, ldg) + # ld_matrix drops dims, expand for comparison + np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) + + +def test_general_two_locus_multi_outputs(): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + 4, recombination_rate=0.1, sequence_length=100, random_seed=123 + ), + rate=0.1, + random_seed=123, + ) + assert ts.num_samples == 8, "8 samples are required" + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + A = ts.samples()[0:4] + B = ts.samples()[4:] + + ldg = ts.two_locus_count_stat([A, B], GeneralStatFuncs.D2_ii_ij_jj_unbiased, 3) + ld = ts.ld_matrix([A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)]) + np.testing.assert_array_almost_equal(ldg, ld) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index ee9d4e1114..62ddfc1967 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10942,6 +10942,75 @@ def two_locus_count_stat( positions=None, mode="site", ): + """ + Compute two-locus statistics with a user-defined python function that + operates on haplotype counts. TODO: reference modes in two-locus docs. + On each pair of sites or trees, the summary function is provided with + ``X``, a matrix with shape (3, k) and ``n``, a vector with shape (k,), + where k is the number of sample sets provided. ``X`` is a read-only + matrix whose rows contain haplotype counts per sample set (counts of AB, + Ab, aB) and ``n`` is a vector of sample set sizes. + + .. note:: + Because we are operating on very small matrices/vectors, vectorised + operations are often times slower than operations on scalars. Simply + returning ``[value]`` can be faster than returning + ``value[np.newaxis,]`` or ``np.expand_dims(value, 0)``. + + What follows is an example of computing ``D`` from a tree sequence. Many + more examples can be found in the test suite + ``test_ld_matrix.py::GeneralStatsFuncs``. Let's begin with our summary + function, ``D``. We convert counts to proportions, then compute ``D``, + returning a numpy array with length equal to the number of sample sets. + + .. code-block:: python + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + ``norm_f`` is a normalisation function used to combine all computed + statistics for multiallelic allele pairs (TODO: see two-locus + docs). Biallelic sites do not require any normalisation (in fact, the + normalisation function is never called for biallelic sites). If one of + either site A or site B is multiallelic, then the normalisation function + will be called. The default normalisation function is identical to + ``total_norm`` shown in the example below. ``hap_norm`` is required for + normalising :math:`r^2`. Both of these examples return a numpy array + with length equal to the number of sample sets (for one-way stats). + + .. code-block:: python + def total_norm(X, n, nA, nB): + [1 / (nA * nB)] * result_dim + + def hap_norm(X, n, nA, nB): + X[0] / n + + A simple call (without specifying normalisation) would look like this + + .. code-block::python + ts.two_locus_count_stat([ts.samples()], D, 1, polarised=True) + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of nodes to compute the statistic with. + :param f: A function that takes two arguments - a two-dimensional array + with shape (3, k) and a one-dimensional array with shape (k, ) where + k is the number of sample sets. + :param int result_dim: The length of ``f`` and ``norm_f``'s return value. + :param norm_f: A function that takes four arguments - the first two are + the same as ``f``, the second two are scalars representing the + number of A and B alleles, respectively. + :param bool polarised: Whether to leave the ancestral state out of + computations: see :ref:`sec_stats` for more details. + :param list sites: TODO: two-locus docs + :param list positions: TODO: two-locus docs + :param str mode: A string giving the "type" of the statistic to be + computed (defaults to "site"). + :return: A ndarray with shape equal to (TODO: reference two-locus docs, + no dimension dropping shape=(k, m, m) where k=num_sample_sets, + m=num_sites or num_trees). + """ row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) @@ -10949,7 +11018,8 @@ def two_locus_count_stat( sample_set_sizes, sample_sets, f, - norm_f or (lambda X, n, nA, nB: 1 / (nA * nB)[np.newaxis,]), + # produce the same number of dims as output dimensions + norm_f or (lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim), result_dim, polarised, row_sites, @@ -10958,11 +11028,9 @@ def two_locus_count_stat( col_positions, mode, ) - if result_dim == 1: # drop dimension - return result.reshape(result.shape[:2]) # Orient the data so that the first dimension is the sample set so that # we get one LD matrix per sample set. - return result.swapaxes(0, 2).swapaxes(1, 2) + return np.moveaxis(result, -1, 0) def ld_matrix( self, From 6685399f83ef7da4a83f279530b3e00647b11014 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 17:57:00 -0500 Subject: [PATCH 20/33] regain test coverage for default sample sets --- python/tests/test_ld_matrix.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 411149aa35..abb7eea3cd 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2594,18 +2594,33 @@ def ts_multiallelic_fixture(): return ts +def assert_no_norm_func(*_): + """Used in biallelic tests""" + raise Exception("Normalisation function should not be called for biallelic sites") + + @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): +def test_general_two_locus_site_stat_default_sample_sets( + stat, ts_100_samp_with_sites_fixture +): ts = ts_100_samp_with_sites_fixture - sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] - # In addition to not needing a normalisation function, normalisation is also # not required because these sites are biallelic. - def assert_no_norm_func(*_): - raise Exception( - "Normalisation function should not be called for biallelic sites" - ) + ldg = ts.two_locus_count_stat( + [ts.samples()], getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func + ) + ld = ts.ld_matrix(stat=stat) # use default sample sets + np.testing.assert_array_almost_equal(ldg, ld) + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_two_locus_site_stat_two_sample_sets( + stat, ts_100_samp_with_sites_fixture +): + ts = ts_100_samp_with_sites_fixture + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + # In addition to not needing a normalisation function, normalisation is also + # not required because these sites are biallelic. ldg = ts.two_locus_count_stat( sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func ) From 70cd6b4186afc9321184137953cdeac212f11bb3 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:07:09 -0500 Subject: [PATCH 21/33] Revert "regain test coverage for default sample sets" This reverts commit 6685399f83ef7da4a83f279530b3e00647b11014. --- python/tests/test_ld_matrix.py | 29 +++++++---------------------- 1 file changed, 7 insertions(+), 22 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index abb7eea3cd..411149aa35 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2594,33 +2594,18 @@ def ts_multiallelic_fixture(): return ts -def assert_no_norm_func(*_): - """Used in biallelic tests""" - raise Exception("Normalisation function should not be called for biallelic sites") - - -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat_default_sample_sets( - stat, ts_100_samp_with_sites_fixture -): - ts = ts_100_samp_with_sites_fixture - # In addition to not needing a normalisation function, normalisation is also - # not required because these sites are biallelic. - ldg = ts.two_locus_count_stat( - [ts.samples()], getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func - ) - ld = ts.ld_matrix(stat=stat) # use default sample sets - np.testing.assert_array_almost_equal(ldg, ld) - - @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat_two_sample_sets( - stat, ts_100_samp_with_sites_fixture -): +def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): ts = ts_100_samp_with_sites_fixture sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + # In addition to not needing a normalisation function, normalisation is also # not required because these sites are biallelic. + def assert_no_norm_func(*_): + raise Exception( + "Normalisation function should not be called for biallelic sites" + ) + ldg = ts.two_locus_count_stat( sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func ) From be997cf0e1083c79681666ba3e88e5086f78eef1 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:33:30 -0500 Subject: [PATCH 22/33] update comment about result dimension --- python/tskit/trees.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 62ddfc1967..582e264885 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -11029,7 +11029,7 @@ def hap_norm(X, n, nA, nB): mode, ) # Orient the data so that the first dimension is the sample set so that - # we get one LD matrix per sample set. + # we get one LD matrix per result dimension return np.moveaxis(result, -1, 0) def ld_matrix( From 9481921706d7928e29d0ff812f3754166a900649 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:39:36 -0500 Subject: [PATCH 23/33] be more explicit about setting the default norm function --- python/tskit/trees.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 582e264885..944d9ae909 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -11014,12 +11014,14 @@ def hap_norm(X, n, nA, nB): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) + if norm_f is None: + # produce the same number of dims as output dimensions with [val] * dim + norm_f = lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, sample_sets, f, - # produce the same number of dims as output dimensions - norm_f or (lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim), + norm_f, result_dim, polarised, row_sites, From 7387461ecfccf42db5bba2d4018ec47388c3e9c6 Mon Sep 17 00:00:00 2001 From: lkirk Date: Tue, 17 Mar 2026 18:41:37 -0500 Subject: [PATCH 24/33] linting does not like assigning lambdas to variables --- python/tskit/trees.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 944d9ae909..9af49ddb88 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -11016,7 +11016,9 @@ def hap_norm(X, n, nA, nB): _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) if norm_f is None: # produce the same number of dims as output dimensions with [val] * dim - norm_f = lambda X, n, nA, nB: [1 / (nA * nB)] * result_dim + def norm_f(X, n, nA, nB): + return [1 / (nA * nB)] * result_dim + result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, sample_sets, From f5b1940ebf6ddd60aa998f37e949792e7653740f Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 15:03:11 -0500 Subject: [PATCH 25/33] add an else statement to improve readability (review) --- c/tskit/trees.c | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 0f8a10c182..e567ef1acd 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3440,21 +3440,22 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); goto out; + } else { + tsk_bug_assert(stat_branch); + ret = check_positions( + row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = check_positions( + col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, + row_positions, out_cols, col_positions, options, result); } - tsk_bug_assert(stat_branch); - ret = check_positions( - row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = check_positions( - col_positions, out_cols, tsk_treeseq_get_sequence_length(self)); - if (ret != 0) { - goto out; - } - ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, - row_positions, out_cols, col_positions, options, result); out: return ret; } From 6d8a81c3ee31d08b5fd668334a9bdcb6a291b148 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 16:25:34 -0500 Subject: [PATCH 26/33] add a few more tests (review) Add a test for behavior on empty tree sequences (no samples, no edges, no sites). Add a "no sites" fixture. Include branch stat testing. Tune branch stat test runtime by reducing the size of `ts_100_samp_with_sites_fixture`, now named `ts_10_samp_with_sites_fixture`. Add explicit testing for output dimensions and assert that the norm func is not called on trees with only biallelic sites and in branch mode. Add a GeneralStatNormFuncs class to explicitly document possible normalisation functions and in what situations they will be used. Tune size of tree sequence in `test_general_multi_outputs` so that the test runs in a reasonable amount of time in branch mode. --- python/tests/test_ld_matrix.py | 178 +++++++++++++++++++++++++-------- 1 file changed, 136 insertions(+), 42 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 411149aa35..4a071ab00f 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2563,17 +2563,53 @@ def D2_ii_ij_jj_unbiased(X, n): return [ii, ij, jj] +class GeneralStatNormFuncs: + @staticmethod + def hap_norm(X, n, nA, nB): + """Stat from 1 sample set -> 1 result""" + return X[0] / n + + @staticmethod + def k_way_hap_norm(X, n, nA, nB): + """Stat from k sample sets -> 1 result""" + return X[0].sum(keepdims=True) / n.sum() + + @staticmethod + def assert_no_norm_func(*_): + """Normalisation is not required in branch mode and with biallelic sites.""" + raise Exception("Normalisation function should not be called") + + @classmethod + def choose(cls, stat, mode, ts): + """ + Choose norm function based on stat, mode, presence of multiallelic sites + """ + is_multiallelic = max({len(s.mutations) for s in ts.sites()}) > 1 + match (stat, mode, is_multiallelic): + case ("r2", "site", True): + return cls.hap_norm + case ("r2_ij", "site", True): + return cls.k_way_hap_norm + case (_, "branch", _): # branch stats do not need a norm func + return cls.assert_no_norm_func + case (_, _, False): # biallelic sites do not need a norm func + return cls.assert_no_norm_func + case _: # total_norm is default (1 / (nA * nB)). handles multi-way stats + return None + + @pytest.fixture(scope="module") -def ts_100_samp_with_sites_fixture(): +def ts_10_samp_with_sites_fixture(): ts = tsutil.get_sim_example( - sample_size=100, - sequence_length=32, - recombination_rate=0.5, + sample_size=10, + sequence_length=15, + recombination_rate=0.1, mutation_rate=0.1, seed=123, ) assert ts.num_sites > 0, "sites are required" - assert ts.num_samples == 100, "100 samples are required" + assert ts.num_samples == 10 # Samples directly indexed in tests below + assert max({len(s.mutations) for s in ts.sites()}) == 1, "sites must be biallelic" return ts @@ -2588,50 +2624,109 @@ def ts_multiallelic_fixture(): ) # Need at least 4 samples to test unbiased statistics assert ts.num_samples >= 4, "At least 4 samples required" - assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + assert max({len(s.mutations) for s in ts.sites()}) > 1, ( "At least one multiallelic site required" ) return ts -@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_two_locus_site_stat(stat, ts_100_samp_with_sites_fixture): - ts = ts_100_samp_with_sites_fixture - sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] +@pytest.fixture(scope="module") +def ts_no_sites_fixture(): + ts = msprime.sim_ancestry( + 2, recombination_rate=0.1, sequence_length=100, random_seed=123 + ) + assert ts.num_sites == 0 + return ts + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +@pytest.mark.parametrize( + "ts", + [ts for ts in get_example_tree_sequences() if ts.id in {"no_samples", "empty_ts"}], +) +def test_general_empty_ts(mode, ts): + with pytest.raises(ValueError, match="at least one element"): + ts.two_locus_count_stat([ts.samples()], GeneralStatFuncs.D, 1, mode=mode) + + +def test_general_no_sites(ts_no_sites_fixture): + ts = ts_no_sites_fixture + ldg = ts.two_locus_count_stat([ts.samples()], GeneralStatFuncs.D, 1) + np.testing.assert_array_equal(ldg, np.zeros((1, 0, 0), np.float64)) + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +def test_general_output_dimensions(mode, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + norm_f = GeneralStatNormFuncs.choose("D", mode, ts) + samples = ts.samples() + expected_dims = dict( + site=(1, ts.num_sites, ts.num_sites), branch=(1, ts.num_trees, ts.num_trees) + )[mode] + result = ts.two_locus_count_stat( + samples, GeneralStatFuncs.D, 1, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims + # we expect that dims are the same with `samples` or `[samples]` + result = ts.two_locus_count_stat( + [samples], GeneralStatFuncs.D, 1, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims + + expected_dims = dict( + site=(2, ts.num_sites, ts.num_sites), branch=(2, ts.num_trees, ts.num_trees) + )[mode] + result = ts.two_locus_count_stat( + [samples, samples], GeneralStatFuncs.D, 2, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims - # In addition to not needing a normalisation function, normalisation is also - # not required because these sites are biallelic. - def assert_no_norm_func(*_): - raise Exception( - "Normalisation function should not be called for biallelic sites" - ) +@pytest.mark.parametrize("mode", ["site", "branch"]) +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_multi_sample_set(mode, stat, ts_10_samp_with_sites_fixture): + ts = ts_10_samp_with_sites_fixture + norm_f = GeneralStatNormFuncs.choose(stat, mode, ts) + sample_sets = [ts.samples()[0:5], ts.samples()[5:10]] ldg = ts.two_locus_count_stat( - sample_sets, getattr(GeneralStatFuncs, stat), 2, norm_f=assert_no_norm_func + sample_sets, + getattr(GeneralStatFuncs, stat), + 2, + norm_f=norm_f, + mode=mode, ) - ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) + ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat, mode=mode) np.testing.assert_array_almost_equal(ldg, ld) +@pytest.mark.parametrize("mode", ["site", "branch"]) @pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) -def test_general_two_locus_two_way_site_stat(stat, ts_100_samp_with_sites_fixture): - ts = ts_100_samp_with_sites_fixture - sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] - ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) +def test_general_two_way(mode, stat, ts_10_samp_with_sites_fixture): + ts = ts_10_samp_with_sites_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_f = GeneralStatNormFuncs.choose(stat, mode, ts) + sample_sets = [ts.samples()[0:5], ts.samples()[5:10]] + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_f, mode=mode) ld = ts.ld_matrix( - sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=[(0, 1)] + sample_sets=sample_sets, + stat=stat.replace("_ij", ""), + indexes=[(0, 1)], + mode=mode, ) np.testing.assert_array_almost_equal(ldg, ld) +# NB: multiallelic testing only needed for sites. branches are biallelic. + + @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): +def test_general_one_way_multiallelic(stat, ts_multiallelic_fixture): ts = ts_multiallelic_fixture general_func = getattr(GeneralStatFuncs, stat) - norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) polarised = POLARIZATION[SUMMARY_FUNCS[stat]] ldg = ts.two_locus_count_stat( - [ts.samples()], general_func, 1, norm_f=norm_func, polarised=polarised + [ts.samples()], general_func, 1, norm_f=norm_f, polarised=polarised ) ld = ts.ld_matrix(stat=stat) # ld_matrix drops dims, expand for comparison @@ -2639,32 +2734,26 @@ def test_general_one_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu @pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) -def test_general_one_way_two_locus_stat_multiallelic_multi_sample_set( - stat, ts_multiallelic_fixture -): +def test_general_one_way_multiallelic_multi_sample_set(stat, ts_multiallelic_fixture): ts = ts_multiallelic_fixture general_func = getattr(GeneralStatFuncs, stat) - norm_func = (lambda X, n, nA, nB: X[0] / n) if stat == "r2" else None + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) polarised = POLARIZATION[SUMMARY_FUNCS[stat]] sample_sets = [ts.samples(), ts.samples()] ldg = ts.two_locus_count_stat( - sample_sets, general_func, 2, norm_f=norm_func, polarised=polarised + sample_sets, general_func, 2, norm_f=norm_f, polarised=polarised ) ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets) np.testing.assert_array_almost_equal(ldg, ld) @pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) -def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixture): +def test_general_two_way_multiallelic(stat, ts_multiallelic_fixture): ts = ts_multiallelic_fixture general_func = getattr(GeneralStatFuncs, stat) - norm_func = ( - (lambda X, n, nA, nB: X[0].sum(keepdims=True) / n.sum()) - if stat == "r2_ij" - else None - ) + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) sample_sets = [ts.samples(), ts.samples()] - ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_func) + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_f) ld = ts.ld_matrix( stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets ) @@ -2672,10 +2761,11 @@ def test_general_two_way_two_locus_stat_multiallelic(stat, ts_multiallelic_fixtu np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) -def test_general_two_locus_multi_outputs(): +@pytest.mark.parametrize("mode", ["site", "branch"]) +def test_general_multi_outputs(mode): ts = msprime.sim_mutations( msprime.sim_ancestry( - 4, recombination_rate=0.1, sequence_length=100, random_seed=123 + 4, recombination_rate=0.1, sequence_length=35, random_seed=123 ), rate=0.1, random_seed=123, @@ -2687,6 +2777,10 @@ def test_general_two_locus_multi_outputs(): A = ts.samples()[0:4] B = ts.samples()[4:] - ldg = ts.two_locus_count_stat([A, B], GeneralStatFuncs.D2_ii_ij_jj_unbiased, 3) - ld = ts.ld_matrix([A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)]) + norm_f = GeneralStatNormFuncs.choose("D2_unbiased", mode, ts) + general_func = GeneralStatFuncs.D2_ii_ij_jj_unbiased + ldg = ts.two_locus_count_stat([A, B], general_func, 3, mode=mode, norm_f=norm_f) + ld = ts.ld_matrix( + [A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)], mode=mode + ) np.testing.assert_array_almost_equal(ldg, ld) From f854aa12d1735f699bddb4d6d4399c5bdbbd4576 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 17:13:01 -0500 Subject: [PATCH 27/33] Add some minimal documentation about the purpose of the two entrypoints --- c/tskit/trees.c | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index e567ef1acd..515162c157 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3384,6 +3384,7 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s return ret; } +/* Called directly by C python interface `two_locus_count_stat` */ int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, @@ -3439,7 +3440,6 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); - goto out; } else { tsk_bug_assert(stat_branch); ret = check_positions( @@ -3460,6 +3460,7 @@ tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, return ret; } +/* Called by summary functions implemented in C */ int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, From 8c49863252d30fb000cfc6d8d1a1c95f2c6ef5c2 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 17:13:41 -0500 Subject: [PATCH 28/33] Test explicitly that our internal data is read only --- python/tests/test_python_c.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 9f8822b2f6..3f7546e4c8 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2127,6 +2127,13 @@ def norm_func(*_): method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args) with pytest.raises(ValueError, match="could not convert string to float"): method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="assignment destination is read-only"): + + def bad_stat_func(X, n): + X[0] = [1] + return [1] + + method(ss_sizes, ss, bad_stat_func, norm_func, 1, True, *site_args) # Exceptions within stat_func are correctly raised. for exception in [ValueError, TypeError]: From 37b14209c10b55ea4a2d9b7a7ffe1144ae8afa29 Mon Sep 17 00:00:00 2001 From: lkirk Date: Fri, 20 Mar 2026 17:20:21 -0500 Subject: [PATCH 29/33] Fix memory leak; more readonly arrays The transpose operation was creating intermediate data that was not being garbage collected, resulting in a rather obvious memory leak. To mitigate this, I opt to wrap the data in a numpy array that is already transposed. The original data is natively laid out with shape (K,3), by creating a numpy array with shape (3,K) and strides (8,8*K), we can avoid an intermediate transpose operation altogether. After leak-checking again, the memory leak is gone. I also add a Py_XDECREF to remove the reference to `norm_func`, though in my leak checking I don't actually see any difference in RSS or heap size. Finally, mark a few more arrays as read-only, since the C functions that accept them as input annotate these arrays as `const`. --- python/_tskitmodule.c | 48 +++++++++++++++++++------------------------ 1 file changed, 21 insertions(+), 27 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index c772270de6..fbb148097f 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7916,6 +7916,7 @@ parse_sites(TreeSequence *self, PyObject *sites, npy_intp *out_dim) if (array == NULL) { goto out; } + PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE); *out_dim = PyArray_DIM(array, 0); } @@ -7940,6 +7941,7 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) if (array == NULL) { goto out; } + PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE); *out_dim = PyArray_DIM(array, 0); } out: @@ -7966,17 +7968,13 @@ general_two_locus_norm_func(tsk_size_t K, const double *X, tsk_size_t result_dim two_locus_general_stat_params *tl_params = params; PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - npy_intp X_dims[2] = { K, 3 }; + npy_intp X_dims[2] = { 3, K }; + npy_intp X_strides[2] = { sizeof(double), sizeof(double) * 3 }; - // Create a read only view of X as a numpy array - X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - 2, X_dims, NPY_FLOAT64, (void *) X); - if (X_array == NULL) { - goto out; - } - PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); - // Transpose into column arrays, so that we can easily decompose the results - X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + // Create a read only view of X as a numpy array. X is transposed from its + // native memory layout (K, 3) -> (3, K). More detailed comment below. + X_array = (PyArrayObject *) PyArray_New(&PyArray_Type, 2, X_dims, NPY_FLOAT64, + X_strides, (void *) X, -1, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL); if (X_array == NULL) { goto out; } @@ -8041,21 +8039,17 @@ general_two_locus_count_stat_func( two_locus_general_stat_params *tl_params = params; PyObject *summary_func = tl_params->summary_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - npy_intp X_dims[2] = { K, 3 }; - - // Create a read only view of X as a numpy array - X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - 2, X_dims, NPY_FLOAT64, (void *) X); - if (X_array == NULL) { - goto out; - } - PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); - // Transpose into column arrays, so that we can easily decompose the results - // For example: pAB, pAb, paB = X / n - // which works with K>1. In addition, the data is not reordered, meaning - // that the data is still oriented where samples are rows, meaning that - // we'll preserve data locality in ops over samples. - X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + npy_intp X_dims[2] = { 3, K }; + npy_intp X_strides[2] = { sizeof(double), sizeof(double) * 3 }; + + // Create a transposed, read only view of X as a numpy array. The native + // memory layout of X is (K, 3), we wrap it in a numpy array with dimensions + // (3, K), creating row arrays of haplotype counts so that we can easily + // decompose the results. For example: `pAB, pAb, paB = X / n` which works + // with K>1. Itemsize is -1 because we specify the dtype. NB: we do not set + // NPY_ARRAY_WRITEABLE, so X_array is read only. + X_array = (PyArrayObject *) PyArray_New(&PyArray_Type, 2, X_dims, NPY_FLOAT64, + X_strides, (void *) X, -1, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL); if (X_array == NULL) { goto out; } @@ -8074,8 +8068,7 @@ general_two_locus_count_stat_func( } if (PyArray_NDIM(Y_array) != 1) { PyErr_Format(PyExc_ValueError, - "Array returned by summary function callback is %d dimensional; " - "must be 1D", + "Array returned by summary function callback is %d dimensional; must be 1D", (int) PyArray_NDIM(Y_array)); goto out; } @@ -8220,6 +8213,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * result_matrix = NULL; out: Py_XDECREF(summary_func); + Py_XDECREF(norm_func); Py_XDECREF(row_sites_array); Py_XDECREF(col_sites_array); Py_XDECREF(row_positions_array); From 279cc72ad6562efcea1b0f953ebe6bedf8c2cf53 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sat, 21 Mar 2026 11:11:06 -0500 Subject: [PATCH 30/33] Return on summary function error (bug) `compute_two_tree_branch_stat` did not check the error returned by the summary function (which is the return value of `compute_two_tree_branch_state_update`.). I caught this because the python callback was setting an exception in the summary function and the python runtime was complaining about an exception being set, despite a successful return status. This also means that failing summary functions would (eventually) be caught, but the code would continue to run. The C python tests did not catch this because we would eventually raise the correct exception. --- c/tskit/trees.c | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 515162c157..3be4073971 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -3188,8 +3188,11 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + ret = compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, &work, result); + if (ret != 0) { + goto out; + } } // Remove samples under nodes from removed edges to parent nodes for (j = 0; j < r_state->n_edges_out; j++) { @@ -3229,8 +3232,11 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + ret = compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, &work, result); + if (ret != 0) { + goto out; + } } out: tsk_safe_free(updated_nodes); From 7378e42724a5db53950f1dc79c67026feeb2ed57 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 22 Mar 2026 15:15:29 -0500 Subject: [PATCH 31/33] Incorporate Peter's improvement to the test comments --- python/tests/test_python_c.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 3f7546e4c8..09370a9a78 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -2042,7 +2042,7 @@ def norm_func(*_): assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) - # happy path - default array filling + # happy path - default values for site and position lists a = method( ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" ) From 2fa87cd07765e37f8db927940c7a82b20bda822c Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 22 Mar 2026 16:27:57 -0500 Subject: [PATCH 32/33] Update docstring (feedback from Peter) Provide more precise requirements for `f` and `norm_f` and give some basic understanding of what these functions are and when normalisation will be required. Attempt to fix syntax errors by adding a newline in code blocks. Clarify output dimensions in the code comments (though this might need to change since I think we'll remove the `np.moveaxis` call at the end. --- python/tskit/trees.py | 67 ++++++++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9af49ddb88..b379e7d2ad 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10945,42 +10945,47 @@ def two_locus_count_stat( """ Compute two-locus statistics with a user-defined python function that operates on haplotype counts. TODO: reference modes in two-locus docs. - On each pair of sites or trees, the summary function is provided with - ``X``, a matrix with shape (3, k) and ``n``, a vector with shape (k,), - where k is the number of sample sets provided. ``X`` is a read-only - matrix whose rows contain haplotype counts per sample set (counts of AB, - Ab, aB) and ``n`` is a vector of sample set sizes. - - .. note:: - Because we are operating on very small matrices/vectors, vectorised - operations are often times slower than operations on scalars. Simply - returning ``[value]`` can be faster than returning - ``value[np.newaxis,]`` or ``np.expand_dims(value, 0)``. - - What follows is an example of computing ``D`` from a tree sequence. Many - more examples can be found in the test suite - ``test_ld_matrix.py::GeneralStatsFuncs``. Let's begin with our summary - function, ``D``. We convert counts to proportions, then compute ``D``, - returning a numpy array with length equal to the number of sample sets. + On each pair of sites or trees, the summary function is called with + haplotype counts for all provided sample sets. The summary function + (``f``) must accept two parameters: ``X``, a matrix with shape (3, k) + and ``n``, a vector with shape (k,), where k is the number of sample + sets provided. ``X`` is a read-only matrix whose rows contain haplotype + counts (AB, Ab, aB) per sample set and ``n`` is a vector of sample set + sizes. ``f`` must return a list of results with length ``result_dim``. + + What follows is an example of computing ``D`` from a tree sequence + (TODO: cite two-locus docs for more details). We convert counts to + proportions, then compute ``D``, returning a numpy array with length + equal to the number of ``result_dim``s. .. code-block:: python + def D(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB return pAB - (pA * pB) - ``norm_f`` is a normalisation function used to combine all computed - statistics for multiallelic allele pairs (TODO: see two-locus - docs). Biallelic sites do not require any normalisation (in fact, the - normalisation function is never called for biallelic sites). If one of - either site A or site B is multiallelic, then the normalisation function - will be called. The default normalisation function is identical to - ``total_norm`` shown in the example below. ``hap_norm`` is required for - normalising :math:`r^2`. Both of these examples return a numpy array - with length equal to the number of sample sets (for one-way stats). + The summary function is called for each pair of sites or trees, + producing results that must be combined when multiallelic sites are + present (``site`` mode only), summary function results must + need to be normalised in order to be aggragated for all pairs of alleles + between both sites. Branch statistics and biallelic sites do not require + any normalisation, ``norm_f`` is only called if one of the two sites + under consideration is multiallelic. TODO: reference two-locus docs for + further information about normalisation. ``norm_f`` is a normalisation + function that must accept four parameters: ``X`` and ``n`` are the same + inputs that ``f`` accepts, along with ``nA`` and ``nB``, which hold the + count of ``A`` alleles and ``B`` alleles. For example, if ``A`` is + biallelic and ``B`` is triallelic, ``nA=2`` and ``nB=3``. ``f`` must + return a list of results with length ``result_dim``. The default + normalisation function is identical to ``total_norm`` shown in the + example below. ``hap_norm`` is required for normalising + :math:`r^2`. Both of these examples return a numpy array with length + equal to the number of ``result_dim``s. .. code-block:: python + def total_norm(X, n, nA, nB): [1 / (nA * nB)] * result_dim @@ -10990,6 +10995,7 @@ def hap_norm(X, n, nA, nB): A simple call (without specifying normalisation) would look like this .. code-block::python + ts.two_locus_count_stat([ts.samples()], D, 1, polarised=True) :param list sample_sets: A list of lists of Node IDs, specifying the @@ -11000,7 +11006,8 @@ def hap_norm(X, n, nA, nB): :param int result_dim: The length of ``f`` and ``norm_f``'s return value. :param norm_f: A function that takes four arguments - the first two are the same as ``f``, the second two are scalars representing the - number of A and B alleles, respectively. + number of A and B alleles, respectively. If ``None``, then defaults + to the "total" normalization described above. :param bool polarised: Whether to leave the ancestral state out of computations: see :ref:`sec_stats` for more details. :param list sites: TODO: two-locus docs @@ -11008,14 +11015,14 @@ def hap_norm(X, n, nA, nB): :param str mode: A string giving the "type" of the statistic to be computed (defaults to "site"). :return: A ndarray with shape equal to (TODO: reference two-locus docs, - no dimension dropping shape=(k, m, m) where k=num_sample_sets, + no dimension dropping shape=(k, m, m) where k=result_dim, m=num_sites or num_trees). """ row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) if norm_f is None: - # produce the same number of dims as output dimensions with [val] * dim + # produce the same number of dims as result dimensions with [val] * dim def norm_f(X, n, nA, nB): return [1 / (nA * nB)] * result_dim @@ -11032,7 +11039,7 @@ def norm_f(X, n, nA, nB): col_positions, mode, ) - # Orient the data so that the first dimension is the sample set so that + # Orient the data so that the first dimension is the result_dim so that # we get one LD matrix per result dimension return np.moveaxis(result, -1, 0) From 824faaa650acbf7ad6b05aa40b2e93dd966338b4 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 22 Mar 2026 17:48:49 -0500 Subject: [PATCH 33/33] turns out the documentation build doesn't like ``result_dim``s changing to ``result_dim`` fixes the docs. --- python/tskit/trees.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index b379e7d2ad..bb69d70c3a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10956,7 +10956,7 @@ def two_locus_count_stat( What follows is an example of computing ``D`` from a tree sequence (TODO: cite two-locus docs for more details). We convert counts to proportions, then compute ``D``, returning a numpy array with length - equal to the number of ``result_dim``s. + equal to the number of ``result_dim``. .. code-block:: python @@ -10982,7 +10982,7 @@ def D(X, n): normalisation function is identical to ``total_norm`` shown in the example below. ``hap_norm`` is required for normalising :math:`r^2`. Both of these examples return a numpy array with length - equal to the number of ``result_dim``s. + equal to the number of ``result_dim``. .. code-block:: python