From f7bda8b8d6be1ce02fedfce7ed48210e21eca7bf Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 26 Oct 2025 18:57:49 -0500 Subject: [PATCH 01/11] initial stab at a general matrix (no normalisation) --- c/tskit/trees.c | 66 ++++++++------ c/tskit/trees.h | 7 ++ python/_tskitmodule.c | 202 ++++++++++++++++++++++++++++++++++++++++++ python/tskit/trees.py | 82 ++++++++++++----- 4 files changed, 307 insertions(+), 50 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index f00bb83d28..fdc316f8d6 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2410,8 +2410,8 @@ static int compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim, - tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, - norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) + tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f, + bool polarised, two_locus_work_t *restrict work, double *result) { int ret = 0; // Sample sets and b sites are rows, a sites are columns @@ -2462,9 +2462,8 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, static int compute_general_two_site_stat_result(const tsk_bitset_t *state, const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off, - tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; tsk_size_t k; @@ -2652,9 +2651,8 @@ static int tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows, - const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites, - tsk_flags_t options, double *result) + void *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites, + tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result) { int ret = 0; tsk_bitset_t allele_samples, allele_sample_sets; @@ -3088,9 +3086,8 @@ advance_collect_edges(iter_state *s, tsk_id_t index) static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, - tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, - double *result) + tsk_size_t result_dim, int sign, general_stat_func_t *f, void *f_params, + two_locus_work_t *restrict work, double *result) { int ret = 0; double a_len, b_len; @@ -3140,8 +3137,8 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, static int compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, - iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params, - tsk_size_t result_dim, tsk_size_t state_dim, double *result) + iter_state *r_state, general_stat_func_t *f, void *f_params, tsk_size_t result_dim, + tsk_size_t state_dim, double *result) { int ret = 0; tsk_id_t e, c, ec, p, *updated_nodes = NULL; @@ -3242,9 +3239,9 @@ static int tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f), - tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols, - const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result) + void *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows, + const double *row_positions, tsk_size_t n_cols, const double *col_positions, + tsk_flags_t TSK_UNUSED(options), double *result) { int ret = 0; int r, c; @@ -3384,10 +3381,10 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s } int -tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, - norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, +tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result) { @@ -3397,10 +3394,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); tsk_size_t state_dim = num_sample_sets; - sample_count_stat_params_t f_params = { .sample_sets = sample_sets, - .num_sample_sets = num_sample_sets, - .sample_set_sizes = sample_set_sizes, - .set_indexes = set_indexes }; // We do not support two-locus node stats if (!!(options & TSK_STAT_NODE)) { @@ -3440,7 +3433,7 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_sites, out_cols, col_sites, options, result); } else if (stat_branch) { ret = check_positions( @@ -3454,13 +3447,30 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl goto out; } ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets, - sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows, row_positions, out_cols, col_positions, options, result); } out: return ret; } +int +tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, + norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result) +{ + sample_count_stat_params_t f_params = { .sample_sets = sample_sets, + .num_sample_sets = num_sample_sets, + .sample_set_sizes = sample_set_sizes, + .set_indexes = set_indexes }; + return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets, + sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows, + row_sites, row_positions, out_cols, col_sites, col_positions, options, result); +} + /*********************************** * Allele frequency spectrum ***********************************/ @@ -8697,8 +8707,8 @@ update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, for (k = offsets[b]; k < offsets[b + 1]; k++) { u = A[j]; v = A[k]; - /* Only increment the upper triangle to (hopefully) improve memory - * access patterns */ + /* Only increment the upper triangle to (hopefully) improve + * memory access patterns */ if (u > v) { u = A[k]; v = A[j]; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index 21495edbf7..a919d396b6 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -1071,6 +1071,13 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, + void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 78cb9f7c8e..33ed084031 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7796,6 +7796,203 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) return array; } +typedef struct { + PyArrayObject *sample_set_sizes; + PyObject *callable; +} two_locus_general_stat_params; + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + two_locus_general_stat_params *tl_params = params; + PyObject *callable = tl_params->callable; + PyArrayObject *sample_set_sizes = tl_params->sample_set_sizes; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + npy_intp X_dims[2] = { K, 3 }; + // Convert "n" to a column array + PyArray_Dims n_dims = { (npy_intp[2]){ PyArray_DIMS(sample_set_sizes)[0], 1 }, 2 }; + npy_intp *Y_dims; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { + goto out; + } + sample_set_sizes + = (PyArrayObject *) PyArray_Newshape(sample_set_sizes, &n_dims, NPY_CORDER); + + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + arglist = Py_BuildValue("OO", X_array, sample_set_sizes); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(callable, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by general_stat callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + Y_dims = PyArray_DIMS(Y_array); + if (Y_dims[0] != (npy_intp) M) { + PyErr_Format(PyExc_ValueError, + "Array returned by general_stat callback is of length %d; " + "must be %d", + Y_dims[0], M); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), M * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + return ret; +} + +static PyObject * +TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", + "output_dim", "polarised", "row_sites", "col_sites", "row_positions", + "column_positions", "mode", NULL }; + two_locus_general_stat_params *params; + PyObject *summary_func = NULL; + unsigned int output_dim; + PyObject *sample_set_sizes = NULL; + PyObject *sample_sets = NULL; + PyObject *row_sites = NULL; + PyObject *col_sites = NULL; + PyObject *row_positions = NULL; + PyObject *col_positions = NULL; + char *mode = NULL; + PyArrayObject *sample_set_sizes_array = NULL; + PyArrayObject *sample_sets_array = NULL; + PyArrayObject *row_sites_array = NULL; + PyArrayObject *col_sites_array = NULL; + PyArrayObject *row_positions_array = NULL; + PyArrayObject *col_positions_array = NULL; + PyArrayObject *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL; + tsk_id_t *col_sites_parsed = NULL; + double *row_positions_parsed = NULL; + double *col_positions_parsed = NULL; + npy_intp result_dim[3] = { 0, 0, 0 }; + tsk_size_t num_sample_sets; + tsk_flags_t options = 0; + int polarised = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, + &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + Py_XINCREF(summary_func); + goto out; + } + Py_INCREF(summary_func); + if (!PyCallable_Check(summary_func)) { + PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (polarised) { + options |= TSK_STAT_POLARISED; + } + + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + PyArray_CLEARFLAGS(sample_set_sizes_array, NPY_ARRAY_WRITEABLE); + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); + } + + result_dim[2] = num_sample_sets; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + + params = &(two_locus_general_stat_params){ + .sample_set_sizes = sample_set_sizes_array, + .callable = summary_func, + }; + // TODO: deal with null norm func, need general stat. + err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + output_dim, general_two_locus_count_stat_func, params, NULL, result_dim[0], + row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, + col_positions_parsed, options, PyArray_DATA(result_matrix)); + + if (err == TSK_PYTHON_CALLBACK_ERROR) { + goto out; + } else if (err != 0) { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(summary_func); + Py_XDECREF(row_sites_array); + Py_XDECREF(col_sites_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(result_matrix); + return ret; +} + static PyObject * TreeSequence_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, two_locus_count_stat_method *method) @@ -8680,6 +8877,11 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_general_stat, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Runs the general stats algorithm for a given summary function." }, + { .ml_name = "two_locus_count_stat", + .ml_meth = (PyCFunction) TreeSequence_two_locus_count_stat, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc + = "Runs the general two locus stats algorithm for a given summary function." }, { .ml_name = "diversity", .ml_meth = (PyCFunction) TreeSequence_diversity, .ml_flags = METH_VARARGS | METH_KEYWORDS, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 9904b60a98..35c6693371 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -696,8 +696,7 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 " - "and is ignored", + "The sample_counts option is not supported since 0.2.4 and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6889,7 +6888,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" ) return "\n".join(output) + "\n" @@ -8177,19 +8176,7 @@ def parse_positions(self, positions): ) return row_positions, col_positions - def __two_locus_sample_set_stat( - self, - ll_method, - sample_sets, - sites=None, - positions=None, - mode=None, - ): - if sample_sets is None: - sample_sets = self.samples() - row_sites, col_sites = self.parse_sites(sites) - row_positions, col_positions = self.parse_positions(positions) - + def __convert_sample_sets(self, sample_sets): # First try to convert to a 1D numpy array. If we succeed, then we strip off # the corresponding dimension from the output. drop_dimension = False @@ -8211,7 +8198,23 @@ def __two_locus_sample_set_stat( raise ValueError("Sample sets must contain at least one element") flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + return drop_dimension, flattened, sample_set_sizes + def __two_locus_sample_set_stat( + self, + ll_method, + sample_sets, + sites=None, + positions=None, + mode=None, + ): + if sample_sets is None: + sample_sets = self.samples() + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) result = ll_method( sample_set_sizes, flattened, @@ -9281,9 +9284,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert ( - time_windows[0] < time_windows[1] - ), "The second argument should be larger." + assert time_windows[0] < time_windows[1], ( + "The second argument should be larger." + ) tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), @@ -9351,9 +9354,9 @@ def _rand_pow_range_finder( """ Algorithm 9 in https://arxiv.org/pdf/2002.01387 """ - assert ( - num_vectors >= rank > 0 - ), "num_vectors should not be smaller than rank" + assert num_vectors >= rank > 0, ( + "num_vectors should not be smaller than rank" + ) for _ in range(depth): Q = np.linalg.qr(Q)[0] Q = operator(Q) @@ -10831,6 +10834,41 @@ def impute_unknown_mutations_time( mutations_time[unknown] = self.nodes_time[self.mutations_node[unknown]] return mutations_time + def two_locus_count_stat( + self, + sample_sets, + f, + result_dim, + polarised=False, + sites=None, + positions=None, + mode="site", + ): + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( + sample_sets + ) + result = self._ll_tree_sequence.two_locus_count_stat( + sample_set_sizes, + sample_sets, + f, + result_dim, + polarised, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) + if drop_dimension: + result = result.reshape(result.shape[:2]) + else: + # Orient the data so that the first dimension is the sample set. + # With this orientation, we get one LD matrix per sample set. + result = result.swapaxes(0, 2).swapaxes(1, 2) + return result + def ld_matrix( self, sample_sets=None, From 0f3a3341668ce0392cc39cc5b2880234cf6f22d6 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sun, 26 Oct 2025 18:59:39 -0500 Subject: [PATCH 02/11] added dimension dropping, but I think transposing is better -- we don't have to add a dimension at the end for scalar operations --- python/_tskitmodule.c | 37 +++++++++++++++++++++++-------------- python/tskit/trees.py | 4 +++- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index 33ed084031..e650f855d1 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7797,6 +7797,7 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) } typedef struct { + bool drop_dimensions; PyArrayObject *sample_set_sizes; PyObject *callable; } two_locus_general_stat_params; @@ -7806,29 +7807,33 @@ general_two_locus_count_stat_func( tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; - two_locus_general_stat_params *tl_params = params; - PyObject *callable = tl_params->callable; - PyArrayObject *sample_set_sizes = tl_params->sample_set_sizes; PyObject *arglist = NULL; PyObject *result = NULL; PyArrayObject *X_array = NULL; PyArrayObject *Y_array = NULL; - npy_intp X_dims[2] = { K, 3 }; - // Convert "n" to a column array - PyArray_Dims n_dims = { (npy_intp[2]){ PyArray_DIMS(sample_set_sizes)[0], 1 }, 2 }; + two_locus_general_stat_params *tl_params = params; + PyObject *callable = tl_params->callable; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + bool drop = (K == 1 && tl_params->drop_dimensions); + // Convert "n" to a column array -- reshape(-1, K) or a scalar if K=1 and drop=True + PyArray_Dims ss_sizes_dims = (drop ? (PyArray_Dims){ (npy_intp[1]){ 1 }, 0 } + : (PyArray_Dims){ (npy_intp[2]){ K, 1 }, 2 }); + int X_ndims = drop ? 1 : 2; + npy_intp *X_dims = drop ? (npy_intp[1]){ 3 } : (npy_intp[2]){ K, 3 }; npy_intp *Y_dims; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - 2, X_dims, NPY_FLOAT64, (void *) X); + X_ndims, X_dims, NPY_FLOAT64, (void *) X); if (X_array == NULL) { goto out; } - sample_set_sizes - = (PyArrayObject *) PyArray_Newshape(sample_set_sizes, &n_dims, NPY_CORDER); - + ss_sizes = (PyArrayObject *) PyArray_Newshape(ss_sizes, &ss_sizes_dims, NPY_CORDER); + if (ss_sizes == NULL) { + goto out; + } PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); - arglist = Py_BuildValue("OO", X_array, sample_set_sizes); + arglist = Py_BuildValue("OO", X_array, ss_sizes); if (arglist == NULL) { goto out; } @@ -7864,6 +7869,7 @@ general_two_locus_count_stat_func( Py_XDECREF(arglist); Py_XDECREF(result); Py_XDECREF(Y_array); + Py_XDECREF(ss_sizes); return ret; } @@ -7873,7 +7879,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", "output_dim", "polarised", "row_sites", "col_sites", "row_positions", - "column_positions", "mode", NULL }; + "column_positions", "mode", "drop_dimensions", NULL }; two_locus_general_stat_params *params; PyObject *summary_func = NULL; unsigned int output_dim; @@ -7898,15 +7904,17 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * npy_intp result_dim[3] = { 0, 0, 0 }; tsk_size_t num_sample_sets; tsk_flags_t options = 0; + int drop_dimensions = 0; int polarised = 0; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|s", kwlist, + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|si", kwlist, &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, - &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { + &row_sites, &col_sites, &row_positions, &col_positions, &mode, + &drop_dimensions)) { Py_XINCREF(summary_func); goto out; } @@ -7965,6 +7973,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * params = &(two_locus_general_stat_params){ .sample_set_sizes = sample_set_sizes_array, .callable = summary_func, + .drop_dimensions = drop_dimensions, }; // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 35c6693371..0cfacf51fa 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10843,6 +10843,7 @@ def two_locus_count_stat( sites=None, positions=None, mode="site", + drop_dimensions=True, ): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) @@ -10851,7 +10852,7 @@ def two_locus_count_stat( ) result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, - sample_sets, + flattened, f, result_dim, polarised, @@ -10860,6 +10861,7 @@ def two_locus_count_stat( row_positions, col_positions, mode, + drop_dimensions, ) if drop_dimension: result = result.reshape(result.shape[:2]) From 4f906a6163e437c52715ed1286f721e23338443b Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 12:31:28 -0600 Subject: [PATCH 03/11] finalize and add tests for single and multipop --- python/_tskitmodule.c | 153 +++++++++++++++----- python/tests/test_ld_matrix.py | 255 +++++++++++++++++++++++++++++++++ python/tskit/trees.py | 22 ++- 3 files changed, 380 insertions(+), 50 deletions(-) diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index e650f855d1..eda96260da 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7797,47 +7797,123 @@ parse_positions(TreeSequence *self, PyObject *positions, npy_intp *out_dim) } typedef struct { - bool drop_dimensions; PyArrayObject *sample_set_sizes; - PyObject *callable; + PyObject *summary_func; + PyObject *norm_func; } two_locus_general_stat_params; static int -general_two_locus_count_stat_func( - tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *params) +general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n_a, + tsk_size_t n_b, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; PyObject *arglist = NULL; PyObject *result = NULL; + PyArrayObject *n_a_scalar = NULL; + PyArrayObject *n_b_scalar = NULL; PyArrayObject *X_array = NULL; PyArrayObject *Y_array = NULL; two_locus_general_stat_params *tl_params = params; - PyObject *callable = tl_params->callable; + PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - bool drop = (K == 1 && tl_params->drop_dimensions); - // Convert "n" to a column array -- reshape(-1, K) or a scalar if K=1 and drop=True - PyArray_Dims ss_sizes_dims = (drop ? (PyArray_Dims){ (npy_intp[1]){ 1 }, 0 } - : (PyArray_Dims){ (npy_intp[2]){ K, 1 }, 2 }); - int X_ndims = drop ? 1 : 2; - npy_intp *X_dims = drop ? (npy_intp[1]){ 3 } : (npy_intp[2]){ K, 3 }; - npy_intp *Y_dims; + npy_intp X_dims[2] = { result_dim, 3 }; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( - X_ndims, X_dims, NPY_FLOAT64, (void *) X); + 2, X_dims, NPY_FLOAT64, (void *) X); if (X_array == NULL) { goto out; } - ss_sizes = (PyArrayObject *) PyArray_Newshape(ss_sizes, &ss_sizes_dims, NPY_CORDER); - if (ss_sizes == NULL) { + PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } + n_a_scalar + = (PyArrayObject *) PyArray_Scalar(&n_a, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_a_scalar == NULL) { + goto out; + } + n_b_scalar + = (PyArrayObject *) PyArray_Scalar(&n_b, PyArray_DescrFromType(NPY_INT64), NULL); + if (n_b_scalar == NULL) { + goto out; + } + arglist = Py_BuildValue("OOOO", X_array, ss_sizes, n_a_scalar, n_b_scalar); + if (arglist == NULL) { + goto out; + } + result = PyObject_CallObject(summary_func, arglist); + if (result == NULL) { + goto out; + } + Y_array = (PyArrayObject *) PyArray_FromAny( + result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL); + if (Y_array == NULL) { + goto out; + } + if (PyArray_NDIM(Y_array) != 1) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is %d dimensional; " + "must be 1D", + (int) PyArray_NDIM(Y_array)); + goto out; + } + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { + PyErr_Format(PyExc_ValueError, + "Array returned by norm function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); + goto out; + } + /* Copy the contents of the return Y array into Y */ + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); + ret = 0; +out: + Py_XDECREF(X_array); + Py_XDECREF(arglist); + Py_XDECREF(result); + Py_XDECREF(Y_array); + Py_XDECREF(n_a_scalar); + Py_XDECREF(n_b_scalar); + return ret; +} + +static int +general_two_locus_count_stat_func( + tsk_size_t K, const double *X, tsk_size_t result_dim, double *Y, void *params) +{ + int ret = TSK_PYTHON_CALLBACK_ERROR; + PyObject *arglist = NULL; + PyObject *result = NULL; + PyArrayObject *X_array = NULL; + PyArrayObject *Y_array = NULL; + two_locus_general_stat_params *tl_params = params; + PyObject *summary_func = tl_params->summary_func; + PyArrayObject *ss_sizes = tl_params->sample_set_sizes; + npy_intp X_dims[2] = { K, 3 }; + + // Create a read only view of X as a numpy array + X_array = (PyArrayObject *) PyArray_SimpleNewFromData( + 2, X_dims, NPY_FLOAT64, (void *) X); + if (X_array == NULL) { goto out; } PyArray_CLEARFLAGS(X_array, NPY_ARRAY_WRITEABLE); + // Transpose into column arrays, so that we can easily decompose the results + // For example: pAB, pAb, paB = X / n + // which works with K>1. In addition, the data is not reordered, meaning + // that the data is still oriented where samples are rows, meaning that + // we'll preserve data locality in ops over samples. + X_array = (PyArrayObject *) PyArray_Transpose(X_array, NULL); + if (X_array == NULL) { + goto out; + } arglist = Py_BuildValue("OO", X_array, ss_sizes); if (arglist == NULL) { goto out; } - result = PyObject_CallObject(callable, arglist); + result = PyObject_CallObject(summary_func, arglist); if (result == NULL) { goto out; } @@ -7848,28 +7924,25 @@ general_two_locus_count_stat_func( } if (PyArray_NDIM(Y_array) != 1) { PyErr_Format(PyExc_ValueError, - "Array returned by general_stat callback is %d dimensional; " + "Array returned by summary function callback is %d dimensional; " "must be 1D", (int) PyArray_NDIM(Y_array)); goto out; } - Y_dims = PyArray_DIMS(Y_array); - if (Y_dims[0] != (npy_intp) M) { + if (PyArray_DIM(Y_array, 0) != (npy_intp) result_dim) { PyErr_Format(PyExc_ValueError, - "Array returned by general_stat callback is of length %d; " - "must be %d", - Y_dims[0], M); + "Array returned by summary function callback is of length %d; must be %d", + PyArray_DIM(Y_array, 0), result_dim); goto out; } /* Copy the contents of the return Y array into Y */ - memcpy(Y, PyArray_DATA(Y_array), M * sizeof(*Y)); + memcpy(Y, PyArray_DATA(Y_array), result_dim * sizeof(*Y)); ret = 0; out: Py_XDECREF(X_array); Py_XDECREF(arglist); Py_XDECREF(result); Py_XDECREF(Y_array); - Py_XDECREF(ss_sizes); return ret; } @@ -7878,10 +7951,11 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * { PyObject *ret = NULL; static char *kwlist[] = { "sample_set_sizes", "sample_sets", "summary_func", - "output_dim", "polarised", "row_sites", "col_sites", "row_positions", - "column_positions", "mode", "drop_dimensions", NULL }; + "norm_func", "output_dim", "polarised", "row_sites", "col_sites", + "row_positions", "column_positions", "mode", NULL }; two_locus_general_stat_params *params; PyObject *summary_func = NULL; + PyObject *norm_func = NULL; unsigned int output_dim; PyObject *sample_set_sizes = NULL; PyObject *sample_sets = NULL; @@ -7904,25 +7978,29 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * npy_intp result_dim[3] = { 0, 0, 0 }; tsk_size_t num_sample_sets; tsk_flags_t options = 0; - int drop_dimensions = 0; int polarised = 0; int err; if (TreeSequence_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOIiOOOO|si", kwlist, - &sample_set_sizes, &sample_sets, &summary_func, &output_dim, &polarised, - &row_sites, &col_sites, &row_positions, &col_positions, &mode, - &drop_dimensions)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOOOIiOOOO|s", kwlist, + &sample_set_sizes, &sample_sets, &summary_func, &norm_func, &output_dim, + &polarised, &row_sites, &col_sites, &row_positions, &col_positions, &mode)) { Py_XINCREF(summary_func); + Py_XINCREF(norm_func); goto out; } Py_INCREF(summary_func); + Py_INCREF(norm_func); if (!PyCallable_Check(summary_func)) { PyErr_SetString(PyExc_TypeError, "summary_func must be callable"); goto out; } + if (!PyCallable_Check(norm_func)) { + PyErr_SetString(PyExc_TypeError, "norm_func must be callable"); + goto out; + } if (parse_stats_mode(mode, &options) != 0) { goto out; } @@ -7963,7 +8041,7 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * col_positions_parsed = PyArray_DATA(col_positions_array); } - result_dim[2] = num_sample_sets; + result_dim[2] = output_dim; result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); if (result_matrix == NULL) { PyErr_NoMemory(); @@ -7972,15 +8050,16 @@ TreeSequence_two_locus_count_stat(TreeSequence *self, PyObject *args, PyObject * params = &(two_locus_general_stat_params){ .sample_set_sizes = sample_set_sizes_array, - .callable = summary_func, - .drop_dimensions = drop_dimensions, + .summary_func = summary_func, + .norm_func = norm_func, }; // TODO: deal with null norm func, need general stat. err = tsk_treeseq_two_locus_count_general_stat(self->tree_sequence, num_sample_sets, PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), - output_dim, general_two_locus_count_stat_func, params, NULL, result_dim[0], - row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, - col_positions_parsed, options, PyArray_DATA(result_matrix)); + output_dim, general_two_locus_count_stat_func, params, + general_two_locus_norm_func, result_dim[0], row_sites_parsed, + row_positions_parsed, result_dim[1], col_sites_parsed, col_positions_parsed, + options, PyArray_DATA(result_matrix)); if (err == TSK_PYTHON_CALLBACK_ERROR) { goto out; diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index fa992adec8..b26160b140 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,6 +22,7 @@ """ Test cases for two-locus statistics """ + import contextlib import io from dataclasses import dataclass @@ -2408,3 +2409,257 @@ def test_multipopulation_r2_varying_unequal_set_sizes(genotypes, sample_sets, ex norm_hap_weighted_ij(1, state, max(a) + 1, max(b) + 1, norm[i, j], params) np.testing.assert_allclose((result * norm).sum(), expected) + + +class GeneralStatFuncs: + """ + functions take X, n as parameters where + + X: shape=(3, #ss) + sample sets + count AB [[ ] + count Ab [ ] + count aB [ ]] + + n: shape=(#ss, ) + [ ] + """ + + @staticmethod + def D(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pAB - (pA * pB) + + @staticmethod + def D2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return (pAB - (pA * pB)) ** 2 + + @staticmethod + def r2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D**2 / denom + + @staticmethod + def r(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = pA * pB * (1 - pA) * (1 - pB) + with suppress_overflow_div0_warning(): + return D / np.sqrt(denom) + + @staticmethod + def D_prime(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + denom = np.vstack( + [ + np.min([pA * (1 - pB), (1 - pA) * pB], axis=0), + np.min([pA * pB, (1 - pA) * (1 - pB)], axis=0), + ] + ) + with suppress_overflow_div0_warning(): + return D / denom[(D < 0).astype(int), range(len(D))] + + @staticmethod + def Dz(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return D * (1 - 2 * pA) * (1 - 2 * pB) + + @staticmethod + def pi2(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + return pA * (1 - pA) * pB * (1 - pB) + + @staticmethod + def D2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((aB**2) * (Ab - 1) * Ab) + + ((ab - 1) * ab * (AB - 1) * AB) + - (aB * Ab * (Ab + (2 * ab * AB) - 1)) + ) + + @staticmethod + def Dz_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + (((AB * ab) - (Ab * aB)) * (aB + ab - AB - Ab) * (Ab + ab - AB - aB)) + - ((AB * ab) * (AB + ab - Ab - aB - 2)) + - ((Ab * aB) * (Ab + aB - AB - ab - 2)) + ) + + @staticmethod + def pi2_unbiased(X, n): + AB, Ab, aB = X + ab = n - X.sum(0) + return (1 / (n * (n - 1) * (n - 2) * (n - 3))) * ( + ((AB + Ab) * (aB + ab) * (AB + aB) * (Ab + ab)) + - ((AB * ab) * (AB + ab + (3 * Ab) + (3 * aB) - 1)) + - ((Ab * aB) * (Ab + aB + (3 * AB) + (3 * ab) - 1)) + ) + + @staticmethod + def r2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = np.prod(pAB - (pA * pB)) + denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) + with suppress_overflow_div0_warning(): + return np.expand_dims(D / denom, axis=0) + + @staticmethod + def D2_ij(X, n): + pAB, pAb, paB = X / n + pA = pAb + pAB + pB = paB + pAB + D = pAB - (pA * pB) + return np.expand_dims(np.prod(D), axis=0) + + @staticmethod + def D2_ij_unbiased(X, n): + """ + NB: the two sample sets must be disjoint + we have no way for testing equality + """ + AB, Ab, aB = X + ab = n - X.sum(0) + return np.expand_dims( + (Ab[0] * aB[0] - AB[0] * ab[0]) + * (Ab[1] * aB[1] - AB[1] * ab[1]) + / n[0] + / (n[0] - 1) + / n[1] + / (n[1] - 1), + axis=0, + ) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "D", + ), + (ts, "D2"), + (ts, "r2"), + (ts, "r"), + (ts, "D_prime"), + (ts, "Dz"), + (ts, "pi2"), + (ts, "D2_unbiased"), + (ts, "Dz_unbiased"), + (ts, "pi2_unbiased"), + ], +) +def test_general_two_locus_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) + ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) + np.testing.assert_equal(ldg, ld) + + +@pytest.mark.parametrize( + "ts,stat", + [ + ( + ts := [ + p for p in get_example_tree_sequences() if p.id == "n=100_m=32_rho=0.5" + ][0].values[0], + "r2_ij", + ), + (ts, "D2_ij"), + (ts, "D2_ij_unbiased"), + ], +) +def test_general_two_locus_two_way_site_stat(ts, stat): + sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] + ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 1) + ld = ts.ld_matrix( + sample_sets=sample_sets, stat=stat.replace("_ij", ""), indexes=(0, 1) + ) + np.testing.assert_allclose(ldg, ld) + + +@pytest.mark.parametrize( + "stat", + [ + "D", + "D2", + "r2", + "r", + "D_prime", + "Dz", + "pi2", + "D2_unbiased", + "Dz_unbiased", + "pi2_unbiased", + ], +) +def test_general_one_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2": + result = ts.two_locus_count_stat( + [ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n + ) + elif stat in {"D", "r", "D_prime"}: + result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples()], func, 1) + np.testing.assert_allclose(ts.ld_matrix(stat=stat), result) + + +@pytest.mark.parametrize( + "stat", + [ + "r2_ij", + "D2_ij", + "D2_ij_unbiased", + ], +) +def test_general_two_way_two_locus_stat_multiallelic(stat): + (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values + func = getattr(GeneralStatFuncs, stat) + if stat == "r2_ij": + result = ts.two_locus_count_stat( + [ts.samples(), ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n + ) + elif stat in {"D", "r", "D_prime"}: + result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + else: + # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) + np.testing.assert_allclose( + ts.ld_matrix( + stat=stat.replace("_ij", ""), + indexes=(0, 1), + sample_sets=[ts.samples(), ts.samples()], + ), + result, + ) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 0cfacf51fa..147294b370 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -10839,21 +10839,20 @@ def two_locus_count_stat( sample_sets, f, result_dim, + norm_f=lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0), polarised=False, sites=None, positions=None, mode="site", - drop_dimensions=True, ): row_sites, col_sites = self.parse_sites(sites) row_positions, col_positions = self.parse_positions(positions) - drop_dimension, flattened, sample_set_sizes = self.__convert_sample_sets( - sample_sets - ) + _, sample_sets, sample_set_sizes = self.__convert_sample_sets(sample_sets) result = self._ll_tree_sequence.two_locus_count_stat( sample_set_sizes, - flattened, + sample_sets, f, + norm_f, result_dim, polarised, row_sites, @@ -10861,15 +10860,12 @@ def two_locus_count_stat( row_positions, col_positions, mode, - drop_dimensions, ) - if drop_dimension: - result = result.reshape(result.shape[:2]) - else: - # Orient the data so that the first dimension is the sample set. - # With this orientation, we get one LD matrix per sample set. - result = result.swapaxes(0, 2).swapaxes(1, 2) - return result + if result_dim == 1: # drop dimension + return result.reshape(result.shape[:2]) + # Orient the data so that the first dimension is the sample set so that + # we get one LD matrix per sample set. + return result.swapaxes(0, 2).swapaxes(1, 2) def ld_matrix( self, From 5d26756a56c39308073cee685a4129d066282d03 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 12:48:58 -0600 Subject: [PATCH 04/11] reformat jitter --- python/tskit/trees.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/tskit/trees.py b/python/tskit/trees.py index 147294b370..503216a53d 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -696,7 +696,8 @@ def __init__( options = 0 if sample_counts is not None: warnings.warn( - "The sample_counts option is not supported since 0.2.4 and is ignored", + "The sample_counts option is not supported since 0.2.4 " + "and is ignored", RuntimeWarning, stacklevel=4, ) @@ -6888,7 +6889,7 @@ def to_macs(self): bytes_genotypes[:] = lookup[variant.genotypes] genotypes = bytes_genotypes.tobytes().decode() output.append( - f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t{genotypes}" + f"SITE:\t{variant.index}\t{variant.position / m}\t0.0\t" f"{genotypes}" ) return "\n".join(output) + "\n" @@ -9284,9 +9285,9 @@ def pca( if time_windows is None: tree_sequence_low, tree_sequence_high = None, self else: - assert time_windows[0] < time_windows[1], ( - "The second argument should be larger." - ) + assert ( + time_windows[0] < time_windows[1] + ), "The second argument should be larger." tree_sequence_low, tree_sequence_high = ( self.decapitate(time_windows[0]), self.decapitate(time_windows[1]), @@ -9354,9 +9355,9 @@ def _rand_pow_range_finder( """ Algorithm 9 in https://arxiv.org/pdf/2002.01387 """ - assert num_vectors >= rank > 0, ( - "num_vectors should not be smaller than rank" - ) + assert ( + num_vectors >= rank > 0 + ), "num_vectors should not be smaller than rank" for _ in range(depth): Q = np.linalg.qr(Q)[0] Q = operator(Q) From a8b4e3b315e2b734413c7ed2a169d0e2f1290508 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 12:50:58 -0600 Subject: [PATCH 05/11] one more reformat jitter --- c/tskit/trees.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index fdc316f8d6..cc2c548e64 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -8707,8 +8707,8 @@ update_site_divergence(const tsk_variant_t *var, const tsk_id_t *restrict A, for (k = offsets[b]; k < offsets[b + 1]; k++) { u = A[j]; v = A[k]; - /* Only increment the upper triangle to (hopefully) improve - * memory access patterns */ + /* Only increment the upper triangle to (hopefully) improve memory + * access patterns */ if (u > v) { u = A[k]; v = A[j]; From d3a054a1a30d5ee119e23f1ab8027dedca8a40ca Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 12:52:36 -0600 Subject: [PATCH 06/11] one more one more reformat jitter --- python/tests/test_ld_matrix.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index b26160b140..58696ead41 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -22,7 +22,6 @@ """ Test cases for two-locus statistics """ - import contextlib import io from dataclasses import dataclass From 7a93dc5cd296f93bdc91ccbfb0f02c480234f2a9 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:16:12 -0600 Subject: [PATCH 07/11] turns out, the general norm function needs to know the state_dims --- c/tskit/trees.c | 17 ++++++++++------- c/tskit/trees.h | 4 ++-- python/_tskitmodule.c | 6 +++--- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index cc2c548e64..1be8e2639c 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2297,8 +2297,9 @@ get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset, } static int -norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2314,8 +2315,9 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, } static int -norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, - tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +norm_hap_weighted_ij(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights, + tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), + double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; const double *weight_row; @@ -2340,8 +2342,9 @@ norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, } static int -norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), - tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) +norm_total_weighted(tsk_size_t TSK_UNUSED(state_dim), + const double *TSK_UNUSED(hap_weights), tsk_size_t result_dim, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; double norm = 1 / (double) (n_a * n_b); @@ -2444,7 +2447,7 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state, if (ret != 0) { goto out; } - ret = norm_f(result_dim, weights, num_a_alleles - is_polarised, + ret = norm_f(state_dim, weights, result_dim, num_a_alleles - is_polarised, num_b_alleles - is_polarised, norm, f_params); if (ret != 0) { goto out; diff --git a/c/tskit/trees.h b/c/tskit/trees.h index a919d396b6..79462efba4 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -987,8 +987,8 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, - tsk_size_t n_b, double *result, void *params); +typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, + tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params); int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index eda96260da..6336cd1826 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -7803,8 +7803,8 @@ typedef struct { } two_locus_general_stat_params; static int -general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n_a, - tsk_size_t n_b, double *Y, void *params) +general_two_locus_norm_func(tsk_size_t K, const double *X, tsk_size_t result_dim, + tsk_size_t n_a, tsk_size_t n_b, double *Y, void *params) { int ret = TSK_PYTHON_CALLBACK_ERROR; PyObject *arglist = NULL; @@ -7816,7 +7816,7 @@ general_two_locus_norm_func(tsk_size_t result_dim, const double *X, tsk_size_t n two_locus_general_stat_params *tl_params = params; PyObject *summary_func = tl_params->norm_func; PyArrayObject *ss_sizes = tl_params->sample_set_sizes; - npy_intp X_dims[2] = { result_dim, 3 }; + npy_intp X_dims[2] = { K, 3 }; // Create a read only view of X as a numpy array X_array = (PyArrayObject *) PyArray_SimpleNewFromData( From d269d744117c012495d1bf661eb91f4bf9c5a008 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:18:40 -0600 Subject: [PATCH 08/11] fix up a bit of naming in general test funcs, remove unneeded branch, fix norm func for r2_ij --- python/tests/test_ld_matrix.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 58696ead41..448cf0177f 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2523,18 +2523,17 @@ def r2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D = np.prod(pAB - (pA * pB)) + D2_ij = np.prod(pAB - (pA * pB)) denom = np.prod(np.sqrt(pA * pB * (1 - pA) * (1 - pB))) with suppress_overflow_div0_warning(): - return np.expand_dims(D / denom, axis=0) + return np.expand_dims(D2_ij / denom, axis=0) @staticmethod def D2_ij(X, n): pAB, pAb, paB = X / n pA = pAb + pAB pB = paB + pAB - D = pAB - (pA * pB) - return np.expand_dims(np.prod(D), axis=0) + return np.expand_dims(np.prod(pAB - (pA * pB)), axis=0) @staticmethod def D2_ij_unbiased(X, n): @@ -2646,11 +2645,8 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - result = ts.two_locus_count_stat( - [ts.samples(), ts.samples()], func, 1, lambda X, n, nA, nB: X[0] / n - ) - elif stat in {"D", "r", "D_prime"}: - result = ts.two_locus_count_stat([ts.samples()], func, 1, polarised=True) + norm_f = lambda X, n, nA, nB: np.expand_dims(X[0].sum() / n.sum(), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1) From 3a932ae83d1800356b93e9b1f7299223f873b7e2 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:21:12 -0600 Subject: [PATCH 09/11] flake8 does not like assigning lambdas to variables --- python/tests/test_ld_matrix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 448cf0177f..ca0a4f6e49 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2645,7 +2645,8 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": - norm_f = lambda X, n, nA, nB: np.expand_dims(X[0].sum() / n.sum(), axis=0) + def norm_f(X, n, nA, nB): + return np.expand_dims(X[0].sum() / n.sum(), axis=0) result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) From 52be81f9192da7d8fa2bacb35c2fb40a42a89a47 Mon Sep 17 00:00:00 2001 From: lkirk Date: Thu, 4 Dec 2025 16:23:53 -0600 Subject: [PATCH 10/11] and black doesn't like that --- python/tests/test_ld_matrix.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index ca0a4f6e49..8428a100e7 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2645,8 +2645,10 @@ def test_general_two_way_two_locus_stat_multiallelic(stat): (ts,) = {t.id: t for t in get_example_tree_sequences()}["all_fields"].values func = getattr(GeneralStatFuncs, stat) if stat == "r2_ij": + def norm_f(X, n, nA, nB): return np.expand_dims(X[0].sum() / n.sum(), axis=0) + result = ts.two_locus_count_stat([ts.samples(), ts.samples()], func, 1, norm_f) else: # default norm func is lambda X, n, nA, nB: np.expand_dims(1 / (nA * nB), axis=0) From be14dd0ddaf391265a4e2fef860dee3100716659 Mon Sep 17 00:00:00 2001 From: lkirk Date: Sat, 6 Dec 2025 18:29:12 -0600 Subject: [PATCH 11/11] do not test equality, this was useful on my local machine but is problematic in practice --- python/tests/test_ld_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 8428a100e7..bd09302abb 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -2578,7 +2578,7 @@ def test_general_two_locus_site_stat(ts, stat): sample_sets = [ts.samples()[0:50], ts.samples()[50:100]] ldg = ts.two_locus_count_stat(sample_sets, getattr(GeneralStatFuncs, stat), 2) ld = ts.ld_matrix(sample_sets=sample_sets, stat=stat) - np.testing.assert_equal(ldg, ld) + np.testing.assert_allclose(ldg, ld) @pytest.mark.parametrize(