From fc7bb999f66bf7ef53351a1d60aee2a9da357d5e Mon Sep 17 00:00:00 2001 From: lkirk Date: Wed, 18 Dec 2024 14:34:46 -0600 Subject: [PATCH] Two-locus malloc optimizations This revision moves all malloc operations out of the hot loop in two-locus statistics, instead providing pre-allocated regions of memory that the two-locus framework will use to perform work. Instead of simply passing each pre-allocated array into each function call, we introduce a simple structure called `two_locus_work_t`, which stores the statistical results, and provides temporary arrays for storing the normalisation constants. Setup and teardown methods for this work structure are provided. All test (python and C) are passing and valgrind reports no memory leaks. --- c/tskit/trees.c | 184 ++++++++++++++++++++++++++++-------------------- 1 file changed, 108 insertions(+), 76 deletions(-) diff --git a/c/tskit/trees.c b/c/tskit/trees.c index a536b08910..7bbf0feb8d 100644 --- a/c/tskit/trees.c +++ b/c/tskit/trees.c @@ -2313,61 +2313,98 @@ get_all_samples_bits(tsk_bit_array_t *all_samples, tsk_size_t n) } } +typedef struct { + double *weights; + double *norm; + double *result_tmp; + tsk_bit_array_t AB_samples; + tsk_bit_array_t ss_A_samples; + tsk_bit_array_t ss_B_samples; + tsk_bit_array_t ss_AB_samples; +} two_locus_work_t; + static int -compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, - const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, - tsk_size_t num_b_alleles, tsk_size_t num_samples, tsk_size_t state_dim, - const tsk_bit_array_t *sample_sets, tsk_size_t result_dim, general_stat_func_t *f, - sample_count_stat_params_t *f_params, norm_func_t *norm_f, bool polarised, - double *result) +two_locus_work_init(tsk_size_t max_alleles, tsk_size_t result_dim, tsk_size_t state_dim, + tsk_size_t num_samples, two_locus_work_t *out) { int ret = 0; - tsk_bit_array_t A_samples, B_samples; - // ss_ prefix refers to a sample set - tsk_bit_array_t ss_row; - tsk_bit_array_t ss_A_samples, ss_B_samples, ss_AB_samples, AB_samples; - // Sample sets and b sites are rows, a sites are columns - // b1 b2 b3 - // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] - tsk_size_t k, mut_a, mut_b; - tsk_size_t result_row_len = num_b_alleles * result_dim; - tsk_size_t w_A = 0, w_B = 0, w_AB = 0; - uint8_t polarised_val = polarised ? 1 : 0; - double *hap_weight_row; - double *result_tmp_row; - double *weights = tsk_malloc(3 * state_dim * sizeof(*weights)); - double *norm = tsk_malloc(result_dim * sizeof(*norm)); - double *result_tmp - = tsk_malloc(result_row_len * num_a_alleles * sizeof(*result_tmp)); + out->weights = tsk_malloc(3 * state_dim * sizeof(*out->weights)); + out->norm = tsk_malloc(result_dim * sizeof(*out->norm)); + out->result_tmp + = tsk_malloc(result_dim * max_alleles * max_alleles * sizeof(*out->result_tmp)); - tsk_memset(&ss_A_samples, 0, sizeof(ss_A_samples)); - tsk_memset(&ss_B_samples, 0, sizeof(ss_B_samples)); - tsk_memset(&ss_AB_samples, 0, sizeof(ss_AB_samples)); - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); + tsk_memset(&out->ss_A_samples, 0, sizeof(out->ss_A_samples)); + tsk_memset(&out->ss_B_samples, 0, sizeof(out->ss_B_samples)); + tsk_memset(&out->ss_AB_samples, 0, sizeof(out->ss_AB_samples)); + tsk_memset(&out->AB_samples, 0, sizeof(out->AB_samples)); - if (weights == NULL || norm == NULL || result_tmp == NULL) { + if (out->weights == NULL || out->norm == NULL || out->result_tmp == NULL) { ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - ret = tsk_bit_array_init(&ss_A_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->AB_samples, num_samples, 1); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&ss_B_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->ss_A_samples, num_samples, 1); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&ss_AB_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->ss_B_samples, num_samples, 1); if (ret != 0) { goto out; } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); + ret = tsk_bit_array_init(&out->ss_AB_samples, num_samples, 1); if (ret != 0) { goto out; } +out: + return ret; +} + +static void +two_locus_work_free(two_locus_work_t *work) +{ + tsk_safe_free(work->weights); + tsk_safe_free(work->norm); + tsk_safe_free(work->result_tmp); + tsk_bit_array_free(&work->AB_samples); + tsk_bit_array_free(&work->ss_A_samples); + tsk_bit_array_free(&work->ss_B_samples); + tsk_bit_array_free(&work->ss_AB_samples); +} + +static int +compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, + const tsk_bit_array_t *site_b_state, tsk_size_t num_a_alleles, + tsk_size_t num_b_alleles, tsk_size_t state_dim, const tsk_bit_array_t *sample_sets, + tsk_size_t result_dim, general_stat_func_t *f, sample_count_stat_params_t *f_params, + norm_func_t *norm_f, bool polarised, two_locus_work_t *restrict work, double *result) +{ + int ret = 0; + tsk_bit_array_t A_samples, B_samples; + // ss_ prefix refers to a sample set + tsk_bit_array_t ss_row; + // Sample sets and b sites are rows, a sites are columns + // b1 b2 b3 + // a1 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a2 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + // a3 [s1, s2, s3] [s1, s2, s3] [s1, s2, s3] + tsk_size_t k, mut_a, mut_b; + tsk_size_t result_row_len = num_b_alleles * result_dim; + tsk_size_t w_A = 0, w_B = 0, w_AB = 0; + uint8_t polarised_val = polarised ? 1 : 0; + double *hap_weight_row; + double *result_tmp_row; + + double *norm = work->norm; + double *weights = work->weights; + double *result_tmp = work->result_tmp; + tsk_bit_array_t AB_samples = work->AB_samples; + tsk_bit_array_t ss_A_samples = work->ss_A_samples; + tsk_bit_array_t ss_B_samples = work->ss_B_samples; + tsk_bit_array_t ss_AB_samples = work->ss_AB_samples; for (mut_a = polarised_val; mut_a < num_a_alleles; mut_a++) { result_tmp_row = GET_2D_ROW(result_tmp, result_row_len, mut_a); @@ -2408,13 +2445,6 @@ compute_general_two_site_stat_result(const tsk_bit_array_t *site_a_state, } out: - tsk_safe_free(weights); - tsk_safe_free(norm); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&ss_A_samples); - tsk_bit_array_free(&ss_B_samples); - tsk_bit_array_free(&ss_AB_samples); - tsk_bit_array_free(&AB_samples); return ret; } @@ -2563,14 +2593,15 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_bit_array_t allele_samples, c_state, r_state; bool polarised = false; tsk_id_t *sites; - tsk_size_t r, c, s, n_alleles, n_sites, *row_idx, *col_idx; + tsk_size_t r, c, s, max_alleles, n_alleles, n_sites, *row_idx, *col_idx; double *result_row; const tsk_size_t num_samples = self->num_samples; tsk_size_t *num_alleles = NULL, *site_offsets = NULL; tsk_size_t result_row_len = n_cols * result_dim; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&allele_samples, 0, sizeof(allele_samples)); - sites = tsk_malloc(self->tables->sites.num_rows * sizeof(*sites)); row_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*row_idx)); col_idx = tsk_malloc(self->tables->sites.num_rows * sizeof(*col_idx)); @@ -2589,11 +2620,20 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, ret = tsk_trace_error(TSK_ERR_NO_MEMORY); goto out; } - n_alleles = 0; + max_alleles = 0; for (s = 0; s < n_sites; s++) { site_offsets[s] = n_alleles; n_alleles += self->site_mutations_length[sites[s]] + 1; + if (self->site_mutations_length[sites[s]] > max_alleles) { + max_alleles = self->site_mutations_length[sites[s]]; + } + } + max_alleles++; + + ret = two_locus_work_init(max_alleles, result_dim, state_dim, num_samples, &work); + if (ret != 0) { + goto out; } ret = tsk_bit_array_init(&allele_samples, num_samples, n_alleles); if (ret != 0) { @@ -2615,8 +2655,8 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_bit_array_get_row(&allele_samples, site_offsets[row_idx[r]], &r_state); tsk_bit_array_get_row(&allele_samples, site_offsets[col_idx[c]], &c_state); ret = compute_general_two_site_stat_result(&r_state, &c_state, - num_alleles[row_idx[r]], num_alleles[col_idx[c]], num_samples, state_dim, - sample_sets, result_dim, f, f_params, norm_f, polarised, + num_alleles[row_idx[r]], num_alleles[col_idx[c]], state_dim, sample_sets, + result_dim, f, f_params, norm_f, polarised, &work, &(result_row[c * result_dim])); if (ret != 0) { goto out; @@ -2630,6 +2670,7 @@ tsk_treeseq_two_site_count_stat(const tsk_treeseq_t *self, tsk_size_t state_dim, tsk_safe_free(col_idx); tsk_safe_free(num_alleles); tsk_safe_free(site_offsets); + two_locus_work_free(&work); tsk_bit_array_free(&allele_samples); return ret; } @@ -2970,41 +3011,27 @@ static int compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, const iter_state *A_state, const iter_state *B_state, tsk_size_t state_dim, tsk_size_t result_dim, int sign, general_stat_func_t *f, - sample_count_stat_params_t *f_params, double *result) + sample_count_stat_params_t *f_params, two_locus_work_t *restrict work, + double *result) { int ret = 0; double a_len, b_len; double *restrict B_branch_len = B_state->branch_len; - double *weights = NULL, *weights_row, *result_tmp = NULL; + double *weights_row; tsk_size_t n, k, a_row, b_row; - tsk_bit_array_t A_samples, B_samples, AB_samples, B_samples_tmp; + tsk_bit_array_t A_samples, B_samples; const double *restrict A_branch_len = A_state->branch_len; const tsk_bit_array_t *restrict A_state_samples = A_state->node_samples; const tsk_bit_array_t *restrict B_state_samples = B_state->node_samples; - tsk_size_t num_samples = ts->num_samples; tsk_size_t num_nodes = ts->tables->nodes.num_rows; + double *weights = work->weights; + double *result_tmp = work->result_tmp; + tsk_bit_array_t AB_samples = work->AB_samples; + b_len = B_branch_len[c] * sign; if (b_len == 0) { return ret; } - - tsk_memset(&AB_samples, 0, sizeof(AB_samples)); - tsk_memset(&B_samples_tmp, 0, sizeof(B_samples_tmp)); - - weights = tsk_calloc(3 * state_dim, sizeof(*weights)); - result_tmp = tsk_calloc(result_dim, sizeof(*result_tmp)); - if (weights == NULL || result_tmp == NULL) { - ret = tsk_trace_error(TSK_ERR_NO_MEMORY); - goto out; - } - ret = tsk_bit_array_init(&AB_samples, num_samples, 1); - if (ret != 0) { - goto out; - } - ret = tsk_bit_array_init(&B_samples_tmp, num_samples, 1); - if (ret != 0) { - goto out; - } for (n = 0; n < num_nodes; n++) { a_len = A_branch_len[n]; if (a_len == 0) { @@ -3032,10 +3059,6 @@ compute_two_tree_branch_state_update(const tsk_treeseq_t *ts, tsk_id_t c, } } out: - tsk_safe_free(weights); - tsk_safe_free(result_tmp); - tsk_bit_array_free(&AB_samples); - tsk_bit_array_free(&B_samples_tmp); return ret; } @@ -3052,8 +3075,16 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; const tsk_size_t num_nodes = ts->tables->nodes.num_rows; tsk_bit_array_t updates, row, ec_row, *r_samples = r_state->node_samples; + const tsk_size_t num_samples = ts->num_samples; + two_locus_work_t work; + tsk_memset(&work, 0, sizeof(work)); tsk_memset(&updates, 0, sizeof(updates)); + // only two alleles are possible for branch stats + ret = two_locus_work_init(2, result_dim, state_dim, num_samples, &work); + if (ret != 0) { + goto out; + } ret = tsk_bit_array_init(&updates, num_nodes, 1); if (ret != 0) { goto out; @@ -3081,8 +3112,8 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, -1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, -1, f, f_params, &work, result); } // Remove samples under nodes from removed edges to parent nodes for (j = 0; j < r_state->n_edges_out; j++) { @@ -3126,11 +3157,12 @@ compute_two_tree_branch_stat(const tsk_treeseq_t *ts, const iter_state *l_state, while (n_updates != 0) { n_updates--; c = updated_nodes[n_updates]; - compute_two_tree_branch_state_update( - ts, c, l_state, r_state, state_dim, result_dim, +1, f, f_params, result); + compute_two_tree_branch_state_update(ts, c, l_state, r_state, state_dim, + result_dim, +1, f, f_params, &work, result); } out: tsk_safe_free(updated_nodes); + two_locus_work_free(&work); tsk_bit_array_free(&updates); return ret; }