Skip to content
Open
79 changes: 46 additions & 33 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2297,8 +2297,9 @@ get_allele_samples(const tsk_site_t *site, tsk_size_t site_offset,
}

static int
norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
norm_hap_weighted(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights,
tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b),
double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
Expand All @@ -2314,8 +2315,9 @@ norm_hap_weighted(tsk_size_t result_dim, const double *hap_weights,
}

static int
norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights,
tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b), double *result, void *params)
norm_hap_weighted_ij(tsk_size_t TSK_UNUSED(state_dim), const double *hap_weights,
tsk_size_t result_dim, tsk_size_t TSK_UNUSED(n_a), tsk_size_t TSK_UNUSED(n_b),
double *result, void *params)
{
sample_count_stat_params_t args = *(sample_count_stat_params_t *) params;
const double *weight_row;
Expand All @@ -2340,8 +2342,9 @@ norm_hap_weighted_ij(tsk_size_t result_dim, const double *hap_weights,
}

static int
norm_total_weighted(tsk_size_t result_dim, const double *TSK_UNUSED(hap_weights),
tsk_size_t n_a, tsk_size_t n_b, double *result, void *TSK_UNUSED(params))
norm_total_weighted(tsk_size_t TSK_UNUSED(state_dim),
const double *TSK_UNUSED(hap_weights), tsk_size_t result_dim, tsk_size_t n_a,
tsk_size_t n_b, double *result, void *TSK_UNUSED(params))
{
tsk_size_t k;
double norm = 1 / (double) (n_a * n_b);
Expand Down Expand Up @@ -2410,8 +2413,8 @@ static int
compute_general_normed_two_site_stat_result(const tsk_bitset_t *state,
const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off,
tsk_size_t num_a_alleles, tsk_size_t num_b_alleles, tsk_size_t state_dim,
tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params,
norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result)
tsk_size_t result_dim, general_stat_func_t *f, void *f_params, norm_func_t *norm_f,
bool polarised, two_locus_work_t *restrict work, double *result)
{
int ret = 0;
// Sample sets and b sites are rows, a sites are columns
Expand Down Expand Up @@ -2444,7 +2447,7 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state,
if (ret != 0) {
goto out;
}
ret = norm_f(result_dim, weights, num_a_alleles - is_polarised,
ret = norm_f(state_dim, weights, result_dim, num_a_alleles - is_polarised,
num_b_alleles - is_polarised, norm, f_params);
if (ret != 0) {
goto out;
Expand All @@ -2462,9 +2465,8 @@ compute_general_normed_two_site_stat_result(const tsk_bitset_t *state,
static int
compute_general_two_site_stat_result(const tsk_bitset_t *state,
const tsk_size_t *allele_counts, tsk_size_t a_off, tsk_size_t b_off,
tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f,
sample_count_stat_params_t *f_params, two_locus_work_t *restrict work,
double *result)
tsk_size_t state_dim, tsk_size_t result_dim, general_stat_func_t *f, void *f_params,
two_locus_work_t *restrict work, double *result)
{
int ret = 0;
tsk_size_t k;
Expand Down Expand Up @@ -2652,9 +2654,8 @@ static int
tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
sample_count_stat_params_t *f_params, norm_func_t *norm_f, tsk_size_t n_rows,
const tsk_id_t *row_sites, tsk_size_t n_cols, const tsk_id_t *col_sites,
tsk_flags_t options, double *result)
void *f_params, norm_func_t *norm_f, tsk_size_t n_rows, const tsk_id_t *row_sites,
tsk_size_t n_cols, const tsk_id_t *col_sites, tsk_flags_t options, double *result)
{
int ret = 0;
tsk_bitset_t allele_samples, allele_sample_sets;
Expand Down Expand Up @@ -3088,9 +3089,8 @@ advance_collect_edges(iter_state *s, tsk_id_t index)
static int
compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c,
const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim,
tsk_size_t result_dim, int sign, general_stat_func_t *f,
sample_count_stat_params_t *f_params, two_locus_work_t *restrict work,
double *result)
tsk_size_t result_dim, int sign, general_stat_func_t *f, void *f_params,
two_locus_work_t *restrict work, double *result)
{
int ret = 0;
double a_len, b_len;
Expand Down Expand Up @@ -3140,8 +3140,8 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c,

static int
compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state,
iter_state *r_state, general_stat_func_t *f, sample_count_stat_params_t *f_params,
tsk_size_t result_dim, tsk_size_t state_dim, double *result)
iter_state *r_state, general_stat_func_t *f, void *f_params, tsk_size_t result_dim,
tsk_size_t state_dim, double *result)
{
int ret = 0;
tsk_id_t e, c, ec, p, *updated_nodes = NULL;
Expand Down Expand Up @@ -3242,9 +3242,9 @@ static int
tsk_treeseq_two_branch_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
sample_count_stat_params_t *f_params, norm_func_t *TSK_UNUSED(norm_f),
tsk_size_t n_rows, const double *row_positions, tsk_size_t n_cols,
const double *col_positions, tsk_flags_t TSK_UNUSED(options), double *result)
void *f_params, norm_func_t *TSK_UNUSED(norm_f), tsk_size_t n_rows,
const double *row_positions, tsk_size_t n_cols, const double *col_positions,
tsk_flags_t TSK_UNUSED(options), double *result)
{
int ret = 0;
int r, c;
Expand Down Expand Up @@ -3384,10 +3384,10 @@ check_sample_set_dups(tsk_size_t num_sample_sets, const tsk_size_t *sample_set_s
}

int
tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f,
norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites,
const double *col_positions, tsk_flags_t options, double *result)
{
Expand All @@ -3397,10 +3397,6 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
bool stat_site = !!(options & TSK_STAT_SITE);
bool stat_branch = !!(options & TSK_STAT_BRANCH);
tsk_size_t state_dim = num_sample_sets;
sample_count_stat_params_t f_params = { .sample_sets = sample_sets,
.num_sample_sets = num_sample_sets,
.sample_set_sizes = sample_set_sizes,
.set_indexes = set_indexes };

// We do not support two-locus node stats
if (!!(options & TSK_STAT_NODE)) {
Expand Down Expand Up @@ -3440,7 +3436,7 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
goto out;
}
ret = tsk_treeseq_two_site_count_stat(self, state_dim, num_sample_sets,
sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows,
sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
row_sites, out_cols, col_sites, options, result);
} else if (stat_branch) {
ret = check_positions(
Expand All @@ -3454,13 +3450,30 @@ tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sampl
goto out;
}
ret = tsk_treeseq_two_branch_count_stat(self, state_dim, num_sample_sets,
sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows,
sample_set_sizes, sample_sets, result_dim, f, f_params, norm_f, out_rows,
row_positions, out_cols, col_positions, options, result);
}
out:
return ret;
}

