Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 108 additions & 76 deletions c/tskit/trees.c
Original file line number Diff line number Diff line change
Expand Up @@ -2313,61 +2313,98 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n)
}
}

typedef struct {
double *weights;
double *norm;
double *result_tmp;
tsk_bit_array_t AB_samples;
tsk_bit_array_t ss_A_samples;
tsk_bit_array_t ss_B_samples;
tsk_bit_array_t ss_AB_samples;
} two_locus_work_t;

static int
compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles,
tsk_size_t num_b_alleles, tsk_size_t num_samples, tsk_size_t state_dim,
const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f,
sample_count_stat_params_t *f_params, norm_func_t *norm_f, bool polarised,
double *result)
two_locus_work_init(tsk_size_t max_alleles, tsk_size_t result_dim, tsk_size_t state_dim,
tsk_size_t num_samples, two_locus_work_t *out)
{
int ret = 0;
tsk_bit_array_t A_samples, B_samples;
// ss_ prefix refers to a sample set
tsk_bit_array_t ss_row;
tsk_bit_array_t ss_A_samples, ss_B_samples, ss_AB_samples, AB_samples;
// Sample sets and b sites are rows, a sites are columns
// b1 b2 b3
// a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
// a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
// a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
tsk_size_t k, mut_a, mut_b;
tsk_size_t result_row_len = num_b_alleles * result_dim;
tsk_size_t w_A = 0, w_B = 0, w_AB = 0;
uint8_t polarised_val = polarised ? 1 : 0;
double *hap_weight_row;
double *result_tmp_row;
double *weights = tsk_malloc(3 * state_dim * sizeof(*weights));
double *norm = tsk_malloc(result_dim * sizeof(*norm));
double *result_tmp
= tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp));
out->weights = tsk_malloc(3 * state_dim * sizeof(*out->weights));
out->norm = tsk_malloc(result_dim * sizeof(*out->norm));
out->result_tmp
= tsk_malloc(result_dim * max_alleles * max_alleles * sizeof(*out->result_tmp));

tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples));
tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples));
tsk_memset(&ss_AB_samples, 0, sizeof(ss_AB_samples));
tsk_memset(&AB_samples, 0, sizeof(AB_samples));
tsk_memset(&out->ss_A_samples, 0, sizeof(out->ss_A_samples));
tsk_memset(&out->ss_B_samples, 0, sizeof(out->ss_B_samples));
tsk_memset(&out->ss_AB_samples, 0, sizeof(out->ss_AB_samples));
tsk_memset(&out->AB_samples, 0, sizeof(out->AB_samples));

if (weights == NULL || norm == NULL || result_tmp == NULL) {
if (out->weights == NULL || out->norm == NULL || out->result_tmp == NULL) {
ret = tsk_trace_error(TSK_ERR_NO_MEMORY);
goto out;
}

ret = tsk_bit_array_init(&ss_A_samples, num_samples, 1);
ret = tsk_bit_array_init(&out->AB_samples, num_samples, 1);
if (ret != 0) {
goto out;
}
ret = tsk_bit_array_init(&ss_B_samples, num_samples, 1);
ret = tsk_bit_array_init(&out->ss_A_samples, num_samples, 1);
if (ret != 0) {
goto out;
}
ret = tsk_bit_array_init(&ss_AB_samples, num_samples, 1);
ret = tsk_bit_array_init(&out->ss_B_samples, num_samples, 1);
if (ret != 0) {
goto out;
}
ret = tsk_bit_array_init(&AB_samples, num_samples, 1);
ret = tsk_bit_array_init(&out->ss_AB_samples, num_samples, 1);
if (ret != 0) {
goto out;
}
out:
return ret;
}

static void
two_locus_work_free(two_locus_work_t *work)
{
tsk_safe_free(work->weights);
tsk_safe_free(work->norm);
tsk_safe_free(work->result_tmp);
tsk_bit_array_free(&work->AB_samples);
tsk_bit_array_free(&work->ss_A_samples);
tsk_bit_array_free(&work->ss_B_samples);
tsk_bit_array_free(&work->ss_AB_samples);
}

static int
compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles,
tsk_size_t num_b_alleles, tsk_size_t state_dim, const tsk_bit_array_t *sample_sets,
tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params,
norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result)
{
int ret = 0;
tsk_bit_array_t A_samples, B_samples;
// ss_ prefix refers to a sample set
tsk_bit_array_t ss_row;
// Sample sets and b sites are rows, a sites are columns
// b1 b2 b3
// a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
// a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
// a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3]
tsk_size_t k, mut_a, mut_b;
tsk_size_t result_row_len = num_b_alleles * result_dim;
tsk_size_t w_A = 0, w_B = 0, w_AB = 0;
uint8_t polarised_val = polarised ? 1 : 0;
double *hap_weight_row;
double *result_tmp_row;

double *norm = work->norm;
double *weights = work->weights;
double *result_tmp = work->result_tmp;
tsk_bit_array_t AB_samples = work->AB_samples;
tsk_bit_array_t ss_A_samples = work->ss_A_samples;
tsk_bit_array_t ss_B_samples = work->ss_B_samples;
tsk_bit_array_t ss_AB_samples = work->ss_AB_samples;

for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) {
result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a);
Expand Down Expand Up @@ -2408,13 +2445,6 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state,
}

out:
tsk_safe_free(weights);
tsk_safe_free(norm);
tsk_safe_free(result_tmp);
tsk_bit_array_free(&ss_A_samples);
tsk_bit_array_free(&ss_B_samples);
tsk_bit_array_free(&ss_AB_samples);
tsk_bit_array_free(&AB_samples);
return ret;
}

