Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
439 changes: 330 additions & 109 deletions c/tests/test_stats.c

Large diffs are not rendered by default.

264 changes: 240 additions & 24 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2225,15 +2225,15 @@ get_allele_samples(const tsk_site_t *site, const tsk_bit_array_t *state,
}

static int
norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights,
norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
double n;
tsk_size_t k;

for (k = 0; k < state_dim; k++) {
for (k = 0; k < result_dim; k++) {
weight_row = GET_2D_ROW(hap_weights, 3, k);
n = (double) args.sample_set_sizes[k];
// TODO: what to do when n = 0
Expand All @@ -2243,12 +2243,38 @@ norm_hap_weighted(tsk_size_t state_dim, const double *hap_weights,
}

static int
norm_total_weighted(tsk_size_t state_dim, const double *TSK_UNUSED(hap_weights),
norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
double ni, nj, wAB_i, wAB_j;
tsk_id_t i, j;
tsk_size_t k;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];
ni = (double) args.sample_set_sizes[i];
nj = (double) args.sample_set_sizes[j];
weight_row = GET_2D_ROW(hap_weights, 3, i);
wAB_i = weight_row[0];
weight_row = GET_2D_ROW(hap_weights, 3, j);
wAB_j = weight_row[0];

result[k] = (wAB_i / ni / 2) + (wAB_j / nj / 2);
}

return 0;
}

static int
norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights),
tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params))
{
tsk_size_t k;

for (k = 0; k < state_dim; k++) {
for (k = 0; k < result_dim; k++) {
result[k] = 1 / (double) (n_a * n_b);
}
return 0;
Expand All @@ -2268,9 +2294,6 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n)
}
}

typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights, tsk_size_t n_a,
tsk_size_t n_b, double *result, void *params);

static int
compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles,
Expand All @@ -2290,14 +2313,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
// a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
// a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
tsk_size_t k, mut_a, mut_b;
tsk_size_t row_len = num_b_alleles * state_dim;
tsk_size_t result_row_len = num_b_alleles * result_dim;
tsk_size_t w_A = 0, w_B = 0, w_AB = 0;
uint8_t polarised_val = polarised ? 1 : 0;
double *hap_weight_row;
double *result_tmp_row;
double *weights = tsk_malloc(3 * state_dim * sizeof(*weights));
double *norm = tsk_malloc(state_dim * sizeof(*norm));
double *result_tmp = tsk_malloc(row_len * num_a_alleles * sizeof(*result_tmp));
double *norm = tsk_malloc(result_dim * sizeof(*norm));
double *result_tmp
= tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp));

tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples));
tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples));
Expand Down Expand Up @@ -2327,7 +2351,7 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
}

for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) {
result_tmp_row = GET_2D_ROW(result_tmp, row_len, mut_a);
result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a);
for (mut_b = polarised_val; mut_b < num_b_alleles; mut_b++) {
tsk_bit_array_get_row(site_a_state, mut_a, &A_samples);
tsk_bit_array_get_row(site_b_state, mut_b, &B_samples);
Expand All @@ -2352,15 +2376,15 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
if (ret != 0) {
goto out;
}
ret = norm_f(state_dim, weights, num_a_alleles - polarised_val,
ret = norm_f(result_dim, weights, num_a_alleles - polarised_val,
num_b_alleles - polarised_val, norm, f_params);
if (ret != 0) {
goto out;
}
for (k = 0; k < state_dim; k++) {
for (k = 0; k < result_dim; k++) {
result[k] += result_tmp_row[k] * norm[k];
}
result_tmp_row += state_dim; // Advance to the next column
result_tmp_row += result_dim; // Advance to the next column
}
}

Expand Down Expand Up @@ -2538,8 +2562,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
get_site_row_col_indices(
n_rows, row_sites, n_cols, col_sites, sites, &n_sites, row_idx, col_idx);

// We rely on n_sites to allocate these arrays, they're initialized to NULL for safe
// deallocation if the previous allocation fails
// We rely on n_sites to allocate these arrays, which are initialized
// to NULL for safe deallocation if the previous allocation fails
num_alleles = tsk_malloc(n_sites * sizeof(*num_alleles));
site_offsets = tsk_malloc(n_sites * sizeof(*site_offsets));
if (num_alleles == NULL || site_offsets == NULL) {
Expand Down Expand Up @@ -3195,7 +3219,7 @@ tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_di
return ret;
}

static int
int
tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f,
Expand All @@ -3209,7 +3233,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
tsk_bit_array_t sample_sets_bits;
bool stat_site = !!(options & TSK_STAT_SITE);
bool stat_branch = !!(options & TSK_STAT_BRANCH);
// double default_windows[] = { 0, self->tables->sequence_length };
tsk_size_t state_dim = num_sample_sets;
sample_count_stat_params_t f_params = { .sample_sets = sample_sets,
.num_sample_sets = num_sample_sets,
Expand All @@ -3232,17 +3255,15 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
ret = tsk_trace_error(TSK_ERR_MULTIPLE_STAT_MODES);
goto out;
}
// TODO: impossible until we implement branch/windows
// if (result_dim < 1) {
// ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS);
// goto out;
// }
ret = tsk_treeseq_check_sample_sets(
self, num_sample_sets, sample_set_sizes, sample_sets);
if (ret != 0) {
goto out;
}
tsk_bug_assert(state_dim > 0);
if (result_dim < 1) {
ret = tsk_trace_error(TSK_ERR_BAD_RESULT_DIMS);
goto out;
}
ret = sample_sets_to_bit_array(
self, sample_set_sizes, sample_sets, num_sample_sets, &sample_sets_bits);
if (ret != 0) {
Expand Down Expand Up @@ -4781,6 +4802,201 @@ tsk_treeseq_f2(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
return ret;
}

static int
D2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *state_row;
double n;
tsk_size_t k;
tsk_id_t i, j;
double p_A, p_B, p_AB, p_Ab, p_aB, D_i, D_j;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];

n = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
p_AB = state_row[0] / n;
p_Ab = state_row[1] / n;
p_aB = state_row[2] / n;
p_A = p_AB + p_Ab;
p_B = p_AB + p_aB;
D_i = p_AB - (p_A * p_B);

n = (double) args.sample_set_sizes[j];
state_row = GET_2D_ROW(state, 3, j);
p_AB = state_row[0] / n;
p_Ab = state_row[1] / n;
p_aB = state_row[2] / n;
p_A = p_AB + p_Ab;
p_B = p_AB + p_aB;
D_j = p_AB - (p_A * p_B);

result[k] = D_i * D_j;
}