int
tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self, tsk_size_t num_sample_sets,
const tsk_size_t *sample_set_sizes, const tsk_id_t *sample_sets,
tsk_size_t result_dim, const tsk_id_t *set_indexes, general_stat_func_t *f,
norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites,
const double *col_positions, tsk_flags_t options, double *result)
{
sample_count_stat_params_t f_params = { .sample_sets = sample_sets,
.num_sample_sets = num_sample_sets,
.sample_set_sizes = sample_set_sizes,
.set_indexes = set_indexes };
return tsk_treeseq_two_locus_count_general_stat(self, num_sample_sets,
sample_set_sizes, sample_sets, result_dim, f, &f_params, norm_f, out_rows,
row_sites, row_positions, out_cols, col_sites, col_positions, options, result);
}

/***********************************
* Allele frequency spectrum
***********************************/
Expand Down
11 changes: 9 additions & 2 deletions c/tskit/trees.h
Original file line number Diff line number Diff line change
Expand Up @@ -987,8 +987,8 @@ int tsk_treeseq_general_stat(const tsk_treeseq_t *self, tsk_size_t K, const doub
tsk_size_t M, general_stat_func_t *f, void *f_params, tsk_size_t num_windows,
const double *windows, tsk_flags_t options, double *result);

typedef int norm_func_t(tsk_size_t result_dim, const double *hap_weights, tsk_size_t n_a,
tsk_size_t n_b, double *result, void *params);
typedef int norm_func_t(tsk_size_t state_dim, const double *hap_weights,
tsk_size_t result_dim, tsk_size_t n_a, tsk_size_t n_b, double *result, void *params);

int tsk_treeseq_two_locus_count_stat(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
Expand Down Expand Up @@ -1071,6 +1071,13 @@ typedef int general_sample_stat_method(const tsk_treeseq_t *self,
const tsk_id_t *sample_sets, tsk_size_t num_indexes, const tsk_id_t *indexes,
tsk_size_t num_windows, const double *windows, tsk_flags_t options, double *result);

int tsk_treeseq_two_locus_count_general_stat(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
void *f_params, norm_func_t *norm_f, tsk_size_t out_rows, const tsk_id_t *row_sites,
const double *row_positions, tsk_size_t out_cols, const tsk_id_t *col_sites,
const double *col_positions, tsk_flags_t options, double *result);

typedef int two_locus_count_stat_method(const tsk_treeseq_t *self,
tsk_size_t num_sample_sets, const tsk_size_t *sample_set_sizes,
const tsk_id_t *sample_sets, tsk_size_t num_rows, const tsk_id_t *row_sites,
Expand Down
Loading
Loading