Expand Down Expand Up @@ -2563,14 +2593,15 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
tsk_bit_array_t allele_samples, c_state, r_state;
bool polarised = false;
tsk_id_t *sites;
tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx;
tsk_size_t r, c, s, max_alleles, n_alleles, n_sites, *row_idx, *col_idx;
double *result_row;
const tsk_size_t num_samples = self->num_samples;
tsk_size_t *num_alleles = NULL, *site_offsets = NULL;
tsk_size_t result_row_len = n_cols * result_dim;
two_locus_work_t work;

tsk_memset(&work, 0, sizeof(work));
tsk_memset(&allele_samples, 0, sizeof(allele_samples));

sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites));
row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx));
col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx));
Expand All @@ -2589,11 +2620,20 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
ret = tsk_trace_error(TSK_ERR_NO_MEMORY);
goto out;
}

n_alleles = 0;
max_alleles = 0;
for (s = 0; s < n_sites; s++) {
site_offsets[s] = n_alleles;
n_alleles += self->site_mutations_length[sites[s]] + 1;
if (self->site_mutations_length[sites[s]] > max_alleles) {
max_alleles = self->site_mutations_length[sites[s]];
}
}
max_alleles++;

ret = two_locus_work_init(max_alleles, result_dim, state_dim, num_samples, &work);
if (ret != 0) {
goto out;
}
ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles);
if (ret != 0) {
Expand All @@ -2615,8 +2655,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state);
tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state);
ret = compute_general_two_site_stat_result(&r_state, &c_state,
num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim,
sample_sets, result_dim, f, f_params, norm_f, polarised,
num_alleles[row_idx[r]], num_alleles[col_idx[c]], state_dim, sample_sets,
result_dim, f, f_params, norm_f, polarised, &work,
&(result_row[c * result_dim]));
if (ret != 0) {
goto out;
Expand All @@ -2630,6 +2670,7 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim,
tsk_safe_free(col_idx);
tsk_safe_free(num_alleles);
tsk_safe_free(site_offsets);
two_locus_work_free(&work);
tsk_bit_array_free(&allele_samples);
return ret;
}
Expand Down Expand Up @@ -2970,41 +3011,27 @@ static int
compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c,
const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim,
tsk_size_t result_dim, int sign, general_stat_func_t *f,
sample_count_stat_params_t *f_params, double *result)
sample_count_stat_params_t *f_params, two_locus_work_t *restrict work,
double *result)
{
int ret = 0;
double a_len, b_len;
double *restrict B_branch_len = B_state->branch_len;
double *weights = NULL, *weights_row, *result_tmp = NULL;
double *weights_row;
tsk_size_t n, k, a_row, b_row;
tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp;
tsk_bit_array_t A_samples, B_samples;
const double *restrict A_branch_len = A_state->branch_len;
const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples;
const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples;
tsk_size_t num_samples = ts->num_samples;
tsk_size_t num_nodes = ts->tables->nodes.num_rows;
double *weights = work->weights;
double *result_tmp = work->result_tmp;
tsk_bit_array_t AB_samples = work->AB_samples;

b_len = B_branch_len[c] * sign;
if (b_len == 0) {
return ret;
}

tsk_memset(&AB_samples, 0, sizeof(AB_samples));
tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp));

weights = tsk_calloc(3 * state_dim, sizeof(*weights));
result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp));
if (weights == NULL || result_tmp == NULL) {
ret = tsk_trace_error(TSK_ERR_NO_MEMORY);
goto out;
}
ret = tsk_bit_array_init(&AB_samples, num_samples, 1);
if (ret != 0) {
goto out;
}
ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1);
if (ret != 0) {
goto out;
}
for (n = 0; n < num_nodes; n++) {
a_len = A_branch_len[n];
if (a_len == 0) {
Expand Down Expand Up @@ -3032,10 +3059,6 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c,
}
}
out:
tsk_safe_free(weights);
tsk_safe_free(result_tmp);
tsk_bit_array_free(&AB_samples);
tsk_bit_array_free(&B_samples_tmp);
return ret;
}

Expand All @@ -3052,8 +3075,16 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state,
const tsk_id_t *restrict edges_parent = ts->tables->edges.parent;
const tsk_size_t num_nodes = ts->tables->nodes.num_rows;
tsk_bit_array_t updates, row, ec_row, *r_samples = r_state->node_samples;
const tsk_size_t num_samples = ts->num_samples;
two_locus_work_t work;

tsk_memset(&work, 0, sizeof(work));
tsk_memset(&updates, 0, sizeof(updates));
// only two alleles are possible for branch stats
ret = two_locus_work_init(2, result_dim, state_dim, num_samples, &work);
if (ret != 0) {
goto out;
}
ret = tsk_bit_array_init(&updates, num_nodes, 1);
if (ret != 0) {
goto out;
Expand Down Expand Up @@ -3081,8 +3112,8 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state,
while (n_updates != 0) {
n_updates--;
c = updated_nodes[n_updates];
compute_two_tree_branch_state_update(
ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, result);
compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim,
result_dim, -1, f, f_params, &work, result);
}
// Remove samples under nodes from removed edges to parent nodes
for (j = 0; j < r_state->n_edges_out; j++) {
Expand Down Expand Up @@ -3126,11 +3157,12 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state,
while (n_updates != 0) {
n_updates--;
c = updated_nodes[n_updates];
compute_two_tree_branch_state_update(
ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, result);
compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim,
result_dim, +1, f, f_params, &work, result);
}
out:
tsk_safe_free(updated_nodes);
two_locus_work_free(&work);
tsk_bit_array_free(&updates);
return ret;
}
Expand Down
Loading