From 266ac444d877b80757e4236db1d762b66de09749 Mon Sep 17 00:00:00 2001 From: lkirk Date: Wed, 18 Dec 2024 14:34:46 -0600 Subject: [PATCH] C and Python API for two-way two-locus stats This PR implements the C and Python API for computing two-way two-locus statistics. The algorithm is identical to the python version, except during testing I uncovered a small issue with normalisation. We need to handle the case where sample sets are of different sizes. The fix for this was to average the normalisation factor for each sample set. Test coverage has been added to cover C, low-level python and some high-level tests. --- c/tests/test_stats.c | 439 +++++++++++++++++++++++++-------- c/tskit/trees.c | 264 ++++++++++++++++++-- c/tskit/trees.h | 83 +++++-- python/_tskitmodule.c | 144 +++++++++++ python/tests/test_ld_matrix.py | 218 ++++++++++++---- python/tests/test_lowlevel.py | 208 ++++++++++++++++ python/tskit/trees.py | 82 +++++- 7 files changed, 1222 insertions(+), 216 deletions(-) diff --git a/c/tests/test_stats.c b/c/tests/test_stats.c index b3515ef2c5..a14f204e55 100644 --- a/c/tests/test_stats.c +++ b/c/tests/test_stats.c @@ -2637,12 +2637,15 @@ test_paper_ex_two_site(void) tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); + double truth_three_index_tuples[27] = { 1, 1, NAN, 0.1111111111111111, + 0.1111111111111111, NAN, 0.1111111111111111, 0.1111111111111111, NAN, + 0.1111111111111111, 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1, + 0.1111111111111111, 0.1111111111111111, NAN, 1, 1, 1, 1, 1, 1 }; - tsk_size_t sample_set_sizes[3]; - tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t sample_set_sizes[3], num_index_tuples; + tsk_id_t sample_sets[ts.num_samples * 3], index_tuples[2 * 3] = { 0, 1, 0, 0, 0, 2 }; tsk_size_t num_sites = ts.tables->sites.num_rows; - tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); - tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_id_t *sites = tsk_malloc(num_sites * sizeof(*sites)); // First sample set contains all of the samples sample_set_sizes[0] = ts.num_samples; @@ -2651,14 +2654,13 @@ test_paper_ex_two_site(void) sample_sets[s] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { - row_sites[s] = (tsk_id_t) s; - col_sites[s] = (tsk_id_t) s; + sites[s] = (tsk_id_t) s; } result_size = num_sites * num_sites; tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_one_set); @@ -2672,7 +2674,7 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size * num_sample_sets, result, truth_two_sets); @@ -2686,15 +2688,48 @@ test_paper_ex_two_site(void) tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan( result_size * num_sample_sets, result, truth_three_sets); + // Two-way stats: we'll reuse all sample sets from the first 3 tests + num_sample_sets = 3; + + num_index_tuples = 1; + // We'll compute r2 between sample set 0 and sample set 1 + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, sites, NULL, num_sites, sites, NULL, + 0, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_one_set); + + // Compare sample sets [(0, 1), (0, 0)] + num_index_tuples = 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, sites, NULL, num_sites, sites, NULL, + 0, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_two_sets); + + // Compare sample sets [(0, 1), (0, 0), (0, 2)] + num_index_tuples = 3; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, sites, NULL, num_sites, sites, NULL, + 0, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan( + result_size * num_index_tuples, result, truth_three_index_tuples); + tsk_treeseq_free(&ts); - tsk_safe_free(row_sites); - tsk_safe_free(col_sites); + tsk_safe_free(sites); } static void @@ -2705,42 +2740,44 @@ test_paper_ex_two_branch(void) double result[27]; tsk_size_t i, result_size, num_sample_sets; tsk_flags_t options = 0; - double truth_one_set[9] - = { 0.001066666666666695, -0.00012666666666665688, -0.0001266666666666534, - -0.00012666666666665688, 6.016666666665456e-05, 6.016666666665629e-05, - -0.0001266666666666534, 6.016666666665629e-05, 6.016666666665629e-05 }; - double truth_two_sets[18] - = { 0.001066666666666695, 0.001066666666666695, -0.00012666666666665688, - -0.00012666666666665688, -0.0001266666666666534, -0.0001266666666666534, - -0.00012666666666665688, -0.00012666666666665688, 6.016666666665456e-05, - 6.016666666665456e-05, 6.016666666665629e-05, 6.016666666665629e-05, - -0.0001266666666666534, -0.0001266666666666534, 6.016666666665629e-05, - 6.016666666665629e-05, 6.016666666665629e-05, 6.016666666665629e-05 }; - double truth_three_sets[27] = { 0.001066666666666695, 0.001066666666666695, NAN, - -0.00012666666666665688, -0.00012666666666665688, NAN, -0.0001266666666666534, - -0.0001266666666666534, NAN, -0.00012666666666665688, -0.00012666666666665688, - NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665629e-05, - 6.016666666665629e-05, NAN, -0.0001266666666666534, -0.0001266666666666534, NAN, - 6.016666666665629e-05, 6.016666666665629e-05, NAN, 6.016666666665629e-05, - 6.016666666665629e-05, NAN }; - double truth_positions_subset_1[12] = { 0.001066666666666695, 0.001066666666666695, - NAN, 0.001066666666666695, 0.001066666666666695, NAN, 0.001066666666666695, - 0.001066666666666695, NAN, 0.001066666666666695, 0.001066666666666695, NAN }; - double truth_positions_subset_2[12] = { 6.016666666665456e-05, 6.016666666665456e-05, - NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, - 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; - double truth_positions_subset_3[12] = { 6.016666666665456e-05, 6.016666666665456e-05, - NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN, 6.016666666665456e-05, - 6.016666666665456e-05, NAN, 6.016666666665456e-05, 6.016666666665456e-05, NAN }; + double truth_one_set[9] = { 0.008890640625, 0.004624203125, 0.005215703125, + 0.004624203125, 0.003737578125, 0.004377078125, 0.005215703125, + 0.004377078124999999, 0.005160578124999998 }; + double truth_two_sets[18] = { 0.008890640625, 0.008890640625, 0.004624203125, + 0.004624203125, 0.005215703125, 0.005215703125, 0.004624203125, 0.004624203125, + 0.003737578125, 0.003737578125, 0.004377078125, 0.004377078125, 0.005215703125, + 0.005215703125, 0.004377078124999999, 0.004377078124999999, 0.005160578124999998, + 0.005160578124999998 }; + double truth_three_sets[27] + = { 0.008890640625, 0.008890640625, 0.007225, 0.004624203125000001, + 0.004624203125, 0.007225, 0.005215703125000002, 0.005215703125, 0.008585, + 0.004624203125, 0.004624203125, 0.007225, 0.003737578125, 0.003737578125, + 0.007225, 0.004377078125, 0.004377078125, 0.008585, 0.005215703125, + 0.005215703125, 0.008585, 0.004377078124999999, 0.004377078124999999, + 0.008585, 0.005160578124999998, 0.005160578124999998, 0.010201 }; + double truth_positions_subset_1[12] = { 0.008890640625, 0.008890640625, 0.007225, + 0.008890640625, 0.008890640625, 0.007225, 0.008890640625, 0.008890640625, + 0.007225, 0.008890640625, 0.008890640625, 0.007225 }; + double truth_positions_subset_2[12] = { 0.003737578125, 0.003737578125, 0.007225, + 0.003737578125, 0.003737578125, 0.007225, 0.003737578125, 0.003737578125, + 0.007225, 0.003737578125, 0.003737578125, 0.007225 }; + double truth_positions_subset_3[12] = { 0.005160578125, 0.005160578125, 0.010201, + 0.005160578125, 0.005160578125, 0.010201, 0.005160578125, 0.005160578125, + 0.010201, 0.005160578125, 0.005160578125, 0.010201 }; + double truth_three_index_tuples[27] = { 0.008890640625, 0.008890640625, 0.0039125, + 0.004624203125, 0.004624203125, 0.0038125, 0.005215703125, 0.005215703125, + 0.0045725, 0.004624203125, 0.004624203125, 0.0038125, 0.003737578125, + 0.003737578125, 0.0040125, 0.004377078125, 0.004377078125, 0.0048525, + 0.005215703125, 0.005215703125, 0.0045725, 0.004377078125, 0.004377078125, + 0.0048525, 0.005160578125, 0.005160578125, 0.0058845 }; tsk_treeseq_from_text(&ts, 10, paper_ex_nodes, paper_ex_edges, NULL, paper_ex_sites, paper_ex_mutations, paper_ex_individuals, NULL, 0); - tsk_size_t sample_set_sizes[3]; - tsk_id_t sample_sets[ts.num_samples * 3]; + tsk_size_t sample_set_sizes[3], num_index_tuples; + tsk_id_t sample_sets[ts.num_samples * 3], index_tuples[2 * 3] = { 0, 1, 0, 0, 0, 2 }; tsk_size_t num_trees = ts.num_trees; - double *row_positions = tsk_malloc(num_trees * sizeof(*row_positions)); - double *col_positions = tsk_malloc(num_trees * sizeof(*col_positions)); + double *positions = tsk_malloc(num_trees * sizeof(*positions)); double positions_subset_1[2] = { 0., 0.1 }; double positions_subset_2[2] = { 2., 6. }; double positions_subset_3[2] = { 9., 9.999 }; @@ -2752,16 +2789,15 @@ test_paper_ex_two_branch(void) sample_sets[i] = (tsk_id_t) i; } for (i = 0; i < num_trees; i++) { - row_positions[i] = ts.breakpoints[i]; - col_positions[i] = ts.breakpoints[i]; + positions[i] = ts.breakpoints[i]; } options |= TSK_STAT_BRANCH; result_size = num_trees * num_trees * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_trees, + NULL, positions, num_trees, NULL, positions, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_one_set); @@ -2774,8 +2810,8 @@ test_paper_ex_two_branch(void) result_size = num_trees * num_trees * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_trees, + NULL, positions, num_trees, NULL, positions, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_two_sets); @@ -2788,35 +2824,70 @@ test_paper_ex_two_branch(void) result_size = num_trees * num_trees * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_trees, NULL, row_positions, num_trees, NULL, col_positions, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_trees, + NULL, positions, num_trees, NULL, positions, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_three_sets); result_size = 4 * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - NULL, positions_subset_1, 2, NULL, positions_subset_1, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, NULL, + positions_subset_1, 2, NULL, positions_subset_1, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_1); result_size = 4 * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - NULL, positions_subset_2, 2, NULL, positions_subset_2, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, NULL, + positions_subset_2, 2, NULL, positions_subset_2, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_2); result_size = 4 * num_sample_sets; tsk_memset(result, 0, sizeof(*result) * result_size); - ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, - NULL, positions_subset_3, 2, NULL, positions_subset_3, options, result); + ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, 2, NULL, + positions_subset_3, 2, NULL, positions_subset_3, options, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal_nan(result_size, result, truth_positions_subset_3); + // Two-way stats: we'll reuse all sample sets from the first 3 tests + num_sample_sets = 3; + result_size = num_trees * num_trees; + + num_index_tuples = 1; + // We'll compute D2 between sample set 0 and sample set 1 + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_trees, NULL, positions, num_trees, NULL, + positions, options, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_one_set); + + // Compare sample sets [(0, 1), (0, 0)] + num_index_tuples = 2; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_trees, NULL, positions, num_trees, NULL, + positions, options, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size * num_index_tuples, result, truth_two_sets); + + // Compare sample sets [(0, 1), (0, 0), (0, 2)] + num_index_tuples = 3; + tsk_memset(result, 0, sizeof(*result) * result_size * num_index_tuples); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_trees, NULL, positions, num_trees, NULL, + positions, options, result); + + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal_nan( + result_size * num_index_tuples, result, truth_three_index_tuples); + tsk_treeseq_free(&ts); - tsk_safe_free(row_positions); - tsk_safe_free(col_positions); + tsk_safe_free(positions); + tsk_safe_free(positions); } static void @@ -2853,8 +2924,8 @@ test_two_site_correlated_multiallelic(void) "0 10 16 13\n" "0 10 16 15\n" "10 20 16 15\n"; - const char *sites = "7 A\n" - "13 G\n"; + const char *tree_sites = "7 A\n" + "13 G\n"; const char *mutations = "0 15 T -1\n" "0 14 G 0\n" "1 15 T -1\n" @@ -2877,71 +2948,133 @@ test_two_site_correlated_multiallelic(void) 0.003387017561686057, 0.003387017561686057 }; double truth_pi2[4] = { 0.04579247743399549, 0.04579247743399549, 0.04579247743399549, 0.0457924774339955 }; + double truth_D2_unbiased[4] = { 0.026455026455026454, 0.026455026455026454, + 0.026455026455026454, 0.026455026455026454 }; + double truth_Dz_unbiased[4] = { -0.008818342151675485, -0.008818342151675485, + -0.008818342151675485, -0.008818342151675485 }; + double truth_pi2_unbiased[4] = { 0.0582010582010582, 0.0582010582010582, + 0.0582010582010582, 0.0582010582010582 }; + double truth_D2_unbiased_disjoint[4] = { 0.007407407407407407, 0.007407407407407407, + 0.007407407407407407, 0.007407407407407407 }; - tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + tsk_treeseq_from_text( + &ts, 20, nodes, edges, NULL, tree_sites, mutations, NULL, NULL, 0); tsk_size_t num_sample_sets = 1; - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; - tsk_id_t sample_sets[ts.num_samples]; + tsk_size_t sample_set_sizes[2] = { ts.num_samples, ts.num_samples }; + tsk_id_t sample_sets[ts.num_samples * 2]; tsk_size_t num_sites = ts.tables->sites.num_rows; - tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); - tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); + tsk_id_t *sites = tsk_malloc(num_sites * sizeof(*sites)); result_size = num_sites * num_sites; double result[result_size]; + // Two sample sets for multipop at the bottom, only presenting one to single pop + // results for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; + sample_sets[s + ts.num_samples] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { - row_sites[s] = (tsk_id_t) s; - col_sites[s] = (tsk_id_t) s; + sites[s] = (tsk_id_t) s; } tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_Dz_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_pi2_unbiased); + + // We'll compute r2 between sample set 0 and sample set 1 + num_sample_sets = 2; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_r2); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2); + + // perfectly overlapping sample sets will produce a result equal to the single + // population case + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + // two disjoint sample sets with 5 and 4 samples {0,1,2,3,4}{5,6,7,8} + sample_set_sizes[0] = 5; + sample_set_sizes[1] = 4; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 1 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased_disjoint); + tsk_treeseq_free(&ts); - tsk_safe_free(row_sites); - tsk_safe_free(col_sites); + tsk_safe_free(sites); } static void @@ -2988,8 +3121,8 @@ test_two_site_uncorrelated_multiallelic(void) "10 20 23 18,20\n" "0 10 16 14,15\n" "10 20 24 22,23\n"; - const char *sites = "7 A\n" - "13 G\n"; + const char *tree_sites = "7 A\n" + "13 G\n"; const char *mutations = "0 15 T -1\n" "0 12 G 0\n" "1 23 T -1\n" @@ -3007,72 +3140,134 @@ test_two_site_uncorrelated_multiallelic(void) double truth_Dz[4] = { 0.0, 0.0, 0.0, 0.0 }; double truth_pi2[4] = { 0.04938271604938272, 0.04938271604938272, 0.04938271604938272, 0.04938271604938272 }; + double truth_D2_unbiased[4] = { 0.027777777777777776, -0.009259259259259259, + -0.009259259259259259, 0.027777777777777776 }; + double truth_Dz_unbiased[4] = { -0.015873015873015872, 0.005291005291005289, + 0.005291005291005289, -0.015873015873015872 }; + double truth_pi2_unbiased[4] = { 0.06349206349206349, 0.06216931216931215, + 0.06216931216931215, 0.06349206349206349 }; + double truth_D2_unbiased_disjoint[4] = { 0.008333333333333333, + -0.0027777777777777775, -0.0027777777777777775, 0.03518518518518518 }; - tsk_treeseq_from_text(&ts, 20, nodes, edges, NULL, sites, mutations, NULL, NULL, 0); + tsk_treeseq_from_text( + &ts, 20, nodes, edges, NULL, tree_sites, mutations, NULL, NULL, 0); tsk_size_t s; tsk_size_t num_sample_sets = 1; tsk_size_t num_sites = ts.tables->sites.num_rows; - tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); - tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; - tsk_id_t sample_sets[ts.num_samples]; + tsk_id_t *sites = tsk_malloc(num_sites * sizeof(*sites)); + tsk_size_t sample_set_sizes[2] = { ts.num_samples, ts.num_samples }; + tsk_id_t sample_sets[ts.num_samples * 2]; tsk_size_t result_size = num_sites * num_sites; double result[result_size]; + // Two sample sets for multipop at the bottom, only presenting one to single pop + // results for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; + sample_sets[s + ts.num_samples] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { - row_sites[s] = (tsk_id_t) s; - col_sites[s] = (tsk_id_t) s; + sites[s] = (tsk_id_t) s; } tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r2); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_D_prime(&ts, num_sample_sets, sample_set_sizes, sample_sets, - num_sites, row_sites, NULL, num_sites, col_sites, NULL, 0, result); + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_D_prime); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_r(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_r); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_Dz(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_Dz); tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); ret = tsk_treeseq_pi2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, - row_sites, NULL, num_sites, col_sites, NULL, 0, result); + sites, NULL, num_sites, sites, NULL, 0, result); CU_ASSERT_EQUAL_FATAL(ret, 0); assert_arrays_almost_equal(result_size, result, truth_pi2); + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_D2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_Dz_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_Dz_unbiased); + + tsk_memset(result, 0, sizeof(*result) * result_size * num_sample_sets); + ret = tsk_treeseq_pi2_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_sites, sites, NULL, num_sites, sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_pi2_unbiased); + + // We'll compute r2 between sample set 0 and sample set 1 + num_sample_sets = 2; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_r2); + + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, 1, + (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2); + + // perfectly overlapping sample sets will produce a result equal to the single + // population case + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 0 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased); + + // two disjoint sample sets with 5 and 4 samples {0,1,2,3,4}{5,6,7,8} + sample_set_sizes[0] = 5; + sample_set_sizes[1] = 4; + tsk_memset(result, 0, sizeof(*result) * result_size); + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + 1, (tsk_id_t[2]){ 0, 1 }, num_sites, sites, NULL, num_sites, sites, NULL, 0, + result); + CU_ASSERT_EQUAL_FATAL(ret, 0); + assert_arrays_almost_equal(result_size, result, truth_D2_unbiased_disjoint); + tsk_treeseq_free(&ts); - tsk_safe_free(row_sites); - tsk_safe_free(col_sites); + tsk_safe_free(sites); } static void @@ -3145,7 +3340,7 @@ test_two_site_backmutation(void) } static void -test_two_locus_site_all_stats(void) +test_two_locus_branch_all_stats(void) { int ret; tsk_treeseq_t ts; @@ -3167,12 +3362,7 @@ test_two_locus_site_all_stats(void) "0 2 20 13\n0 2 20 16\n2 10 21 13\n6 10 21 18\n0 6 21 19\n" "0 2 21 20\n"; - double truth_D[16] = { -6.938893903907228e-18, 5.551115123125783e-17, - 4.85722573273506e-17, 2.7755575615628914e-17, 1.0408340855860843e-17, - 8.326672684688674e-17, 7.979727989493313e-17, 6.938893903907228e-17, - -2.42861286636753e-17, 4.163336342344337e-17, 2.42861286636753e-17, - 4.163336342344337e-17, 1.3877787807814457e-17, 5.551115123125783e-17, - 2.0816681711721685e-17, 2.7755575615628914e-17 }; + double truth_D[16] = { 0 }; double truth_D2[16] = { 0.21949755999999998, 0.1867003599999999, 0.18798699999999988, 0.18941379999999983, 0.18670035999999995, 0.21159555999999993, 0.21257979999999996, 0.21222580000000005, 0.187987, 0.21257979999999996, @@ -3355,9 +3545,11 @@ test_two_locus_stat_input_errors(void) tsk_size_t num_sites = ts.tables->sites.num_rows; tsk_id_t *row_sites = tsk_malloc(num_sites * sizeof(*row_sites)); tsk_id_t *col_sites = tsk_malloc(num_sites * sizeof(*col_sites)); - tsk_size_t sample_set_sizes[1] = { ts.num_samples }; + tsk_size_t sample_set_sizes[2] = { ts.num_samples, ts.num_samples }; tsk_size_t num_sample_sets = 1; - tsk_id_t sample_sets[ts.num_samples]; + tsk_id_t index_tuples[2] = { 0 }; + tsk_size_t num_index_tuples = 1; + tsk_id_t sample_sets[ts.num_samples * 2]; // need 2 sample sets for multipop double positions[10] = { 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9 }; double bad_col_positions[2] = { 0., 0. }; // used in 1 test to cover column check double result[100]; @@ -3365,17 +3557,25 @@ test_two_locus_stat_input_errors(void) for (s = 0; s < ts.num_samples; s++) { sample_sets[s] = (tsk_id_t) s; + sample_sets[s + ts.num_samples] = (tsk_id_t) s; } for (s = 0; s < num_sites; s++) { row_sites[s] = (tsk_id_t) s; col_sites[s] = (tsk_id_t) s; } + // begin with the happy path + ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, + row_sites, NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); - sample_set_sizes[0] = ts.num_samples; - num_sample_sets = 1; - for (s = 0; s < ts.num_samples; s++) { - sample_sets[s] = (tsk_id_t) s; - } + ret = tsk_treeseq_two_locus_count_stat(&ts, num_sample_sets, sample_set_sizes, + sample_sets, 0, NULL, NULL, NULL, num_sites, row_sites, NULL, num_sites, + col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_RESULT_DIMS); + + ret = tsk_treeseq_r2(&ts, 1, sample_set_sizes, sample_sets, num_sites, row_sites, + NULL, num_sites, col_sites, NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, 0); sample_sets[1] = 0; ret = tsk_treeseq_r2(&ts, num_sample_sets, sample_set_sizes, sample_sets, num_sites, @@ -3478,6 +3678,27 @@ test_two_locus_stat_input_errors(void) positions, 10, NULL, positions, TSK_STAT_NODE, result); CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_UNSUPPORTED_STAT_MODE); + num_sample_sets = 2; + num_index_tuples = 0; + ret = tsk_treeseq_r2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, row_sites, NULL, num_sites, col_sites, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_INDEX_TUPLES); + + num_sample_sets = 1; + num_index_tuples = 1; + ret = tsk_treeseq_D2_ij(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, row_sites, NULL, num_sites, col_sites, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_INSUFFICIENT_SAMPLE_SETS); + + num_sample_sets = 2; + index_tuples[0] = 2; + ret = tsk_treeseq_D2_ij_unbiased(&ts, num_sample_sets, sample_set_sizes, sample_sets, + num_index_tuples, index_tuples, num_sites, row_sites, NULL, num_sites, col_sites, + NULL, 0, result); + CU_ASSERT_EQUAL_FATAL(ret, TSK_ERR_BAD_SAMPLE_SET_INDEX); + tsk_treeseq_free(&ts); tsk_safe_free(row_sites); tsk_safe_free(col_sites); @@ -3817,7 +4038,7 @@ main(int argc, char **argv) { "test_two_site_uncorrelated_multiallelic", test_two_site_uncorrelated_multiallelic }, { "test_two_site_backmutation", test_two_site_backmutation }, - { "test_two_locus_site_all_stats", test_two_locus_site_all_stats }, + { "test_two_locus_site_all_stats", test_two_locus_branch_all_stats }, { "test_paper_ex_two_site_subset", test_paper_ex_two_site_subset }, { "test_two_locus_stat_input_errors", test_two_locus_stat_input_errors }, diff --git a/c/tskit/trees.c b/c/tskit/trees.c index 2b778f195b..dd286b4fcc 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2225,7 +2225,7 @@ get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state, } static int -norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights, +norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) { sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; @@ -2233,7 +2233,7 @@ norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights, double n; tsk_size_t k; - for (k = 0; k < state_dim; k++) { + for (k = 0; k < result_dim; k++) { weight_row = GET_2D_ROW(hap_weights, 3, k); n = (double) args.sample_set_sizes[k]; // TODO: what to do when n = 0 @@ -2243,12 +2243,38 @@ norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights, } static int -norm_total_weighted(tsk_size_t state_dim, const double *TSK_UNUSED(hap_weights), +norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights, + tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *weight_row; + double ni, nj, wAB_i, wAB_j; + tsk_id_t i, j; + tsk_size_t k; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + ni = (double) args.sample_set_sizes[i]; + nj = (double) args.sample_set_sizes[j]; + weight_row = GET_2D_ROW(hap_weights, 3, i); + wAB_i = weight_row[0]; + weight_row = GET_2D_ROW(hap_weights, 3, j); + wAB_j = weight_row[0]; + + result[k] = (wAB_i / ni / 2) + (wAB_j / nj / 2); + } + + return 0; +} + +static int +norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights), tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params)) { tsk_size_t k; - for (k = 0; k < state_dim; k++) { + for (k = 0; k < result_dim; k++) { result[k] = 1 / (double) (n_a * n_b); } return 0; @@ -2268,9 +2294,6 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n) } } -typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, tsk_size_t n_a, - tsk_size_t n_b, double *result, void *params); - static int compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, @@ -2290,14 +2313,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] tsk_size_t k, mut_a, mut_b; - tsk_size_t row_len = num_b_alleles * state_dim; + tsk_size_t result_row_len = num_b_alleles * result_dim; tsk_size_t w_A = 0, w_B = 0, w_AB = 0; uint8_t polarised_val = polarised ? 1 : 0; double *hap_weight_row; double *result_tmp_row; double *weights = tsk_malloc(3 * state_dim * sizeof(*weights)); - double *norm = tsk_malloc(state_dim * sizeof(*norm)); - double *result_tmp = tsk_malloc(row_len * num_a_alleles * sizeof(*result_tmp)); + double *norm = tsk_malloc(result_dim * sizeof(*norm)); + double *result_tmp + = tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp)); tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples)); tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples)); @@ -2327,7 +2351,7 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, } for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) { - result_tmp_row = GET_2D_ROW(result_tmp, row_len, mut_a); + result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a); for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) { tsk_bit_array_get_row(site_a_state, mut_a, &A_samples); tsk_bit_array_get_row(site_b_state, mut_b, &B_samples); @@ -2352,15 +2376,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, if (ret != 0) { goto out; } - ret = norm_f(state_dim, weights, num_a_alleles - polarised_val, + ret = norm_f(result_dim, weights, num_a_alleles - polarised_val, num_b_alleles - polarised_val, norm, f_params); if (ret != 0) { goto out; } - for (k = 0; k < state_dim; k++) { + for (k = 0; k < result_dim; k++) { result[k] += result_tmp_row[k] * norm[k]; } - result_tmp_row += state_dim; // Advance to the next column + result_tmp_row += result_dim; // Advance to the next column } } @@ -2538,8 +2562,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, get_site_row_col_indices( n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx); - // We rely on n_sites to allocate these arrays, they're initialized to NULL for safe - // deallocation if the previous allocation fails + // We rely on n_sites to allocate these arrays, which are initialized + // to NULL for safe deallocation if the previous allocation fails num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles)); site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets)); if (num_alleles == NULL || site_offsets == NULL) { @@ -3195,7 +3219,7 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di return ret; } -static int +int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f, @@ -3209,7 +3233,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl tsk_bit_array_t sample_sets_bits; bool stat_site = !!(options & TSK_STAT_SITE); bool stat_branch = !!(options & TSK_STAT_BRANCH); - // double default_windows[] = { 0, self->tables->sequence_length }; tsk_size_t state_dim = num_sample_sets; sample_count_stat_params_t f_params = { .sample_sets = sample_sets, .num_sample_sets = num_sample_sets, @@ -3232,17 +3255,15 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES); goto out; } - // TODO: impossible until we implement branch/windows - // if (result_dim < 1) { - // ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); - // goto out; - // } ret = tsk_treeseq_check_sample_sets( self, num_sample_sets, sample_set_sizes, sample_sets); if (ret != 0) { goto out; } - tsk_bug_assert(state_dim > 0); + if (result_dim < 1) { + ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS); + goto out; + } ret = sample_sets_to_bit_array( self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits); if (ret != 0) { @@ -4781,6 +4802,201 @@ tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, return ret; } +static int +D2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + double n; + tsk_size_t k; + tsk_id_t i, j; + double p_A, p_B, p_AB, p_Ab, p_aB, D_i, D_j; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + + n = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_i = p_AB - (p_A * p_B); + + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_j = p_AB - (p_A * p_B); + + result[k] = D_i * D_j; + } + + return 0; +} + +int +tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +out: + return ret; +} + +static int +D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + const double *state_row; + tsk_size_t k; + tsk_id_t i, j; + double n_i, n_j; + double w_AB_i, w_Ab_i, w_aB_i, w_ab_i; + double w_AB_j, w_Ab_j, w_aB_j, w_ab_j; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + if (i == j) { + // We require disjoint sample sets because we test equality here + n_i = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + w_AB_i = state_row[0]; + w_Ab_i = state_row[1]; + w_aB_i = state_row[2]; + w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i); + result[k] = (w_AB_i * (w_AB_i - 1) * w_ab_i * (w_ab_i - 1) + + w_Ab_i * (w_Ab_i - 1) * w_aB_i * (w_aB_i - 1) + - 2 * w_AB_i * w_Ab_i * w_aB_i * w_ab_i) + / n_i / (n_i - 1) / (n_i - 2) / (n_i - 3); + } + + else { + n_i = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + w_AB_i = state_row[0]; + w_Ab_i = state_row[1]; + w_aB_i = state_row[2]; + w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i); + + n_j = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + w_AB_j = state_row[0]; + w_Ab_j = state_row[1]; + w_aB_j = state_row[2]; + w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j); + + result[k] = (w_Ab_i * w_aB_i - w_AB_i * w_ab_i) + * (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) / n_i / (n_i - 1) / n_j + / (n_j - 1); + } + } + + return 0; +} + +int +tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func, + norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +out: + return ret; +} + +static int +r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state, + tsk_size_t result_dim, double *result, void *params) +{ + sample_count_stat_params_t args = *(sample_count_stat_params_t *) params; + double n; + const double *state_row; + tsk_size_t k; + tsk_id_t i, j; + double p_AB, p_Ab, p_aB, p_A, p_B, D_i, D_j, denom_i, denom_j; + + for (k = 0; k < result_dim; k++) { + i = args.set_indexes[2 * k]; + j = args.set_indexes[2 * k + 1]; + + n = (double) args.sample_set_sizes[i]; + state_row = GET_2D_ROW(state, 3, i); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_i = p_AB - (p_A * p_B); + denom_i = sqrt(p_A * p_B * (1 - p_A) * (1 - p_B)); + + n = (double) args.sample_set_sizes[j]; + state_row = GET_2D_ROW(state, 3, j); + p_AB = state_row[0] / n; + p_Ab = state_row[1] / n; + p_aB = state_row[2] / n; + p_A = p_AB + p_Ab; + p_B = p_AB + p_aB; + D_j = p_AB - (p_A * p_B); + denom_j = sqrt(p_A * p_B * (1 - p_A) * (1 - p_B)); + + result[k] = (D_i * D_j) / (denom_i * denom_j); + } + return 0; +} + +int +tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result) +{ + int ret = 0; + ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples); + if (ret != 0) { + goto out; + } + ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes, + sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func, + norm_hap_weighted_ij, num_rows, row_sites, row_positions, num_cols, col_sites, + col_positions, options, result); +out: + return ret; +} + /*********************************** * Three way stats ***********************************/ diff --git a/c/tskit/trees.h b/c/tskit/trees.h index ac4100f7b0..040e88f86a 100644 --- a/c/tskit/trees.h +++ b/c/tskit/trees.h @@ -981,15 +981,17 @@ typedef int general_stat_func_t(tsk_size_t state_dim, const double *state, int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const double *W, tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -// TODO: expose this externally? -/* int tsk_treeseq_two_locus_general_stat(const tsk_treeseq_t *self, */ -/* tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, */ -/* const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, - */ -/* general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t num_left_windows, */ -/* const double *left_windows, tsk_size_t num_right_windows, */ -/* const double *right_windows, tsk_flags_t options, tsk_size_t num_result, */ -/* double *result); */ + +typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a, + tsk_size_t n_b, double *result, void *params); + +int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t result_dim, const tsk_id_t *set_indexes, + general_stat_func_t *f, norm_func_t *norm_f, tsk_size_t out_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t out_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); /* One way weighted stats */ @@ -1063,24 +1065,6 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self, const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes, tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_Y2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, - const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, - tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, - const double *windows, tsk_flags_t options, double *result); -int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, - tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, - const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, - const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, - tsk_flags_t options, double *result); - typedef int two_locus_count_stat_method(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites, @@ -1138,6 +1122,51 @@ int tsk_treeseq_pi2_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_se const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, double *result); +typedef int k_way_two_locus_count_stat_method(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_rows, const tsk_id_t *row_sites, + const double *row_positions, tsk_size_t num_cols, const tsk_id_t *col_sites, + const double *col_positions, tsk_flags_t options, double *result); + +/* Two way sample set stats */ + +int tsk_treeseq_divergence(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_Y2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_windows, + const double *windows, tsk_flags_t options, double *result); +int tsk_treeseq_genetic_relatedness(const tsk_treeseq_t *self, + tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, + const tsk_id_t *sample_sets, tsk_size_t num_index_tuples, + const tsk_id_t *index_tuples, tsk_size_t num_windows, const double *windows, + tsk_flags_t options, double *result); +int tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); +int tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, + const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, + tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows, + const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols, + const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options, + double *result); + /* Three way sample set stats */ int tsk_treeseq_Y3(const tsk_treeseq_t *self, tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets, diff --git a/python/_tskitmodule.c b/python/_tskitmodule.c index cab445e5d0..611d4c0141 100644 --- a/python/_tskitmodule.c +++ b/python/_tskitmodule.c @@ -10659,6 +10659,138 @@ TreeSequence_Dz_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kw return TreeSequence_ld_matrix(self, args, kwds, tsk_treeseq_Dz_unbiased); } +static PyObject * +TreeSequence_k_way_ld_matrix(TreeSequence *self, PyObject *args, PyObject *kwds, + npy_intp tuple_size, k_way_two_locus_count_stat_method *method) +{ + PyObject *ret = NULL; + static char *kwlist[] = { "sample_set_sizes", "sample_sets", "indexes", "row_sites", + "col_sites", "row_positions", "column_positions", "mode", NULL }; + + PyObject *row_sites = NULL, *col_sites = NULL, *row_positions = NULL, + *col_positions = NULL, *sample_set_sizes = NULL, *sample_sets = NULL, + *indexes = NULL; + PyArrayObject *row_sites_array = NULL, *col_sites_array = NULL, + *row_positions_array = NULL, *col_positions_array = NULL, + *sample_sets_array = NULL, *sample_set_sizes_array = NULL, + *indexes_array = NULL, *result_matrix = NULL; + tsk_id_t *row_sites_parsed = NULL, *col_sites_parsed = NULL; + double *row_positions_parsed = NULL, *col_positions_parsed = NULL; + npy_intp *shape, result_dim[3] = { 0, 0, 0 }; + char *mode = NULL; + tsk_size_t num_sample_sets, num_set_index_tuples; + tsk_flags_t options = 0; + int err; + + if (TreeSequence_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO|OOOOs", kwlist, &sample_set_sizes, + &sample_sets, &indexes, &row_sites, &col_sites, &row_positions, + &col_positions, &mode)) { + goto out; + } + if (parse_stats_mode(mode, &options) != 0) { + goto out; + } + if (parse_sample_sets(sample_set_sizes, &sample_set_sizes_array, sample_sets, + &sample_sets_array, &num_sample_sets) + != 0) { + goto out; + } + + if (options & TSK_STAT_SITE) { + if (row_positions != Py_None || col_positions != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify positions in site mode"); + goto out; + } + row_sites_array = parse_sites(self, row_sites, &(result_dim[0])); + col_sites_array = parse_sites(self, col_sites, &(result_dim[1])); + if (row_sites_array == NULL || col_sites_array == NULL) { + goto out; + } + row_sites_parsed = PyArray_DATA(row_sites_array); + col_sites_parsed = PyArray_DATA(col_sites_array); + } else if (options & TSK_STAT_BRANCH) { + if (row_sites != Py_None || col_sites != Py_None) { + PyErr_SetString(PyExc_ValueError, "Cannot specify sites in branch mode"); + goto out; + } + row_positions_array = parse_positions(self, row_positions, &(result_dim[0])); + col_positions_array = parse_positions(self, col_positions, &(result_dim[1])); + if (col_positions_array == NULL || row_positions_array == NULL) { + goto out; + } + row_positions_parsed = PyArray_DATA(row_positions_array); + col_positions_parsed = PyArray_DATA(col_positions_array); + } + + indexes_array = (PyArrayObject *) PyArray_FROMANY( + indexes, NPY_INT32, 2, 2, NPY_ARRAY_IN_ARRAY); + if (indexes_array == NULL) { + goto out; + } + shape = PyArray_DIMS(indexes_array); + if (shape[0] < 1 || shape[1] != tuple_size) { + PyErr_Format( + PyExc_ValueError, "indexes must be a k x %d array.", (int) tuple_size); + goto out; + } + num_set_index_tuples = shape[0]; + + result_dim[2] = num_set_index_tuples; + result_matrix = (PyArrayObject *) PyArray_ZEROS(3, result_dim, NPY_FLOAT64, 0); + if (result_matrix == NULL) { + PyErr_NoMemory(); + goto out; + } + + // clang-format off + Py_BEGIN_ALLOW_THREADS + err = method(self->tree_sequence, num_sample_sets, + PyArray_DATA(sample_set_sizes_array), PyArray_DATA(sample_sets_array), + num_set_index_tuples, PyArray_DATA(indexes_array), result_dim[0], + row_sites_parsed, row_positions_parsed, result_dim[1], col_sites_parsed, + col_positions_parsed, options, PyArray_DATA(result_matrix)); + Py_END_ALLOW_THREADS + // clang-format on + if (err != 0) + { + handle_library_error(err); + goto out; + } + ret = (PyObject *) result_matrix; + result_matrix = NULL; +out: + Py_XDECREF(row_sites_array); + Py_XDECREF(col_sites_array); + Py_XDECREF(row_positions_array); + Py_XDECREF(col_positions_array); + Py_XDECREF(sample_sets_array); + Py_XDECREF(sample_set_sizes_array); + Py_XDECREF(indexes_array); + Py_XDECREF(result_matrix); + return ret; +} + +static PyObject * +TreeSequence_D2_ij_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_D2_ij); +} + +static PyObject * +TreeSequence_D2_ij_unbiased_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_D2_ij_unbiased); +} + +static PyObject * +TreeSequence_r2_ij_matrix(TreeSequence *self, PyObject *args, PyObject *kwds) +{ + return TreeSequence_k_way_ld_matrix(self, args, kwds, 2, tsk_treeseq_r2_ij); +} + static PyObject * TreeSequence_get_num_mutations(TreeSequence *self) { @@ -11765,6 +11897,18 @@ static PyMethodDef TreeSequence_methods[] = { .ml_meth = (PyCFunction) TreeSequence_pi2_unbiased_matrix, .ml_flags = METH_VARARGS | METH_KEYWORDS, .ml_doc = "Computes the unbiased pi2 matrix." }, + { .ml_name = "D2_ij_matrix", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way D^2 matrix." }, + { .ml_name = "D2_ij_unbiased_matrix", + .ml_meth = (PyCFunction) TreeSequence_D2_ij_unbiased_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way unbiased D^2 matrix." }, + { .ml_name = "r2_ij_matrix", + .ml_meth = (PyCFunction) TreeSequence_r2_ij_matrix, + .ml_flags = METH_VARARGS | METH_KEYWORDS, + .ml_doc = "Computes the two-way r^2 matrix." }, { NULL } /* Sentinel */ }; diff --git a/python/tests/test_ld_matrix.py b/python/tests/test_ld_matrix.py index 6c88c7bd67..fc28e07ef9 100644 --- a/python/tests/test_ld_matrix.py +++ b/python/tests/test_ld_matrix.py @@ -244,6 +244,41 @@ def norm_hap_weighted( result[k] = hap_weights[0, k] / n +def norm_hap_weighted_ij( + result_dim: int, + hap_weights: np.ndarray, + n_a: int, + n_b: int, + result: np.ndarray, + params: Dict[str, Any], +) -> None: + """ + Create a vector of normalizing coefficients, length of the number of + index tuples. Each allele's statistic will be weighted by the average + of the proportion of AB haplotypes in each population present in the + index tuple. + + :param result_dim: Number of dimensions in output. Dependent on arity of stat. + :param hap_weights: Proportion of each two-locus haplotype. + :param n_a: Number of alleles at the A locus. + :param n_b: Number of alleles at the B locus. + :param result: Result vector to store the normalizing coefficients in. + :param params: Params of summary function. + """ + del n_a, n_b # handle unused params + sample_set_sizes = params["sample_set_sizes"] + set_indexes = params["set_indexes"] + + for k in range(result_dim): + i = set_indexes[k][0] + j = set_indexes[k][1] + ni = sample_set_sizes[i] + nj = sample_set_sizes[j] + wAB_i = hap_weights[0, i] + wAB_j = hap_weights[0, j] + result[k] = (wAB_i / ni / 2) + (wAB_j / nj / 2) + + def norm_total_weighted( result_dim: int, hap_weights: np.ndarray, @@ -523,7 +558,6 @@ def compute_general_two_site_stat_result( result_tmp = np.zeros(result_dim, np.float64) polarised_val = 1 if polarised else 0 - for mut_a in range(polarised_val, num_row_alleles): a = int(mut_a + row_site_offset) for mut_b in range(polarised_val, num_col_alleles): @@ -1001,18 +1035,18 @@ def r2_ij_summary_func( i = set_indexes[k][0] j = set_indexes[k][1] n = sample_set_sizes[i] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n + p_AB = state[0, i] / n + p_Ab = state[1, i] / n + p_aB = state[2, i] / n p_A = p_AB + p_Ab p_B = p_AB + p_aB D_i = p_AB - (p_A * p_B) denom_i = np.sqrt(p_A * p_B * (1 - p_A) * (1 - p_B)) n = sample_set_sizes[j] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n + p_AB = state[0, j] / n + p_Ab = state[1, j] / n + p_aB = state[2, j] / n p_A = p_AB + p_Ab p_B = p_AB + p_aB D_j = p_AB - (p_A * p_B) @@ -1249,17 +1283,17 @@ def D2_ij_summary_func( j = set_indexes[k][1] n = sample_set_sizes[i] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n + p_AB = state[0, i] / n + p_Ab = state[1, i] / n + p_aB = state[2, i] / n p_A = p_AB + p_Ab p_B = p_AB + p_aB D_i = p_AB - (p_A * p_B) n = sample_set_sizes[j] - p_AB = state[0, k] / n - p_Ab = state[1, k] / n - p_aB = state[2, k] / n + p_AB = state[0, j] / n + p_Ab = state[1, j] / n + p_aB = state[2, j] / n p_A = p_AB + p_Ab p_B = p_AB + p_aB D_j = p_AB - (p_A * p_B) @@ -1287,17 +1321,18 @@ def D2_ij_unbiased_summary_func( w_Ab = state[1, i] w_aB = state[2, i] w_ab = n - (w_AB + w_Ab + w_aB) - result[k] = ( - ( - w_AB * (w_AB - 1) * w_ab * (w_ab - 1) - + w_Ab * (w_Ab - 1) * w_aB * (w_aB - 1) - - 2 * w_AB * w_Ab * w_aB * w_ab + with suppress_overflow_div0_warning(): + result[k] = ( + ( + w_AB * (w_AB - 1) * w_ab * (w_ab - 1) + + w_Ab * (w_Ab - 1) * w_aB * (w_aB - 1) + - 2 * w_AB * w_Ab * w_aB * w_ab + ) + / n + / (n - 1) + / (n - 2) + / (n - 3) ) - / n - / (n - 1) - / (n - 2) - / (n - 3) - ) else: n_i = sample_set_sizes[i] w_AB_i = state[0, i] @@ -1311,14 +1346,15 @@ def D2_ij_unbiased_summary_func( w_aB_j = state[2, j] w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j) - result[k] = ( - (w_Ab_i * w_aB_i - w_AB_i * w_ab_i) - * (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) - / n_i - / (n_i - 1) - / n_j - / (n_j - 1) - ) + with suppress_overflow_div0_warning(): + result[k] = ( + (w_Ab_i * w_aB_i - w_AB_i * w_ab_i) + * (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) + / n_i + / (n_i - 1) + / n_j + / (n_j - 1) + ) SUMMARY_FUNCS = { @@ -1351,7 +1387,7 @@ def D2_ij_unbiased_summary_func( D2_unbiased_summary_func: norm_total_weighted, Dz_unbiased_summary_func: norm_total_weighted, pi2_unbiased_summary_func: norm_total_weighted, - r2_ij_summary_func: norm_hap_weighted, + r2_ij_summary_func: norm_hap_weighted_ij, D2_ij_summary_func: norm_total_weighted, D2_ij_unbiased_summary_func: norm_total_weighted, } @@ -1373,9 +1409,11 @@ def D2_ij_unbiased_summary_func( } -def check_set_indexes(num_sets: int, num_set_indexes: int, set_indexes: np.ndarray): - for i in range(len(set_indexes)): - for j in range(num_set_indexes): +def check_set_indexes( + num_sets: int, num_set_indexes: int, tuple_size: int, set_indexes: np.ndarray +): + for i in range(num_set_indexes): + for j in range(tuple_size): if set_indexes[i, j] < 0 or set_indexes[i, j] >= num_sets: raise ValueError(f"Bad sample set index: {set_indexes[i, j]}") @@ -1393,7 +1431,7 @@ def check_sample_stat_inputs( ) if num_index_tuples < 1: raise ValueError(f"Insufficient number of index tuples: {num_index_tuples}") - check_set_indexes(num_sample_sets, num_index_tuples, index_tuples) + check_set_indexes(num_sample_sets, num_index_tuples, tuple_size, index_tuples) def ld_matrix( @@ -1761,6 +1799,17 @@ def test_input_validation(): with pytest.raises(ValueError, match="must be a length 1 or 2 list"): ts.ld_matrix(positions=[], mode="branch") + with pytest.raises( + ValueError, match="Sample sets must contain at least one element" + ): + ts.ld_matrix(sample_sets=[[1, 2, 3], []], indexes=[]) + with pytest.raises( + ValueError, match="Indexes must be convertable to a 2D numpy array" + ): + ts.ld_matrix( + sample_sets=[ts.samples(), ts.samples()], indexes=[[1, 2, 3], [2, 3, 4]] + ) + @dataclass class TreeState: @@ -2152,6 +2201,49 @@ def test_branch_ld_matrix_2pop_sample_sets_unbiased(ts, sample_set, stat): ) +def gen_dims_test_cases(ts, mode): + ss = ts.samples() + dim = ts.num_sites if mode == "site" else ts.num_trees + base = (dim, dim) + return [ + {"name": f"{mode}_default", "ld_params": {"mode": mode}, "shape": base}, + { + "name": f"{mode}_dim_drop", + "ld_params": {"mode": mode, "sample_sets": ss}, + "shape": base, + }, + { + "name": f"{mode}_no_dim_drop", + "ld_params": {"mode": mode, "sample_sets": [ss]}, + "shape": (1, *base), + }, + { + "name": f"{mode}_two_sample_sets", + "ld_params": {"mode": mode, "sample_sets": [ss, ss]}, + "shape": (2, *base), + }, + { + "name": f"{mode}_two_way_dim_drop", + "ld_params": {"mode": mode, "sample_sets": [ss, ss], "indexes": (0, 1)}, + "shape": base, + }, + { + "name": f"{mode}_two_way_no_dim_drop", + "ld_params": {"mode": mode, "sample_sets": [ss, ss], "indexes": [(0, 1)]}, + "shape": (1, *base), + }, + { + "name": f"{mode}_two_way_three_set_indexes", + "ld_params": { + "mode": mode, + "sample_sets": [ss, ss], + "indexes": [(0, 0), (0, 1), (1, 1)], + }, + "shape": (3, *base), + }, + ] + + def get_test_dims_test_cases(): test_cases = { "empty_tree", @@ -2161,16 +2253,48 @@ def get_test_dims_test_cases(): "internal_nodes_samples", "mixed_internal_leaf_samples", } - return [t for t in get_example_tree_sequences() if t.id in test_cases] + for ts_case in [t for t in get_example_tree_sequences() if t.id in test_cases]: + ts = ts_case.values[0] + for dim_case in gen_dims_test_cases(ts, "site"): + name = "_".join([dim_case["name"], ts_case.id]) + yield pytest.param(ts, dim_case["ld_params"], dim_case["shape"], id=name) + for dim_case in gen_dims_test_cases(ts, "branch"): + name = "_".join([dim_case["name"], ts_case.id]) + yield pytest.param(ts, dim_case["ld_params"], dim_case["shape"], id=name) -@pytest.mark.parametrize("ts", get_test_dims_test_cases()) -def test_dims(ts): - ss = ts.samples() - assert ld_matrix(ts).ndim == 2 - assert ld_matrix(ts, sample_sets=ss).ndim == 2 - assert ld_matrix(ts, sample_sets=[ss]).ndim == 3 - assert ld_matrix(ts, sample_sets=[ss, ss]).ndim == 3 - assert ld_matrix(ts, sample_sets=[ss, ss], indexes=(0, 0)).ndim == 2 - assert ld_matrix(ts, sample_sets=[ss, ss], indexes=[(0, 0)]).ndim == 3 - assert ld_matrix(ts, sample_sets=[ss, ss], indexes=[(0, 0), (0, 1)]).ndim == 3 +@pytest.mark.parametrize("ts,params,shape", get_test_dims_test_cases()) +def test_dims(ts, params, shape): + assert ts.ld_matrix(**params).shape == ld_matrix(ts, **params).shape == shape + + +@pytest.mark.parametrize("ts,sample_sets", get_test_branch_2pop_test_cases()) +@pytest.mark.parametrize("stat", sorted(TWO_WAY_SUMMARY_FUNCS.keys())) +def test_two_way_branch_ld_matrix(ts, sample_sets, stat): + np.testing.assert_array_almost_equal( + ld_matrix(ts, sample_sets=sample_sets, indexes=[(0, 0), (0, 1), (1, 1)]), + ts.ld_matrix(sample_sets=sample_sets, indexes=[(0, 0), (0, 1), (1, 1)]), + ) + + +@pytest.mark.parametrize( + "ts", + [ + ts + for ts in get_example_tree_sequences() + if ts.id not in {"no_samples", "empty_ts"} + ], +) +@pytest.mark.parametrize( + "stat", + sorted(TWO_WAY_SUMMARY_FUNCS.keys()), +) +def test_two_way_site_ld_matrix(ts, stat): + np.testing.assert_array_almost_equal( + ld_matrix(ts, stat=stat), ts.ld_matrix(stat=stat) + ) + ss = [ts.samples()] * 3 + np.testing.assert_array_almost_equal( + ld_matrix(ts, stat=stat, sample_sets=ss, indexes=[(0, 0), (0, 1), (1, 1)]), + ts.ld_matrix(stat=stat, sample_sets=ss, indexes=[(0, 0), (0, 1), (1, 1)]), + ) diff --git a/python/tests/test_lowlevel.py b/python/tests/test_lowlevel.py index 9cc8206cb1..d6416c22b8 100644 --- a/python/tests/test_lowlevel.py +++ b/python/tests/test_lowlevel.py @@ -1711,6 +1711,214 @@ def test_ld_matrix(self, stat_method_name): with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): stat_method(ss_sizes, ss, col_sites, row_sites, None, None, "node") + @pytest.mark.parametrize( + "stat_method_name", + [ + "D2_ij_matrix", + "r2_ij_matrix", + "D2_ij_unbiased_matrix", + ], + ) + def test_ld_matrix_multipop(self, stat_method_name): + ts = self.get_example_tree_sequence(10) + stat_method = getattr(ts, stat_method_name) + + num_samples = len(ts.get_samples()) + ss = np.hstack([ts.get_samples(), ts.get_samples()]) # sample sets + ss_sizes = np.array([num_samples, num_samples], dtype=np.uint32) + indexes = [(0, 0), (0, 1)] + row_sites = np.arange(ts.get_num_sites(), dtype=np.int32) + col_sites = row_sites + row_pos = ts.get_breakpoints()[:-1] + col_pos = row_pos + row_pos_list = list(map(float, ts.get_breakpoints()[:-1])) + col_pos_list = row_pos_list + row_sites_list = list(range(ts.get_num_sites())) + col_sites_list = row_sites_list + + # happy path + a = stat_method(ss_sizes, ss, indexes, row_sites, col_sites, None, None, "site") + assert a.shape == (10, 10, 2) + a = stat_method( + ss_sizes, ss, indexes, row_sites_list, col_sites_list, None, None, "site" + ) + assert a.shape == (10, 10, 2) + a = stat_method(ss_sizes, ss, indexes, None, None, None, None, "site") + assert a.shape == (10, 10, 2) + + a = stat_method(ss_sizes, ss, indexes, None, None, row_pos, col_pos, "branch") + assert a.shape == (2, 2, 2) + a = stat_method( + ss_sizes, ss, indexes, None, None, row_pos_list, col_pos_list, "branch" + ) + assert a.shape == (2, 2, 2) + a = stat_method(ss_sizes, ss, indexes, None, None, None, None, "branch") + assert a.shape == (2, 2, 2) + + # CPython API errors + with pytest.raises(ValueError, match="Sum of sample_set_sizes"): + bad_ss = np.array([], dtype=np.int32) + stat_method( + ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(TypeError, match="cast array data"): + bad_ss = np.array(ts.get_samples(), dtype=np.uint32) + stat_method( + ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(ValueError, match="Unrecognised stats mode"): + stat_method(ss_sizes, ss, indexes, row_sites, col_sites, None, None, "bla") + with pytest.raises(TypeError, match="at most"): + stat_method( + ss_sizes, ss, indexes, row_sites, col_sites, None, None, "site", "abc" + ) + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(ValueError, match="invalid literal"): + bad_sites = ["abadsite", 0, 3, 2] + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [None, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(TypeError): + bad_sites = [{}, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(TypeError, match="Cannot cast array data"): + bad_sites = np.array([0, 1, 2], dtype=np.uint32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0.1, 0.2, 2.0] + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(TypeError): + bad_pos = [{}, 0.1, 0.2, 2.0] + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(ValueError): + bad_pos = ["abadpos", 0, 3, 2] + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises(TypeError): + bad_pos = [{}, 0, 3, 2] + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises(ValueError, match="Cannot specify sites in branch mode"): + stat_method( + ss_sizes, ss, indexes, row_sites, col_sites, None, None, "branch" + ) + with pytest.raises(ValueError, match="Cannot specify positions in site mode"): + stat_method(ss_sizes, ss, indexes, None, None, row_pos, col_pos, "site") + # C API errors + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_SITES"): + bad_sites = np.array([1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_SITES"): + bad_sites = np.array([1, 1, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, bad_sites, col_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_SITE_OUT_OF_BOUNDS"): + bad_sites = np.array([-1, 0, 2], dtype=np.int32) + stat_method(ss_sizes, ss, indexes, row_sites, bad_sites, None, None, "site") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_STAT_UNSORTED_POSITIONS"): + bad_pos = np.array([0.7, 0, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS" + ): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises( + tskit.LibraryError, match="TSK_ERR_STAT_DUPLICATE_POSITIONS" + ): + bad_pos = np.array([0.7, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, bad_pos, col_pos, "branch") + with pytest.raises(tskit.LibraryError, match="TSK_ERR_POSITION_OUT_OF_BOUNDS"): + bad_pos = np.array([-0.1, 0.7, 0.8], dtype=np.float64) + stat_method(ss_sizes, ss, indexes, None, None, row_pos, bad_pos, "branch") + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([0], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises( + _tskit.LibraryError, match="TSK_ERR_INSUFFICIENT_SAMPLE_SETS" + ): + bad_ss = np.array([], dtype=np.int32) + bad_ss_sizes = np.array([0], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + bad_ss = np.array([1000, 1000], dtype=np.int32) + bad_ss_sizes = np.array([1, 1], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_NODE_OUT_OF_BOUNDS"): + bad_ss = np.array([1000, 1000], dtype=np.int32) + bad_ss_sizes = np.array([1, 1], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): + bad_ss = np.array([1, 1, 2, 3], dtype=np.int32) + bad_ss_sizes = np.array([2, 2], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_DUPLICATE_SAMPLE"): + bad_ss = np.array([1, 1, 2, 3], dtype=np.int32) + bad_ss_sizes = np.array([2, 2], dtype=np.uint32) + stat_method( + bad_ss_sizes, bad_ss, indexes, None, None, row_pos, col_pos, "branch" + ) + with pytest.raises(ValueError, match="indexes must be a"): + bad_indexes = np.array([[0, 0, 1, 1], [0, 0, 1, 1]], dtype=np.int32) + stat_method( + ss_sizes, ss, bad_indexes, row_sites, col_sites, None, None, "site" + ) + with pytest.raises(_tskit.LibraryError, match="TSK_ERR_UNSUPPORTED_STAT_MODE"): + stat_method(ss_sizes, ss, indexes, col_sites, row_sites, None, None, "node") + def test_kc_distance_errors(self): ts1 = self.get_example_tree_sequence(10) with pytest.raises(TypeError): diff --git a/python/tskit/trees.py b/python/tskit/trees.py index e740181f81..77ccbf4838 100644 --- a/python/tskit/trees.py +++ b/python/tskit/trees.py @@ -8095,6 +8095,52 @@ def __two_locus_sample_set_stat( return result + def __k_way_two_locus_sample_set_stat( + self, + ll_method, + k, + sample_sets, + indexes=None, + sites=None, + positions=None, + mode=None, + ): + sample_set_sizes = np.array( + [len(sample_set) for sample_set in sample_sets], dtype=np.uint32 + ) + if np.any(sample_set_sizes == 0): + raise ValueError("Sample sets must contain at least one element") + flattened = util.safe_np_int_cast(np.hstack(sample_sets), np.int32) + row_sites, col_sites = self.parse_sites(sites) + row_positions, col_positions = self.parse_positions(positions) + drop_dimension = False + indexes = util.safe_np_int_cast(indexes, np.int32) + if len(indexes.shape) == 1: + indexes = indexes.reshape((1, indexes.shape[0])) + drop_dimension = True + if len(indexes.shape) != 2 or indexes.shape[1] != k: + raise ValueError( + "Indexes must be convertable to a 2D numpy array with {} " + "columns".format(k) + ) + result = ll_method( + sample_set_sizes, + flattened, + indexes, + row_sites, + col_sites, + row_positions, + col_positions, + mode, + ) + if drop_dimension: + result = result.reshape(result.shape[:2]) + else: + # Orient the data so that the first dimension is the sample set. + # With this orientation, we get one LD matrix per sample set. + result = result.swapaxes(0, 2).swapaxes(1, 2) + return result + def __k_way_sample_set_stat( self, ll_method, @@ -10627,9 +10673,15 @@ def impute_unknown_mutations_time( return mutations_time def ld_matrix( - self, sample_sets=None, sites=None, positions=None, mode="site", stat="r2" + self, + sample_sets=None, + sites=None, + positions=None, + mode="site", + stat="r2", + indexes=None, ): - stats = { + one_way_stats = { "D": self._ll_tree_sequence.D_matrix, "D2": self._ll_tree_sequence.D2_matrix, "r2": self._ll_tree_sequence.r2_matrix, @@ -10641,20 +10693,32 @@ def ld_matrix( "D2_unbiased": self._ll_tree_sequence.D2_unbiased_matrix, "pi2_unbiased": self._ll_tree_sequence.pi2_unbiased_matrix, } - + two_way_stats = { + "D2": self._ll_tree_sequence.D2_ij_matrix, + "D2_unbiased": self._ll_tree_sequence.D2_ij_unbiased_matrix, + "r2": self._ll_tree_sequence.r2_ij_matrix, + } + stats = one_way_stats if indexes is None else two_way_stats try: - two_locus_stat = stats[stat] + stat_func = stats[stat] except KeyError: raise ValueError( f"Unknown two-locus statistic '{stat}', we support: {list(stats.keys())}" ) + if indexes is not None: + return self.__k_way_two_locus_sample_set_stat( + stat_func, + 2, + sample_sets, + indexes=indexes, + sites=sites, + positions=positions, + mode=mode, + ) + return self.__two_locus_sample_set_stat( - two_locus_stat, - sample_sets, - sites=sites, - positions=positions, - mode=mode, + stat_func, sample_sets, sites=sites, positions=positions, mode=mode ) def sample_nodes_by_ploidy(self, ploidy):