return 0;
}

int
tsk_treeseq_D2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
double *result)
{
int ret = 0;
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_index_tuples, index_tuples, D2_ij_summary_func,
norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites,
col_positions, options, result);
out:
return ret;
}

static int
D2_ij_unbiased_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *state_row;
tsk_size_t k;
tsk_id_t i, j;
double n_i, n_j;
double w_AB_i, w_Ab_i, w_aB_i, w_ab_i;
double w_AB_j, w_Ab_j, w_aB_j, w_ab_j;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];
if (i == j) {
// We require disjoint sample sets because we test equality here
n_i = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
w_AB_i = state_row[0];
w_Ab_i = state_row[1];
w_aB_i = state_row[2];
w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i);
result[k] = (w_AB_i * (w_AB_i - 1) * w_ab_i * (w_ab_i - 1)
+ w_Ab_i * (w_Ab_i - 1) * w_aB_i * (w_aB_i - 1)
- 2 * w_AB_i * w_Ab_i * w_aB_i * w_ab_i)
/ n_i / (n_i - 1) / (n_i - 2) / (n_i - 3);
}

else {
n_i = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
w_AB_i = state_row[0];
w_Ab_i = state_row[1];
w_aB_i = state_row[2];
w_ab_i = n_i - (w_AB_i + w_Ab_i + w_aB_i);

n_j = (double) args.sample_set_sizes[j];
state_row = GET_2D_ROW(state, 3, j);
w_AB_j = state_row[0];
w_Ab_j = state_row[1];
w_aB_j = state_row[2];
w_ab_j = n_j - (w_AB_j + w_Ab_j + w_aB_j);

result[k] = (w_Ab_i * w_aB_i - w_AB_i * w_ab_i)
* (w_Ab_j * w_aB_j - w_AB_j * w_ab_j) / n_i / (n_i - 1) / n_j
/ (n_j - 1);
}
}

return 0;
}

int
tsk_treeseq_D2_ij_unbiased(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
double *result)
{
int ret = 0;
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_index_tuples, index_tuples, D2_ij_unbiased_summary_func,
norm_total_weighted, num_rows, row_sites, row_positions, num_cols, col_sites,
col_positions, options, result);
out:
return ret;
}

static int
r2_ij_summary_func(tsk_size_t TSK_UNUSED(state_dim), const double *state,
tsk_size_t result_dim, double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
double n;
const double *state_row;
tsk_size_t k;
tsk_id_t i, j;
double p_AB, p_Ab, p_aB, p_A, p_B, D_i, D_j, denom_i, denom_j;

for (k = 0; k < result_dim; k++) {
i = args.set_indexes[2 * k];
j = args.set_indexes[2 * k + 1];

n = (double) args.sample_set_sizes[i];
state_row = GET_2D_ROW(state, 3, i);
p_AB = state_row[0] / n;
p_Ab = state_row[1] / n;
p_aB = state_row[2] / n;
p_A = p_AB + p_Ab;
p_B = p_AB + p_aB;
D_i = p_AB - (p_A * p_B);
denom_i = sqrt(p_A * p_B * (1 - p_A) * (1 - p_B));

n = (double) args.sample_set_sizes[j];
state_row = GET_2D_ROW(state, 3, j);
p_AB = state_row[0] / n;
p_Ab = state_row[1] / n;
p_aB = state_row[2] / n;
p_A = p_AB + p_Ab;
p_B = p_AB + p_aB;
D_j = p_AB - (p_A * p_B);
denom_j = sqrt(p_A * p_B * (1 - p_A) * (1 - p_B));

result[k] = (D_i * D_j) / (denom_i * denom_j);
}
return 0;
}

int
tsk_treeseq_r2_ij(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t num_index_tuples, const tsk_id_t *index_tuples, tsk_size_t num_rows,
const tsk_id_t *row_sites, const double *row_positions, tsk_size_t num_cols,
const tsk_id_t *col_sites, const double *col_positions, tsk_flags_t options,
double *result)
{
int ret = 0;
ret = check_sample_stat_inputs(num_sample_sets, 2, num_index_tuples, index_tuples);
if (ret != 0) {
goto out;
}
ret = tsk_treeseq_two_locus_count_stat(self, num_sample_sets, sample_set_sizes,
sample_sets, num_index_tuples, index_tuples, r2_ij_summary_func,
norm_hap_weighted_ij, num_rows, row_sites, row_positions, num_cols, col_sites,
col_positions, options, result);
out:
return ret;
}

/***********************************
* Three way stats
***********************************/
Expand Down
Loading
Loading