diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 1aa06e5b03..3be4073971 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2298,8 +2298,9 @@ get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset, } static int -norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2315,8 +2316,9 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, } static int -norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted_ij(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2341,8 +2343,9 @@ norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, } static int -norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), - tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) +norm_total_weighted(tsk_size_t TSK_UNUSED(state_dim), + const double *TSK_UNUSED(hap_weights), tsk_size_t result_dim, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; double norm = 1 / (double) (n_a * n_b); @@ -2411,8 +2414,8 @@ static int compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim, - tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, - norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) + tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f, + bool polarised, two_locus_work_t *restrict work, double *result) { int ret = 0; // Sample sets and b sites are rows, a sites are columns @@ -2445,7 +2448,7 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, if (ret != 0) { goto out; } - ret = norm_f(result_dim, weights, num_a_alleles - is_polarised, + ret = norm_f(state_dim, weights, result_dim, num_a_alleles - is_polarised, num_b_alleles - is_polarised, norm, f_params); if (ret != 0) { goto out; @@ -2463,9 +2466,8 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, static int compute_general_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, - tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; tsk_size_t k; @@ -2653,9 +2655,8 @@ static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, - const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + void *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites, + tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { int ret = 0; tsk_bitset_t allele_samples, allele_sample_sets; @@ -3089,9 +3090,8 @@ advance_collect_edges(iter_state *s, tsk_id_t index) static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, - tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t result_dim, int sign, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; double a_len, b_len; @@ -3141,8 +3141,8 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, static int compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, - iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, - tsk_size_t result_dim, tsk_size_t state_dim, double *result) + iter_state *r_state, general_stat_func_t *f, void *f_params, tsk_size_t result_dim, + tsk_size_t state_dim, double *result) { int ret = 0; tsk_id_t e, c, ec, p, *updated_nodes = NULL; @@ -3188,8 +3188,11 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + ret = compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, &work, result); + if (ret != 0) { + goto out; + } } // Remove samples under nodes from removed edges to parent nodes for (j = 0; j < r_state->n_edges_out; j++) { @@ -3229,8 +3232,11 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + ret = compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, &work, result); + if (ret != 0) { + goto out; + } } out: tsk_safe_free(updated_nodes); @@ -3243,9 +3249,9 @@ static int tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), - tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, - const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) + void *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows, + const double *row_positions, tsk_size_t n_cols, const double *col_positions, + tsk_flags_t TSK_UNUSED(options), double *result) { int ret = 0; int r, c; @@ -3384,11 +3390,12 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s return ret; } +/* Called directly by C python interface `two_locus_count_stat` */ int -tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, +tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { @@ -3398,10 +3405,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); tsk_size_t state_dim = num_sample_sets; - sample_count_stat_params_t f_params = { .sample_sets = sample_sets, - .num_sample_sets = num_sample_sets, - .sample_set_sizes = sample_set_sizes, - .set_indexes = set_indexes }; // We do not support two-locus node stats if (!!(options & TSK_STAT_NODE)) { @@ -3441,9 +3444,10 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); - } else if (stat_branch) { + } else { + tsk_bug_assert(stat_branch); ret = check_positions( row_positions, out_rows, tsk_treeseq_get_sequence_length(self)); if (ret != 0) { @@ -3455,13 +3459,31 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_positions, out_cols, col_positions, options, result); } out: return ret; } +/* Called by summary functions implemented in C */ +int +tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) +{ + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; + return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_sites, row_positions, out_cols, col_sites, col_positions, options, result); +} + /*********************************** * Allele frequency spectrum ***********************************/ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 84480ed96e..2bf1a26cc9 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1036,8 +1036,8 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, - tsk_size_t n_b, double *result, void *params); +typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, + tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params); int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, @@ -1120,6 +1120,13 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 0e0c1c5ed5..fbb148097f 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7916,6 +7916,7 @@ parse_sites(TreeSequence *self, PyObject *sites, npy_intp *out_dim) if (array == NULL) { goto out; } + PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE); *out_dim = PyArray_DIM(array, 0); } @@ -7940,12 +7941,289 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) if (array == NULL) { goto out; } + PyArray_CLEARFLAGS(array, NPY_ARRAY_WRITEABLE); *out_dim = PyArray_DIM(array, 0); } out: return array; } +typedef struct { + PyArrayObject *sample_set_sizes; + PyObject *summary_func; + PyObject *norm_func; +} two_locus_general_stat_params; + +static int +general_two_locus_norm_func(tsk_size_t K, const double *X, tsk_size_t result_dim, + tsk_size_t n_a, tsk_size_t n_b, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *n_a_scalar = NULL; + PyArrayObject *n_b_scalar = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + two_locus_general_stat_params *tl_params = params; + PyObject *summary_func = tl_params->norm_func; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + npy_intp X_dims[2] = { 3, K }; + npy_intp X_strides[2] = { sizeof(double), sizeof(double) * 3 }; + + // Create a read only view of X as a numpy array. X is transposed from its + // native memory layout (K, 3) -> (3, K). More detailed comment below. + X_array = (PyArrayObject *) PyArray_New(&PyArray_Type, 2, X_dims, NPY_FLOAT64, + X_strides, (void *) X, -1, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL); + if (X_array == NULL) { + goto out; + } + n_a_scalar + = (PyArrayObject *) PyArray_Scalar(&n_a, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_a_scalar == NULL) { + goto out; + } + n_b_scalar + = (PyArrayObject *) PyArray_Scalar(&n_b, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_b_scalar == NULL) { + goto out; + } + arglist = Py_BuildValue("OOOO", X_array, ss_sizes, n_a_scalar, n_b_scalar); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(summary_func, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + Py_XDECREF(n_a_scalar); + Py_XDECREF(n_b_scalar); + return ret; +} + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t result_dim, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + two_locus_general_stat_params *tl_params = params; + PyObject *summary_func = tl_params->summary_func; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + npy_intp X_dims[2] = { 3, K }; + npy_intp X_strides[2] = { sizeof(double), sizeof(double) * 3 }; + + // Create a transposed, read only view of X as a numpy array. The native + // memory layout of X is (K, 3), we wrap it in a numpy array with dimensions + // (3, K), creating row arrays of haplotype counts so that we can easily + // decompose the results. For example: `pAB, pAb, paB = X / n` which works + // with K>1. Itemsize is -1 because we specify the dtype. NB: we do not set + // NPY_ARRAY_WRITEABLE, so X_array is read only. + X_array = (PyArrayObject *) PyArray_New(&PyArray_Type, 2, X_dims, NPY_FLOAT64, + X_strides, (void *) X, -1, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_ALIGNED, NULL); + if (X_array == NULL) { + goto out; + } + arglist = Py_BuildValue("OO", X_array, ss_sizes); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(summary_func, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by summary function callback is %d dimensional; must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { + PyErr_Format(PyExc_ValueError, + "Array returned by summary function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + return ret; +} + +static PyObject * +TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", + "norm_func", "output_dim", "polarised", "row_sites", "col_sites", + "row_positions", "column_positions", "mode", NULL }; + two_locus_general_stat_params *params; + PyObject *summary_func = NULL; + PyObject *norm_func = NULL; + unsigned int output_dim; + PyObject *sample_set_sizes = NULL; + PyObject *sample_sets = NULL; + PyObject *row_sites = NULL; + PyObject *col_sites = NULL; + PyObject *row_positions = NULL; + PyObject *col_positions = NULL; + char *mode = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *row_sites_array = NULL; + PyArrayObject *col_sites_array = NULL; + PyArrayObject *row_positions_array = NULL; + PyArrayObject *col_positions_array = NULL; + PyArrayObject *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL; + tsk_id_t *col_sites_parsed = NULL; + double *row_positions_parsed = NULL; + double *col_positions_parsed = NULL; + npy_intp result_dim[3] = { 0, 0, 0 }; + tsk_size_t num_sample_sets; + tsk_flags_t options = 0; + int polarised = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &norm_func, &output_dim, + &polarised, &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + Py_XINCREF(summary_func); + Py_XINCREF(norm_func); + goto out; + } + Py_INCREF(summary_func); + Py_INCREF(norm_func); + if (!PyCallable_Check(summary_func)) { + PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); + goto out; + } + if (!PyCallable_Check(norm_func)) { + PyErr_SetString(PyExc_TypeError, "norm_func must be callable"); + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + PyArray_CLEARFLAGS(sample_set_sizes_array, NPY_ARRAY_WRITEABLE); + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); + } + + result_dim[2] = output_dim; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + + params = &(two_locus_general_stat_params) { + .sample_set_sizes = sample_set_sizes_array, + .summary_func = summary_func, + .norm_func = norm_func, + }; + err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + output_dim, general_two_locus_count_stat_func, params, + general_two_locus_norm_func, result_dim[0], row_sites_parsed, + row_positions_parsed, result_dim[1], col_sites_parsed, col_positions_parsed, + options, PyArray_DATA(result_matrix)); + + if (err == TSK_PYTHON_CALLBACK_ERROR) { + goto out; + } else if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(summary_func); + Py_XDECREF(norm_func); + Py_XDECREF(row_sites_array); + Py_XDECREF(col_sites_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_matrix); + return ret; +} + static PyObject * TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_count_stat_method *method) @@ -8831,6 +9109,11 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_general_stat, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Runs the general stats algorithm for a given summary function." }, + { .ml_name = "two_locus_count_stat", + .ml_meth = (PyCFunction) TreeSequence_two_locus_count_stat, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc + = "Runs the general two locus stats algorithm for a given summary function." }, { .ml_name = "diversity", .ml_meth = (PyCFunction) TreeSequence_diversity, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 4d6e47ddcc..4a071ab00f 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2063,38 +2063,36 @@ def compute_branch_stat( ts for ts in get_example_tree_sequences() if ts.id - not in { - "no_samples", - "empty_ts", - # We must skip these cases so that tests run in a reasonable - # amount of time. To get more complete testing, these filters - # can be commented out. (runtime ~1hr) - "gap_0", - "gap_0.1", - "gap_0.5", - "gap_0.75", - "n=2_m=32_rho=0", - "n=10_m=1_rho=0", - "n=10_m=1_rho=0.1", - "n=10_m=2_rho=0", - "n=10_m=2_rho=0.1", - "n=10_m=32_rho=0", - "n=10_m=32_rho=0.1", - "n=10_m=32_rho=0.5", + in { + # We run only these cases so that tests run in a reasonable + # amount of time. All examples takes ~1hr. + "decapitate_recomb", + "gap_at_end", + "all_nodes_samples", + "internal_nodes_samples", + "mixed_internal_leaf_samples", + "bottleneck_n=3_mutated", + "bottleneck_n=10_mutated", + "rev_node_order", + "empty_tree", + "n=3_m=2_rho=0.5", + "n=3_m=32_rho=0", + "n=3_m=32_rho=0.1", + "n=2_m=1_rho=0", + "n=2_m=1_rho=0.1", + "n=2_m=1_rho=0.5", + "n=2_m=2_rho=0", + "n=2_m=2_rho=0.1", + "n=2_m=2_rho=0.5", + "n=2_m=32_rho=0.1", + "n=2_m=32_rho=0.5", + "n=3_m=1_rho=0", + "n=3_m=1_rho=0.5", + "n=3_m=2_rho=0", + "n=10_m=1_rho=0.5", + "n=10_m=2_rho=0.5", # we keep one n=100 case to ensure bit arrays are working - "n=100_m=1_rho=0.1", - "n=100_m=1_rho=0.5", - "n=100_m=2_rho=0", - "n=100_m=2_rho=0.1", - "n=100_m=2_rho=0.5", - "n=100_m=32_rho=0", - "n=100_m=32_rho=0.1", - "n=100_m=32_rho=0.5", - "all_fields", - "back_mutations", - "multichar", - "multichar_no_metadata", - "bottleneck_n=100_mutated", + "n=100_m=1_rho=0", } ], ) @@ -2398,3 +2396,391 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex norm_hap_weighted_ij(1, state, max(a) + 1, max(b) + 1, norm[i, j], params) np.testing.assert_allclose((result * norm).sum(), expected) + + +class GeneralStatFuncs: + """ + Summary functions take X, n as parameters where X is a matrix of haplotype + counts per sample set and n is a vector of sample set sizes. X has shape (3, k) + and n has shape (k, ), where k is the number of sample sets. The rows of X + contain haplotype counts for AB, Ab, aB (capitalized == derived). + + X: shape=(3, k) + sample sets + count AB [[ #ss1, #ss2, ... ] + count Ab [ #ss1, #ss2, ... ] + count aB [ #ss1, #ss2, ... ]] + + n: shape=(k, ) + [ #ss1, #ss2, ... ] + """ + + @staticmethod + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + @staticmethod + def D2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return (pAB - (pA * pB)) ** 2 + + @staticmethod + def r2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D**2 / denom + + @staticmethod + def r(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D / np.sqrt(denom) + + @staticmethod + def D_prime(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = np.vstack( + [ + np.min([pA * (1 - pB), (1 - pA) * pB], axis=0), + np.min([pA * pB, (1 - pA) * (1 - pB)], axis=0), + ] + ) + with suppress_overflow_div0_warning(): + return D / denom[(D < 0).astype(int), range(len(D))] + + @staticmethod + def Dz(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return D * (1 - 2 * pA) * (1 - 2 * pB) + + @staticmethod + def pi2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pA * (1 - pA) * pB * (1 - pB) + + @staticmethod + def D2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return ( + ((aB**2) * (Ab - 1) * Ab) + + ((ab - 1) * ab * (AB - 1) * AB) + - (aB * Ab * (Ab + (2 * ab * AB) - 1)) + ) / (n * (n - 1) * (n - 2) * (n - 3)) + + @staticmethod + def Dz_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return ( + (((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB)) + - ((AB * ab) * (AB + ab - Ab - aB - 2)) + - ((Ab * aB) * (Ab + aB - AB - ab - 2)) + ) / (n * (n - 1) * (n - 2) * (n - 3)) + + @staticmethod + def pi2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return ( + ((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab)) + - ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1)) + - ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1)) + ) / (n * (n - 1) * (n - 2) * (n - 3)) + + # Two-way statistics have the _ij suffix. + @staticmethod + def r2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + # keepdims preserves the output shape of (1, ) + D2_ij = np.prod(pAB - (pA * pB), keepdims=True) + denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB)), keepdims=True) + with suppress_overflow_div0_warning(): + return D2_ij / denom + + @staticmethod + def D2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return np.prod(pAB - (pA * pB), keepdims=True) + + @staticmethod + def D2_ij_unbiased(X, n): + """The identity of the sample sets is up to the user.""" + AB, Ab, aB = X + ab = n - X.sum(0) + return [ + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / (n[0] * (n[0] - 1) * n[1] * (n[1] - 1)) + ] + + @staticmethod + def D2_ii_ij_jj_unbiased(X, n): + """ + Multiple stats can be computed from the same data. The identity of the + sample sets is up to the user. This function assumes two sample sets. + """ + AB, Ab, aB = X + ab = n - X.sum(0) + + # unbiased estimator for equal sample sets + ii, jj = ( + AB * (AB - 1) * ab * (ab - 1) + + Ab * (Ab - 1) * aB * (aB - 1) + - 2 * AB * Ab * aB * ab + ) / (n * (n - 1) * (n - 2) * (n - 3)) + # unbiased estimator for disjoint sample sets + ij = ( + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / (n[0] * (n[0] - 1) * n[1] * (n[1] - 1)) + ) + return [ii, ij, jj] + + +class GeneralStatNormFuncs: + @staticmethod + def hap_norm(X, n, nA, nB): + """Stat from 1 sample set -> 1 result""" + return X[0] / n + + @staticmethod + def k_way_hap_norm(X, n, nA, nB): + """Stat from k sample sets -> 1 result""" + return X[0].sum(keepdims=True) / n.sum() + + @staticmethod + def assert_no_norm_func(*_): + """Normalisation is not required in branch mode and with biallelic sites.""" + raise Exception("Normalisation function should not be called") + + @classmethod + def choose(cls, stat, mode, ts): + """ + Choose norm function based on stat, mode, presence of multiallelic sites + """ + is_multiallelic = max({len(s.mutations) for s in ts.sites()}) > 1 + match (stat, mode, is_multiallelic): + case ("r2", "site", True): + return cls.hap_norm + case ("r2_ij", "site", True): + return cls.k_way_hap_norm + case (_, "branch", _): # branch stats do not need a norm func + return cls.assert_no_norm_func + case (_, _, False): # biallelic sites do not need a norm func + return cls.assert_no_norm_func + case _: # total_norm is default (1 / (nA * nB)). handles multi-way stats + return None + + +@pytest.fixture(scope="module") +def ts_10_samp_with_sites_fixture(): + ts = tsutil.get_sim_example( + sample_size=10, + sequence_length=15, + recombination_rate=0.1, + mutation_rate=0.1, + seed=123, + ) + assert ts.num_sites > 0, "sites are required" + assert ts.num_samples == 10 # Samples directly indexed in tests below + assert max({len(s.mutations) for s in ts.sites()}) == 1, "sites must be biallelic" + return ts + + +@pytest.fixture(scope="module") +def ts_multiallelic_fixture(): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + 2, recombination_rate=0.1, sequence_length=100, random_seed=123 + ), + rate=0.1, + random_seed=123, + ) + # Need at least 4 samples to test unbiased statistics + assert ts.num_samples >= 4, "At least 4 samples required" + assert max({len(s.mutations) for s in ts.sites()}) > 1, ( + "At least one multiallelic site required" + ) + return ts + + +@pytest.fixture(scope="module") +def ts_no_sites_fixture(): + ts = msprime.sim_ancestry( + 2, recombination_rate=0.1, sequence_length=100, random_seed=123 + ) + assert ts.num_sites == 0 + return ts + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +@pytest.mark.parametrize( + "ts", + [ts for ts in get_example_tree_sequences() if ts.id in {"no_samples", "empty_ts"}], +) +def test_general_empty_ts(mode, ts): + with pytest.raises(ValueError, match="at least one element"): + ts.two_locus_count_stat([ts.samples()], GeneralStatFuncs.D, 1, mode=mode) + + +def test_general_no_sites(ts_no_sites_fixture): + ts = ts_no_sites_fixture + ldg = ts.two_locus_count_stat([ts.samples()], GeneralStatFuncs.D, 1) + np.testing.assert_array_equal(ldg, np.zeros((1, 0, 0), np.float64)) + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +def test_general_output_dimensions(mode, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + norm_f = GeneralStatNormFuncs.choose("D", mode, ts) + samples = ts.samples() + expected_dims = dict( + site=(1, ts.num_sites, ts.num_sites), branch=(1, ts.num_trees, ts.num_trees) + )[mode] + result = ts.two_locus_count_stat( + samples, GeneralStatFuncs.D, 1, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims + # we expect that dims are the same with `samples` or `[samples]` + result = ts.two_locus_count_stat( + [samples], GeneralStatFuncs.D, 1, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims + + expected_dims = dict( + site=(2, ts.num_sites, ts.num_sites), branch=(2, ts.num_trees, ts.num_trees) + )[mode] + result = ts.two_locus_count_stat( + [samples, samples], GeneralStatFuncs.D, 2, mode=mode, norm_f=norm_f + ) + assert result.shape == expected_dims + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_multi_sample_set(mode, stat, ts_10_samp_with_sites_fixture): + ts = ts_10_samp_with_sites_fixture + norm_f = GeneralStatNormFuncs.choose(stat, mode, ts) + sample_sets = [ts.samples()[0:5], ts.samples()[5:10]] + ldg = ts.two_locus_count_stat( + sample_sets, + getattr(GeneralStatFuncs, stat), + 2, + norm_f=norm_f, + mode=mode, + ) + ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat, mode=mode) + np.testing.assert_array_almost_equal(ldg, ld) + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) +def test_general_two_way(mode, stat, ts_10_samp_with_sites_fixture): + ts = ts_10_samp_with_sites_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_f = GeneralStatNormFuncs.choose(stat, mode, ts) + sample_sets = [ts.samples()[0:5], ts.samples()[5:10]] + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_f, mode=mode) + ld = ts.ld_matrix( + sample_sets=sample_sets, + stat=stat.replace("_ij", ""), + indexes=[(0, 1)], + mode=mode, + ) + np.testing.assert_array_almost_equal(ldg, ld) + + +# NB: multiallelic testing only needed for sites. branches are biallelic. + + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_multiallelic(stat, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) + polarised = POLARIZATION[SUMMARY_FUNCS[stat]] + ldg = ts.two_locus_count_stat( + [ts.samples()], general_func, 1, norm_f=norm_f, polarised=polarised + ) + ld = ts.ld_matrix(stat=stat) + # ld_matrix drops dims, expand for comparison + np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) + + +@pytest.mark.parametrize("stat", SUMMARY_FUNCS.keys()) +def test_general_one_way_multiallelic_multi_sample_set(stat, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) + polarised = POLARIZATION[SUMMARY_FUNCS[stat]] + sample_sets = [ts.samples(), ts.samples()] + ldg = ts.two_locus_count_stat( + sample_sets, general_func, 2, norm_f=norm_f, polarised=polarised + ) + ld = ts.ld_matrix(stat=stat, sample_sets=sample_sets) + np.testing.assert_array_almost_equal(ldg, ld) + + +@pytest.mark.parametrize("stat", ["r2_ij", "D2_ij", "D2_ij_unbiased"]) +def test_general_two_way_multiallelic(stat, ts_multiallelic_fixture): + ts = ts_multiallelic_fixture + general_func = getattr(GeneralStatFuncs, stat) + norm_f = GeneralStatNormFuncs.choose(stat, "site", ts) + sample_sets = [ts.samples(), ts.samples()] + ldg = ts.two_locus_count_stat(sample_sets, general_func, 1, norm_f=norm_f) + ld = ts.ld_matrix( + stat=stat.replace("_ij", ""), indexes=(0, 1), sample_sets=sample_sets + ) + # ld_matrix drops dims, expand for comparison + np.testing.assert_array_almost_equal(ldg, np.expand_dims(ld, 0)) + + +@pytest.mark.parametrize("mode", ["site", "branch"]) +def test_general_multi_outputs(mode): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + 4, recombination_rate=0.1, sequence_length=35, random_seed=123 + ), + rate=0.1, + random_seed=123, + ) + assert ts.num_samples == 8, "8 samples are required" + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + A = ts.samples()[0:4] + B = ts.samples()[4:] + + norm_f = GeneralStatNormFuncs.choose("D2_unbiased", mode, ts) + general_func = GeneralStatFuncs.D2_ii_ij_jj_unbiased + ldg = ts.two_locus_count_stat([A, B], general_func, 3, mode=mode, norm_f=norm_f) + ld = ts.ld_matrix( + [A, B], stat="D2_unbiased", indexes=[(0, 0), (0, 1), (1, 1)], mode=mode + ) + np.testing.assert_array_almost_equal(ldg, ld) diff --git a/python/tests/test_python_c.py b/python/tests/test_python_c.py index 15f9967f3f..09370a9a78 100644 --- a/python/tests/test_python_c.py +++ b/python/tests/test_python_c.py @@ -138,6 +138,23 @@ def get_example_migration_tree_sequence(self): ) return ts.ll_tree_sequence + def get_example_tree_sequence_multiallelic(self, sample_size=10): + ts = msprime.sim_mutations( + msprime.sim_ancestry( + sample_size, + recombination_rate=0.1, + sequence_length=100, + ploidy=1, + random_seed=123, + ), + rate=0.1, + random_seed=123, + ) + assert max({len(s.mutations) for s in ts.sites()}) > 2, ( + "At least one multiallelic site required" + ) + return ts.ll_tree_sequence + def verify_iterator(self, iterator): """ Checks that the specified non-empty iterator implements the @@ -1987,6 +2004,246 @@ def test_ld_matrix_multipop(self, stat_method_name): with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") + def test_two_locus_count_stat(self): + """Test two_locus_count_stat on biallelic data (no norm function)""" + ts = self.get_example_tree_sequence(10) + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + row_pos = ts.get_breakpoints()[:-1] + col_pos = row_pos + row_sites_list = list(range(ts.get_num_sites())) + col_sites_list = row_sites_list + row_pos_list = list(map(float, ts.get_breakpoints()[:-1])) + col_pos_list = row_pos_list + + def stat_func(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + def norm_func(*_): + raise Exception # norm function will not be used + + method = ts.two_locus_count_stat + site_args = row_sites, col_sites, None, None, "site" + branch_args = None, None, row_pos, col_pos, "branch" + # happy path + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_args) + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) + # happy path - sample sets as lists are also valid + site_list_args = row_sites_list, col_sites_list, None, None, "site" + branch_list_args = None, None, row_pos_list, col_pos_list, "branch" + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_list_args) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *branch_list_args) + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) + # happy path - default values for site and position lists + a = method( + ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "site" + ) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + a = method( + ss_sizes, ss, stat_func, norm_func, 1, True, None, None, None, None, "branch" + ) + assert a.shape == (ts.get_num_trees(), ts.get_num_trees(), 1) + # CPython API errors + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + bad_ss = np.array([], dtype=np.int32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="cast array data"): + bad_ss = np.array(ts.get_samples(), dtype=np.uint32) + method(ss_sizes, bad_ss, stat_func, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="Unrecognised stats mode"): + bad_args = row_sites, col_sites, None, None, "bla" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_args) + with pytest.raises(TypeError, match="at most"): + method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args, "extraarg") + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0.1, 0.2, 2.0] + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError): + bad_pos = [{}, 0, 3, 2] + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(ValueError, match="Cannot specify positions in site mode"): + bad_site_args = None, None, row_pos, col_pos, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(ValueError, match="Cannot specify sites in branch mode"): + bad_branch_args = row_sites, col_sites, None, None, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(TypeError, match="summary_func must be callable"): + method(ss_sizes, ss, "uncallable", norm_func, 1, True, *site_args) + with pytest.raises(TypeError, match="norm_func must be callable"): + method(ss_sizes, ss, stat_func, "uncallable", 1, True, *site_args) + with pytest.raises(ValueError, match="summary function.*must be 1D"): + method(ss_sizes, ss, lambda *_: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="summary function.*length 2; must be 1"): + method(ss_sizes, ss, lambda *_: [1, 2], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="could not convert string to float"): + method(ss_sizes, ss, lambda *_: ["nonfloat"], norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="assignment destination is read-only"): + + def bad_stat_func(X, n): + X[0] = [1] + return [1] + + method(ss_sizes, ss, bad_stat_func, norm_func, 1, True, *site_args) + # Exceptions within stat_func are correctly raised. + for exception in [ValueError, TypeError]: + + def stat_func_except(*_): + raise exception("test") # noqa: B023 + + with pytest.raises(exception, match="test"): + method(ss_sizes, ss, stat_func_except, norm_func, 1, True, *site_args) + # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_BAD_RESULT_DIMS"): + method(ss_sizes, ss, stat_func, norm_func, 0, True, *site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = bad_sites, col_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + bad_site_args = row_sites, bad_sites, None, None, "site" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_site_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS"): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, bad_pos, col_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + bad_branch_args = None, None, row_pos, bad_pos, "branch" + method(ss_sizes, ss, stat_func, norm_func, 1, True, *bad_branch_args) + + def test_two_locus_count_stat_multialleliic(self): + """ + Test two_locus_count_stat on multiallelic sites to test the behavior of + the norm function. + """ + ts = self.get_example_tree_sequence_multiallelic() + + def stat_func(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + def norm_func(X, n, nA, nB): + return X[0].sum(keepdims=True) / n.sum() + + ss = ts.get_samples() # sample sets + ss_sizes = np.array([len(ss)], dtype=np.uint32) + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + method = ts.two_locus_count_stat + site_args = row_sites, col_sites, None, None, "site" + + # happy path + a = method(ss_sizes, ss, stat_func, norm_func, 1, True, *site_args) + assert a.shape == (ts.get_num_sites(), ts.get_num_sites(), 1) + # CPython API errors + with pytest.raises(ValueError, match="norm function.*must be 1D"): + method(ss_sizes, ss, stat_func, lambda *_: 1, 1, True, *site_args) + with pytest.raises( + TypeError, match="takes 1 positional argument but 2 were given" + ): + method(ss_sizes, ss, lambda _: 1, norm_func, 1, True, *site_args) + with pytest.raises(ValueError, match="norm function.*length 2; must be 1"): + method(ss_sizes, ss, stat_func, lambda *_: [1, 2], 1, True, *site_args) + with pytest.raises( + TypeError, match="takes 1 positional argument but 4 were given" + ): + method(ss_sizes, ss, stat_func, lambda _: [1, 2], 1, True, *site_args) + with pytest.raises(ValueError, match="could not convert string to float"): + method(ss_sizes, ss, stat_func, lambda *_: ["nonfloat"], 1, True, *site_args) + # Exceptions within stat_func are correctly raised. + for exception in [ValueError, TypeError]: + + def norm_func_except(*_): + raise exception("test") # noqa: B023 + + with pytest.raises(exception, match="test"): + method(ss_sizes, ss, stat_func, norm_func_except, 1, True, *site_args) + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): diff --git a/python/tests/tsutil.py b/python/tests/tsutil.py index 0037e06391..48ff72d0d7 100644 --- a/python/tests/tsutil.py +++ b/python/tests/tsutil.py @@ -2463,6 +2463,28 @@ def get_back_mutation_examples(): yield insert_branch_mutations(ts) +@functools.lru_cache +def get_sim_example( + sample_size, sequence_length, recombination_rate, mutation_rate, seed +): + recomb_map = msprime.RecombinationMap.uniform_map( + sequence_length, recombination_rate + ) + ts = msprime.simulate( + recombination_map=recomb_map, + mutation_rate=mutation_rate, + random_seed=seed, + population_configurations=[ + msprime.PopulationConfiguration(sample_size), + msprime.PopulationConfiguration(0), + ], + migration_matrix=[[0, 1], [1, 0]], + ) + ts = insert_random_ploidy_individuals(ts, 4, seed=seed) + ts = add_random_metadata(ts, seed=seed) + return ts + + def make_example_tree_sequences(custom_max=None): yield from get_decapitated_examples(custom_max=custom_max) yield from get_gap_examples(custom_max=custom_max) @@ -2475,22 +2497,14 @@ def make_example_tree_sequences(custom_max=None): for n in n_list: for m in [1, 2, 32]: for rho in [0, 0.1, 0.5]: - recomb_map = msprime.RecombinationMap.uniform_map(m, rho, num_loci=m) - ts = msprime.simulate( - recombination_map=recomb_map, + ts = get_sim_example( + sample_size=n, + sequence_length=m, + recombination_rate=rho, mutation_rate=0.1, - random_seed=seed, - population_configurations=[ - msprime.PopulationConfiguration(n), - msprime.PopulationConfiguration(0), - ], - migration_matrix=[[0, 1], [1, 0]], - ) - ts = insert_random_ploidy_individuals(ts, 4, seed=seed) - yield ( - f"n={n}_m={m}_rho={rho}", - add_random_metadata(ts, seed=seed), + seed=seed, ) + yield (f"n={n}_m={m}_rho={rho}", ts) seed += 1 for name, ts in get_bottleneck_examples(custom_max=custom_max): yield ( diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 45d2da59e0..bb69d70c3a 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8249,19 +8249,7 @@ def parse_positions(self, positions): ) return row_positions, col_positions - def __two_locus_sample_set_stat( - self, - ll_method, - sample_sets, - sites=None, - positions=None, - mode=None, - ): - if sample_sets is None: - sample_sets = self.samples() - row_sites, col_sites = self.parse_sites(sites) - row_positions, col_positions = self.parse_positions(positions) - + def __convert_sample_sets(self, sample_sets): # First try to convert to a 1D numpy array. If we succeed, then we strip off # the corresponding dimension from the output. drop_dimension = False @@ -8283,7 +8271,23 @@ def __two_locus_sample_set_stat( raise ValueError("Sample sets must contain at least one element") flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + return drop_dimension, flattened, sample_set_sizes + def __two_locus_sample_set_stat( + self, + ll_method, + sample_sets, + sites=None, + positions=None, + mode=None, + ): + if sample_sets is None: + sample_sets = self.samples() + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) result = ll_method( sample_set_sizes, flattened, @@ -10927,6 +10931,118 @@ def impute_unknown_mutations_time( mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] return mutations_time + def two_locus_count_stat( + self, + sample_sets, + f, + result_dim, + norm_f=None, + polarised=False, + sites=None, + positions=None, + mode="site", + ): + """ + Compute two-locus statistics with a user-defined python function that + operates on haplotype counts. TODO: reference modes in two-locus docs. + On each pair of sites or trees, the summary function is called with + haplotype counts for all provided sample sets. The summary function + (``f``) must accept two parameters: ``X``, a matrix with shape (3, k) + and ``n``, a vector with shape (k,), where k is the number of sample + sets provided. ``X`` is a read-only matrix whose rows contain haplotype + counts (AB, Ab, aB) per sample set and ``n`` is a vector of sample set + sizes. ``f`` must return a list of results with length ``result_dim``. + + What follows is an example of computing ``D`` from a tree sequence + (TODO: cite two-locus docs for more details). We convert counts to + proportions, then compute ``D``, returning a numpy array with length + equal to the number of ``result_dim``. + + .. code-block:: python + + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + The summary function is called for each pair of sites or trees, + producing results that must be combined when multiallelic sites are + present (``site`` mode only), summary function results must + need to be normalised in order to be aggragated for all pairs of alleles + between both sites. Branch statistics and biallelic sites do not require + any normalisation, ``norm_f`` is only called if one of the two sites + under consideration is multiallelic. TODO: reference two-locus docs for + further information about normalisation. ``norm_f`` is a normalisation + function that must accept four parameters: ``X`` and ``n`` are the same + inputs that ``f`` accepts, along with ``nA`` and ``nB``, which hold the + count of ``A`` alleles and ``B`` alleles. For example, if ``A`` is + biallelic and ``B`` is triallelic, ``nA=2`` and ``nB=3``. ``f`` must + return a list of results with length ``result_dim``. The default + normalisation function is identical to ``total_norm`` shown in the + example below. ``hap_norm`` is required for normalising + :math:`r^2`. Both of these examples return a numpy array with length + equal to the number of ``result_dim``. + + .. code-block:: python + + def total_norm(X, n, nA, nB): + [1 / (nA * nB)] * result_dim + + def hap_norm(X, n, nA, nB): + X[0] / n + + A simple call (without specifying normalisation) would look like this + + .. code-block::python + + ts.two_locus_count_stat([ts.samples()], D, 1, polarised=True) + + :param list sample_sets: A list of lists of Node IDs, specifying the + groups of nodes to compute the statistic with. + :param f: A function that takes two arguments - a two-dimensional array + with shape (3, k) and a one-dimensional array with shape (k, ) where + k is the number of sample sets. + :param int result_dim: The length of ``f`` and ``norm_f``'s return value. + :param norm_f: A function that takes four arguments - the first two are + the same as ``f``, the second two are scalars representing the + number of A and B alleles, respectively. If ``None``, then defaults + to the "total" normalization described above. + :param bool polarised: Whether to leave the ancestral state out of + computations: see :ref:`sec_stats` for more details. + :param list sites: TODO: two-locus docs + :param list positions: TODO: two-locus docs + :param str mode: A string giving the "type" of the statistic to be + computed (defaults to "site"). + :return: A ndarray with shape equal to (TODO: reference two-locus docs, + no dimension dropping shape=(k, m, m) where k=result_dim, + m=num_sites or num_trees). + """ + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) + if norm_f is None: + # produce the same number of dims as result dimensions with [val] * dim + def norm_f(X, n, nA, nB): + return [1 / (nA * nB)] * result_dim + + result = self._ll_tree_sequence.two_locus_count_stat( + sample_set_sizes, + sample_sets, + f, + norm_f, + result_dim, + polarised, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) + # Orient the data so that the first dimension is the result_dim so that + # we get one LD matrix per result dimension + return np.moveaxis(result, -1, 0) + def ld_matrix( self, sample_sets=None,