diff --git a/c/tskit/trees.c b/c/tskit/trees.c index a536b08910..7bbf0feb8d 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2313,61 +2313,98 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n) } } +typedef struct { + double *weights; + double *norm; + double *result_tmp; + tsk_bit_array_t AB_samples; + tsk_bit_array_t ss_A_samples; + tsk_bit_array_t ss_B_samples; + tsk_bit_array_t ss_AB_samples; +} two_locus_work_t; + static int -compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, - const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, - tsk_size_t num_b_alleles, tsk_size_t num_samples, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, bool polarised, - double *result) +two_locus_work_init(tsk_size_t max_alleles, tsk_size_t result_dim, tsk_size_t state_dim, + tsk_size_t num_samples, two_locus_work_t *out) { int ret = 0; - tsk_bit_array_t A_samples, B_samples; - // ss_ prefix refers to a sample set - tsk_bit_array_t ss_row; - tsk_bit_array_t ss_A_samples, ss_B_samples, ss_AB_samples, AB_samples; - // Sample sets and b sites are rows, a sites are columns - // b1 b2 b3 - // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - tsk_size_t k, mut_a, mut_b; - tsk_size_t result_row_len = num_b_alleles * result_dim; - tsk_size_t w_A = 0, w_B = 0, w_AB = 0; - uint8_t polarised_val = polarised ? 1 : 0; - double *hap_weight_row; - double *result_tmp_row; - double *weights = tsk_malloc(3 * state_dim * sizeof(*weights)); - double *norm = tsk_malloc(result_dim * sizeof(*norm)); - double *result_tmp - = tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp)); + out->weights = tsk_malloc(3 * state_dim * sizeof(*out->weights)); + out->norm = tsk_malloc(result_dim * sizeof(*out->norm)); + out->result_tmp + = tsk_malloc(result_dim * max_alleles * max_alleles * sizeof(*out->result_tmp)); - tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples)); - tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples)); - tsk_memset(&ss_AB_samples, 0, sizeof(ss_AB_samples)); - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); + tsk_memset(&out->ss_A_samples, 0, sizeof(out->ss_A_samples)); + tsk_memset(&out->ss_B_samples, 0, sizeof(out->ss_B_samples)); + tsk_memset(&out->ss_AB_samples, 0, sizeof(out->ss_AB_samples)); + tsk_memset(&out->AB_samples, 0, sizeof(out->AB_samples)); - if (weights == NULL || norm == NULL || result_tmp == NULL) { + if (out->weights == NULL || out->norm == NULL || out->result_tmp == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_bit_array_init(&ss_A_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->AB_samples, num_samples, 1); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&ss_B_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->ss_A_samples, num_samples, 1); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&ss_AB_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->ss_B_samples, num_samples, 1); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->ss_AB_samples, num_samples, 1); if (ret != 0) { goto out; } +out: + return ret; +} + +static void +two_locus_work_free(two_locus_work_t *work) +{ + tsk_safe_free(work->weights); + tsk_safe_free(work->norm); + tsk_safe_free(work->result_tmp); + tsk_bit_array_free(&work->AB_samples); + tsk_bit_array_free(&work->ss_A_samples); + tsk_bit_array_free(&work->ss_B_samples); + tsk_bit_array_free(&work->ss_AB_samples); +} + +static int +compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, + const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, + tsk_size_t num_b_alleles, tsk_size_t state_dim, const tsk_bit_array_t *sample_sets, + tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, + norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) +{ + int ret = 0; + tsk_bit_array_t A_samples, B_samples; + // ss_ prefix refers to a sample set + tsk_bit_array_t ss_row; + // Sample sets and b sites are rows, a sites are columns + // b1 b2 b3 + // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + tsk_size_t k, mut_a, mut_b; + tsk_size_t result_row_len = num_b_alleles * result_dim; + tsk_size_t w_A = 0, w_B = 0, w_AB = 0; + uint8_t polarised_val = polarised ? 1 : 0; + double *hap_weight_row; + double *result_tmp_row; + + double *norm = work->norm; + double *weights = work->weights; + double *result_tmp = work->result_tmp; + tsk_bit_array_t AB_samples = work->AB_samples; + tsk_bit_array_t ss_A_samples = work->ss_A_samples; + tsk_bit_array_t ss_B_samples = work->ss_B_samples; + tsk_bit_array_t ss_AB_samples = work->ss_AB_samples; for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) { result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a); @@ -2408,13 +2445,6 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, } out: - tsk_safe_free(weights); - tsk_safe_free(norm); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&ss_A_samples); - tsk_bit_array_free(&ss_B_samples); - tsk_bit_array_free(&ss_AB_samples); - tsk_bit_array_free(&AB_samples); return ret; } @@ -2563,14 +2593,15 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_bit_array_t allele_samples, c_state, r_state; bool polarised = false; tsk_id_t *sites; - tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx; + tsk_size_t r, c, s, max_alleles, n_alleles, n_sites, *row_idx, *col_idx; double *result_row; const tsk_size_t num_samples = self->num_samples; tsk_size_t *num_alleles = NULL, *site_offsets = NULL; tsk_size_t result_row_len = n_cols * result_dim; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&allele_samples, 0, sizeof(allele_samples)); - sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites)); row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx)); col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx)); @@ -2589,11 +2620,20 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - n_alleles = 0; + max_alleles = 0; for (s = 0; s < n_sites; s++) { site_offsets[s] = n_alleles; n_alleles += self->site_mutations_length[sites[s]] + 1; + if (self->site_mutations_length[sites[s]] > max_alleles) { + max_alleles = self->site_mutations_length[sites[s]]; + } + } + max_alleles++; + + ret = two_locus_work_init(max_alleles, result_dim, state_dim, num_samples, &work); + if (ret != 0) { + goto out; } ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles); if (ret != 0) { @@ -2615,8 +2655,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state); tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state); ret = compute_general_two_site_stat_result(&r_state, &c_state, - num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim, - sample_sets, result_dim, f, f_params, norm_f, polarised, + num_alleles[row_idx[r]], num_alleles[col_idx[c]], state_dim, sample_sets, + result_dim, f, f_params, norm_f, polarised, &work, &(result_row[c * result_dim])); if (ret != 0) { goto out; @@ -2630,6 +2670,7 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_safe_free(col_idx); tsk_safe_free(num_alleles); tsk_safe_free(site_offsets); + two_locus_work_free(&work); tsk_bit_array_free(&allele_samples); return ret; } @@ -2970,41 +3011,27 @@ static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, double *result) + sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, + double *result) { int ret = 0; double a_len, b_len; double *restrict B_branch_len = B_state->branch_len; - double *weights = NULL, *weights_row, *result_tmp = NULL; + double *weights_row; tsk_size_t n, k, a_row, b_row; - tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp; + tsk_bit_array_t A_samples, B_samples; const double *restrict A_branch_len = A_state->branch_len; const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples; const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples; - tsk_size_t num_samples = ts->num_samples; tsk_size_t num_nodes = ts->tables->nodes.num_rows; + double *weights = work->weights; + double *result_tmp = work->result_tmp; + tsk_bit_array_t AB_samples = work->AB_samples; + b_len = B_branch_len[c] * sign; if (b_len == 0) { return ret; } - - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); - tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp)); - - weights = tsk_calloc(3 * state_dim, sizeof(*weights)); - result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); - if (weights == NULL || result_tmp == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; - } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1); - if (ret != 0) { - goto out; - } for (n = 0; n < num_nodes; n++) { a_len = A_branch_len[n]; if (a_len == 0) { @@ -3032,10 +3059,6 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, } } out: - tsk_safe_free(weights); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&AB_samples); - tsk_bit_array_free(&B_samples_tmp); return ret; } @@ -3052,8 +3075,16 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; const tsk_size_t num_nodes = ts->tables->nodes.num_rows; tsk_bit_array_t updates, row, ec_row, *r_samples = r_state->node_samples; + const tsk_size_t num_samples = ts->num_samples; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&updates, 0, sizeof(updates)); + // only two alleles are possible for branch stats + ret = two_locus_work_init(2, result_dim, state_dim, num_samples, &work); + if (ret != 0) { + goto out; + } ret = tsk_bit_array_init(&updates, num_nodes, 1); if (ret != 0) { goto out; @@ -3081,8 +3112,8 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, -1, f, f_params, &work, result); } // Remove samples under nodes from removed edges to parent nodes for (j = 0; j < r_state->n_edges_out; j++) { @@ -3126,11 +3157,12 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, +1, f, f_params, &work, result); } out: tsk_safe_free(updated_nodes); + two_locus_work_free(&work); tsk_bit_array_free(&updates); return ret; }