From 7ef65c7f450ff6e24eac62b98d0e9b74b9828797 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 13 Apr 2023 23:00:41 +0100 Subject: [PATCH 01/42] Initial rough work on getting ancstor_matcher decoupled from tree_sequence_builder and workign directly on the tree sequence instead. --- lib/ancestor_matcher.c | 950 +++++++++++++++++++++++++++++++++++++++++ lib/tsinfer.h | 60 +++ 2 files changed, 1010 insertions(+) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index ee992264..a3e68539 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1011,3 +1011,953 @@ ancestor_matcher_get_total_memory(ancestor_matcher_t *self) return total; } + +/* NEW IMPLEMENTATION */ + +static void +lshmm_check_state(lshmm_t *self) +{ + int num_likelihoods; + int j; + tsk_id_t u; + + /* Check the properties of the likelihood map */ + for (j = 0; j < self->num_likelihood_nodes; j++) { + u = self->likelihood_nodes[j]; + assert(self->likelihood[u] >= 0 && self->likelihood[u] <= 2); + } + /* Make sure that there are no other non null likelihoods in the array */ + num_likelihoods = 0; + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + if (self->likelihood[u] >= 0) { + num_likelihoods++; + } + if (is_nonzero_root(u, self->parent, self->left_child)) { + assert(self->likelihood[u] == NONZERO_ROOT_LIKELIHOOD); + } + assert(self->allelic_state[u] == TSK_NULL); + } + assert(num_likelihoods == self->num_likelihood_nodes); +} + +int +lshmm_print_state(lshmm_t *self, FILE *out) +{ + int j, k; + tsk_id_t u; + + fprintf(out, "Ancestor matcher state\n"); + fprintf(out, "site\trecomb_rate\tmut_rate\n"); + for (j = 0; j < (int) self->num_sites; j++) { + fprintf( + out, "%d\t%f\t%f\n", j, self->recombination_rate[j], self->mismatch_rate[j]); + } + fprintf(out, "tree = \n"); + fprintf(out, "id\tparent\tlchild\trchild\tlsib\trsib\tlikelihood\n"); + for (j = 0; j < (int) self->num_nodes; j++) { + fprintf(out, "%d\t%d\t%d\t%d\t%d\t%d\t%f\n", (int) j, self->parent[j], + self->left_child[j], self->right_child[j], self->left_sib[j], + self->right_sib[j], self->likelihood[j]); + } + fprintf(out, "likelihood nodes\n"); + /* Check the properties of the likelihood map */ + for (j = 0; j < self->num_likelihood_nodes; j++) { + u = self->likelihood_nodes[j]; + fprintf(out, "\t%d -> %f\n", u, self->likelihood[u]); + } + fprintf(out, "traceback\n"); + for (j = 0; j < (int) self->num_sites; j++) { + fprintf(out, "\t%d:%d (%d)\t", (int) j, self->max_likelihood_node[j], + self->traceback[j].size); + for (k = 0; k < self->traceback[j].size; k++) { + fprintf(out, "(%d, %d)", self->traceback[j].node[k], + self->traceback[j].recombination_required[k]); + } + fprintf(out, "\n"); + } + tsk_blkalloc_print_state(&self->traceback_allocator, out); + + /* lshmm_check_state(self); */ + return 0; +} + +static int +lshmm_copy_edge_indexes(lshmm_t *self) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t k; + edge_t e; + const tsk_id_t *restrict I = self->ts->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = self->ts->tables->indexes.edge_removal_order; + const double *restrict edges_right = self->ts->tables->edges.right; + const double *restrict edges_left = self->ts->tables->edges.left; + const tsk_id_t *restrict edges_child = self->ts->tables->edges.child; + const tsk_id_t *restrict edges_parent = self->ts->tables->edges.parent; + + for (j = 0; j < self->num_edges; j++) { + k = I[j]; + /* TODO check that the edges can be cast */ + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->left_index_edges[j] = e; + + k = O[j]; + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->right_index_edges[j] = e; + } + return ret; +} + +static int +lshmm_copy_mutation_data(lshmm_t *self) +{ + int ret = 0; + tsk_site_t site; + tsk_size_t j; + + for (j = 0; j < self->num_sites; j++) { + ret = tsk_treeseq_get_site(self->ts, (tsk_id_t) j, &site); + if (ret != 0) { + goto out; + } + if (site.mutations_length > 1) { + ret = TSI_ERR_GENERIC; + goto out; + } + self->sites.num_alleles[j] = site.mutations_length; + /* Next, allocate the corresponding mutation object here and link + * to it. */ + delibaretely not compiling + } +out: + return ret; +} + +int +lshmm_alloc(lshmm_t *self, const tsk_treeseq_t *ts, const double *recombination_rate, + const double *mismatch_rate, unsigned int precision, tsk_flags_t flags) +{ + int ret = 0; + + memset(self, 0, sizeof(lshmm_t)); + /* All allocs for arrays related to nodes are done in expand_nodes */ + self->flags = flags; + self->precision = precision; + self->ts = ts; + self->num_nodes = tsk_treeseq_get_num_nodes(ts); + self->num_sites = tsk_treeseq_get_num_sites(ts); + self->recombination_rate + = malloc(self->num_sites * sizeof(*self->recombination_rate)); + self->mismatch_rate = malloc(self->num_sites * sizeof(*self->mismatch_rate)); + self->output.max_size = self->num_sites; /* We can probably make this smaller */ + self->traceback = calloc(self->num_sites, sizeof(node_state_list_t)); + self->max_likelihood_node = malloc(self->num_sites * sizeof(tsk_id_t)); + self->output.left = malloc(self->output.max_size * sizeof(tsk_id_t)); + self->output.right = malloc(self->output.max_size * sizeof(tsk_id_t)); + self->output.parent = malloc(self->output.max_size * sizeof(tsk_id_t)); + self->left_index_edges = malloc(self->num_edges * sizeof(*self->left_index_edges)); + self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); + self->sites.mutations = malloc(self->num_sites * sizeof(*self->sites.mutations)); + self->sites.num_alleles = malloc(self->num_sites * sizeof(*self->sites.num_alleles)); + if (self->recombination_rate == NULL || self->mismatch_rate == NULL + || self->traceback == NULL || self->max_likelihood_node == NULL + || self->output.left == NULL || self->output.right == NULL + || self->output.parent == NULL || self->left_index_edges == NULL + || self->right_index_edges == NULL || self->sites.mutations == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + /* Alloc in 64MiB blocks. */ + self->traceback_block_size = 64 * 1024 * 1024; + /* If the traceback allocator is using more than 2GiB of RAM free it, so + * that other threads can use the memory */ + self->traceback_realloc_size = 2L * 1024L * 1024L * 1024L; + ret = tsk_blkalloc_init(&self->traceback_allocator, self->traceback_block_size); + if (ret != 0) { + goto out; + } + memcpy(self->recombination_rate, recombination_rate, + self->num_sites * sizeof(*self->recombination_rate)); + memcpy(self->mismatch_rate, mismatch_rate, + self->num_sites * sizeof(*self->mismatch_rate)); + ret = lshmm_copy_edge_indexes(self); + if (ret != 0) { + goto out; + } + ret = lshmm_copy_mutation_data(self); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +int +lshmm_free(lshmm_t *self) +{ + tsi_safe_free(self->recombination_rate); + tsi_safe_free(self->mismatch_rate); + tsi_safe_free(self->parent); + tsi_safe_free(self->left_child); + tsi_safe_free(self->right_child); + tsi_safe_free(self->left_sib); + tsi_safe_free(self->right_sib); + tsi_safe_free(self->recombination_required); + tsi_safe_free(self->likelihood); + tsi_safe_free(self->likelihood_cache); + tsi_safe_free(self->likelihood_nodes); + tsi_safe_free(self->likelihood_nodes_tmp); + tsi_safe_free(self->allelic_state); + tsi_safe_free(self->max_likelihood_node); + tsi_safe_free(self->traceback); + tsi_safe_free(self->output.left); + tsi_safe_free(self->output.right); + tsi_safe_free(self->output.parent); + tsk_safe_free(self->left_index_edges); + tsk_safe_free(self->right_index_edges); + tsk_safe_free(self->sites.mutations); + tsk_safe_free(self->sites.num_alleles); + tsk_blkalloc_free(&self->traceback_allocator); + return 0; +} + +static int +lshmm_delete_likelihood(lshmm_t *self, const tsk_id_t node, double *restrict L) +{ + /* Remove the specified node from the list of nodes */ + int j, k; + tsk_id_t *restrict L_nodes = self->likelihood_nodes; + + k = 0; + for (j = 0; j < self->num_likelihood_nodes; j++) { + L_nodes[k] = L_nodes[j]; + if (L_nodes[j] != node) { + k++; + } + } + assert(self->num_likelihood_nodes == k + 1); + self->num_likelihood_nodes = k; + L[node] = NULL_LIKELIHOOD; + return 0; +} + +/* Store the recombination_required state in the traceback */ +static int WARN_UNUSED +lshmm_store_traceback(lshmm_t *self, const tsk_id_t site_id) +{ + int ret = 0; + tsk_id_t u; + int j; + int8_t *restrict list_R; + tsk_id_t *restrict list_node; + node_state_list_t *restrict list; + node_state_list_t *restrict T = self->traceback; + const tsk_id_t *restrict nodes = self->likelihood_nodes; + const int8_t *restrict R = self->recombination_required; + const int num_likelihood_nodes = self->num_likelihood_nodes; + bool match; + + /* Check to see if the previous site has the same recombination_required. If so, + * we can reuse the same list. */ + match = false; + if (site_id > 0) { + list = &T[site_id - 1]; + if (list->size == num_likelihood_nodes) { + list_node = list->node; + list_R = list->recombination_required; + match = true; + for (j = 0; j < num_likelihood_nodes; j++) { + if (list_node[j] != nodes[j] || list_R[j] != R[nodes[j]]) { + match = false; + break; + } + } + } + } + + if (match) { + T[site_id].size = T[site_id - 1].size; + T[site_id].node = T[site_id - 1].node; + T[site_id].recombination_required = T[site_id - 1].recombination_required; + } else { + list_node = tsk_blkalloc_get(&self->traceback_allocator, + (size_t) num_likelihood_nodes * sizeof(tsk_id_t)); + list_R = tsk_blkalloc_get( + &self->traceback_allocator, (size_t) num_likelihood_nodes * sizeof(int8_t)); + if (list_node == NULL || list_R == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + T[site_id].node = list_node; + T[site_id].recombination_required = list_R; + T[site_id].size = num_likelihood_nodes; + for (j = 0; j < num_likelihood_nodes; j++) { + u = nodes[j]; + list_node[j] = u; + list_R[j] = R[u]; + } + } + self->total_traceback_size += (size_t) num_likelihood_nodes; +out: + return ret; +} + +/* Sets the specified allelic state array to reflect the mutations at the + * specified site. */ +static inline void +lshmm_set_allelic_state( + lshmm_t *self, const tsk_id_t site, allele_t *restrict allelic_state) +{ + mutation_list_node_t *mutation; + + /* FIXME assuming that 0 is always the ancestral state */ + allelic_state[0] = 0; + + for (mutation = self->sites.mutations[site]; mutation != NULL; + mutation = mutation->next) { + allelic_state[mutation->node] = mutation->derived_state; + } +} + +/* Resets the allelic state at this site to NULL. */ +static inline void +lshmm_unset_allelic_state( + lshmm_t *self, const tsk_id_t site, allele_t *restrict allelic_state) +{ + mutation_list_node_t *mutation; + + allelic_state[0] = NULL_NODE; + for (mutation = self->sites.mutations[site]; mutation != NULL; + mutation = mutation->next) { + allelic_state[mutation->node] = TSK_NULL; + } +} + +static int WARN_UNUSED +lshmm_update_site_likelihood_values(lshmm_t *self, const tsk_id_t site, + const allele_t state, const tsk_id_t *restrict parent, double *restrict L) +{ + int ret = 0; + const int num_likelihood_nodes = self->num_likelihood_nodes; + const tsk_id_t *restrict L_nodes = self->likelihood_nodes; + allele_t *restrict allelic_state = self->allelic_state; + int8_t *restrict recombination_required = self->recombination_required; + int j; + tsk_id_t u, v, max_L_node; + double max_L, p_last, p_no_recomb, p_recomb, p_t, p_e; + const double rho = self->recombination_rate[site]; + const double mu = self->mismatch_rate[site]; + const double n = (double) self->num_nodes; + const double num_alleles = 2; + /* = (double) self->tree_sequence_builder->sites.num_alleles[site]; */ + printf("FIXME get num alleles at site\n"); + + if (state >= num_alleles) { + ret = TSI_ERR_BAD_HAPLOTYPE_ALLELE; + goto out; + } + + lshmm_set_allelic_state(self, site, allelic_state); + + max_L = -1; + max_L_node = NULL_NODE; + assert(num_likelihood_nodes > 0); + /* printf("likelihoods for node=%d, n=%d\n", mutation_node, + * self->num_likelihood_nodes); */ + for (j = 0; j < num_likelihood_nodes; j++) { + u = L_nodes[j]; + /* Get the allelic state at u. */ + /* TODO we can cache the states here to save some time. One nice way we could + * do the caching is to save the L_node index in the allelic_state array as + * we traverse upwards, and then keep an array of the L_node states which + * we then look up. This would save a second upward traversal to mark the + * array after we've found the state value. */ + v = u; + while (allelic_state[v] == TSK_NULL) { + v = parent[v]; + } + p_last = L[u]; + p_no_recomb = p_last * (1 - rho + rho / n); + p_recomb = rho / n; + recombination_required[u] = false; + if (p_no_recomb > p_recomb) { + p_t = p_no_recomb; + } else { + p_t = p_recomb; + recombination_required[u] = true; + } + p_e = mu; + if (allelic_state[v] == state || state == TSK_MISSING_DATA) { + p_e = 1 - (num_alleles - 1) * mu; + } + L[u] = p_t * p_e; + + if (L[u] > max_L) { + max_L = L[u]; + max_L_node = u; + } + } + /* lshmm_print_state(self, stdout); */ + if (max_L <= 0) { + if (mu <= 0 || mu >= 1) { + ret = TSI_ERR_MATCH_IMPOSSIBLE_EXTREME_MUTATION_PROBA; + goto out; + } + if (rho == 0) { + ret = TSI_ERR_MATCH_IMPOSSIBLE_ZERO_RECOMB_PRECISION; + goto out; + } + ret = TSI_ERR_MATCH_IMPOSSIBLE; + goto out; + } + assert(max_L_node != NULL_NODE); + self->max_likelihood_node[site] = max_L_node; + + /* Renormalise the likelihoods. */ + for (j = 0; j < num_likelihood_nodes; j++) { + u = L_nodes[j]; + L[u] = tsk_round(L[u] / max_L, self->precision); + } + lshmm_unset_allelic_state(self, site, allelic_state); +out: + return ret; +} + +static int WARN_UNUSED +lshmm_coalesce_likelihoods(lshmm_t *self, const tsk_id_t *restrict parent, + double *restrict L, double *restrict L_cache) +{ + int ret = 0; + double L_p; + tsk_id_t u, v, p; + tsk_id_t *restrict cached_paths = self->likelihood_nodes_tmp; + const int old_num_likelihood_nodes = self->num_likelihood_nodes; + tsk_id_t *restrict L_nodes = self->likelihood_nodes; + int j, num_cached_paths, num_likelihood_nodes; + + num_cached_paths = 0; + num_likelihood_nodes = 0; + assert(old_num_likelihood_nodes > 0); + for (j = 0; j < old_num_likelihood_nodes; j++) { + u = L_nodes[j]; + p = parent[u]; + if (p != NULL_NODE) { + cached_paths[num_cached_paths] = p; + num_cached_paths++; + v = p; + while ( + likely(L[v] == NULL_LIKELIHOOD) && likely(L_cache[v] == CACHE_UNSET)) { + v = parent[v]; + } + L_p = L_cache[v]; + if (unlikely(L_p == CACHE_UNSET)) { + L_p = L[v]; + } + /* Fill in the L cache */ + v = p; + while ( + likely(L[v] == NULL_LIKELIHOOD) && likely(L_cache[v] == CACHE_UNSET)) { + L_cache[v] = L_p; + v = parent[v]; + } + /* If the likelihood for the parent is equal to the child we can + * delete the child likelihood */ + if (L[u] == L_p) { + L[u] = NULL_LIKELIHOOD; + } + } + if (L[u] >= 0) { + L_nodes[num_likelihood_nodes] = L_nodes[j]; + num_likelihood_nodes++; + } + } + /* lshmm_print_state(self, stdout); */ + assert(num_likelihood_nodes > 0); + + self->num_likelihood_nodes = num_likelihood_nodes; + /* Reset the L cache */ + for (j = 0; j < num_cached_paths; j++) { + v = cached_paths[j]; + while (likely(v != NULL_NODE) && likely(L_cache[v] != CACHE_UNSET)) { + L_cache[v] = CACHE_UNSET; + v = parent[v]; + } + } + + return ret; +} + +static int +lshmm_update_site_state(lshmm_t *self, const tsk_id_t site, const allele_t state, + tsk_id_t *restrict parent, double *restrict L, double *restrict L_cache) +{ + int ret = 0; + mutation_list_node_t *mutation = NULL; + tsk_id_t u; + + assert(self->num_likelihood_nodes > 0); + + if (self->flags & TSI_EXTENDED_CHECKS) { + lshmm_check_state(self); + } + for (mutation = self->sites.mutations[site]; mutation != NULL; + mutation = mutation->next) { + /* Insert a new L-value for the mutation node if needed */ + if (L[mutation->node] == NULL_LIKELIHOOD) { + u = mutation->node; + while (L[u] == NULL_LIKELIHOOD) { + u = parent[u]; + assert(u != NULL_NODE); + } + L[mutation->node] = L[u]; + self->likelihood_nodes[self->num_likelihood_nodes] = mutation->node; + self->num_likelihood_nodes++; + } + } + ret = lshmm_update_site_likelihood_values(self, site, state, parent, L); + if (ret != 0) { + goto out; + } + ret = lshmm_store_traceback(self, site); + if (ret != 0) { + goto out; + } + ret = lshmm_coalesce_likelihoods(self, parent, L, L_cache); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +static void +lshmm_reset_tree(lshmm_t *self) +{ + memset(self->parent, 0xff, self->num_nodes * sizeof(*self->parent)); + memset(self->left_child, 0xff, self->num_nodes * sizeof(*self->left_child)); + memset(self->right_child, 0xff, self->num_nodes * sizeof(*self->right_child)); + memset(self->left_sib, 0xff, self->num_nodes * sizeof(*self->left_sib)); + memset(self->right_sib, 0xff, self->num_nodes * sizeof(*self->right_sib)); + memset(self->recombination_required, 0xff, + self->num_nodes * sizeof(*self->recombination_required)); +} + +static int +lshmm_reset(lshmm_t *self) +{ + int ret = 0; + + memset(self->allelic_state, 0xff, self->num_nodes * sizeof(*self->allelic_state)); + + if (self->traceback_allocator.total_size > self->traceback_realloc_size) { + tsk_blkalloc_free(&self->traceback_allocator); + ret = tsk_blkalloc_init(&self->traceback_allocator, self->traceback_block_size); + if (ret != 0) { + goto out; + } + } else { + ret = tsk_blkalloc_reset(&self->traceback_allocator); + if (ret != 0) { + goto out; + } + } + self->total_traceback_size = 0; + self->num_likelihood_nodes = 0; + lshmm_reset_tree(self); +out: + return ret; +} + +/* Resets the recombination_required array from the traceback at the specified site. + */ +static inline void +lshmm_set_recombination_required( + lshmm_t *self, tsk_id_t site, int8_t *restrict recombination_required) +{ + int j; + const int8_t *restrict R = self->traceback[site].recombination_required; + const tsk_id_t *restrict node = self->traceback[site].node; + const int size = self->traceback[site].size; + + /* We always set recombination_required for node 0 to false for the cases + * where no recombination is needed at a particular site (which are + * encoded by a traceback of size 0) */ + recombination_required[0] = 0; + for (j = 0; j < size; j++) { + recombination_required[node[j]] = R[j]; + } +} + +/* Unsets the likelihood array from the traceback at the specified site. + */ +static inline void +lshmm_unset_recombination_required( + lshmm_t *self, tsk_id_t site, int8_t *restrict recombination_required) +{ + int j; + const tsk_id_t *restrict node = self->traceback[site].node; + const int size = self->traceback[site].size; + + for (j = 0; j < size; j++) { + recombination_required[node[j]] = -1; + } + recombination_required[0] = -1; +} + +static int WARN_UNUSED +lshmm_run_traceback(lshmm_t *self, tsk_id_t start, tsk_id_t end, + allele_t *TSK_UNUSED(haplotype), allele_t *match) +{ + int ret = 0; + tsk_id_t l; + edge_t edge; + tsk_id_t u, v, max_likelihood_node; + tsk_id_t left, right, pos; + tsk_id_t *restrict parent = self->parent; + allele_t *restrict allelic_state = self->allelic_state; + int8_t *restrict recombination_required = self->recombination_required; + const edge_t *restrict in = self->right_index_edges; + const edge_t *restrict out = self->left_index_edges; + int_fast32_t in_index = (int_fast32_t) self->num_edges - 1; + int_fast32_t out_index = (int_fast32_t) self->num_edges - 1; + + /* Prepare for the traceback and get the memory ready for recording + * the output edges. */ + self->output.size = 0; + self->output.right[self->output.size] = end; + self->output.parent[self->output.size] = NULL_NODE; + + max_likelihood_node = self->max_likelihood_node[end - 1]; + assert(max_likelihood_node != NULL_NODE); + self->output.parent[self->output.size] = max_likelihood_node; + assert(self->output.parent[self->output.size] != NULL_NODE); + + /* Now go through the trees in reverse and run the traceback */ + memset(parent, 0xff, self->num_nodes * sizeof(*parent)); + memset( + recombination_required, 0xff, self->num_nodes * sizeof(*recombination_required)); + pos = (tsk_id_t) self->num_sites; + + while (pos > start) { + while (out_index >= 0 && out[out_index].left == pos) { + edge = out[out_index]; + out_index--; + parent[edge.child] = NULL_NODE; + } + while (in_index >= 0 && in[in_index].right == pos) { + edge = in[in_index]; + in_index--; + parent[edge.child] = edge.parent; + } + right = pos; + left = 0; + if (out_index >= 0) { + left = TSK_MAX(left, out[out_index].left); + } + if (in_index >= 0) { + left = TSK_MAX(left, in[in_index].right); + } + pos = left; + + /* The tree is ready; perform the traceback at each site in this tree */ + assert(left < right); + for (l = TSK_MIN(right, end) - 1; l >= (int) TSK_MAX(left, start); l--) { + lshmm_set_allelic_state(self, l, allelic_state); + u = self->output.parent[self->output.size]; + v = u; + while (allelic_state[v] == TSK_NULL) { + v = parent[v]; + } + match[l] = allelic_state[v]; + lshmm_unset_allelic_state(self, l, allelic_state); + + /* Mark the traceback nodes on the tree */ + lshmm_set_recombination_required(self, l, recombination_required); + + /* Traverse up the tree from the current node. The first marked node that we + * meed tells us whether we need to recombine */ + while (u != 0 && recombination_required[u] == -1) { + u = parent[u]; + assert(u != NULL_NODE); + } + if (recombination_required[u] && l > start) { + max_likelihood_node = self->max_likelihood_node[l - 1]; + assert(max_likelihood_node != NULL_NODE); + self->output.left[self->output.size] = l; + self->output.size++; + assert(self->output.size < self->output.max_size); + /* Start the next output edge */ + self->output.right[self->output.size] = l; + self->output.parent[self->output.size] = max_likelihood_node; + } + /* Unset the values in the tree for the next site. */ + lshmm_unset_recombination_required(self, l, recombination_required); + } + } + + self->output.left[self->output.size] = start; + self->output.size++; + assert(self->output.right[self->output.size - 1] != start); + return ret; +} + +static int +lshmm_run_forwards_match( + lshmm_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype) +{ + int ret = 0; + tsk_id_t site; + edge_t edge; + tsk_id_t u, root, last_root; + double L_child = 0; + /* Use the restrict keyword here to try to improve performance by avoiding + * unecessary loads. We must be very careful to to ensure that all references + * to this memory for the duration of this function is through these variables. + */ + double *restrict L = self->likelihood; + double *restrict L_cache = self->likelihood_cache; + tsk_id_t *restrict parent = self->parent; + tsk_id_t *restrict left_child = self->left_child; + tsk_id_t *restrict right_child = self->right_child; + tsk_id_t *restrict left_sib = self->left_sib; + tsk_id_t *restrict right_sib = self->right_sib; + tsk_id_t pos, left, right; + const edge_t *restrict in = self->left_index_edges; + const edge_t *restrict out = self->right_index_edges; + const int_fast32_t M = (tsk_id_t) self->num_edges; + int_fast32_t in_index, out_index, l, remove_start; + + /* Load the tree for start */ + left = 0; + pos = 0; + in_index = 0; + out_index = 0; + right = (tsk_id_t) self->num_sites; + if (in_index < M && start < in[in_index].left) { + right = in[in_index].left; + } + + /* TODO there's probably quite a big gain to made here by seeking + * directly to the tree that we're interested in rather than just + * building the trees sequentially */ + while (in_index < M && out_index < M && in[in_index].left <= start) { + while (out_index < M && out[out_index].right == pos) { + remove_edge( + out[out_index], parent, left_child, right_child, left_sib, right_sib); + out_index++; + } + while (in_index < M && in[in_index].left == pos) { + insert_edge( + in[in_index], parent, left_child, right_child, left_sib, right_sib); + in_index++; + } + left = pos; + right = (tsk_id_t) self->num_sites; + if (in_index < M) { + right = TSK_MIN(right, in[in_index].left); + } + if (out_index < M) { + right = TSK_MIN(right, out[out_index].right); + } + pos = right; + } + + /* Insert the initial likelihoods. All non-zero roots are marked with a + * special value so we can identify them when the enter the tree */ + L_cache[0] = CACHE_UNSET; + for (u = 0; u < (tsk_id_t) self->num_nodes; u++) { + L_cache[u] = CACHE_UNSET; + if (parent[u] != NULL_NODE) { + L[u] = NULL_LIKELIHOOD; + } else { + L[u] = NONZERO_ROOT_LIKELIHOOD; + } + } + if (self->flags & TSI_EXTENDED_CHECKS) { + lshmm_check_state(self); + } + last_root = 0; + if (left_child[0] != NULL_NODE) { + last_root = left_child[0]; + assert(right_sib[last_root] == NULL_NODE); + } + L[last_root] = 1.0; + self->likelihood_nodes[0] = last_root; + self->num_likelihood_nodes = 1; + + remove_start = out_index; + while (left < end) { + assert(left < right); + + /* Remove the likelihoods for any nonzero roots that have just left + * the tree */ + for (l = remove_start; l < out_index; l++) { + edge = out[l]; + if (unlikely(is_nonzero_root(edge.child, parent, left_child))) { + if (L[edge.child] >= 0) { + lshmm_delete_likelihood(self, edge.child, L); + } + L[edge.child] = NONZERO_ROOT_LIKELIHOOD; + } + if (unlikely(is_nonzero_root(edge.parent, parent, left_child))) { + if (L[edge.parent] >= 0) { + lshmm_delete_likelihood(self, edge.parent, L); + } + L[edge.parent] = NONZERO_ROOT_LIKELIHOOD; + } + } + + root = 0; + if (left_child[0] != NULL_NODE) { + root = left_child[0]; + assert(right_sib[root] == NULL_NODE); + } + if (root != last_root) { + if (last_root == 0) { + lshmm_delete_likelihood(self, last_root, L); + L[last_root] = NONZERO_ROOT_LIKELIHOOD; + } + if (L[root] == NONZERO_ROOT_LIKELIHOOD) { + L[root] = 0; + self->likelihood_nodes[self->num_likelihood_nodes] = root; + self->num_likelihood_nodes++; + } + last_root = root; + } + + if (self->flags & TSI_EXTENDED_CHECKS) { + lshmm_check_state(self); + } + for (site = TSK_MAX(left, start); site < TSK_MIN(right, end); site++) { + ret = lshmm_update_site_state( + self, site, haplotype[site], parent, L, L_cache); + if (ret != 0) { + goto out; + } + } + + /* Move on to the next tree */ + remove_start = out_index; + while (out_index < M && out[out_index].right == right) { + edge = out[out_index]; + out_index++; + remove_edge(edge, parent, left_child, right_child, left_sib, right_sib); + assert(L[edge.child] != NONZERO_ROOT_LIKELIHOOD); + if (L[edge.child] == NULL_LIKELIHOOD) { + u = edge.parent; + while (likely(L[u] == NULL_LIKELIHOOD) + && likely(L_cache[u] == CACHE_UNSET)) { + u = parent[u]; + } + L_child = L_cache[u]; + if (unlikely(L_child == CACHE_UNSET)) { + L_child = L[u]; + } + assert(L_child >= 0); + u = edge.parent; + /* Fill in the cache by traversing back upwards */ + while (likely(L[u] == NULL_LIKELIHOOD) + && likely(L_cache[u] == CACHE_UNSET)) { + L_cache[u] = L_child; + u = parent[u]; + } + L[edge.child] = L_child; + self->likelihood_nodes[self->num_likelihood_nodes] = edge.child; + self->num_likelihood_nodes++; + } + } + /* reset the L cache */ + for (l = remove_start; l < out_index; l++) { + edge = out[l]; + u = edge.parent; + while (likely(L_cache[u] != CACHE_UNSET)) { + L_cache[u] = CACHE_UNSET; + u = parent[u]; + } + } + + left = right; + while (in_index < M && in[in_index].left == left) { + edge = in[in_index]; + in_index++; + insert_edge(edge, parent, left_child, right_child, left_sib, right_sib); + /* Insert zero likelihoods for any nonzero roots that have entered + * the tree. Note we don't bother trying to compress the tree here + * because this will be done for the next site anyway. */ + if (unlikely( + edge.parent != 0 && L[edge.parent] == NONZERO_ROOT_LIKELIHOOD)) { + L[edge.parent] = 0; + self->likelihood_nodes[self->num_likelihood_nodes] = edge.parent; + self->num_likelihood_nodes++; + } + if (unlikely(L[edge.child] == NONZERO_ROOT_LIKELIHOOD)) { + L[edge.child] = 0; + self->likelihood_nodes[self->num_likelihood_nodes] = edge.child; + self->num_likelihood_nodes++; + } + } + right = (tsk_id_t) self->num_sites; + if (in_index < M) { + right = TSK_MIN(right, in[in_index].left); + } + if (out_index < M) { + right = TSK_MIN(right, out[out_index].right); + } + } +out: + return ret; +} + +int +lshmm_find_path(lshmm_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype, + allele_t *matched_haplotype, size_t *num_output_edges, tsk_id_t **left_output, + tsk_id_t **right_output, tsk_id_t **parent_output) +{ + int ret = 0; + + ret = lshmm_reset(self); + if (ret != 0) { + goto out; + } + ret = lshmm_run_forwards_match(self, start, end, haplotype); + if (ret != 0) { + goto out; + } + ret = lshmm_run_traceback(self, start, end, haplotype, matched_haplotype); + if (ret != 0) { + goto out; + } + /* Reset some memory for the next call */ + memset( + self->traceback + start, 0, ((size_t)(end - start)) * sizeof(*self->traceback)); + memset(self->max_likelihood_node + start, 0xff, + ((size_t)(end - start)) * sizeof(*self->max_likelihood_node)); + + *left_output = self->output.left; + *right_output = self->output.right; + *parent_output = self->output.parent; + *num_output_edges = self->output.size; +out: + return ret; +} + +double +lshmm_get_mean_traceback_size(lshmm_t *self) +{ + return (double) self->total_traceback_size / ((double) self->num_sites); +} + +size_t +lshmm_get_total_memory(lshmm_t *self) +{ + size_t total = self->traceback_allocator.total_size; + /* TODO add contributions from other objects */ + + return total; +} diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 628a5519..36659616 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -209,6 +209,56 @@ typedef struct { } output; } ancestor_matcher_t; +typedef struct { + tsk_flags_t flags; + const tsk_treeseq_t *ts; + tsk_size_t num_sites; + tsk_size_t num_nodes; + tsk_size_t num_edges; + /* FIXME Copying these in here as a quick way of getting the code working + * again. However, the memory cost might actually be worth it, needs checking + */ + edge_t *left_index_edges; + edge_t *right_index_edges; + /* Copying this in here for simplicity. */ + struct { + mutation_list_node_t **mutations; + tsk_size_t *num_alleles; + } sites; + /* Input LS model rates */ + unsigned int precision; + double *recombination_rate; + double *mismatch_rate; + /* The quintuply linked tree */ + tsk_id_t *parent; + tsk_id_t *left_child; + tsk_id_t *right_child; + tsk_id_t *left_sib; + tsk_id_t *right_sib; + double *likelihood; + double *likelihood_cache; + allele_t *allelic_state; + int num_likelihood_nodes; + /* At each site, record a node with the maximum likelihood. */ + tsk_id_t *max_likelihood_node; + /* Used during traceback to map nodes where recombination is required. */ + int8_t *recombination_required; + tsk_id_t *likelihood_nodes_tmp; + tsk_id_t *likelihood_nodes; + node_state_list_t *traceback; + tsk_blkalloc_t traceback_allocator; + size_t total_traceback_size; + size_t traceback_block_size; + size_t traceback_realloc_size; + struct { + tsk_id_t *left; + tsk_id_t *right; + tsk_id_t *parent; + size_t size; + size_t max_size; + } output; +} lshmm_t; + int ancestor_builder_alloc(ancestor_builder_t *self, size_t num_samples, size_t num_sites, int mmap_fd, int flags); int ancestor_builder_free(ancestor_builder_t *self); @@ -267,6 +317,16 @@ int tree_sequence_builder_dump_edges(tree_sequence_builder_t *self, tsk_id_t *le int tree_sequence_builder_dump_mutations(tree_sequence_builder_t *self, tsk_id_t *site, tsk_id_t *node, allele_t *derived_state, tsk_id_t *parent); +int lshmm_alloc(lshmm_t *self, const tsk_treeseq_t *ts, const double *recombination_rate, + const double *mismatch_rate, unsigned int precision, tsk_flags_t flags); +int lshmm_free(lshmm_t *self); +int lshmm_find_path(lshmm_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype, + allele_t *matched_haplotype, size_t *num_output_edges, tsk_id_t **left_output, + tsk_id_t **right_output, tsk_id_t **parent_output); +int lshmm_print_state(lshmm_t *self, FILE *out); +double lshmm_get_mean_traceback_size(lshmm_t *self); +size_t lshmm_get_total_memory(lshmm_t *self); + int packbits(const allele_t *restrict source, size_t len, uint8_t *restrict dest); void unpackbits(const uint8_t *restrict source, size_t len, allele_t *restrict dest); From 540265acac90f548c9b13d368777959a7da6f513 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 25 Apr 2023 22:25:47 +0100 Subject: [PATCH 02/42] Initial dump of required HMM code --- tests/test_ls_hmm.py | 505 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 505 insertions(+) create mode 100644 tests/test_ls_hmm.py diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py new file mode 100644 index 00000000..08ae35a5 --- /dev/null +++ b/tests/test_ls_hmm.py @@ -0,0 +1,505 @@ +""" +Tests for the haplotype matching algorithm. +""" +import numpy as np +import pytest +import tskit + + +def example_binary(n, L): + tables = tskit.TableCollection(L) + tables.nodes.add_row(time=1) + tables.nodes.add_row(time=0) + tables.edges.add_row(0, L, parent=0, child=1) + tree = tskit.Tree.generate_balanced(n, span=L) + binary_tables = tree.tree_sequence.dump_tables() + binary_tables.nodes.time += 1 + tables.nodes.time += np.max(binary_tables.nodes.time) + 1 + binary_tables.edges.child += len(tables.nodes) + binary_tables.edges.parent += len(tables.nodes) + for node in binary_tables.nodes: + tables.nodes.append(node) + for edge in binary_tables.edges: + tables.edges.append(edge) + # FIXME brittle + tables.edges.add_row(0, L, parent=1, child=tree.root + 2) + tables.sort() + return tables.tree_sequence() + + +# Special values used to indicate compressed paths and nodes that are +# not present in the current tree. +COMPRESSED = -1 +NONZERO_ROOT = -2 + + +class AncestorMatcher: + def __init__( + self, + ts, + recombination=None, + mismatch=None, + precision=None, + extended_checks=False, + ): + self.ts = ts + self.mismatch = mismatch + self.recombination = recombination + self.precision = precision + self.extended_checks = extended_checks + self.num_sites = ts.num_sites + self.parent = None + self.left_child = None + self.right_sib = None + self.traceback = None + self.max_likelihood_node = None + self.likelihood = None + self.likelihood_nodes = None + self.allelic_state = None + self.total_memory = 0 + + def print_state(self): + # TODO - don't crash when self.max_likelihood_node or self.traceback == None + print("Ancestor matcher state") + print("max_L_node\ttraceback") + for site_index in range(self.num_sites): + print( + site_index, + self.max_likelihood_node[site_index], + self.traceback[site_index], + sep="\t", + ) + + def is_root(self, u): + return self.parent[u] == tskit.NULL + + def check_likelihoods(self): + assert len(set(self.likelihood_nodes)) == len(self.likelihood_nodes) + # Every value in L_nodes must be positive. + for u in self.likelihood_nodes: + assert self.likelihood[u] >= 0 + for u, v in enumerate(self.likelihood): + # Every non-negative value in L should be in L_nodes + if v >= 0: + assert u in self.likelihood_nodes + # Roots other than 0 should have v == -2 + if u != 0 and self.is_root(u) and self.left_child[u] == -1: + # print("root: u = ", u, self.parent[u], self.left_child[u]) + assert v == -2 + + def set_allelic_state(self, site): + """ + Sets the allelic state array to reflect the mutations at this site. + """ + # We know that 0 is always a root. + # FIXME assuming for now that the ancestral state is always zero. + self.allelic_state[0] = 0 + for node, state in self.tree_sequence_builder.mutations[site]: + self.allelic_state[node] = state + + def unset_allelic_state(self, site): + """ + Sets the allelic state values for this site back to null. + """ + # We know that 0 is always a root. + self.allelic_state[0] = -1 + for node, _ in self.tree_sequence_builder.mutations[site]: + self.allelic_state[node] = -1 + assert np.all(self.allelic_state == -1) + + def update_site(self, site, haplotype_state): + n = self.tree_sequence_builder.num_match_nodes + rho = self.recombination[site] + mu = self.mismatch[site] + num_alleles = self.tree_sequence_builder.num_alleles[site] + assert haplotype_state < num_alleles + + self.set_allelic_state(site) + + for node, _ in self.tree_sequence_builder.mutations[site]: + # Insert an new L-value for the mutation node if needed. + if self.likelihood[node] == COMPRESSED: + u = node + while self.likelihood[u] == COMPRESSED: + u = self.parent[u] + self.likelihood[node] = self.likelihood[u] + self.likelihood_nodes.append(node) + + max_L = -1 + max_L_node = -1 + for u in self.likelihood_nodes: + # Get the allelic_state at u. TODO we can cache these states to + # avoid some upward traversals. + v = u + while self.allelic_state[v] == -1: + v = self.parent[v] + assert v != -1 + + p_last = self.likelihood[u] + p_no_recomb = p_last * (1 - rho + rho / n) + p_recomb = rho / n + recombination_required = False + if p_no_recomb > p_recomb: + p_t = p_no_recomb + else: + p_t = p_recomb + recombination_required = True + self.traceback[site][u] = recombination_required + p_e = mu + if haplotype_state in (tskit.MISSING_DATA, self.allelic_state[v]): + p_e = 1 - (num_alleles - 1) * mu + self.likelihood[u] = p_t * p_e + + if self.likelihood[u] > max_L: + max_L = self.likelihood[u] + max_L_node = u + + if max_L == 0: + if mu == 0: + raise _tsinfer.MatchImpossible( + "Trying to match non-existent allele with zero mismatch rate" + ) + elif mu == 1: + raise _tsinfer.MatchImpossible( + "Match impossible: mismatch prob=1 & no haplotype with other allele" + ) + elif rho == 0: + raise _tsinfer.MatchImpossible( + "Matching failed with recombination=0, potentially due to " + "rounding issues. Try increasing the precision value" + ) + raise AssertionError("Unexpected matching failure") + + for u in self.likelihood_nodes: + x = self.likelihood[u] / max_L + self.likelihood[u] = round(x, self.precision) + + self.max_likelihood_node[site] = max_L_node + self.unset_allelic_state(site) + self.compress_likelihoods() + + def compress_likelihoods(self): + L_cache = np.zeros_like(self.likelihood) - 1 + cached_paths = [] + old_likelihood_nodes = list(self.likelihood_nodes) + self.likelihood_nodes.clear() + for u in old_likelihood_nodes: + # We need to find the likelihood of the parent of u. If this is + # the same as u, we can delete it. + if not self.is_root(u): + p = self.parent[u] + cached_paths.append(p) + v = p + while self.likelihood[v] == -1 and L_cache[v] == -1: + v = self.parent[v] + L_p = L_cache[v] + if L_p == -1: + L_p = self.likelihood[v] + # Fill in the L cache + v = p + while self.likelihood[v] == -1 and L_cache[v] == -1: + L_cache[v] = L_p + v = self.parent[v] + + if self.likelihood[u] == L_p: + # Delete u from the map + self.likelihood[u] = -1 + if self.likelihood[u] >= 0: + self.likelihood_nodes.append(u) + # Reset the L cache + for u in cached_paths: + v = u + while v != -1 and L_cache[v] != -1: + L_cache[v] = -1 + v = self.parent[v] + assert np.all(L_cache == -1) + + def remove_edge(self, edge): + p = edge.parent + c = edge.child + lsib = self.left_sib[c] + rsib = self.right_sib[c] + if lsib == tskit.NULL: + self.left_child[p] = rsib + else: + self.right_sib[lsib] = rsib + if rsib == tskit.NULL: + self.right_child[p] = lsib + else: + self.left_sib[rsib] = lsib + self.parent[c] = tskit.NULL + self.left_sib[c] = tskit.NULL + self.right_sib[c] = tskit.NULL + + def insert_edge(self, edge): + p = edge.parent + c = edge.child + self.parent[c] = p + u = self.right_child[p] + if u == tskit.NULL: + self.left_child[p] = c + self.left_sib[c] = tskit.NULL + self.right_sib[c] = tskit.NULL + else: + self.right_sib[u] = c + self.left_sib[c] = u + self.right_sib[c] = tskit.NULL + self.right_child[p] = c + + def is_nonzero_root(self, u): + return u != 0 and self.is_root(u) and self.left_child[u] == -1 + + def find_path(self, h, start, end, match): + Il = self.tree_sequence_builder.left_index + Ir = self.tree_sequence_builder.right_index + M = len(Il) + n = self.tree_sequence_builder.num_nodes + m = self.tree_sequence_builder.num_sites + self.parent = np.zeros(n, dtype=int) - 1 + self.left_child = np.zeros(n, dtype=int) - 1 + self.right_child = np.zeros(n, dtype=int) - 1 + self.left_sib = np.zeros(n, dtype=int) - 1 + self.right_sib = np.zeros(n, dtype=int) - 1 + self.traceback = [{} for _ in range(m)] + self.max_likelihood_node = np.zeros(m, dtype=int) - 1 + self.allelic_state = np.zeros(n, dtype=int) - 1 + + self.likelihood = np.full(n, NONZERO_ROOT, dtype=float) + self.likelihood_nodes = [] + L_cache = np.zeros_like(self.likelihood) - 1 + + # print("MATCH: start=", start, "end = ", end, "h = ", h) + j = 0 + k = 0 + left = 0 + pos = 0 + right = m + if j < M and start < Il.peekitem(j)[1].left: + right = Il.peekitem(j)[1].left + while j < M and k < M and Il.peekitem(j)[1].left <= start: + while Ir.peekitem(k)[1].right == pos: + self.remove_edge(Ir.peekitem(k)[1]) + k += 1 + while j < M and Il.peekitem(j)[1].left == pos: + self.insert_edge(Il.peekitem(j)[1]) + j += 1 + left = pos + right = m + if j < M: + right = min(right, Il.peekitem(j)[1].left) + if k < M: + right = min(right, Ir.peekitem(k)[1].right) + pos = right + assert left < right + + for u in range(n): + if not self.is_root(u): + self.likelihood[u] = -1 + + last_root = 0 + if self.left_child[0] != -1: + last_root = self.left_child[0] + assert self.right_sib[last_root] == -1 + self.likelihood_nodes.append(last_root) + self.likelihood[last_root] = 1 + + remove_start = k + while left < end: + # print("START OF TREE LOOP", left, right) + # print("L:", {u: self.likelihood[u] for u in self.likelihood_nodes}) + assert left < right + for site_index in range(remove_start, k): + edge = Ir.peekitem(site_index)[1] + for u in [edge.parent, edge.child]: + if self.is_nonzero_root(u): + self.likelihood[u] = NONZERO_ROOT + if u in self.likelihood_nodes: + self.likelihood_nodes.remove(u) + root = 0 + if self.left_child[0] != -1: + root = self.left_child[0] + assert self.right_sib[root] == -1 + + if root != last_root: + if last_root == 0: + self.likelihood[last_root] = NONZERO_ROOT + self.likelihood_nodes.remove(last_root) + if self.likelihood[root] == NONZERO_ROOT: + self.likelihood[root] = 0 + self.likelihood_nodes.append(root) + last_root = root + + if self.extended_checks: + self.check_likelihoods() + for site in range(max(left, start), min(right, end)): + self.update_site(site, h[site]) + + remove_start = k + while k < M and Ir.peekitem(k)[1].right == right: + edge = Ir.peekitem(k)[1] + self.remove_edge(edge) + k += 1 + if self.likelihood[edge.child] == -1: + # If the child has an L value, traverse upwards until we + # find the parent that carries it. To avoid repeated traversals + # along the same path we make a cache of the L values. + u = edge.parent + while self.likelihood[u] == -1 and L_cache[u] == -1: + u = self.parent[u] + L_child = L_cache[u] + if L_child == -1: + L_child = self.likelihood[u] + # Fill in the L_cache + u = edge.parent + while self.likelihood[u] == -1 and L_cache[u] == -1: + L_cache[u] = L_child + u = self.parent[u] + self.likelihood[edge.child] = L_child + self.likelihood_nodes.append(edge.child) + # Clear the L cache + for site_index in range(remove_start, k): + edge = Ir.peekitem(site_index)[1] + u = edge.parent + while L_cache[u] != -1: + L_cache[u] = -1 + u = self.parent[u] + assert np.all(L_cache == -1) + + left = right + while j < M and Il.peekitem(j)[1].left == left: + edge = Il.peekitem(j)[1] + self.insert_edge(edge) + j += 1 + # There's no point in compressing the likelihood tree here as we'll be + # doing it after we update the first site anyway. + for u in [edge.parent, edge.child]: + if u != 0 and self.likelihood[u] == NONZERO_ROOT: + self.likelihood[u] = 0 + self.likelihood_nodes.append(u) + right = m + if j < M: + right = min(right, Il.peekitem(j)[1].left) + if k < M: + right = min(right, Ir.peekitem(k)[1].right) + + return self.run_traceback(start, end, match) + + def run_traceback(self, start, end, match): + Il = self.tree_sequence_builder.left_index + Ir = self.tree_sequence_builder.right_index + M = len(Il) + u = self.max_likelihood_node[end - 1] + output_edge = Edge(right=end, parent=u) + output_edges = [output_edge] + recombination_required = ( + np.zeros(self.tree_sequence_builder.num_nodes, dtype=int) - 1 + ) + + # Now go back through the trees. + j = M - 1 + k = M - 1 + # Construct the matched haplotype + match[:] = 0 + match[:start] = tskit.MISSING_DATA + match[end:] = tskit.MISSING_DATA + # Reset the tree. + self.parent[:] = -1 + self.left_child[:] = -1 + self.right_child[:] = -1 + self.left_sib[:] = -1 + self.right_sib[:] = -1 + + pos = self.tree_sequence_builder.num_sites + while pos > start: + # print("Top of loop: pos = ", pos) + while k >= 0 and Il.peekitem(k)[1].left == pos: + edge = Il.peekitem(k)[1] + self.remove_edge(edge) + k -= 1 + while j >= 0 and Ir.peekitem(j)[1].right == pos: + edge = Ir.peekitem(j)[1] + self.insert_edge(edge) + j -= 1 + right = pos + left = 0 + if k >= 0: + left = max(left, Il.peekitem(k)[1].left) + if j >= 0: + left = max(left, Ir.peekitem(j)[1].right) + pos = left + + assert left < right + for site_index in range(min(right, end) - 1, max(left, start) - 1, -1): + u = output_edge.parent + self.set_allelic_state(site_index) + v = u + while self.allelic_state[v] == -1: + v = self.parent[v] + match[site_index] = self.allelic_state[v] + self.unset_allelic_state(site_index) + + for u, recombine in self.traceback[site_index].items(): + # Mark the traceback nodes on the tree. + recombination_required[u] = recombine + # Now traverse up the tree from the current node. The first marked node + # we meet tells us whether we need to recombine. + u = output_edge.parent + while u != 0 and recombination_required[u] == -1: + u = self.parent[u] + if recombination_required[u] and site_index > start: + output_edge.left = site_index + u = self.max_likelihood_node[site_index - 1] + output_edge = Edge(right=site_index, parent=u) + output_edges.append(output_edge) + # Reset the nodes in the recombination tree. + for u in self.traceback[site_index].keys(): + recombination_required[u] = -1 + output_edge.left = start + + self.mean_traceback_size = sum(len(t) for t in self.traceback) / self.num_sites + + left = np.zeros(len(output_edges), dtype=np.uint32) + right = np.zeros(len(output_edges), dtype=np.uint32) + parent = np.zeros(len(output_edges), dtype=np.int32) + for j, e in enumerate(output_edges): + assert e.left >= start + assert e.right <= end + # TODO this does happen in the C code, so if it ever happends in a Python + # instance we need to pop the last edge off the list. Or, see why we're + # generating it in the first place. + assert e.left < e.right + left[j] = e.left + right[j] = e.right + parent[j] = e.parent + + return left, right, parent + + +class TestSingleBalancedTreeExample: + # 5.00┊ 0 ┊ + # ┊ ┃ ┊ + # 4.00┊ 1 ┊ + # ┊ ┃ ┊ + # 3.00┊ 8 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 6 7 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 2 3 4 5 ┊ + # 0 4 + + @staticmethod + def ts(): + tables = example_binary(4, 4).dump_tables() + # Add a site for each sample with a single mutation above that sample. + for j in range(4): + tables.sites.add_row(j, "0") + tables.mutations.add_row(site=j, node=2 + j, derived_state="1") + return tables.tree_sequence() + + def test_something(self): + ts = self.ts() + print(ts.draw_text()) + am = AncestorMatcher(ts) + print(am) + match = [0 for _ in range(4)] + am.find_path([1, 0, 0, 0], 0, 4, match) From 6f066cc062e9b08cbc8be214609b4e33752d1145 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 25 Apr 2023 22:45:35 +0100 Subject: [PATCH 03/42] First steps - isolated code works with shim Builder class --- tests/test_ls_hmm.py | 63 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 08ae35a5..9c5f7b8d 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -1,10 +1,23 @@ """ Tests for the haplotype matching algorithm. """ +import collections +import dataclasses + import numpy as np -import pytest +import sortedcontainers import tskit +import _tsinfer + + +@dataclasses.dataclass +class Edge: + left: float = dataclasses.field(default=None) + right: float = dataclasses.field(default=None) + parent: int = dataclasses.field(default=None) + child: int = dataclasses.field(default=None) + def example_binary(n, L): tables = tskit.TableCollection(L) @@ -33,19 +46,44 @@ def example_binary(n, L): NONZERO_ROOT = -2 +class TreeSequenceBuilder: + # Temporary dummy implementation to get things working. + def __init__(self, ts): + self.time = ts.nodes_time + self.num_nodes = ts.num_nodes + self.num_match_nodes = ts.num_nodes + self.num_sites = ts.num_sites + self.left_index = sortedcontainers.SortedDict() + self.right_index = sortedcontainers.SortedDict() + + for edge in ts.edges(): + self.left_index[(edge.left, self.time[edge.child], edge.child)] = edge + self.right_index[(edge.right, -self.time[edge.child], edge.child)] = edge + self.num_alleles = [var.num_alleles for var in ts.variants()] + + self.mutations = collections.defaultdict(list) + for site in ts.sites(): + for mutation in site.mutations: + # FIXME - should be allele index + self.mutations[site.id].append((mutation.node, 1)) + + class AncestorMatcher: def __init__( self, ts, - recombination=None, - mismatch=None, - precision=None, + # recombination=None, + # mismatch=None, + # precision=None, extended_checks=False, ): - self.ts = ts - self.mismatch = mismatch - self.recombination = recombination - self.precision = precision + self.tree_sequence_builder = TreeSequenceBuilder(ts) + self.recombination = np.zeros(ts.num_sites) + 1e-9 + self.mismatch = np.zeros(ts.num_sites) + # self.mismatch = mismatch + # self.recombination = recombination + # self.precision = precision + self.precision = 14 self.extended_checks = extended_checks self.num_sites = ts.num_sites self.parent = None @@ -498,8 +536,9 @@ def ts(): def test_something(self): ts = self.ts() - print(ts.draw_text()) am = AncestorMatcher(ts) - print(am) - match = [0 for _ in range(4)] - am.find_path([1, 0, 0, 0], 0, 4, match) + match = np.zeros(4, dtype=int) + left, right, parent = am.find_path([1, 0, 0, 0], 0, 4, match) + assert list(left) == [0] + assert list(right) == [4] + assert list(parent) == [2] From 936c338fca41f0ff3011929483217beafda4dcd7 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 26 Apr 2023 22:17:34 +0100 Subject: [PATCH 04/42] Some more tests --- tests/test_ls_hmm.py | 143 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 114 insertions(+), 29 deletions(-) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 9c5f7b8d..20a29747 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -3,8 +3,10 @@ """ import collections import dataclasses +import io import numpy as np +import pytest import sortedcontainers import tskit @@ -19,27 +21,40 @@ class Edge: child: int = dataclasses.field(default=None) -def example_binary(n, L): - tables = tskit.TableCollection(L) - tables.nodes.add_row(time=1) - tables.nodes.add_row(time=0) - tables.edges.add_row(0, L, parent=0, child=1) - tree = tskit.Tree.generate_balanced(n, span=L) - binary_tables = tree.tree_sequence.dump_tables() - binary_tables.nodes.time += 1 - tables.nodes.time += np.max(binary_tables.nodes.time) + 1 - binary_tables.edges.child += len(tables.nodes) - binary_tables.edges.parent += len(tables.nodes) - for node in binary_tables.nodes: +def add_vestigial_root(ts): + """ + Adds the nodes and edges required by tsinfer to the specified tree sequence + and returns it. + """ + if not ts.discrete_genome: + raise ValueError("Only discrete genome coords supported") + + base_tables = ts.dump_tables() + tables = base_tables.copy() + tables.nodes.clear() + t = ts.max_root_time + tables.nodes.add_row(time=t + 1) + num_additonal_nodes = len(tables.nodes) + tables.mutations.node += num_additonal_nodes + tables.edges.child += num_additonal_nodes + tables.edges.parent += num_additonal_nodes + for node in base_tables.nodes: tables.nodes.append(node) - for edge in binary_tables.edges: - tables.edges.append(edge) - # FIXME brittle - tables.edges.add_row(0, L, parent=1, child=tree.root + 2) + for tree in ts.trees(): + root = tree.root + num_additonal_nodes + tables.edges.add_row( + tree.interval.left, tree.interval.right, parent=0, child=root + ) + tables.edges.squash() tables.sort() return tables.tree_sequence() +def example_binary(n, L): + tree = tskit.Tree.generate_balanced(n, span=L) + return add_vestigial_root(tree.tree_sequence) + + # Special values used to indicate compressed paths and nodes that are # not present in the current tree. COMPRESSED = -1 @@ -56,7 +71,10 @@ def __init__(self, ts): self.left_index = sortedcontainers.SortedDict() self.right_index = sortedcontainers.SortedDict() - for edge in ts.edges(): + for tsk_edge in ts.edges(): + edge = Edge( + int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child + ) self.left_index[(edge.left, self.time[edge.child], edge.child)] = edge self.right_index[(edge.right, -self.time[edge.child], edge.child)] = edge self.num_alleles = [var.num_alleles for var in ts.variants()] @@ -514,15 +532,13 @@ def run_traceback(self, start, end, match): class TestSingleBalancedTreeExample: - # 5.00┊ 0 ┊ - # ┊ ┃ ┊ - # 4.00┊ 1 ┊ + # 4.00┊ 0 ┊ # ┊ ┃ ┊ - # 3.00┊ 8 ┊ + # 3.00┊ 7 ┊ # ┊ ┏━┻━┓ ┊ - # 2.00┊ 6 7 ┊ + # 2.00┊ 5 6 ┊ # ┊ ┏┻┓ ┏┻┓ ┊ - # 1.00┊ 2 3 4 5 ┊ + # 1.00┊ 1 2 3 4 ┊ # 0 4 @staticmethod @@ -531,14 +547,83 @@ def ts(): # Add a site for each sample with a single mutation above that sample. for j in range(4): tables.sites.add_row(j, "0") - tables.mutations.add_row(site=j, node=2 + j, derived_state="1") + tables.mutations.add_row(site=j, node=1 + j, derived_state="1") + return tables.tree_sequence() + + @pytest.mark.parametrize("j", [0, 1, 2, 3]) + def test_match_sample(self, j): + ts = self.ts() + am = AncestorMatcher(ts) + m = 4 + match = np.zeros(m, dtype=int) + h = np.zeros(m) + h[j] = 1 + left, right, parent = am.find_path(h, 0, m, match) + assert list(left) == [0] + assert list(right) == [m] + assert list(parent) == [ts.samples()[j]] + + +class TestMultiTreeExample: + # 1.84┊ 0 ┊ 0 ┊ + # ┊ ┃ ┊ ┃ ┊ + # 0.84┊ 8 ┊ 8 ┊ + # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ + # 0.42┊ ┃ ┃ ┊ 7 ┃ ┊ + # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ + # 0.05┊ 6 ┃ ┊ ┃ ┃ ┃ ┊ + # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ + # 0.04┊ ┃ 5 ┃ ┊ ┃ ┃ 5 ┊ + # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ + # 0.00┊ 1 2 3 4 ┊ 1 4 2 3 ┊ + # 0 2 4 + @staticmethod + def ts(): + nodes = """\ + is_sample time + 0 1.838075 + 1 0.000000 + 1 0.000000 + 1 0.000000 + 1 0.000000 + 0 0.041304 + 0 0.045967 + 0 0.416719 + 0 0.838075 + """ + edges = """\ + left right parent child + 0.000000 4.000000 5 2 + 0.000000 4.000000 5 3 + 0.000000 2.000000 6 1 + 0.000000 2.000000 6 5 + 2.000000 4.000000 7 1 + 2.000000 4.000000 7 4 + 0.000000 2.000000 8 4 + 2.000000 4.000000 8 5 + 0.000000 2.000000 8 6 + 2.000000 4.000000 8 7 + 0.000000 4.000000 0 8 + """ + ts = tskit.load_text( + nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False + ) + tables = ts.dump_tables() + # Add a site for each sample with a single mutation above that sample. + for j in range(4): + tables.sites.add_row(j, "0") + tables.mutations.add_row(site=j, node=1 + j, derived_state="1") return tables.tree_sequence() - def test_something(self): + @pytest.mark.parametrize("j", [0, 1, 2, 3]) + def test_match_sample(self, j): ts = self.ts() am = AncestorMatcher(ts) - match = np.zeros(4, dtype=int) - left, right, parent = am.find_path([1, 0, 0, 0], 0, 4, match) + m = 4 + match = np.zeros(m, dtype=int) + h = np.zeros(m) + h[j] = 1 + left, right, parent = am.find_path(h, 0, m, match) assert list(left) == [0] - assert list(right) == [4] - assert list(parent) == [2] + assert list(right) == [m] + assert list(parent) == [ts.samples()[j]] From 2e2b19243e3a0dc182c94478b30106ca89411257 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 26 Apr 2023 22:29:13 +0100 Subject: [PATCH 05/42] Remove TreeSequenceBuilder g --- tests/test_ls_hmm.py | 106 +++++++++++++++++++++++++------------------ 1 file changed, 63 insertions(+), 43 deletions(-) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 20a29747..f1c212d9 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -15,8 +15,8 @@ @dataclasses.dataclass class Edge: - left: float = dataclasses.field(default=None) - right: float = dataclasses.field(default=None) + left: int = dataclasses.field(default=None) + right: int = dataclasses.field(default=None) parent: int = dataclasses.field(default=None) child: int = dataclasses.field(default=None) @@ -61,31 +61,6 @@ def example_binary(n, L): NONZERO_ROOT = -2 -class TreeSequenceBuilder: - # Temporary dummy implementation to get things working. - def __init__(self, ts): - self.time = ts.nodes_time - self.num_nodes = ts.num_nodes - self.num_match_nodes = ts.num_nodes - self.num_sites = ts.num_sites - self.left_index = sortedcontainers.SortedDict() - self.right_index = sortedcontainers.SortedDict() - - for tsk_edge in ts.edges(): - edge = Edge( - int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child - ) - self.left_index[(edge.left, self.time[edge.child], edge.child)] = edge - self.right_index[(edge.right, -self.time[edge.child], edge.child)] = edge - self.num_alleles = [var.num_alleles for var in ts.variants()] - - self.mutations = collections.defaultdict(list) - for site in ts.sites(): - for mutation in site.mutations: - # FIXME - should be allele index - self.mutations[site.id].append((mutation.node, 1)) - - class AncestorMatcher: def __init__( self, @@ -95,7 +70,6 @@ def __init__( # precision=None, extended_checks=False, ): - self.tree_sequence_builder = TreeSequenceBuilder(ts) self.recombination = np.zeros(ts.num_sites) + 1e-9 self.mismatch = np.zeros(ts.num_sites) # self.mismatch = mismatch @@ -114,6 +88,26 @@ def __init__( self.allelic_state = None self.total_memory = 0 + # stuff that used to be in TreeSequenceBuilder + self.num_nodes = ts.num_nodes + self.num_match_nodes = ts.num_nodes + self.num_alleles = [var.num_alleles for var in ts.variants()] + self.mutations = collections.defaultdict(list) + for site in ts.sites(): + for mutation in site.mutations: + # FIXME - should be allele index + self.mutations[site.id].append((mutation.node, 1)) + + self.left_index = sortedcontainers.SortedDict() + self.right_index = sortedcontainers.SortedDict() + time = ts.nodes_time + for tsk_edge in ts.edges(): + edge = Edge( + int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child + ) + self.left_index[(edge.left, time[edge.child], edge.child)] = edge + self.right_index[(edge.right, -time[edge.child], edge.child)] = edge + def print_state(self): # TODO - don't crash when self.max_likelihood_node or self.traceback == None print("Ancestor matcher state") @@ -150,7 +144,7 @@ def set_allelic_state(self, site): # We know that 0 is always a root. # FIXME assuming for now that the ancestral state is always zero. self.allelic_state[0] = 0 - for node, state in self.tree_sequence_builder.mutations[site]: + for node, state in self.mutations[site]: self.allelic_state[node] = state def unset_allelic_state(self, site): @@ -159,20 +153,20 @@ def unset_allelic_state(self, site): """ # We know that 0 is always a root. self.allelic_state[0] = -1 - for node, _ in self.tree_sequence_builder.mutations[site]: + for node, _ in self.mutations[site]: self.allelic_state[node] = -1 assert np.all(self.allelic_state == -1) def update_site(self, site, haplotype_state): - n = self.tree_sequence_builder.num_match_nodes + n = self.num_match_nodes rho = self.recombination[site] mu = self.mismatch[site] - num_alleles = self.tree_sequence_builder.num_alleles[site] + num_alleles = self.num_alleles[site] assert haplotype_state < num_alleles self.set_allelic_state(site) - for node, _ in self.tree_sequence_builder.mutations[site]: + for node, _ in self.mutations[site]: # Insert an new L-value for the mutation node if needed. if self.likelihood[node] == COMPRESSED: u = node @@ -306,11 +300,11 @@ def is_nonzero_root(self, u): return u != 0 and self.is_root(u) and self.left_child[u] == -1 def find_path(self, h, start, end, match): - Il = self.tree_sequence_builder.left_index - Ir = self.tree_sequence_builder.right_index + Il = self.left_index + Ir = self.right_index M = len(Il) - n = self.tree_sequence_builder.num_nodes - m = self.tree_sequence_builder.num_sites + n = self.num_nodes + m = self.num_sites self.parent = np.zeros(n, dtype=int) - 1 self.left_child = np.zeros(n, dtype=int) - 1 self.right_child = np.zeros(n, dtype=int) - 1 @@ -441,15 +435,13 @@ def find_path(self, h, start, end, match): return self.run_traceback(start, end, match) def run_traceback(self, start, end, match): - Il = self.tree_sequence_builder.left_index - Ir = self.tree_sequence_builder.right_index + Il = self.left_index + Ir = self.right_index M = len(Il) u = self.max_likelihood_node[end - 1] output_edge = Edge(right=end, parent=u) output_edges = [output_edge] - recombination_required = ( - np.zeros(self.tree_sequence_builder.num_nodes, dtype=int) - 1 - ) + recombination_required = np.zeros(self.num_nodes, dtype=int) - 1 # Now go back through the trees. j = M - 1 @@ -465,7 +457,7 @@ def run_traceback(self, start, end, match): self.left_sib[:] = -1 self.right_sib[:] = -1 - pos = self.tree_sequence_builder.num_sites + pos = self.num_sites while pos > start: # print("Top of loop: pos = ", pos) while k >= 0 and Il.peekitem(k)[1].left == pos: @@ -531,6 +523,10 @@ def run_traceback(self, start, end, match): return left, right, parent +# TODO the tests on these two classes are the same right now, should +# refactor. + + class TestSingleBalancedTreeExample: # 4.00┊ 0 ┊ # ┊ ┃ ┊ @@ -563,6 +559,18 @@ def test_match_sample(self, j): assert list(right) == [m] assert list(parent) == [ts.samples()[j]] + def test_switch_each_sample(self): + ts = self.ts() + am = AncestorMatcher(ts) + m = 4 + match = np.zeros(m, dtype=int) + h = np.zeros(m) + h[:] = 1 + left, right, parent = am.find_path(h, 0, m, match) + assert list(left) == [3, 2, 1, 0] + assert list(right) == [4, 3, 2, 1] + assert list(parent) == [4, 3, 2, 1] + class TestMultiTreeExample: # 1.84┊ 0 ┊ 0 ┊ @@ -627,3 +635,15 @@ def test_match_sample(self, j): assert list(left) == [0] assert list(right) == [m] assert list(parent) == [ts.samples()[j]] + + def test_switch_each_sample(self): + ts = self.ts() + am = AncestorMatcher(ts) + m = 4 + match = np.zeros(m, dtype=int) + h = np.zeros(m) + h[:] = 1 + left, right, parent = am.find_path(h, 0, m, match) + assert list(left) == [3, 2, 1, 0] + assert list(right) == [4, 3, 2, 1] + assert list(parent) == [4, 3, 2, 1] From 32d580d5f79186c275edc0ccbce1a7d7d6347e57 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 26 Apr 2023 22:42:10 +0100 Subject: [PATCH 06/42] Fake the sortedcontainers API for now --- tests/test_ls_hmm.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index f1c212d9..71736d04 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -7,7 +7,6 @@ import numpy as np import pytest -import sortedcontainers import tskit import _tsinfer @@ -61,6 +60,17 @@ def example_binary(n, L): NONZERO_ROOT = -2 +class FakeSortedSet: + def __init__(self, values): + self.values = values + + def peekitem(self, index): + return None, self.values[index] + + def __len__(self): + return len(self.values) + + class AncestorMatcher: def __init__( self, @@ -98,15 +108,25 @@ def __init__( # FIXME - should be allele index self.mutations[site.id].append((mutation.node, 1)) - self.left_index = sortedcontainers.SortedDict() - self.right_index = sortedcontainers.SortedDict() - time = ts.nodes_time - for tsk_edge in ts.edges(): + Il = ts.tables.indexes.edge_insertion_order + values = [] + for j in Il: + tsk_edge = ts.edge(j) + edge = Edge( + int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child + ) + values.append(edge) + self.left_index = FakeSortedSet(values) + + Ir = ts.tables.indexes.edge_removal_order + values = [] + for j in Ir: + tsk_edge = ts.edge(j) edge = Edge( int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child ) - self.left_index[(edge.left, time[edge.child], edge.child)] = edge - self.right_index[(edge.right, -time[edge.child], edge.child)] = edge + values.append(edge) + self.right_index = FakeSortedSet(values) def print_state(self): # TODO - don't crash when self.max_likelihood_node or self.traceback == None From 230bffec58369e13e59e7a96679cfa2a5ea0fbd0 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 26 Apr 2023 22:57:38 +0100 Subject: [PATCH 07/42] Fully move to sorted lists of edges. --- tests/test_ls_hmm.py | 93 +++++++++++++++++++------------------------- 1 file changed, 41 insertions(+), 52 deletions(-) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 71736d04..4afe51d0 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -60,15 +60,15 @@ def example_binary(n, L): NONZERO_ROOT = -2 -class FakeSortedSet: - def __init__(self, values): - self.values = values - - def peekitem(self, index): - return None, self.values[index] - - def __len__(self): - return len(self.values) +def convert_edge_list(ts, order): + values = [] + for j in order: + tsk_edge = ts.edge(j) + edge = Edge( + int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child + ) + values.append(edge) + return values class AncestorMatcher: @@ -101,6 +101,15 @@ def __init__( # stuff that used to be in TreeSequenceBuilder self.num_nodes = ts.num_nodes self.num_match_nodes = ts.num_nodes + + # Store the edges in left and right order. + # NOTE: we should probably pull these out into a per-process shared + # resources, since it would be a good bit of extra overhead to store them + # per match thread. + self.left_index = convert_edge_list(ts, ts.tables.indexes.edge_insertion_order) + self.right_index = convert_edge_list(ts, ts.tables.indexes.edge_removal_order) + + # TODO update self.num_alleles = [var.num_alleles for var in ts.variants()] self.mutations = collections.defaultdict(list) for site in ts.sites(): @@ -108,26 +117,6 @@ def __init__( # FIXME - should be allele index self.mutations[site.id].append((mutation.node, 1)) - Il = ts.tables.indexes.edge_insertion_order - values = [] - for j in Il: - tsk_edge = ts.edge(j) - edge = Edge( - int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child - ) - values.append(edge) - self.left_index = FakeSortedSet(values) - - Ir = ts.tables.indexes.edge_removal_order - values = [] - for j in Ir: - tsk_edge = ts.edge(j) - edge = Edge( - int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child - ) - values.append(edge) - self.right_index = FakeSortedSet(values) - def print_state(self): # TODO - don't crash when self.max_likelihood_node or self.traceback == None print("Ancestor matcher state") @@ -344,21 +333,21 @@ def find_path(self, h, start, end, match): left = 0 pos = 0 right = m - if j < M and start < Il.peekitem(j)[1].left: - right = Il.peekitem(j)[1].left - while j < M and k < M and Il.peekitem(j)[1].left <= start: - while Ir.peekitem(k)[1].right == pos: - self.remove_edge(Ir.peekitem(k)[1]) + if j < M and start < Il[j].left: + right = Il[j].left + while j < M and k < M and Il[j].left <= start: + while Ir[k].right == pos: + self.remove_edge(Ir[k]) k += 1 - while j < M and Il.peekitem(j)[1].left == pos: - self.insert_edge(Il.peekitem(j)[1]) + while j < M and Il[j].left == pos: + self.insert_edge(Il[j]) j += 1 left = pos right = m if j < M: - right = min(right, Il.peekitem(j)[1].left) + right = min(right, Il[j].left) if k < M: - right = min(right, Ir.peekitem(k)[1].right) + right = min(right, Ir[k].right) pos = right assert left < right @@ -379,7 +368,7 @@ def find_path(self, h, start, end, match): # print("L:", {u: self.likelihood[u] for u in self.likelihood_nodes}) assert left < right for site_index in range(remove_start, k): - edge = Ir.peekitem(site_index)[1] + edge = Ir[site_index] for u in [edge.parent, edge.child]: if self.is_nonzero_root(u): self.likelihood[u] = NONZERO_ROOT @@ -405,8 +394,8 @@ def find_path(self, h, start, end, match): self.update_site(site, h[site]) remove_start = k - while k < M and Ir.peekitem(k)[1].right == right: - edge = Ir.peekitem(k)[1] + while k < M and Ir[k].right == right: + edge = Ir[k] self.remove_edge(edge) k += 1 if self.likelihood[edge.child] == -1: @@ -428,7 +417,7 @@ def find_path(self, h, start, end, match): self.likelihood_nodes.append(edge.child) # Clear the L cache for site_index in range(remove_start, k): - edge = Ir.peekitem(site_index)[1] + edge = Ir[site_index] u = edge.parent while L_cache[u] != -1: L_cache[u] = -1 @@ -436,8 +425,8 @@ def find_path(self, h, start, end, match): assert np.all(L_cache == -1) left = right - while j < M and Il.peekitem(j)[1].left == left: - edge = Il.peekitem(j)[1] + while j < M and Il[j].left == left: + edge = Il[j] self.insert_edge(edge) j += 1 # There's no point in compressing the likelihood tree here as we'll be @@ -448,9 +437,9 @@ def find_path(self, h, start, end, match): self.likelihood_nodes.append(u) right = m if j < M: - right = min(right, Il.peekitem(j)[1].left) + right = min(right, Il[j].left) if k < M: - right = min(right, Ir.peekitem(k)[1].right) + right = min(right, Ir[k].right) return self.run_traceback(start, end, match) @@ -480,20 +469,20 @@ def run_traceback(self, start, end, match): pos = self.num_sites while pos > start: # print("Top of loop: pos = ", pos) - while k >= 0 and Il.peekitem(k)[1].left == pos: - edge = Il.peekitem(k)[1] + while k >= 0 and Il[k].left == pos: + edge = Il[k] self.remove_edge(edge) k -= 1 - while j >= 0 and Ir.peekitem(j)[1].right == pos: - edge = Ir.peekitem(j)[1] + while j >= 0 and Ir[j].right == pos: + edge = Ir[j] self.insert_edge(edge) j -= 1 right = pos left = 0 if k >= 0: - left = max(left, Il.peekitem(k)[1].left) + left = max(left, Il[k].left) if j >= 0: - left = max(left, Ir.peekitem(j)[1].right) + left = max(left, Ir[j].right) pos = left assert left < right From d49d4a385ceeeaf91573cf12503f90cbd77f3e89 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 26 Apr 2023 23:47:31 +0100 Subject: [PATCH 08/42] Stuff --- lib/ancestor_matcher.c | 3 ++- lib/tests/tests.c | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index a3e68539..3370c2da 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1133,7 +1133,8 @@ lshmm_copy_mutation_data(lshmm_t *self) self->sites.num_alleles[j] = site.mutations_length; /* Next, allocate the corresponding mutation object here and link * to it. */ - delibaretely not compiling + /* delibaretely not compiling */ + printf("FIXME mutations not working\n"); } out: return ret; diff --git a/lib/tests/tests.c b/lib/tests/tests.c index 5ded462c..fffad717 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -985,6 +985,21 @@ test_packbits_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSI_ERR_ONE_BIT_NON_BINARY); } +static void +test_matching_simplest_tree_one_site(void) +{ + int ret = 0; + tsk_table_collection_t tables; + + ret = tsk_table_collection_init(&tables, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + tables.sequence_length = 1; + printf("ADD some topology and one site here\n"); + CU_ASSERT_FATAL(1 == 0); + + tsk_table_collection_free(&tables); +} + static void test_strerror(void) { @@ -1077,6 +1092,8 @@ main(int argc, char **argv) { "test_packbits_4", test_packbits_4 }, { "test_packbits_errors", test_packbits_errors }, + { "test_matching_simplest_tree_one_site", test_matching_simplest_tree_one_site }, + { "test_strerror", test_strerror }, CU_TEST_INFO_NULL, From 9ef666bfe03d0840d7cb2e6febe872b409533b5b Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 27 Apr 2023 22:19:48 +0100 Subject: [PATCH 09/42] Partial --- tests/test_ls_hmm.py | 150 +++++++++++++++++++++++-------------------- 1 file changed, 80 insertions(+), 70 deletions(-) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 4afe51d0..8744dde1 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -20,40 +20,6 @@ class Edge: child: int = dataclasses.field(default=None) -def add_vestigial_root(ts): - """ - Adds the nodes and edges required by tsinfer to the specified tree sequence - and returns it. - """ - if not ts.discrete_genome: - raise ValueError("Only discrete genome coords supported") - - base_tables = ts.dump_tables() - tables = base_tables.copy() - tables.nodes.clear() - t = ts.max_root_time - tables.nodes.add_row(time=t + 1) - num_additonal_nodes = len(tables.nodes) - tables.mutations.node += num_additonal_nodes - tables.edges.child += num_additonal_nodes - tables.edges.parent += num_additonal_nodes - for node in base_tables.nodes: - tables.nodes.append(node) - for tree in ts.trees(): - root = tree.root + num_additonal_nodes - tables.edges.add_row( - tree.interval.left, tree.interval.right, parent=0, child=root - ) - tables.edges.squash() - tables.sort() - return tables.tree_sequence() - - -def example_binary(n, L): - tree = tskit.Tree.generate_balanced(n, span=L) - return add_vestigial_root(tree.tree_sequence) - - # Special values used to indicate compressed paths and nodes that are # not present in the current tree. COMPRESSED = -1 @@ -71,23 +37,46 @@ def convert_edge_list(ts, order): return values +class MatcherIndexes: + def __init__(self, ts): + + self.num_nodes = ts.num_nodes + self.num_sites = ts.num_sites + + # Store the edges in left and right order. + self.left_index = convert_edge_list(ts, ts.tables.indexes.edge_insertion_order) + self.right_index = convert_edge_list(ts, ts.tables.indexes.edge_removal_order) + + # TODO update + self.num_alleles = [var.num_alleles for var in ts.variants()] + self.mutations = collections.defaultdict(list) + for site in ts.sites(): + if len(site.mutations) > 1: + raise ValueError("Only single mutations supported for now") + for mutation in site.mutations: + # FIXME - should be allele index + self.mutations[site.id].append((mutation.node, 1)) + + class AncestorMatcher: def __init__( self, - ts, + matcher_indexes, # recombination=None, # mismatch=None, # precision=None, extended_checks=False, ): - self.recombination = np.zeros(ts.num_sites) + 1e-9 - self.mismatch = np.zeros(ts.num_sites) + self.matcher_indexes = matcher_indexes + self.num_sites = matcher_indexes.num_sites + self.num_nodes = matcher_indexes.num_nodes + self.recombination = np.zeros(self.num_sites) + 1e-9 + self.mismatch = np.zeros(self.num_sites) # self.mismatch = mismatch # self.recombination = recombination # self.precision = precision self.precision = 14 self.extended_checks = extended_checks - self.num_sites = ts.num_sites self.parent = None self.left_child = None self.right_sib = None @@ -98,25 +87,6 @@ def __init__( self.allelic_state = None self.total_memory = 0 - # stuff that used to be in TreeSequenceBuilder - self.num_nodes = ts.num_nodes - self.num_match_nodes = ts.num_nodes - - # Store the edges in left and right order. - # NOTE: we should probably pull these out into a per-process shared - # resources, since it would be a good bit of extra overhead to store them - # per match thread. - self.left_index = convert_edge_list(ts, ts.tables.indexes.edge_insertion_order) - self.right_index = convert_edge_list(ts, ts.tables.indexes.edge_removal_order) - - # TODO update - self.num_alleles = [var.num_alleles for var in ts.variants()] - self.mutations = collections.defaultdict(list) - for site in ts.sites(): - for mutation in site.mutations: - # FIXME - should be allele index - self.mutations[site.id].append((mutation.node, 1)) - def print_state(self): # TODO - don't crash when self.max_likelihood_node or self.traceback == None print("Ancestor matcher state") @@ -162,15 +132,15 @@ def unset_allelic_state(self, site): """ # We know that 0 is always a root. self.allelic_state[0] = -1 - for node, _ in self.mutations[site]: + for node, _ in self.matcher_indexes.mutations[site]: self.allelic_state[node] = -1 assert np.all(self.allelic_state == -1) def update_site(self, site, haplotype_state): - n = self.num_match_nodes + n = self.num_nodes rho = self.recombination[site] mu = self.mismatch[site] - num_alleles = self.num_alleles[site] + num_alleles = self.matcher_indexes.num_alleles[site] assert haplotype_state < num_alleles self.set_allelic_state(site) @@ -309,8 +279,8 @@ def is_nonzero_root(self, u): return u != 0 and self.is_root(u) and self.left_child[u] == -1 def find_path(self, h, start, end, match): - Il = self.left_index - Ir = self.right_index + Il = self.matcher_indexes.left_index + Ir = self.matcher_indexes.right_index M = len(Il) n = self.num_nodes m = self.num_sites @@ -444,8 +414,8 @@ def find_path(self, h, start, end, match): return self.run_traceback(start, end, match) def run_traceback(self, start, end, match): - Il = self.left_index - Ir = self.right_index + Il = self.matcher_indexes.left_index + Ir = self.matcher_indexes.right_index M = len(Il) u = self.max_likelihood_node[end - 1] output_edge = Edge(right=end, parent=u) @@ -532,10 +502,53 @@ def run_traceback(self, start, end, match): return left, right, parent +def run_match(ts, h): + assert len(h) == ts.num_sites + matcher_indexes = MatcherIndexes(ts) + matcher = AncestorMatcher(matcher_indexes) + match = np.zeros(ts.num_sites, dtype=int) + left, right, parent = matcher.find_path(h, 0, ts.num_sites, match) + return left, right, parent, match + + # TODO the tests on these two classes are the same right now, should # refactor. +def add_vestigial_root(ts): + """ + Adds the nodes and edges required by tsinfer to the specified tree sequence + and returns it. + """ + if not ts.discrete_genome: + raise ValueError("Only discrete genome coords supported") + + base_tables = ts.dump_tables() + tables = base_tables.copy() + tables.nodes.clear() + t = ts.max_root_time + tables.nodes.add_row(time=t + 1) + num_additonal_nodes = len(tables.nodes) + tables.mutations.node += num_additonal_nodes + tables.edges.child += num_additonal_nodes + tables.edges.parent += num_additonal_nodes + for node in base_tables.nodes: + tables.nodes.append(node) + for tree in ts.trees(): + root = tree.root + num_additonal_nodes + tables.edges.add_row( + tree.interval.left, tree.interval.right, parent=0, child=root + ) + tables.edges.squash() + tables.sort() + return tables.tree_sequence() + + +def example_binary(n, L): + tree = tskit.Tree.generate_balanced(n, span=L) + return add_vestigial_root(tree.tree_sequence) + + class TestSingleBalancedTreeExample: # 4.00┊ 0 ┊ # ┊ ┃ ┊ @@ -558,12 +571,8 @@ def ts(): @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): ts = self.ts() - am = AncestorMatcher(ts) - m = 4 - match = np.zeros(m, dtype=int) - h = np.zeros(m) - h[j] = 1 - left, right, parent = am.find_path(h, 0, m, match) + h = np.zeros(4) + left, right, parent, match = run_match(ts, h) assert list(left) == [0] assert list(right) == [m] assert list(parent) == [ts.samples()[j]] @@ -640,6 +649,7 @@ def test_match_sample(self, j): match = np.zeros(m, dtype=int) h = np.zeros(m) h[j] = 1 + left, right, parent, match = run_match(self.ts(), h) left, right, parent = am.find_path(h, 0, m, match) assert list(left) == [0] assert list(right) == [m] From 3394a81d46cce4fbcf322645573bcf1374d537c7 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 27 Apr 2023 22:26:58 +0100 Subject: [PATCH 10/42] Abstract out the MatcherIndexes class --- tests/test_ls_hmm.py | 33 ++++++++++++++++++--------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 8744dde1..a22550cf 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -22,8 +22,6 @@ class Edge: # Special values used to indicate compressed paths and nodes that are # not present in the current tree. -COMPRESSED = -1 -NONZERO_ROOT = -2 def convert_edge_list(ts, order): @@ -38,8 +36,11 @@ def convert_edge_list(ts, order): class MatcherIndexes: - def __init__(self, ts): + """ + The memory that can be shared between AncestorMatcher instances. + """ + def __init__(self, ts): self.num_nodes = ts.num_nodes self.num_sites = ts.num_sites @@ -58,6 +59,10 @@ def __init__(self, ts): self.mutations[site.id].append((mutation.node, 1)) +COMPRESSED = -1 +NONZERO_ROOT = -2 + + class AncestorMatcher: def __init__( self, @@ -123,7 +128,7 @@ def set_allelic_state(self, site): # We know that 0 is always a root. # FIXME assuming for now that the ancestral state is always zero. self.allelic_state[0] = 0 - for node, state in self.mutations[site]: + for node, state in self.matcher_indexes.mutations[site]: self.allelic_state[node] = state def unset_allelic_state(self, site): @@ -145,7 +150,7 @@ def update_site(self, site, haplotype_state): self.set_allelic_state(site) - for node, _ in self.mutations[site]: + for node, _ in self.matcher_indexes.mutations[site]: # Insert an new L-value for the mutation node if needed. if self.likelihood[node] == COMPRESSED: u = node @@ -572,22 +577,23 @@ def ts(): def test_match_sample(self, j): ts = self.ts() h = np.zeros(4) + h[j] = 1 left, right, parent, match = run_match(ts, h) assert list(left) == [0] - assert list(right) == [m] + assert list(right) == [4] assert list(parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, match) def test_switch_each_sample(self): ts = self.ts() - am = AncestorMatcher(ts) m = 4 - match = np.zeros(m, dtype=int) h = np.zeros(m) h[:] = 1 - left, right, parent = am.find_path(h, 0, m, match) + left, right, parent, match = run_match(ts, h) assert list(left) == [3, 2, 1, 0] assert list(right) == [4, 3, 2, 1] assert list(parent) == [4, 3, 2, 1] + np.testing.assert_array_equal(h, match) class TestMultiTreeExample: @@ -644,25 +650,22 @@ def ts(): @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): ts = self.ts() - am = AncestorMatcher(ts) m = 4 - match = np.zeros(m, dtype=int) h = np.zeros(m) h[j] = 1 left, right, parent, match = run_match(self.ts(), h) - left, right, parent = am.find_path(h, 0, m, match) assert list(left) == [0] assert list(right) == [m] assert list(parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, match) def test_switch_each_sample(self): ts = self.ts() - am = AncestorMatcher(ts) m = 4 - match = np.zeros(m, dtype=int) h = np.zeros(m) h[:] = 1 - left, right, parent = am.find_path(h, 0, m, match) + left, right, parent, match = run_match(ts, h) assert list(left) == [3, 2, 1, 0] assert list(right) == [4, 3, 2, 1] assert list(parent) == [4, 3, 2, 1] + np.testing.assert_array_equal(h, match) From 3c6f5c0de42cee05cf26e020bf4a55e79de30ba1 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 27 Apr 2023 22:53:27 +0100 Subject: [PATCH 11/42] Pull out stuff necessary for static match --- lib/ancestor_matcher.c | 136 +++++++++++++++++++++++++++++++++++++++++ lib/tsinfer.h | 22 +++++++ 2 files changed, 158 insertions(+) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index 3370c2da..508e2abd 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1962,3 +1962,139 @@ lshmm_get_total_memory(lshmm_t *self) return total; } + +static int +matcher_indexes_copy_edge_indexes(matcher_indexes_t *self, const tsk_treeseq_t *ts) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t k; + edge_t e; + const tsk_id_t *restrict I = ts->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = ts->tables->indexes.edge_removal_order; + const double *restrict edges_right = ts->tables->edges.right; + const double *restrict edges_left = ts->tables->edges.left; + const tsk_id_t *restrict edges_child = ts->tables->edges.child; + const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; + + for (j = 0; j < self->num_edges; j++) { + k = I[j]; + /* TODO check that the edges can be cast */ + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->left_index_edges[j] = e; + + k = O[j]; + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->right_index_edges[j] = e; + } + return ret; +} + +static int WARN_UNUSED +matcher_indexes_add_mutation( + matcher_indexes_t *self, tsk_id_t site, tsk_id_t node, allele_t derived_state) +{ + int ret = 0; + mutation_list_node_t *list_node, *tail; + + list_node = tsk_blkalloc_get(&self->allocator, sizeof(mutation_list_node_t)); + if (list_node == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + list_node->node = node; + list_node->derived_state = derived_state; + list_node->next = NULL; + if (self->sites.mutations[site] == NULL) { + self->sites.mutations[site] = list_node; + } else { + tail = self->sites.mutations[site]; + while (tail->next != NULL) { + tail = tail->next; + } + tail->next = list_node; + } + self->num_mutations++; +out: + return ret; +} + +static int +matcher_indexes_copy_mutation_data(matcher_indexes_t *self, const tsk_treeseq_t *ts) +{ + int ret = 0; + tsk_site_t site; + tsk_size_t j, k; + + for (j = 0; j < self->num_sites; j++) { + ret = tsk_treeseq_get_site(ts, (tsk_id_t) j, &site); + if (ret != 0) { + goto out; + } + if (site.mutations_length > 1) { + ret = TSI_ERR_GENERIC; + goto out; + } + /* FIXME need to through this properly when we've got things working. */ + self->sites.num_alleles[j] = site.mutations_length + 1; + for (k = 0; k < site.mutations_length; k++) { + ret = matcher_indexes_add_mutation(self, site.id, site.mutations[k].node, 1); + if (ret != 0) { + goto out; + } + } + } +out: + return ret; +} + +int +matcher_indexes_alloc( + matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t flags) +{ + int ret = 0; + + self->flags = flags; + self->num_edges = tsk_treeseq_get_num_edges(ts); + self->num_nodes = tsk_treeseq_get_num_nodes(ts); + self->num_sites = tsk_treeseq_get_num_sites(ts); + self->num_mutations = tsk_treeseq_get_num_mutations(ts); + + self->left_index_edges = malloc(self->num_edges * sizeof(*self->left_index_edges)); + self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); + self->sites.mutations = malloc(self->num_sites * sizeof(*self->sites.mutations)); + self->sites.num_alleles = malloc(self->num_sites * sizeof(*self->sites.num_alleles)); + if (self->left_index_edges == NULL || self->right_index_edges == NULL + || self->sites.mutations == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + + ret = matcher_indexes_copy_edge_indexes(self, ts); + if (ret != 0) { + goto out; + } + ret = matcher_indexes_copy_mutation_data(self, ts); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +int +matcher_indexes_free(matcher_indexes_t *self) +{ + tsk_safe_free(self->left_index_edges); + tsk_safe_free(self->right_index_edges); + tsk_safe_free(self->sites.mutations); + tsk_safe_free(self->sites.num_alleles); + tsk_blkalloc_free(&self->allocator); + return 0; +} diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 36659616..861ada0f 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -209,6 +209,24 @@ typedef struct { } output; } ancestor_matcher_t; +typedef struct { + tsk_flags_t flags; + size_t num_sites; + size_t num_nodes; + size_t num_mutations; + size_t num_edges; + struct { + mutation_list_node_t **mutations; + tsk_size_t *num_alleles; + } sites; + /* TODO add nodes struct */ + double *time; + uint32_t *node_flags; + edge_t *left_index_edges; + edge_t *right_index_edges; + tsk_blkalloc_t allocator; +} matcher_indexes_t; + typedef struct { tsk_flags_t flags; const tsk_treeseq_t *ts; @@ -317,6 +335,10 @@ int tree_sequence_builder_dump_edges(tree_sequence_builder_t *self, tsk_id_t *le int tree_sequence_builder_dump_mutations(tree_sequence_builder_t *self, tsk_id_t *site, tsk_id_t *node, allele_t *derived_state, tsk_id_t *parent); +int matcher_indexes_alloc( + matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t flags); +int matcher_indexes_free(matcher_indexes_t *self); + int lshmm_alloc(lshmm_t *self, const tsk_treeseq_t *ts, const double *recombination_rate, const double *mismatch_rate, unsigned int precision, tsk_flags_t flags); int lshmm_free(lshmm_t *self); From b81f408bcc8226e9629c37b4739f6d673050ee94 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 27 Apr 2023 23:02:56 +0100 Subject: [PATCH 12/42] Compiling version without tree_sequence_builder --- lib/ancestor_matcher.c | 503 ++++++++++++++++++----------------------- lib/tsinfer.h | 43 ++-- 2 files changed, 234 insertions(+), 312 deletions(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index 508e2abd..1d4d6088 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1012,10 +1012,146 @@ ancestor_matcher_get_total_memory(ancestor_matcher_t *self) return total; } -/* NEW IMPLEMENTATION */ +/* New implementation */ + +static int +matcher_indexes_copy_edge_indexes(matcher_indexes_t *self, const tsk_treeseq_t *ts) +{ + int ret = 0; + tsk_size_t j; + tsk_id_t k; + edge_t e; + const tsk_id_t *restrict I = ts->tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = ts->tables->indexes.edge_removal_order; + const double *restrict edges_right = ts->tables->edges.right; + const double *restrict edges_left = ts->tables->edges.left; + const tsk_id_t *restrict edges_child = ts->tables->edges.child; + const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; + + for (j = 0; j < self->num_edges; j++) { + k = I[j]; + /* TODO check that the edges can be cast */ + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->left_index_edges[j] = e; + + k = O[j]; + e.left = (tsk_id_t) edges_left[k]; + e.right = (tsk_id_t) edges_right[k]; + e.parent = edges_parent[k]; + e.child = edges_child[k]; + self->right_index_edges[j] = e; + } + return ret; +} + +static int WARN_UNUSED +matcher_indexes_add_mutation( + matcher_indexes_t *self, tsk_id_t site, tsk_id_t node, allele_t derived_state) +{ + int ret = 0; + mutation_list_node_t *list_node, *tail; + + list_node = tsk_blkalloc_get(&self->allocator, sizeof(mutation_list_node_t)); + if (list_node == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + list_node->node = node; + list_node->derived_state = derived_state; + list_node->next = NULL; + if (self->sites.mutations[site] == NULL) { + self->sites.mutations[site] = list_node; + } else { + tail = self->sites.mutations[site]; + while (tail->next != NULL) { + tail = tail->next; + } + tail->next = list_node; + } + self->num_mutations++; +out: + return ret; +} + +static int +matcher_indexes_copy_mutation_data(matcher_indexes_t *self, const tsk_treeseq_t *ts) +{ + int ret = 0; + tsk_site_t site; + tsk_size_t j, k; + + for (j = 0; j < self->num_sites; j++) { + ret = tsk_treeseq_get_site(ts, (tsk_id_t) j, &site); + if (ret != 0) { + goto out; + } + if (site.mutations_length > 1) { + ret = TSI_ERR_GENERIC; + goto out; + } + /* FIXME need to through this properly when we've got things working. */ + self->sites.num_alleles[j] = site.mutations_length + 1; + for (k = 0; k < site.mutations_length; k++) { + ret = matcher_indexes_add_mutation(self, site.id, site.mutations[k].node, 1); + if (ret != 0) { + goto out; + } + } + } +out: + return ret; +} + +int +matcher_indexes_alloc( + matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t flags) +{ + int ret = 0; + + self->flags = flags; + self->num_edges = tsk_treeseq_get_num_edges(ts); + self->num_nodes = tsk_treeseq_get_num_nodes(ts); + self->num_sites = tsk_treeseq_get_num_sites(ts); + self->num_mutations = tsk_treeseq_get_num_mutations(ts); + + self->left_index_edges = malloc(self->num_edges * sizeof(*self->left_index_edges)); + self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); + self->sites.mutations = malloc(self->num_sites * sizeof(*self->sites.mutations)); + self->sites.num_alleles = malloc(self->num_sites * sizeof(*self->sites.num_alleles)); + if (self->left_index_edges == NULL || self->right_index_edges == NULL + || self->sites.mutations == NULL) { + ret = TSI_ERR_NO_MEMORY; + goto out; + } + + ret = matcher_indexes_copy_edge_indexes(self, ts); + if (ret != 0) { + goto out; + } + ret = matcher_indexes_copy_mutation_data(self, ts); + if (ret != 0) { + goto out; + } +out: + return ret; +} + +int +matcher_indexes_free(matcher_indexes_t *self) +{ + tsk_safe_free(self->left_index_edges); + tsk_safe_free(self->right_index_edges); + tsk_safe_free(self->sites.mutations); + tsk_safe_free(self->sites.num_alleles); + tsk_blkalloc_free(&self->allocator); + return 0; +} static void -lshmm_check_state(lshmm_t *self) +ancestor_matcher2_check_state(ancestor_matcher2_t *self) { int num_likelihoods; int j; @@ -1041,7 +1177,7 @@ lshmm_check_state(lshmm_t *self) } int -lshmm_print_state(lshmm_t *self, FILE *out) +ancestor_matcher2_print_state(ancestor_matcher2_t *self, FILE *out) { int j, k; tsk_id_t u; @@ -1077,82 +1213,24 @@ lshmm_print_state(lshmm_t *self, FILE *out) } tsk_blkalloc_print_state(&self->traceback_allocator, out); - /* lshmm_check_state(self); */ + /* ancestor_matcher2_check_state(self); */ return 0; } -static int -lshmm_copy_edge_indexes(lshmm_t *self) -{ - int ret = 0; - tsk_size_t j; - tsk_id_t k; - edge_t e; - const tsk_id_t *restrict I = self->ts->tables->indexes.edge_insertion_order; - const tsk_id_t *restrict O = self->ts->tables->indexes.edge_removal_order; - const double *restrict edges_right = self->ts->tables->edges.right; - const double *restrict edges_left = self->ts->tables->edges.left; - const tsk_id_t *restrict edges_child = self->ts->tables->edges.child; - const tsk_id_t *restrict edges_parent = self->ts->tables->edges.parent; - - for (j = 0; j < self->num_edges; j++) { - k = I[j]; - /* TODO check that the edges can be cast */ - e.left = (tsk_id_t) edges_left[k]; - e.right = (tsk_id_t) edges_right[k]; - e.parent = edges_parent[k]; - e.child = edges_child[k]; - self->left_index_edges[j] = e; - - k = O[j]; - e.left = (tsk_id_t) edges_left[k]; - e.right = (tsk_id_t) edges_right[k]; - e.parent = edges_parent[k]; - e.child = edges_child[k]; - self->right_index_edges[j] = e; - } - return ret; -} - -static int -lshmm_copy_mutation_data(lshmm_t *self) -{ - int ret = 0; - tsk_site_t site; - tsk_size_t j; - - for (j = 0; j < self->num_sites; j++) { - ret = tsk_treeseq_get_site(self->ts, (tsk_id_t) j, &site); - if (ret != 0) { - goto out; - } - if (site.mutations_length > 1) { - ret = TSI_ERR_GENERIC; - goto out; - } - self->sites.num_alleles[j] = site.mutations_length; - /* Next, allocate the corresponding mutation object here and link - * to it. */ - /* delibaretely not compiling */ - printf("FIXME mutations not working\n"); - } -out: - return ret; -} - int -lshmm_alloc(lshmm_t *self, const tsk_treeseq_t *ts, const double *recombination_rate, - const double *mismatch_rate, unsigned int precision, tsk_flags_t flags) +ancestor_matcher2_alloc(ancestor_matcher2_t *self, + const matcher_indexes_t *matcher_indexes, double *recombination_rate, + double *mismatch_rate, unsigned int precision, int flags) { int ret = 0; - memset(self, 0, sizeof(lshmm_t)); + memset(self, 0, sizeof(ancestor_matcher2_t)); /* All allocs for arrays related to nodes are done in expand_nodes */ self->flags = flags; self->precision = precision; - self->ts = ts; - self->num_nodes = tsk_treeseq_get_num_nodes(ts); - self->num_sites = tsk_treeseq_get_num_sites(ts); + self->max_nodes = 0; + self->matcher_indexes = matcher_indexes; + self->num_sites = matcher_indexes->num_sites; self->recombination_rate = malloc(self->num_sites * sizeof(*self->recombination_rate)); self->mismatch_rate = malloc(self->num_sites * sizeof(*self->mismatch_rate)); @@ -1162,15 +1240,10 @@ lshmm_alloc(lshmm_t *self, const tsk_treeseq_t *ts, const double *recombination_ self->output.left = malloc(self->output.max_size * sizeof(tsk_id_t)); self->output.right = malloc(self->output.max_size * sizeof(tsk_id_t)); self->output.parent = malloc(self->output.max_size * sizeof(tsk_id_t)); - self->left_index_edges = malloc(self->num_edges * sizeof(*self->left_index_edges)); - self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); - self->sites.mutations = malloc(self->num_sites * sizeof(*self->sites.mutations)); - self->sites.num_alleles = malloc(self->num_sites * sizeof(*self->sites.num_alleles)); if (self->recombination_rate == NULL || self->mismatch_rate == NULL || self->traceback == NULL || self->max_likelihood_node == NULL || self->output.left == NULL || self->output.right == NULL - || self->output.parent == NULL || self->left_index_edges == NULL - || self->right_index_edges == NULL || self->sites.mutations == NULL) { + || self->output.parent == NULL) { ret = TSI_ERR_NO_MEMORY; goto out; } @@ -1187,20 +1260,12 @@ lshmm_alloc(lshmm_t *self, const tsk_treeseq_t *ts, const double *recombination_ self->num_sites * sizeof(*self->recombination_rate)); memcpy(self->mismatch_rate, mismatch_rate, self->num_sites * sizeof(*self->mismatch_rate)); - ret = lshmm_copy_edge_indexes(self); - if (ret != 0) { - goto out; - } - ret = lshmm_copy_mutation_data(self); - if (ret != 0) { - goto out; - } out: return ret; } int -lshmm_free(lshmm_t *self) +ancestor_matcher2_free(ancestor_matcher2_t *self) { tsi_safe_free(self->recombination_rate); tsi_safe_free(self->mismatch_rate); @@ -1220,16 +1285,13 @@ lshmm_free(lshmm_t *self) tsi_safe_free(self->output.left); tsi_safe_free(self->output.right); tsi_safe_free(self->output.parent); - tsk_safe_free(self->left_index_edges); - tsk_safe_free(self->right_index_edges); - tsk_safe_free(self->sites.mutations); - tsk_safe_free(self->sites.num_alleles); tsk_blkalloc_free(&self->traceback_allocator); return 0; } static int -lshmm_delete_likelihood(lshmm_t *self, const tsk_id_t node, double *restrict L) +ancestor_matcher2_delete_likelihood( + ancestor_matcher2_t *self, const tsk_id_t node, double *restrict L) { /* Remove the specified node from the list of nodes */ int j, k; @@ -1250,7 +1312,7 @@ lshmm_delete_likelihood(lshmm_t *self, const tsk_id_t node, double *restrict L) /* Store the recombination_required state in the traceback */ static int WARN_UNUSED -lshmm_store_traceback(lshmm_t *self, const tsk_id_t site_id) +ancestor_matcher2_store_traceback(ancestor_matcher2_t *self, const tsk_id_t site_id) { int ret = 0; tsk_id_t u; @@ -1312,15 +1374,15 @@ lshmm_store_traceback(lshmm_t *self, const tsk_id_t site_id) /* Sets the specified allelic state array to reflect the mutations at the * specified site. */ static inline void -lshmm_set_allelic_state( - lshmm_t *self, const tsk_id_t site, allele_t *restrict allelic_state) +ancestor_matcher2_set_allelic_state( + ancestor_matcher2_t *self, const tsk_id_t site, allele_t *restrict allelic_state) { mutation_list_node_t *mutation; /* FIXME assuming that 0 is always the ancestral state */ allelic_state[0] = 0; - for (mutation = self->sites.mutations[site]; mutation != NULL; + for (mutation = self->matcher_indexes->sites.mutations[site]; mutation != NULL; mutation = mutation->next) { allelic_state[mutation->node] = mutation->derived_state; } @@ -1328,21 +1390,22 @@ lshmm_set_allelic_state( /* Resets the allelic state at this site to NULL. */ static inline void -lshmm_unset_allelic_state( - lshmm_t *self, const tsk_id_t site, allele_t *restrict allelic_state) +ancestor_matcher2_unset_allelic_state( + ancestor_matcher2_t *self, const tsk_id_t site, allele_t *restrict allelic_state) { mutation_list_node_t *mutation; allelic_state[0] = NULL_NODE; - for (mutation = self->sites.mutations[site]; mutation != NULL; + for (mutation = self->matcher_indexes->sites.mutations[site]; mutation != NULL; mutation = mutation->next) { allelic_state[mutation->node] = TSK_NULL; } } static int WARN_UNUSED -lshmm_update_site_likelihood_values(lshmm_t *self, const tsk_id_t site, - const allele_t state, const tsk_id_t *restrict parent, double *restrict L) +ancestor_matcher2_update_site_likelihood_values(ancestor_matcher2_t *self, + const tsk_id_t site, const allele_t state, const tsk_id_t *restrict parent, + double *restrict L) { int ret = 0; const int num_likelihood_nodes = self->num_likelihood_nodes; @@ -1354,17 +1417,15 @@ lshmm_update_site_likelihood_values(lshmm_t *self, const tsk_id_t site, double max_L, p_last, p_no_recomb, p_recomb, p_t, p_e; const double rho = self->recombination_rate[site]; const double mu = self->mismatch_rate[site]; - const double n = (double) self->num_nodes; - const double num_alleles = 2; - /* = (double) self->tree_sequence_builder->sites.num_alleles[site]; */ - printf("FIXME get num alleles at site\n"); + const double n = (double) self->matcher_indexes->num_nodes; + const double num_alleles = (double) self->matcher_indexes->sites.num_alleles[site]; if (state >= num_alleles) { ret = TSI_ERR_BAD_HAPLOTYPE_ALLELE; goto out; } - lshmm_set_allelic_state(self, site, allelic_state); + ancestor_matcher2_set_allelic_state(self, site, allelic_state); max_L = -1; max_L_node = NULL_NODE; @@ -1404,7 +1465,7 @@ lshmm_update_site_likelihood_values(lshmm_t *self, const tsk_id_t site, max_L_node = u; } } - /* lshmm_print_state(self, stdout); */ + /* ancestor_matcher2_print_state(self, stdout); */ if (max_L <= 0) { if (mu <= 0 || mu >= 1) { ret = TSI_ERR_MATCH_IMPOSSIBLE_EXTREME_MUTATION_PROBA; @@ -1425,14 +1486,14 @@ lshmm_update_site_likelihood_values(lshmm_t *self, const tsk_id_t site, u = L_nodes[j]; L[u] = tsk_round(L[u] / max_L, self->precision); } - lshmm_unset_allelic_state(self, site, allelic_state); + ancestor_matcher2_unset_allelic_state(self, site, allelic_state); out: return ret; } static int WARN_UNUSED -lshmm_coalesce_likelihoods(lshmm_t *self, const tsk_id_t *restrict parent, - double *restrict L, double *restrict L_cache) +ancestor_matcher2_coalesce_likelihoods(ancestor_matcher2_t *self, + const tsk_id_t *restrict parent, double *restrict L, double *restrict L_cache) { int ret = 0; double L_p; @@ -1478,7 +1539,7 @@ lshmm_coalesce_likelihoods(lshmm_t *self, const tsk_id_t *restrict parent, num_likelihood_nodes++; } } - /* lshmm_print_state(self, stdout); */ + /* ancestor_matcher2_print_state(self, stdout); */ assert(num_likelihood_nodes > 0); self->num_likelihood_nodes = num_likelihood_nodes; @@ -1495,19 +1556,20 @@ lshmm_coalesce_likelihoods(lshmm_t *self, const tsk_id_t *restrict parent, } static int -lshmm_update_site_state(lshmm_t *self, const tsk_id_t site, const allele_t state, - tsk_id_t *restrict parent, double *restrict L, double *restrict L_cache) +ancestor_matcher2_update_site_state(ancestor_matcher2_t *self, const tsk_id_t site, + const allele_t state, tsk_id_t *restrict parent, double *restrict L, + double *restrict L_cache) { int ret = 0; - mutation_list_node_t *mutation = NULL; + mutation_list_node_t *mutation = self->matcher_indexes->sites.mutations[site]; tsk_id_t u; assert(self->num_likelihood_nodes > 0); if (self->flags & TSI_EXTENDED_CHECKS) { - lshmm_check_state(self); + ancestor_matcher2_check_state(self); } - for (mutation = self->sites.mutations[site]; mutation != NULL; + for (mutation = self->matcher_indexes->sites.mutations[site]; mutation != NULL; mutation = mutation->next) { /* Insert a new L-value for the mutation node if needed */ if (L[mutation->node] == NULL_LIKELIHOOD) { @@ -1521,15 +1583,15 @@ lshmm_update_site_state(lshmm_t *self, const tsk_id_t site, const allele_t state self->num_likelihood_nodes++; } } - ret = lshmm_update_site_likelihood_values(self, site, state, parent, L); + ret = ancestor_matcher2_update_site_likelihood_values(self, site, state, parent, L); if (ret != 0) { goto out; } - ret = lshmm_store_traceback(self, site); + ret = ancestor_matcher2_store_traceback(self, site); if (ret != 0) { goto out; } - ret = lshmm_coalesce_likelihoods(self, parent, L, L_cache); + ret = ancestor_matcher2_coalesce_likelihoods(self, parent, L, L_cache); if (ret != 0) { goto out; } @@ -1538,7 +1600,7 @@ lshmm_update_site_state(lshmm_t *self, const tsk_id_t site, const allele_t state } static void -lshmm_reset_tree(lshmm_t *self) +ancestor_matcher2_reset_tree(ancestor_matcher2_t *self) { memset(self->parent, 0xff, self->num_nodes * sizeof(*self->parent)); memset(self->left_child, 0xff, self->num_nodes * sizeof(*self->left_child)); @@ -1550,7 +1612,7 @@ lshmm_reset_tree(lshmm_t *self) } static int -lshmm_reset(lshmm_t *self) +ancestor_matcher2_reset(ancestor_matcher2_t *self) { int ret = 0; @@ -1570,7 +1632,7 @@ lshmm_reset(lshmm_t *self) } self->total_traceback_size = 0; self->num_likelihood_nodes = 0; - lshmm_reset_tree(self); + ancestor_matcher2_reset_tree(self); out: return ret; } @@ -1578,8 +1640,8 @@ lshmm_reset(lshmm_t *self) /* Resets the recombination_required array from the traceback at the specified site. */ static inline void -lshmm_set_recombination_required( - lshmm_t *self, tsk_id_t site, int8_t *restrict recombination_required) +ancestor_matcher2_set_recombination_required( + ancestor_matcher2_t *self, tsk_id_t site, int8_t *restrict recombination_required) { int j; const int8_t *restrict R = self->traceback[site].recombination_required; @@ -1598,8 +1660,8 @@ lshmm_set_recombination_required( /* Unsets the likelihood array from the traceback at the specified site. */ static inline void -lshmm_unset_recombination_required( - lshmm_t *self, tsk_id_t site, int8_t *restrict recombination_required) +ancestor_matcher2_unset_recombination_required( + ancestor_matcher2_t *self, tsk_id_t site, int8_t *restrict recombination_required) { int j; const tsk_id_t *restrict node = self->traceback[site].node; @@ -1612,7 +1674,7 @@ lshmm_unset_recombination_required( } static int WARN_UNUSED -lshmm_run_traceback(lshmm_t *self, tsk_id_t start, tsk_id_t end, +ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, allele_t *TSK_UNUSED(haplotype), allele_t *match) { int ret = 0; @@ -1623,10 +1685,10 @@ lshmm_run_traceback(lshmm_t *self, tsk_id_t start, tsk_id_t end, tsk_id_t *restrict parent = self->parent; allele_t *restrict allelic_state = self->allelic_state; int8_t *restrict recombination_required = self->recombination_required; - const edge_t *restrict in = self->right_index_edges; - const edge_t *restrict out = self->left_index_edges; - int_fast32_t in_index = (int_fast32_t) self->num_edges - 1; - int_fast32_t out_index = (int_fast32_t) self->num_edges - 1; + const edge_t *restrict in = self->matcher_indexes->right_index_edges; + const edge_t *restrict out = self->matcher_indexes->left_index_edges; + int_fast32_t in_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; + int_fast32_t out_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; /* Prepare for the traceback and get the memory ready for recording * the output edges. */ @@ -1669,17 +1731,18 @@ lshmm_run_traceback(lshmm_t *self, tsk_id_t start, tsk_id_t end, /* The tree is ready; perform the traceback at each site in this tree */ assert(left < right); for (l = TSK_MIN(right, end) - 1; l >= (int) TSK_MAX(left, start); l--) { - lshmm_set_allelic_state(self, l, allelic_state); + ancestor_matcher2_set_allelic_state(self, l, allelic_state); u = self->output.parent[self->output.size]; v = u; while (allelic_state[v] == TSK_NULL) { v = parent[v]; } match[l] = allelic_state[v]; - lshmm_unset_allelic_state(self, l, allelic_state); + ancestor_matcher2_unset_allelic_state(self, l, allelic_state); /* Mark the traceback nodes on the tree */ - lshmm_set_recombination_required(self, l, recombination_required); + ancestor_matcher2_set_recombination_required( + self, l, recombination_required); /* Traverse up the tree from the current node. The first marked node that we * meed tells us whether we need to recombine */ @@ -1698,7 +1761,8 @@ lshmm_run_traceback(lshmm_t *self, tsk_id_t start, tsk_id_t end, self->output.parent[self->output.size] = max_likelihood_node; } /* Unset the values in the tree for the next site. */ - lshmm_unset_recombination_required(self, l, recombination_required); + ancestor_matcher2_unset_recombination_required( + self, l, recombination_required); } } @@ -1709,8 +1773,8 @@ lshmm_run_traceback(lshmm_t *self, tsk_id_t start, tsk_id_t end, } static int -lshmm_run_forwards_match( - lshmm_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype) +ancestor_matcher2_run_forwards_match( + ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype) { int ret = 0; tsk_id_t site; @@ -1729,9 +1793,9 @@ lshmm_run_forwards_match( tsk_id_t *restrict left_sib = self->left_sib; tsk_id_t *restrict right_sib = self->right_sib; tsk_id_t pos, left, right; - const edge_t *restrict in = self->left_index_edges; - const edge_t *restrict out = self->right_index_edges; - const int_fast32_t M = (tsk_id_t) self->num_edges; + const edge_t *restrict in = self->matcher_indexes->left_index_edges; + const edge_t *restrict out = self->matcher_indexes->right_index_edges; + const int_fast32_t M = (tsk_id_t) self->matcher_indexes->num_edges; int_fast32_t in_index, out_index, l, remove_start; /* Load the tree for start */ @@ -1781,7 +1845,7 @@ lshmm_run_forwards_match( } } if (self->flags & TSI_EXTENDED_CHECKS) { - lshmm_check_state(self); + ancestor_matcher2_check_state(self); } last_root = 0; if (left_child[0] != NULL_NODE) { @@ -1802,13 +1866,13 @@ lshmm_run_forwards_match( edge = out[l]; if (unlikely(is_nonzero_root(edge.child, parent, left_child))) { if (L[edge.child] >= 0) { - lshmm_delete_likelihood(self, edge.child, L); + ancestor_matcher2_delete_likelihood(self, edge.child, L); } L[edge.child] = NONZERO_ROOT_LIKELIHOOD; } if (unlikely(is_nonzero_root(edge.parent, parent, left_child))) { if (L[edge.parent] >= 0) { - lshmm_delete_likelihood(self, edge.parent, L); + ancestor_matcher2_delete_likelihood(self, edge.parent, L); } L[edge.parent] = NONZERO_ROOT_LIKELIHOOD; } @@ -1821,7 +1885,7 @@ lshmm_run_forwards_match( } if (root != last_root) { if (last_root == 0) { - lshmm_delete_likelihood(self, last_root, L); + ancestor_matcher2_delete_likelihood(self, last_root, L); L[last_root] = NONZERO_ROOT_LIKELIHOOD; } if (L[root] == NONZERO_ROOT_LIKELIHOOD) { @@ -1833,10 +1897,10 @@ lshmm_run_forwards_match( } if (self->flags & TSI_EXTENDED_CHECKS) { - lshmm_check_state(self); + ancestor_matcher2_check_state(self); } for (site = TSK_MAX(left, start); site < TSK_MIN(right, end); site++) { - ret = lshmm_update_site_state( + ret = ancestor_matcher2_update_site_state( self, site, haplotype[site], parent, L, L_cache); if (ret != 0) { goto out; @@ -1916,21 +1980,22 @@ lshmm_run_forwards_match( } int -lshmm_find_path(lshmm_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype, - allele_t *matched_haplotype, size_t *num_output_edges, tsk_id_t **left_output, - tsk_id_t **right_output, tsk_id_t **parent_output) +ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, + allele_t *haplotype, allele_t *matched_haplotype, size_t *num_output_edges, + tsk_id_t **left_output, tsk_id_t **right_output, tsk_id_t **parent_output) { int ret = 0; - ret = lshmm_reset(self); + ret = ancestor_matcher2_reset(self); if (ret != 0) { goto out; } - ret = lshmm_run_forwards_match(self, start, end, haplotype); + ret = ancestor_matcher2_run_forwards_match(self, start, end, haplotype); if (ret != 0) { goto out; } - ret = lshmm_run_traceback(self, start, end, haplotype, matched_haplotype); + ret = ancestor_matcher2_run_traceback( + self, start, end, haplotype, matched_haplotype); if (ret != 0) { goto out; } @@ -1949,152 +2014,16 @@ lshmm_find_path(lshmm_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype } double -lshmm_get_mean_traceback_size(lshmm_t *self) +ancestor_matcher2_get_mean_traceback_size(ancestor_matcher2_t *self) { return (double) self->total_traceback_size / ((double) self->num_sites); } size_t -lshmm_get_total_memory(lshmm_t *self) +ancestor_matcher2_get_total_memory(ancestor_matcher2_t *self) { size_t total = self->traceback_allocator.total_size; /* TODO add contributions from other objects */ return total; } - -static int -matcher_indexes_copy_edge_indexes(matcher_indexes_t *self, const tsk_treeseq_t *ts) -{ - int ret = 0; - tsk_size_t j; - tsk_id_t k; - edge_t e; - const tsk_id_t *restrict I = ts->tables->indexes.edge_insertion_order; - const tsk_id_t *restrict O = ts->tables->indexes.edge_removal_order; - const double *restrict edges_right = ts->tables->edges.right; - const double *restrict edges_left = ts->tables->edges.left; - const tsk_id_t *restrict edges_child = ts->tables->edges.child; - const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; - - for (j = 0; j < self->num_edges; j++) { - k = I[j]; - /* TODO check that the edges can be cast */ - e.left = (tsk_id_t) edges_left[k]; - e.right = (tsk_id_t) edges_right[k]; - e.parent = edges_parent[k]; - e.child = edges_child[k]; - self->left_index_edges[j] = e; - - k = O[j]; - e.left = (tsk_id_t) edges_left[k]; - e.right = (tsk_id_t) edges_right[k]; - e.parent = edges_parent[k]; - e.child = edges_child[k]; - self->right_index_edges[j] = e; - } - return ret; -} - -static int WARN_UNUSED -matcher_indexes_add_mutation( - matcher_indexes_t *self, tsk_id_t site, tsk_id_t node, allele_t derived_state) -{ - int ret = 0; - mutation_list_node_t *list_node, *tail; - - list_node = tsk_blkalloc_get(&self->allocator, sizeof(mutation_list_node_t)); - if (list_node == NULL) { - ret = TSI_ERR_NO_MEMORY; - goto out; - } - list_node->node = node; - list_node->derived_state = derived_state; - list_node->next = NULL; - if (self->sites.mutations[site] == NULL) { - self->sites.mutations[site] = list_node; - } else { - tail = self->sites.mutations[site]; - while (tail->next != NULL) { - tail = tail->next; - } - tail->next = list_node; - } - self->num_mutations++; -out: - return ret; -} - -static int -matcher_indexes_copy_mutation_data(matcher_indexes_t *self, const tsk_treeseq_t *ts) -{ - int ret = 0; - tsk_site_t site; - tsk_size_t j, k; - - for (j = 0; j < self->num_sites; j++) { - ret = tsk_treeseq_get_site(ts, (tsk_id_t) j, &site); - if (ret != 0) { - goto out; - } - if (site.mutations_length > 1) { - ret = TSI_ERR_GENERIC; - goto out; - } - /* FIXME need to through this properly when we've got things working. */ - self->sites.num_alleles[j] = site.mutations_length + 1; - for (k = 0; k < site.mutations_length; k++) { - ret = matcher_indexes_add_mutation(self, site.id, site.mutations[k].node, 1); - if (ret != 0) { - goto out; - } - } - } -out: - return ret; -} - -int -matcher_indexes_alloc( - matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t flags) -{ - int ret = 0; - - self->flags = flags; - self->num_edges = tsk_treeseq_get_num_edges(ts); - self->num_nodes = tsk_treeseq_get_num_nodes(ts); - self->num_sites = tsk_treeseq_get_num_sites(ts); - self->num_mutations = tsk_treeseq_get_num_mutations(ts); - - self->left_index_edges = malloc(self->num_edges * sizeof(*self->left_index_edges)); - self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); - self->sites.mutations = malloc(self->num_sites * sizeof(*self->sites.mutations)); - self->sites.num_alleles = malloc(self->num_sites * sizeof(*self->sites.num_alleles)); - if (self->left_index_edges == NULL || self->right_index_edges == NULL - || self->sites.mutations == NULL) { - ret = TSI_ERR_NO_MEMORY; - goto out; - } - - ret = matcher_indexes_copy_edge_indexes(self, ts); - if (ret != 0) { - goto out; - } - ret = matcher_indexes_copy_mutation_data(self, ts); - if (ret != 0) { - goto out; - } -out: - return ret; -} - -int -matcher_indexes_free(matcher_indexes_t *self) -{ - tsk_safe_free(self->left_index_edges); - tsk_safe_free(self->right_index_edges); - tsk_safe_free(self->sites.mutations); - tsk_safe_free(self->sites.num_alleles); - tsk_blkalloc_free(&self->allocator); - return 0; -} diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 861ada0f..9c62c5c6 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -228,21 +228,11 @@ typedef struct { } matcher_indexes_t; typedef struct { - tsk_flags_t flags; - const tsk_treeseq_t *ts; - tsk_size_t num_sites; - tsk_size_t num_nodes; - tsk_size_t num_edges; - /* FIXME Copying these in here as a quick way of getting the code working - * again. However, the memory cost might actually be worth it, needs checking - */ - edge_t *left_index_edges; - edge_t *right_index_edges; - /* Copying this in here for simplicity. */ - struct { - mutation_list_node_t **mutations; - tsk_size_t *num_alleles; - } sites; + int flags; + const matcher_indexes_t *matcher_indexes; + size_t num_nodes; + size_t num_sites; + size_t max_nodes; /* Input LS model rates */ unsigned int precision; double *recombination_rate; @@ -275,7 +265,7 @@ typedef struct { size_t size; size_t max_size; } output; -} lshmm_t; +} ancestor_matcher2_t; int ancestor_builder_alloc(ancestor_builder_t *self, size_t num_samples, size_t num_sites, int mmap_fd, int flags); @@ -335,19 +325,22 @@ int tree_sequence_builder_dump_edges(tree_sequence_builder_t *self, tsk_id_t *le int tree_sequence_builder_dump_mutations(tree_sequence_builder_t *self, tsk_id_t *site, tsk_id_t *node, allele_t *derived_state, tsk_id_t *parent); +/* New impelementation */ + int matcher_indexes_alloc( matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t flags); int matcher_indexes_free(matcher_indexes_t *self); -int lshmm_alloc(lshmm_t *self, const tsk_treeseq_t *ts, const double *recombination_rate, - const double *mismatch_rate, unsigned int precision, tsk_flags_t flags); -int lshmm_free(lshmm_t *self); -int lshmm_find_path(lshmm_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype, - allele_t *matched_haplotype, size_t *num_output_edges, tsk_id_t **left_output, - tsk_id_t **right_output, tsk_id_t **parent_output); -int lshmm_print_state(lshmm_t *self, FILE *out); -double lshmm_get_mean_traceback_size(lshmm_t *self); -size_t lshmm_get_total_memory(lshmm_t *self); +int ancestor_matcher2_alloc(ancestor_matcher2_t *self, + const matcher_indexes_t *matcher_indexes, double *recombination_rate, + double *mismatch_rate, unsigned int precision, int flags); +int ancestor_matcher2_free(ancestor_matcher2_t *self); +int ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, + allele_t *haplotype, allele_t *matched_haplotype, size_t *num_output_edges, + tsk_id_t **left_output, tsk_id_t **right_output, tsk_id_t **parent_output); +int ancestor_matcher2_print_state(ancestor_matcher2_t *self, FILE *out); +double ancestor_matcher2_get_mean_traceback_size(ancestor_matcher2_t *self); +size_t ancestor_matcher2_get_total_memory(ancestor_matcher2_t *self); int packbits(const allele_t *restrict source, size_t len, uint8_t *restrict dest); void unpackbits(const uint8_t *restrict source, size_t len, allele_t *restrict dest); From 17b1d0fbdb72c98c5727f3123c53b6b30508aad0 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 28 Apr 2023 22:58:42 +0100 Subject: [PATCH 13/42] Partway through getting C tests running --- lib/ancestor_matcher.c | 31 ++++++++++++++----- lib/tests/tests.c | 67 ++++++++++++++++++++++++++++++++++++++---- lib/tsinfer.h | 3 +- tests/test_ls_hmm.py | 1 + 4 files changed, 86 insertions(+), 16 deletions(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index 1d4d6088..da39dd8a 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -742,6 +742,7 @@ insert_edge(edge_t edge, tsk_id_t *restrict parent, tsk_id_t *restrict left_chil { const tsk_id_t p = edge.parent; const tsk_id_t c = edge.child; + assert(right_child != NULL); const tsk_id_t u = right_child[p]; parent[c] = p; @@ -1092,6 +1093,8 @@ matcher_indexes_copy_mutation_data(matcher_indexes_t *self, const tsk_treeseq_t ret = TSI_ERR_GENERIC; goto out; } + self->sites.mutations[j] = NULL; + /* FIXME need to through this properly when we've got things working. */ self->sites.num_alleles[j] = site.mutations_length + 1; for (k = 0; k < site.mutations_length; k++) { @@ -1126,6 +1129,10 @@ matcher_indexes_alloc( ret = TSI_ERR_NO_MEMORY; goto out; } + ret = tsk_blkalloc_init(&self->allocator, 65536); + if (ret != 0) { + goto out; + } ret = matcher_indexes_copy_edge_indexes(self, ts); if (ret != 0) { @@ -1237,13 +1244,26 @@ ancestor_matcher2_alloc(ancestor_matcher2_t *self, self->output.max_size = self->num_sites; /* We can probably make this smaller */ self->traceback = calloc(self->num_sites, sizeof(node_state_list_t)); self->max_likelihood_node = malloc(self->num_sites * sizeof(tsk_id_t)); + /* TODO get rid of output and just provide pointers from client code. + * We're allocating num_sites anyway, so there's no memory saving. */ self->output.left = malloc(self->output.max_size * sizeof(tsk_id_t)); self->output.right = malloc(self->output.max_size * sizeof(tsk_id_t)); + self->output.parent = malloc(self->output.max_size * sizeof(tsk_id_t)); + + self->parent = malloc(self->num_nodes * sizeof(*self->parent)); + self->left_child = malloc(self->num_nodes * sizeof(*self->left_child)); + self->right_child = malloc(self->num_nodes * sizeof(*self->right_child)); + self->left_sib = malloc(self->num_nodes * sizeof(*self->left_sib)); + self->right_sib = malloc(self->num_nodes * sizeof(*self->right_sib)); + self->recombination_required + = malloc(self->num_nodes * sizeof(*self->recombination_required)); + if (self->recombination_rate == NULL || self->mismatch_rate == NULL || self->traceback == NULL || self->max_likelihood_node == NULL || self->output.left == NULL || self->output.right == NULL || self->output.parent == NULL) { + /* FIXME check tree allocs above */ ret = TSI_ERR_NO_MEMORY; goto out; } @@ -1675,7 +1695,7 @@ ancestor_matcher2_unset_recombination_required( static int WARN_UNUSED ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, - allele_t *TSK_UNUSED(haplotype), allele_t *match) + const allele_t *TSK_UNUSED(haplotype), allele_t *match) { int ret = 0; tsk_id_t l; @@ -1774,7 +1794,7 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i static int ancestor_matcher2_run_forwards_match( - ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, allele_t *haplotype) + ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, const allele_t *haplotype) { int ret = 0; tsk_id_t site; @@ -1981,8 +2001,7 @@ ancestor_matcher2_run_forwards_match( int ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, - allele_t *haplotype, allele_t *matched_haplotype, size_t *num_output_edges, - tsk_id_t **left_output, tsk_id_t **right_output, tsk_id_t **parent_output) + const allele_t *haplotype, allele_t *matched_haplotype) { int ret = 0; @@ -2005,10 +2024,6 @@ ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t memset(self->max_likelihood_node + start, 0xff, ((size_t)(end - start)) * sizeof(*self->max_likelihood_node)); - *left_output = self->output.left; - *right_output = self->output.right; - *parent_output = self->output.parent; - *num_output_edges = self->output.size; out: return ret; } diff --git a/lib/tests/tests.c b/lib/tests/tests.c index fffad717..510b0f07 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -31,6 +31,11 @@ #include +/* FIXME this needs to be updated somehow to allow the tests to be run from + * different directories, i.e., with ninja -C build test + */ +#define TEST_DATA_DIR "test_data" + /* Global variables used for test in state in the test suite */ char *_tmp_file_name; @@ -985,19 +990,69 @@ test_packbits_errors(void) CU_ASSERT_EQUAL_FATAL(ret, TSI_ERR_ONE_BIT_NON_BINARY); } +static int +run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, + allele_t *match, tsk_size_t *path_length, tsk_id_t *left, tsk_id_t *right, + tsk_id_t *parent) +{ + int ret; + ancestor_matcher2_t am; + matcher_indexes_t mi; + const size_t m = tsk_treeseq_get_num_sites(ts); + double *recombination_rate = calloc(m, sizeof(*recombination_rate)); + double *mutation_rate = calloc(m, sizeof(*mutation_rate)); + size_t j; + + CU_ASSERT_FATAL(recombination_rate != NULL); + CU_ASSERT_FATAL(mutation_rate != NULL); + for (j = 0; j < m; j++) { + mutation_rate[j] = mu; + recombination_rate[j] = rho; + } + + ret = matcher_indexes_alloc(&mi, ts, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + ret = ancestor_matcher2_alloc(&am, &mi, recombination_rate, mutation_rate, 14, 0); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + ret = ancestor_matcher2_find_path(&am, 0, (tsk_id_t) m, h, match); + CU_ASSERT_EQUAL_FATAL(ret, 0); + + *path_length = am.output.size; + for (j = 0; j < am.output.size; j++) { + left[j] = am.output.left[j]; + right[j] = am.output.right[j]; + parent[j] = am.output.parent[j]; + } + + /* ancestor_matcher2_print_state(&am, stdout); */ + + ancestor_matcher2_free(&am); + matcher_indexes_free(&mi); + free(recombination_rate); + free(mutation_rate); + + return 0; +} + static void test_matching_simplest_tree_one_site(void) { int ret = 0; - tsk_table_collection_t tables; + tsk_treeseq_t ts; + allele_t h[] = { 0, 0, 0, 0 }; + allele_t match[4]; + tsk_id_t left[4], right[4], parent[4]; + tsk_size_t path_length; - ret = tsk_table_collection_init(&tables, 0); + ret = tsk_treeseq_load(&ts, TEST_DATA_DIR "/single_tree_example.trees", 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - tables.sequence_length = 1; - printf("ADD some topology and one site here\n"); - CU_ASSERT_FATAL(1 == 0); + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_sites(&ts), 4); - tsk_table_collection_free(&tables); + h[0] = 1; + run_match(&ts, 1e-8, 0, h, match, &path_length, left, right, parent); + + tsk_treeseq_free(&ts); } static void diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 9c62c5c6..3063180f 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -336,8 +336,7 @@ int ancestor_matcher2_alloc(ancestor_matcher2_t *self, double *mismatch_rate, unsigned int precision, int flags); int ancestor_matcher2_free(ancestor_matcher2_t *self); int ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, - allele_t *haplotype, allele_t *matched_haplotype, size_t *num_output_edges, - tsk_id_t **left_output, tsk_id_t **right_output, tsk_id_t **parent_output); + const allele_t *haplotype, allele_t *matched_haplotype); int ancestor_matcher2_print_state(ancestor_matcher2_t *self, FILE *out); double ancestor_matcher2_get_mean_traceback_size(ancestor_matcher2_t *self); size_t ancestor_matcher2_get_total_memory(ancestor_matcher2_t *self); diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index a22550cf..d1080c1f 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -576,6 +576,7 @@ def ts(): @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): ts = self.ts() + ts.dump("single_tree_example.trees") h = np.zeros(4) h[j] = 1 left, right, parent, match = run_match(ts, h) From 2c59357a06ff2ecd36ffbb09bdac318cabe39e1c Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 28 Apr 2023 22:59:07 +0100 Subject: [PATCH 14/42] add test dataq --- lib/test_data/single_tree_example.trees | Bin 0 -> 6484 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 lib/test_data/single_tree_example.trees diff --git a/lib/test_data/single_tree_example.trees b/lib/test_data/single_tree_example.trees new file mode 100644 index 0000000000000000000000000000000000000000..991f94d02e9147833aaaec74775fa1e32c7ab917 GIT binary patch literal 6484 zcmbW5O>7%Q6vvm3LI|{kZxRqPZVwfz?Oi)@J`j+qfRLa{q^1Whn2mRwtg;_wXPu^$ z3Jx3qA;B?6BqW5yu^=H%<$$QffnKO`tU#Q4$*~9!%SsT+mM%a z%j-Adk0{aFCZ5E?;SgE=ApS9pA1fr40r_ug{P%|W(Z9nJD&7N>54>srCW%M?cMSW7 z=^fMf32Fr5!A~F2zZ&uAe~I`zbwHCb{}+gVn(W7kA2omQZxTO6{A-3h>gU8iMEpAO zTDb!v_V0(pW7$o4U#8dF;RC1M=uki+J?kw0`tI zAs+n)kND`5KEnQ~#(!kkfAsGwZT_Y_`u832v<)p&9{GPFo(5wc`L7+Ye)R88;-8=f zqkhz<_Z>3U?;qlyA;ag?f7s!#kH|muu(Hr#eQmgYh<}1We13q3UC#>{`Crla%ZB;U zpSOs|`wKkm$fu962gGClT&4W5>l|d*-_!D6H{`K@KGFDHLmvIzCLZtmiN{riI1f>; zKBE3FwfWnIJl5}$R{t{bdOlr7{XY_qWe1P`=p1C^zozk9hWRo78`}KAkLn-l`bm>>KS@ze$j{I_t>6O)YkUE*we&=+kkaabr-VK_1x6mFSX)M zhcnHk#oOyFB+ss!u!tNIe{DVO#7XF;wa5(_-EOjk`*Gw)9@`S>Wrmg?ah9Y)+lmv9 zCHrg%3*${U=v(|v-|M*nuXO@~-aMTjK{yLs=v zk^s)$=`j(x6|iV6U6(!MwJcSdh-WKX(xFQ58G!jIPe?;uXc3iXHx{FKFlAM(UgW2| z*YiC~p1<-PdzbjhvrC@W_lTeAxhGr?bjdu(g?xiKCl*z3^D*Tcep*cv#(0Lx|Bpl5 zqPa>(S%a)c_CWSR_5|-HaXa#I#04ik5A?%E-SCN|xM)*qxKdHl| zg&z>%H)|ZB$5B?NA9e#KtIiHlS=Xs$WfgR;XN{_3SDRVW0)I7e6Q8LW zZn!*+sp&cYdex~q<%{CNn#*s0SysiKuGHsN>h%S?zR(c!wVg(z0WZti`$nHww8aHW ze&6G}3btF1*iYC;V~f3LBZ|*O7PGy&*JRDgv@`2f+U-WCGS_Z6mDQ%}tTvh**R8Yv E0n9!t`~Uy| literal 0 HcmV?d00001 From 5869f0a4ef661d2596596c61d74f8510240008b2 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sun, 30 Apr 2023 22:50:22 +0100 Subject: [PATCH 15/42] C code working and some tests --- lib/ancestor_matcher.c | 15 +++- lib/test_data/multi_tree_example.trees | Bin 0 -> 6188 bytes lib/tests/tests.c | 104 +++++++++++++++++++++---- lib/tsinfer.h | 1 - tests/test_ls_hmm.py | 2 +- 5 files changed, 104 insertions(+), 18 deletions(-) create mode 100644 lib/test_data/multi_tree_example.trees diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index da39dd8a..d31a4c80 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1235,9 +1235,9 @@ ancestor_matcher2_alloc(ancestor_matcher2_t *self, /* All allocs for arrays related to nodes are done in expand_nodes */ self->flags = flags; self->precision = precision; - self->max_nodes = 0; self->matcher_indexes = matcher_indexes; self->num_sites = matcher_indexes->num_sites; + self->num_nodes = matcher_indexes->num_nodes; self->recombination_rate = malloc(self->num_sites * sizeof(*self->recombination_rate)); self->mismatch_rate = malloc(self->num_sites * sizeof(*self->mismatch_rate)); @@ -1248,7 +1248,6 @@ ancestor_matcher2_alloc(ancestor_matcher2_t *self, * We're allocating num_sites anyway, so there's no memory saving. */ self->output.left = malloc(self->output.max_size * sizeof(tsk_id_t)); self->output.right = malloc(self->output.max_size * sizeof(tsk_id_t)); - self->output.parent = malloc(self->output.max_size * sizeof(tsk_id_t)); self->parent = malloc(self->num_nodes * sizeof(*self->parent)); @@ -1256,14 +1255,24 @@ ancestor_matcher2_alloc(ancestor_matcher2_t *self, self->right_child = malloc(self->num_nodes * sizeof(*self->right_child)); self->left_sib = malloc(self->num_nodes * sizeof(*self->left_sib)); self->right_sib = malloc(self->num_nodes * sizeof(*self->right_sib)); + self->likelihood = malloc(self->num_nodes * sizeof(*self->likelihood)); + self->allelic_state = malloc(self->num_nodes * sizeof(*self->allelic_state)); self->recombination_required = malloc(self->num_nodes * sizeof(*self->recombination_required)); + self->likelihood_cache = malloc(self->num_nodes * sizeof(*self->likelihood_cache)); + self->likelihood_nodes = malloc(self->num_nodes * sizeof(*self->likelihood_nodes)); + self->likelihood_nodes_tmp + = malloc(self->num_nodes * sizeof(*self->likelihood_nodes_tmp)); if (self->recombination_rate == NULL || self->mismatch_rate == NULL || self->traceback == NULL || self->max_likelihood_node == NULL + || self->parent == NULL || self->left_child == NULL || self->right_child == NULL + || self->left_sib == NULL || self->right_sib == NULL + || self->recombination_required == NULL || self->likelihood == NULL + || self->likelihood_cache == NULL || self->likelihood_nodes == NULL + || self->likelihood_nodes_tmp == NULL || self->allelic_state == NULL || self->output.left == NULL || self->output.right == NULL || self->output.parent == NULL) { - /* FIXME check tree allocs above */ ret = TSI_ERR_NO_MEMORY; goto out; } diff --git a/lib/test_data/multi_tree_example.trees b/lib/test_data/multi_tree_example.trees new file mode 100644 index 0000000000000000000000000000000000000000..b783471d908bdf4b09629fdd096211e14ce480b7 GIT binary patch literal 6188 zcmbW5OKjX!6oyTq5K>x(cTx}nQ3Qg9$xBKRf=rQEv7iVxY#7Y=PGV)o9&ArQ2#E>- zDhP>6-DFt_@lrQP6&qG`fe0!RVnG*dc!>?&WP=Ez75I$*-|?J`<4Y{L*XP{xo%7#w zJ+{+#o;>=@{_%t3S2&LISb83wNdNgC$5n{^^7`yzAMD1>!#3R8hHK$D4}KhW&ypA2 za{K`PKxR7Y;9-^ydolAH_*)czT`8GOu>PjvH!bV4fBS1$KGq}8ewN8#|0cmR{(dZk zb{S9qA;llJxp!9RW0v*VpXb5z z{Y9ShtfxoLcfj-hc>(n~uQ}SBe@)fzS@OJp-c|f@OP>8*1JCb;)8Lt(di6;CAE@=O zTk_n$vr7L4cwJB1)PD{-=FRJyEuj(J?{{c_E4b>)X)A+g6H=KdG@EO{=DK(TGnU&r@*5PjxA69E_gma$TNSH{#U{C z{#mos&-(9!&%a;6SJmhKo&nGAhtrlk^M9fAe*&KOXW9Cx|2w7sGw_@*<5~YAcs@VZ z!IyFLGymU8KY7kGUXQGQ05?AO?`O;U?EgLBsh@mR|EOmMJb%9q{*Jze0TtsT;8|}B z{GkfYcuOX(x0d~mmtA&DJISuPCUHGCafhvD*lx#0g*4;VvgsnS=SC(#K;$niC+#rm zx=B57y9Sr5CW`$q@B`1hl9rbWO+ScDl%(Ef7ysVm!g5^SU4tztmpcTKmScQ zPGvQEQMhUXH)y5eBhyNM3EMr)QV10h>6UbR#pv7#AT}@cO`5scF~L%@EIo_sgHlP9 z{ ziHwN{spbFw_|A}C?nzJSuV@ra@_u1&C64benUffa6-MG@PGV%8!th<0UREA6>A|3) z4#~6Ap=1A5hp@$%yAyvUPBvdJ8yPPQPx0wPl^JPjW+!`O0_7BbGWWYV*jf4gG@GV zuudbh#~S@$B?w;*oLMvD%`Ys>PR%XMOi#@j)1I22dw728$eh=nesQ5~+S9ZD0n#G2 AQ2+n{ literal 0 HcmV?d00001 diff --git a/lib/tests/tests.c b/lib/tests/tests.c index 510b0f07..d0e9d022 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -41,6 +41,10 @@ char *_tmp_file_name; FILE *_devnull; +/* FIXME add drawings and descriptions of the trees here */ +tsk_treeseq_t _single_tree_ex_ts; +tsk_treeseq_t _multi_tree_ex_ts; + static void dump_tree_sequence_builder( tree_sequence_builder_t *tsb, tsk_table_collection_t *tables, tsk_flags_t options) @@ -1017,6 +1021,7 @@ run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, ret = ancestor_matcher2_find_path(&am, 0, (tsk_id_t) m, h, match); CU_ASSERT_EQUAL_FATAL(ret, 0); + /* ancestor_matcher2_print_state(&am, stdout); */ *path_length = am.output.size; for (j = 0; j < am.output.size; j++) { @@ -1025,8 +1030,6 @@ run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, parent[j] = am.output.parent[j]; } - /* ancestor_matcher2_print_state(&am, stdout); */ - ancestor_matcher2_free(&am); matcher_indexes_free(&mi); free(recombination_rate); @@ -1036,23 +1039,75 @@ run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, } static void -test_matching_simplest_tree_one_site(void) +check_matching_single_site_match(const tsk_treeseq_t *ts) { - int ret = 0; - tsk_treeseq_t ts; allele_t h[] = { 0, 0, 0, 0 }; allele_t match[4]; + tsk_id_t j, left[4], right[4], parent[4]; + tsk_size_t path_length; + + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_sites(ts), 4); + + for (j = 0; j < 4; j++) { + memset(h, 0, sizeof(h)); + h[j] = 1; + run_match(ts, 1e-8, 0, h, match, &path_length, left, right, parent); + CU_ASSERT_EQUAL_FATAL(path_length, 1); + CU_ASSERT_EQUAL_FATAL(left[0], 0); + CU_ASSERT_EQUAL_FATAL(right[0], 4); + CU_ASSERT_EQUAL_FATAL(parent[0], j + 1); + } +} + +static void +test_matching_single_tree_single_site_match(void) +{ + check_matching_single_site_match(&_single_tree_ex_ts); +} + +static void +test_matching_multi_tree_single_site_match(void) +{ + check_matching_single_site_match(&_multi_tree_ex_ts); +} + +static void +check_matching_multi_switch(const tsk_treeseq_t *ts) +{ + allele_t h[] = { 1, 1, 1, 1 }; + allele_t match[4]; tsk_id_t left[4], right[4], parent[4]; tsk_size_t path_length; - ret = tsk_treeseq_load(&ts, TEST_DATA_DIR "/single_tree_example.trees", 0); - CU_ASSERT_EQUAL_FATAL(ret, 0); - CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_sites(&ts), 4); + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_num_sites(ts), 4); + CU_ASSERT_EQUAL_FATAL(tsk_treeseq_get_sequence_length(ts), 4); + + run_match(ts, 1e-8, 0, h, match, &path_length, left, right, parent); + CU_ASSERT_EQUAL_FATAL(path_length, 4); + CU_ASSERT_EQUAL_FATAL(left[3], 0); + CU_ASSERT_EQUAL_FATAL(right[3], 1); + CU_ASSERT_EQUAL_FATAL(parent[3], 1); + CU_ASSERT_EQUAL_FATAL(left[2], 1); + CU_ASSERT_EQUAL_FATAL(right[2], 2); + CU_ASSERT_EQUAL_FATAL(parent[2], 2); + CU_ASSERT_EQUAL_FATAL(left[1], 2); + CU_ASSERT_EQUAL_FATAL(right[1], 3); + CU_ASSERT_EQUAL_FATAL(parent[1], 3); + CU_ASSERT_EQUAL_FATAL(left[0], 3); + CU_ASSERT_EQUAL_FATAL(right[0], 4); + CU_ASSERT_EQUAL_FATAL(parent[0], 4); +} - h[0] = 1; - run_match(&ts, 1e-8, 0, h, match, &path_length, left, right, parent); +static void +test_matching_single_tree_multi_switch(void) +{ + check_matching_multi_switch(&_single_tree_ex_ts); +} - tsk_treeseq_free(&ts); +static void +test_matching_multi_tree_multi_switch(void) +{ + check_matching_multi_switch(&_multi_tree_ex_ts); } static void @@ -1074,11 +1129,13 @@ test_strerror(void) static int tsinfer_suite_init(void) { - int fd; + int ret, fd; static char template[] = "/tmp/tsi_c_test_XXXXXX"; _tmp_file_name = NULL; _devnull = NULL; + memset(&_single_tree_ex_ts, 0, sizeof(_single_tree_ex_ts)); + memset(&_multi_tree_ex_ts, 0, sizeof(_multi_tree_ex_ts)); _tmp_file_name = malloc(sizeof(template)); if (_tmp_file_name == NULL) { @@ -1094,6 +1151,18 @@ tsinfer_suite_init(void) if (_devnull == NULL) { return CUE_SINIT_FAILED; } + + ret = tsk_treeseq_load( + &_single_tree_ex_ts, TEST_DATA_DIR "/single_tree_example.trees", 0); + if (ret != 0) { + return CUE_SINIT_FAILED; + } + ret = tsk_treeseq_load( + &_multi_tree_ex_ts, TEST_DATA_DIR "/multi_tree_example.trees", 0); + if (ret != 0) { + return CUE_SINIT_FAILED; + } + return CUE_SUCCESS; } @@ -1107,6 +1176,8 @@ tsinfer_suite_cleanup(void) if (_devnull != NULL) { fclose(_devnull); } + tsk_treeseq_free(&_single_tree_ex_ts); + tsk_treeseq_free(&_multi_tree_ex_ts); return CUE_SUCCESS; } @@ -1147,7 +1218,14 @@ main(int argc, char **argv) { "test_packbits_4", test_packbits_4 }, { "test_packbits_errors", test_packbits_errors }, - { "test_matching_simplest_tree_one_site", test_matching_simplest_tree_one_site }, + { "test_matching_single_tree_single_site_match", + test_matching_single_tree_single_site_match }, + { "test_matching_multi_tree_single_site_match", + test_matching_multi_tree_single_site_match }, + { "test_matching_single_tree_multi_switch", + test_matching_single_tree_multi_switch }, + { "test_matching_multi_tree_multi_switch", + test_matching_multi_tree_multi_switch }, { "test_strerror", test_strerror }, diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 3063180f..8baf6766 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -232,7 +232,6 @@ typedef struct { const matcher_indexes_t *matcher_indexes; size_t num_nodes; size_t num_sites; - size_t max_nodes; /* Input LS model rates */ unsigned int precision; double *recombination_rate; diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index d1080c1f..5ae9be50 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -576,7 +576,6 @@ def ts(): @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): ts = self.ts() - ts.dump("single_tree_example.trees") h = np.zeros(4) h[j] = 1 left, right, parent, match = run_match(ts, h) @@ -651,6 +650,7 @@ def ts(): @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): ts = self.ts() + ts.dump("multi_tree_example.trees") m = 4 h = np.zeros(m) h[j] = 1 From 0297f36e15747f8a2f06c8dda5f8746db2fee9e1 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Mon, 1 May 2023 22:55:40 +0100 Subject: [PATCH 16/42] Factor out the "output" struct in matcher --- lib/ancestor_matcher.c | 52 ++++++++++++++++++------------------------ lib/tests/tests.c | 10 ++------ lib/tsinfer.h | 10 ++------ 3 files changed, 26 insertions(+), 46 deletions(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index d31a4c80..0bf6f057 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1241,14 +1241,8 @@ ancestor_matcher2_alloc(ancestor_matcher2_t *self, self->recombination_rate = malloc(self->num_sites * sizeof(*self->recombination_rate)); self->mismatch_rate = malloc(self->num_sites * sizeof(*self->mismatch_rate)); - self->output.max_size = self->num_sites; /* We can probably make this smaller */ self->traceback = calloc(self->num_sites, sizeof(node_state_list_t)); self->max_likelihood_node = malloc(self->num_sites * sizeof(tsk_id_t)); - /* TODO get rid of output and just provide pointers from client code. - * We're allocating num_sites anyway, so there's no memory saving. */ - self->output.left = malloc(self->output.max_size * sizeof(tsk_id_t)); - self->output.right = malloc(self->output.max_size * sizeof(tsk_id_t)); - self->output.parent = malloc(self->output.max_size * sizeof(tsk_id_t)); self->parent = malloc(self->num_nodes * sizeof(*self->parent)); self->left_child = malloc(self->num_nodes * sizeof(*self->left_child)); @@ -1270,9 +1264,7 @@ ancestor_matcher2_alloc(ancestor_matcher2_t *self, || self->left_sib == NULL || self->right_sib == NULL || self->recombination_required == NULL || self->likelihood == NULL || self->likelihood_cache == NULL || self->likelihood_nodes == NULL - || self->likelihood_nodes_tmp == NULL || self->allelic_state == NULL - || self->output.left == NULL || self->output.right == NULL - || self->output.parent == NULL) { + || self->likelihood_nodes_tmp == NULL || self->allelic_state == NULL) { ret = TSI_ERR_NO_MEMORY; goto out; } @@ -1311,9 +1303,6 @@ ancestor_matcher2_free(ancestor_matcher2_t *self) tsi_safe_free(self->allelic_state); tsi_safe_free(self->max_likelihood_node); tsi_safe_free(self->traceback); - tsi_safe_free(self->output.left); - tsi_safe_free(self->output.right); - tsi_safe_free(self->output.parent); tsk_blkalloc_free(&self->traceback_allocator); return 0; } @@ -1704,7 +1693,9 @@ ancestor_matcher2_unset_recombination_required( static int WARN_UNUSED ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, - const allele_t *TSK_UNUSED(haplotype), allele_t *match) + const allele_t *TSK_UNUSED(haplotype), allele_t *restrict match, + size_t *path_length_out, tsk_id_t *restrict path_left, tsk_id_t *restrict path_right, + tsk_id_t *restrict path_parent) { int ret = 0; tsk_id_t l; @@ -1718,17 +1709,17 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i const edge_t *restrict out = self->matcher_indexes->left_index_edges; int_fast32_t in_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; int_fast32_t out_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; + size_t path_length = 0; /* Prepare for the traceback and get the memory ready for recording * the output edges. */ - self->output.size = 0; - self->output.right[self->output.size] = end; - self->output.parent[self->output.size] = NULL_NODE; + path_right[path_length] = end; + path_parent[path_length] = NULL_NODE; max_likelihood_node = self->max_likelihood_node[end - 1]; assert(max_likelihood_node != NULL_NODE); - self->output.parent[self->output.size] = max_likelihood_node; - assert(self->output.parent[self->output.size] != NULL_NODE); + path_parent[path_length] = max_likelihood_node; + assert(path_parent[path_length] != NULL_NODE); /* Now go through the trees in reverse and run the traceback */ memset(parent, 0xff, self->num_nodes * sizeof(*parent)); @@ -1761,7 +1752,7 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i assert(left < right); for (l = TSK_MIN(right, end) - 1; l >= (int) TSK_MAX(left, start); l--) { ancestor_matcher2_set_allelic_state(self, l, allelic_state); - u = self->output.parent[self->output.size]; + u = path_parent[path_length]; v = u; while (allelic_state[v] == TSK_NULL) { v = parent[v]; @@ -1782,12 +1773,11 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i if (recombination_required[u] && l > start) { max_likelihood_node = self->max_likelihood_node[l - 1]; assert(max_likelihood_node != NULL_NODE); - self->output.left[self->output.size] = l; - self->output.size++; - assert(self->output.size < self->output.max_size); + path_left[path_length] = l; + path_length++; /* Start the next output edge */ - self->output.right[self->output.size] = l; - self->output.parent[self->output.size] = max_likelihood_node; + path_right[path_length] = l; + path_parent[path_length] = max_likelihood_node; } /* Unset the values in the tree for the next site. */ ancestor_matcher2_unset_recombination_required( @@ -1795,9 +1785,10 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i } } - self->output.left[self->output.size] = start; - self->output.size++; - assert(self->output.right[self->output.size - 1] != start); + path_left[path_length] = start; + path_length++; + assert(path_right[path_length - 1] != start); + *path_length_out = path_length; return ret; } @@ -2010,7 +2001,8 @@ ancestor_matcher2_run_forwards_match( int ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, - const allele_t *haplotype, allele_t *matched_haplotype) + const allele_t *haplotype, allele_t *matched_haplotype, size_t *path_length, + tsk_id_t *path_left, tsk_id_t *path_right, tsk_id_t *path_parent) { int ret = 0; @@ -2022,8 +2014,8 @@ ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t if (ret != 0) { goto out; } - ret = ancestor_matcher2_run_traceback( - self, start, end, haplotype, matched_haplotype); + ret = ancestor_matcher2_run_traceback(self, start, end, haplotype, matched_haplotype, + path_length, path_left, path_right, path_parent); if (ret != 0) { goto out; } diff --git a/lib/tests/tests.c b/lib/tests/tests.c index d0e9d022..314e8cc6 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -1019,17 +1019,11 @@ run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, ret = ancestor_matcher2_alloc(&am, &mi, recombination_rate, mutation_rate, 14, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); - ret = ancestor_matcher2_find_path(&am, 0, (tsk_id_t) m, h, match); + ret = ancestor_matcher2_find_path( + &am, 0, (tsk_id_t) m, h, match, path_length, left, right, parent); CU_ASSERT_EQUAL_FATAL(ret, 0); /* ancestor_matcher2_print_state(&am, stdout); */ - *path_length = am.output.size; - for (j = 0; j < am.output.size; j++) { - left[j] = am.output.left[j]; - right[j] = am.output.right[j]; - parent[j] = am.output.parent[j]; - } - ancestor_matcher2_free(&am); matcher_indexes_free(&mi); free(recombination_rate); diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 8baf6766..53138048 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -257,13 +257,6 @@ typedef struct { size_t total_traceback_size; size_t traceback_block_size; size_t traceback_realloc_size; - struct { - tsk_id_t *left; - tsk_id_t *right; - tsk_id_t *parent; - size_t size; - size_t max_size; - } output; } ancestor_matcher2_t; int ancestor_builder_alloc(ancestor_builder_t *self, size_t num_samples, @@ -335,7 +328,8 @@ int ancestor_matcher2_alloc(ancestor_matcher2_t *self, double *mismatch_rate, unsigned int precision, int flags); int ancestor_matcher2_free(ancestor_matcher2_t *self); int ancestor_matcher2_find_path(ancestor_matcher2_t *self, tsk_id_t start, tsk_id_t end, - const allele_t *haplotype, allele_t *matched_haplotype); + const allele_t *haplotype, allele_t *matched_haplotype, size_t *path_length, + tsk_id_t *path_left, tsk_id_t *path_right, tsk_id_t *path_parent); int ancestor_matcher2_print_state(ancestor_matcher2_t *self, FILE *out); double ancestor_matcher2_get_mean_traceback_size(ancestor_matcher2_t *self); size_t ancestor_matcher2_get_total_memory(ancestor_matcher2_t *self); From 698cc191ff7e1cfdc903ebd83aa5791fff9d2485 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 2 May 2023 22:06:00 +0100 Subject: [PATCH 17/42] Some basics for new Matcher infrastructure --- _tsinfermodule.c | 104 ++++++++++++++++++++++++++++++++++++++++ lib/tsinfer.h | 2 +- setup.py | 2 +- tests/test_low_level.py | 8 ++++ 4 files changed, 114 insertions(+), 2 deletions(-) diff --git a/_tsinfermodule.c b/_tsinfermodule.c index 0e6b62f7..028d99be 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -36,6 +36,17 @@ typedef struct { TreeSequenceBuilder *tree_sequence_builder; } AncestorMatcher; +typedef struct { + PyObject_HEAD + matcher_indexes_t *matcher_indexes; +} MatcherIndexes; + +typedef struct { + PyObject_HEAD + ancestor_matcher2_t *ancestor_matcher; + MatcherIndexes *matcher_indexes; +} AncestorMatcher2; + static void handle_library_error(int err) { @@ -1579,6 +1590,90 @@ static PyTypeObject AncestorMatcherType = { (initproc)AncestorMatcher_init, /* tp_init */ }; + +/*=================================================================== + * MatcherIndexes + *=================================================================== + */ + +static void +MatcherIndexes_dealloc(MatcherIndexes* self) +{ + if (self->matcher_indexes != NULL) { + matcher_indexes_free(self->matcher_indexes); + PyMem_Free(self->matcher_indexes); + self->matcher_indexes = NULL; + } + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + static char *kwlist[] = {"tree_sequence", NULL}; + + self->matcher_indexes = NULL; + /* FIXME */ + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist)) { + goto out; + } + + self->matcher_indexes = PyMem_Malloc(sizeof(*self->matcher_indexes)); + if (self->matcher_indexes == NULL) { + PyErr_NoMemory(); + goto out; + } + err = matcher_indexes_alloc(self->matcher_indexes, NULL, 0); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + return ret; +} +/* TODO update to c99 form */ +static PyTypeObject MatcherIndexesType = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tsinfer.MatcherIndexes", /* tp_name */ + sizeof(MatcherIndexes), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)MatcherIndexes_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "MatcherIndexes objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + 0, /* tp_methods */ + 0, /* tp_members */ + 0, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)MatcherIndexes_init, /* tp_init */ +}; + /*=================================================================== * Module level code. *=================================================================== @@ -1650,6 +1745,15 @@ init_tsinfer(void) Py_INCREF(&TreeSequenceBuilderType); PyModule_AddObject(module, "TreeSequenceBuilder", (PyObject *) &TreeSequenceBuilderType); + + /* MatcherIndexes type */ + MatcherIndexesType.tp_new = PyType_GenericNew; + if (PyType_Ready(&MatcherIndexesType) < 0) { + INITERROR; + } + Py_INCREF(&MatcherIndexesType); + PyModule_AddObject(module, "MatcherIndexes", (PyObject *) &MatcherIndexesType); + TsinfLibraryError = PyErr_NewException("_tsinfer.LibraryError", NULL, NULL); Py_INCREF(TsinfLibraryError); PyModule_AddObject(module, "LibraryError", TsinfLibraryError); diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 53138048..bbc324e9 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -320,7 +320,7 @@ int tree_sequence_builder_dump_mutations(tree_sequence_builder_t *self, tsk_id_t /* New impelementation */ int matcher_indexes_alloc( - matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t flags); + matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t options); int matcher_indexes_free(matcher_indexes_t *self); int ancestor_matcher2_alloc(ancestor_matcher2_t *self, diff --git a/setup.py b/setup.py index f0579d28..d162e019 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ ] # We're not actually using very much of tskit at the moment, so # just build the stuff we need. -tsk_source_files = ["core.c"] +tsk_source_files = ["core.c", "tables.c", "trees.c"] kas_source_files = ["kastore.c"] sources = ( diff --git a/tests/test_low_level.py b/tests/test_low_level.py index 550a8994..30d3f8de 100644 --- a/tests/test_low_level.py +++ b/tests/test_low_level.py @@ -22,6 +22,7 @@ import sys import pytest +import tskit import _tsinfer @@ -151,3 +152,10 @@ def test_add_too_many_sites(self): assert str(record.value) == msg # TODO need tester methods for the remaining methonds in the class. + + +class TestMatcherIndexes: + def test_single_tree(self): + ts = tskit.Tree.generate_balanced(4) + mi = _tsinfer.MatcherIndexes(ts) + print(mi) From 5f3e7d9dfec8b1fe4312989a66ca8605f079c772 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 2 May 2023 22:16:03 +0100 Subject: [PATCH 18/42] compiling with lwt interface --- _tsinfermodule.c | 4 ++++ lwt_interface | 1 + setup.py | 2 +- 3 files changed, 6 insertions(+), 1 deletion(-) create mode 120000 lwt_interface diff --git a/_tsinfermodule.c b/_tsinfermodule.c index 028d99be..6a3ff5cb 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -20,6 +20,8 @@ static PyObject *TsinfLibraryError; static PyObject *TsinfMatchImpossible; +#include "tskit_lwt_interface.h" + typedef struct { PyObject_HEAD ancestor_builder_t *builder; @@ -1723,6 +1725,8 @@ init_tsinfer(void) /* Initialise numpy */ import_array(); + register_lwt_class(module); + /* AncestorBuilder type */ AncestorBuilderType.tp_new = PyType_GenericNew; if (PyType_Ready(&AncestorBuilderType) < 0) { diff --git a/lwt_interface b/lwt_interface new file mode 120000 index 00000000..30dc544d --- /dev/null +++ b/lwt_interface @@ -0,0 +1 @@ +git-submodules/tskit/python/lwt_interface/ \ No newline at end of file diff --git a/setup.py b/setup.py index d162e019..6c0f1a31 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ tskroot = os.path.join(libdir, "subprojects", "tskit") tskdir = os.path.join(tskroot, "tskit") kasdir = os.path.join(tskroot, "subprojects", "kastore") -includes = [libdir, tskroot, tskdir, kasdir] +includes = ["lwt_interface", libdir, tskroot, tskdir, kasdir] tsi_source_files = [ "ancestor_matcher.c", From 6c7748d0aac666ce4b90152f2186bb74d729f932 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 2 May 2023 22:34:56 +0100 Subject: [PATCH 19/42] partial update to use table collection --- _tsinfermodule.c | 9 +++++---- tests/test_low_level.py | 7 +++++-- tests/test_ls_hmm.py | 35 ++++++++++++++++++++--------------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/_tsinfermodule.c b/_tsinfermodule.c index 6a3ff5cb..dc8b34a9 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -1614,11 +1614,12 @@ MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) { int ret = -1; int err; - static char *kwlist[] = {"tree_sequence", NULL}; + tsk_table_collection_t *tables; + static char *kwlist[] = {"tables", NULL}; self->matcher_indexes = NULL; - /* FIXME */ - if (!PyArg_ParseTupleAndKeywords(args, kwds, "O", kwlist)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!", kwlist, + &LightweightTableCollectionType, &tables)) { goto out; } @@ -1627,7 +1628,7 @@ MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) PyErr_NoMemory(); goto out; } - err = matcher_indexes_alloc(self->matcher_indexes, NULL, 0); + err = matcher_indexes_alloc(self->matcher_indexes, tables, 0); if (err != 0) { handle_library_error(err); goto out; diff --git a/tests/test_low_level.py b/tests/test_low_level.py index 30d3f8de..cd5ae559 100644 --- a/tests/test_low_level.py +++ b/tests/test_low_level.py @@ -156,6 +156,9 @@ def test_add_too_many_sites(self): class TestMatcherIndexes: def test_single_tree(self): - ts = tskit.Tree.generate_balanced(4) - mi = _tsinfer.MatcherIndexes(ts) + ts = tskit.Tree.generate_balanced(4).tree_sequence + tables = ts.dump_tables() + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + mi = _tsinfer.MatcherIndexes(ll_tables) print(mi) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 5ae9be50..d8f7a9b1 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -24,10 +24,10 @@ class Edge: # not present in the current tree. -def convert_edge_list(ts, order): +def convert_edge_list(edges, order): values = [] for j in order: - tsk_edge = ts.edge(j) + tsk_edge = edges[j] edge = Edge( int(tsk_edge.left), int(tsk_edge.right), tsk_edge.parent, tsk_edge.child ) @@ -40,23 +40,28 @@ class MatcherIndexes: The memory that can be shared between AncestorMatcher instances. """ - def __init__(self, ts): - self.num_nodes = ts.num_nodes - self.num_sites = ts.num_sites + def __init__(self, tables): + self.num_nodes = len(tables.nodes) + self.num_sites = len(tables.sites) # Store the edges in left and right order. - self.left_index = convert_edge_list(ts, ts.tables.indexes.edge_insertion_order) - self.right_index = convert_edge_list(ts, ts.tables.indexes.edge_removal_order) + self.left_index = convert_edge_list( + tables.edges, tables.indexes.edge_insertion_order + ) + self.right_index = convert_edge_list( + tables.edges, tables.indexes.edge_removal_order + ) - # TODO update - self.num_alleles = [var.num_alleles for var in ts.variants()] + # TODO fixme + self.num_alleles = np.zeros(self.num_sites, dtype=int) + 2 self.mutations = collections.defaultdict(list) - for site in ts.sites(): - if len(site.mutations) > 1: + last_site = -1 + for mutation in tables.mutations: + if last_site == mutation.site: raise ValueError("Only single mutations supported for now") - for mutation in site.mutations: - # FIXME - should be allele index - self.mutations[site.id].append((mutation.node, 1)) + # FIXME - should be allele index + self.mutations[mutation.site].append((mutation.node, 1)) + last_site = mutation.site COMPRESSED = -1 @@ -509,7 +514,7 @@ def run_traceback(self, start, end, match): def run_match(ts, h): assert len(h) == ts.num_sites - matcher_indexes = MatcherIndexes(ts) + matcher_indexes = MatcherIndexes(ts.tables) matcher = AncestorMatcher(matcher_indexes) match = np.zeros(ts.num_sites, dtype=int) left, right, parent = matcher.find_path(h, 0, ts.num_sites, match) From 211aae0be91fc50232066773d08753b36156d98b Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 2 May 2023 22:54:50 +0100 Subject: [PATCH 20/42] Convert C code to use tables --- lib/ancestor_matcher.c | 69 ++++++++++++++++++++++-------------------- lib/tests/tests.c | 2 +- lib/tsinfer.h | 2 +- 3 files changed, 39 insertions(+), 34 deletions(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index 0bf6f057..acc1979d 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1016,18 +1016,19 @@ ancestor_matcher_get_total_memory(ancestor_matcher_t *self) /* New implementation */ static int -matcher_indexes_copy_edge_indexes(matcher_indexes_t *self, const tsk_treeseq_t *ts) +matcher_indexes_copy_edge_indexes( + matcher_indexes_t *self, const tsk_table_collection_t *tables) { int ret = 0; tsk_size_t j; tsk_id_t k; edge_t e; - const tsk_id_t *restrict I = ts->tables->indexes.edge_insertion_order; - const tsk_id_t *restrict O = ts->tables->indexes.edge_removal_order; - const double *restrict edges_right = ts->tables->edges.right; - const double *restrict edges_left = ts->tables->edges.left; - const tsk_id_t *restrict edges_child = ts->tables->edges.child; - const tsk_id_t *restrict edges_parent = ts->tables->edges.parent; + const tsk_id_t *restrict I = tables->indexes.edge_insertion_order; + const tsk_id_t *restrict O = tables->indexes.edge_removal_order; + const double *restrict edges_right = tables->edges.right; + const double *restrict edges_left = tables->edges.left; + const tsk_id_t *restrict edges_child = tables->edges.child; + const tsk_id_t *restrict edges_parent = tables->edges.parent; for (j = 0; j < self->num_edges; j++) { k = I[j]; @@ -1078,47 +1079,51 @@ matcher_indexes_add_mutation( } static int -matcher_indexes_copy_mutation_data(matcher_indexes_t *self, const tsk_treeseq_t *ts) +matcher_indexes_copy_mutation_data( + matcher_indexes_t *self, const tsk_table_collection_t *tables) { int ret = 0; - tsk_site_t site; - tsk_size_t j, k; - - for (j = 0; j < self->num_sites; j++) { - ret = tsk_treeseq_get_site(ts, (tsk_id_t) j, &site); - if (ret != 0) { - goto out; - } - if (site.mutations_length > 1) { + tsk_size_t j; + tsk_id_t site, last_site; + const tsk_id_t *restrict mutations_site = tables->mutations.site; + const tsk_id_t *restrict mutations_node = tables->mutations.node; + const tsk_size_t total_mutations = tables->mutations.num_rows; + + last_site = -1; + for (j = 0; j < total_mutations; j++) { + site = mutations_site[j]; + if (site == last_site) { ret = TSI_ERR_GENERIC; goto out; } - self->sites.mutations[j] = NULL; - /* FIXME need to through this properly when we've got things working. */ - self->sites.num_alleles[j] = site.mutations_length + 1; - for (k = 0; k < site.mutations_length; k++) { - ret = matcher_indexes_add_mutation(self, site.id, site.mutations[k].node, 1); - if (ret != 0) { - goto out; - } + self->sites.mutations[site] = NULL; + /* FIXME */ + self->sites.num_alleles[site] = 2; + ret = matcher_indexes_add_mutation(self, site, mutations_node[j], 1); + if (ret != 0) { + goto out; } + last_site = site; } + out: return ret; } int matcher_indexes_alloc( - matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t flags) + matcher_indexes_t *self, const tsk_table_collection_t *tables, tsk_flags_t flags) { int ret = 0; self->flags = flags; - self->num_edges = tsk_treeseq_get_num_edges(ts); - self->num_nodes = tsk_treeseq_get_num_nodes(ts); - self->num_sites = tsk_treeseq_get_num_sites(ts); - self->num_mutations = tsk_treeseq_get_num_mutations(ts); + self->num_edges = tables->edges.num_rows; + self->num_nodes = tables->nodes.num_rows; + self->num_sites = tables->sites.num_rows; + /* FIXME this is used below by the code that adds in mutations in the linked + * list, so *don't* set from the tables */ + self->num_mutations = 0; self->left_index_edges = malloc(self->num_edges * sizeof(*self->left_index_edges)); self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); @@ -1134,11 +1139,11 @@ matcher_indexes_alloc( goto out; } - ret = matcher_indexes_copy_edge_indexes(self, ts); + ret = matcher_indexes_copy_edge_indexes(self, tables); if (ret != 0) { goto out; } - ret = matcher_indexes_copy_mutation_data(self, ts); + ret = matcher_indexes_copy_mutation_data(self, tables); if (ret != 0) { goto out; } diff --git a/lib/tests/tests.c b/lib/tests/tests.c index 314e8cc6..b576f7a8 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -1014,7 +1014,7 @@ run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, recombination_rate[j] = rho; } - ret = matcher_indexes_alloc(&mi, ts, 0); + ret = matcher_indexes_alloc(&mi, ts->tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); ret = ancestor_matcher2_alloc(&am, &mi, recombination_rate, mutation_rate, 14, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); diff --git a/lib/tsinfer.h b/lib/tsinfer.h index bbc324e9..fdb936da 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -320,7 +320,7 @@ int tree_sequence_builder_dump_mutations(tree_sequence_builder_t *self, tsk_id_t /* New impelementation */ int matcher_indexes_alloc( - matcher_indexes_t *self, const tsk_treeseq_t *ts, tsk_flags_t options); + matcher_indexes_t *self, const tsk_table_collection_t *tables, tsk_flags_t options); int matcher_indexes_free(matcher_indexes_t *self); int ancestor_matcher2_alloc(ancestor_matcher2_t *self, From 5d77002698718f8ba5bb10e36f25aacbab3e9d87 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 2 May 2023 23:03:36 +0100 Subject: [PATCH 21/42] Roughly working Python-C infrastructure --- _tsinfermodule.c | 12 ++++++++---- tests/test_ls_hmm.py | 8 ++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/_tsinfermodule.c b/_tsinfermodule.c index dc8b34a9..ac603de1 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -1602,7 +1602,7 @@ static void MatcherIndexes_dealloc(MatcherIndexes* self) { if (self->matcher_indexes != NULL) { - matcher_indexes_free(self->matcher_indexes); + /* matcher_indexes_free(self->matcher_indexes); */ PyMem_Free(self->matcher_indexes); self->matcher_indexes = NULL; } @@ -1614,7 +1614,7 @@ MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) { int ret = -1; int err; - tsk_table_collection_t *tables; + LightweightTableCollection *tables; static char *kwlist[] = {"tables", NULL}; self->matcher_indexes = NULL; @@ -1622,13 +1622,16 @@ MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) &LightweightTableCollectionType, &tables)) { goto out; } + if (LightweightTableCollection_check_state(tables) != 0) { + goto out; + } - self->matcher_indexes = PyMem_Malloc(sizeof(*self->matcher_indexes)); + self->matcher_indexes = PyMem_Calloc(1, sizeof(*self->matcher_indexes)); if (self->matcher_indexes == NULL) { PyErr_NoMemory(); goto out; } - err = matcher_indexes_alloc(self->matcher_indexes, tables, 0); + err = matcher_indexes_alloc(self->matcher_indexes, tables->tables, 0); if (err != 0) { handle_library_error(err); goto out; @@ -1637,6 +1640,7 @@ MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) out: return ret; } + /* TODO update to c99 form */ static PyTypeObject MatcherIndexesType = { PyVarObject_HEAD_INIT(NULL, 0) diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index d8f7a9b1..289caa92 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -675,3 +675,11 @@ def test_switch_each_sample(self): assert list(right) == [4, 3, 2, 1] assert list(parent) == [4, 3, 2, 1] np.testing.assert_array_equal(h, match) + + def test_matcher_indexes(self): + ts = self.ts() + tables = ts.dump_tables() + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + mi = _tsinfer.MatcherIndexes(ll_tables) + print(mi) From 5b4b02e883f622d94ee45d649616383266645ef8 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 3 May 2023 22:19:36 +0100 Subject: [PATCH 22/42] Add basic debug support to the MatcherIndexes --- _tsinfermodule.c | 120 ++++++++++++++++++++++++++++------------ lib/ancestor_matcher.c | 27 +++++++++ lib/tests/tests.c | 1 + lib/tsinfer.h | 4 +- tests/test_low_level.py | 18 ++++++ 5 files changed, 131 insertions(+), 39 deletions(-) diff --git a/_tsinfermodule.c b/_tsinfermodule.c index ac603de1..487aaa63 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -62,6 +62,33 @@ handle_library_error(int err) } } +static FILE * +make_file(PyObject *fileobj, const char *mode) +{ + FILE *ret = NULL; + FILE *file = NULL; + int fileobj_fd, new_fd; + + fileobj_fd = PyObject_AsFileDescriptor(fileobj); + if (fileobj_fd == -1) { + goto out; + } + new_fd = dup(fileobj_fd); + if (new_fd == -1) { + PyErr_SetFromErrno(PyExc_OSError); + goto out; + } + file = fdopen(new_fd, mode); + if (file == NULL) { + (void) close(new_fd); + PyErr_SetFromErrno(PyExc_OSError); + goto out; + } + ret = file; +out: + return ret; +} + static int uint64_PyArray_converter(PyObject *in, PyObject **out) { @@ -1641,46 +1668,67 @@ MatcherIndexes_init(MatcherIndexes *self, PyObject *args, PyObject *kwds) return ret; } -/* TODO update to c99 form */ +static int +MatcherIndexes_check_state(MatcherIndexes *self) +{ + int ret = 0; + if (self->matcher_indexes == NULL) { + PyErr_SetString(PyExc_SystemError, "MatcherIndexes not initialised"); + ret = -1; + } + return ret; +} + + +static PyObject * +MatcherIndexes_print_state(MatcherIndexes *self, PyObject *args) +{ + PyObject *ret = NULL; + PyObject *fileobj; + FILE *file = NULL; + + if (MatcherIndexes_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "O", &fileobj)) { + goto out; + } + file = make_file(fileobj, "w"); + if (file == NULL) { + goto out; + } + matcher_indexes_print_state(self->matcher_indexes, file); + ret = Py_BuildValue(""); +out: + if (file != NULL) { + (void) fclose(file); + } + return ret; +} + + +static PyMethodDef MatcherIndexes_methods[] = { + {"print_state", (PyCFunction) MatcherIndexes_print_state, + METH_VARARGS, "Low-level debug method"}, + {NULL} /* Sentinel */ +}; + + static PyTypeObject MatcherIndexesType = { + // clang-format off PyVarObject_HEAD_INIT(NULL, 0) - "_tsinfer.MatcherIndexes", /* tp_name */ - sizeof(MatcherIndexes), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)MatcherIndexes_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - "MatcherIndexes objects", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - 0, /* tp_methods */ - 0, /* tp_members */ - 0, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)MatcherIndexes_init, /* tp_init */ + .tp_name = "_tsinfer.MatcherIndexes", + .tp_basicsize = sizeof(MatcherIndexes), + .tp_dealloc = (destructor) MatcherIndexes_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_doc = "MatcherIndexes objects", + .tp_methods = MatcherIndexes_methods, + .tp_init = (initproc) MatcherIndexes_init, + .tp_new = PyType_GenericNew, + // clang-format on }; + /*=================================================================== * Module level code. *=================================================================== diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index acc1979d..78ec3738 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1015,6 +1015,33 @@ ancestor_matcher_get_total_memory(ancestor_matcher_t *self) /* New implementation */ +int +matcher_indexes_print_state(const matcher_indexes_t *self, FILE *out) +{ + size_t j; + mutation_list_node_t *u; + + fprintf(out, "Matcher indexes state\n"); + fprintf(out, "flags = %d\n", (int) self->flags); + fprintf(out, "num_sites = %d\n", (int) self->num_sites); + fprintf(out, "num_nodes = %d\n", (int) self->num_nodes); + fprintf(out, "num_edges = %d\n", (int) self->num_edges); + + fprintf(out, "Mutations = \n"); + fprintf(out, "site\t(node, derived_state),...\n"); + for (j = 0; j < self->num_sites; j++) { + if (self->sites.mutations[j] != NULL) { + fprintf(out, "%d\t", (int) j); + for (u = self->sites.mutations[j]; u != NULL; u = u->next) { + fprintf(out, "(%d, %d) ", u->node, u->derived_state); + } + fprintf(out, "\n"); + } + } + + return 0; +} + static int matcher_indexes_copy_edge_indexes( matcher_indexes_t *self, const tsk_table_collection_t *tables) diff --git a/lib/tests/tests.c b/lib/tests/tests.c index b576f7a8..0421123f 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -1016,6 +1016,7 @@ run_match(const tsk_treeseq_t *ts, double rho, double mu, const allele_t *h, ret = matcher_indexes_alloc(&mi, ts->tables, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); + /* matcher_indexes_print_state(&mi, stdout); */ ret = ancestor_matcher2_alloc(&am, &mi, recombination_rate, mutation_rate, 14, 0); CU_ASSERT_EQUAL_FATAL(ret, 0); diff --git a/lib/tsinfer.h b/lib/tsinfer.h index fdb936da..7f7319ba 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -219,9 +219,6 @@ typedef struct { mutation_list_node_t **mutations; tsk_size_t *num_alleles; } sites; - /* TODO add nodes struct */ - double *time; - uint32_t *node_flags; edge_t *left_index_edges; edge_t *right_index_edges; tsk_blkalloc_t allocator; @@ -321,6 +318,7 @@ int tree_sequence_builder_dump_mutations(tree_sequence_builder_t *self, tsk_id_t int matcher_indexes_alloc( matcher_indexes_t *self, const tsk_table_collection_t *tables, tsk_flags_t options); +int matcher_indexes_print_state(const matcher_indexes_t *self, FILE *out); int matcher_indexes_free(matcher_indexes_t *self); int ancestor_matcher2_alloc(ancestor_matcher2_t *self, diff --git a/tests/test_low_level.py b/tests/test_low_level.py index cd5ae559..36fd5050 100644 --- a/tests/test_low_level.py +++ b/tests/test_low_level.py @@ -20,6 +20,7 @@ Integrity tests for the low-level module. """ import sys +import tempfile import pytest import tskit @@ -162,3 +163,20 @@ def test_single_tree(self): ll_tables.fromdict(tables.asdict()) mi = _tsinfer.MatcherIndexes(ll_tables) print(mi) + mi.print_state(sys.stdout) + + def test_print_state(self): + ts = tskit.Tree.generate_balanced(4).tree_sequence + tables = ts.dump_tables() + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + mi = _tsinfer.MatcherIndexes(ll_tables) + with pytest.raises(TypeError): + mi.print_state() + + with tempfile.TemporaryFile("w+") as f: + mi.print_state(f) + f.seek(0) + output = f.read() + assert len(output) > 0 + assert "indexes" in output From b0546f69184ef4ece8f20b14aea8c0e1bae97f08 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 3 May 2023 22:47:38 +0100 Subject: [PATCH 23/42] Basic Python-C linkage works :hooray: --- _tsinfermodule.c | 346 ++++++++++++++++++++++++++++++++++++++++++- tests/test_ls_hmm.py | 48 ++++-- 2 files changed, 382 insertions(+), 12 deletions(-) diff --git a/_tsinfermodule.c b/_tsinfermodule.c index 487aaa63..4be7012f 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -1729,6 +1729,341 @@ static PyTypeObject MatcherIndexesType = { }; +/*=================================================================== + * AncestorMatcher2 + *=================================================================== + */ + +static int +AncestorMatcher2_check_state(AncestorMatcher2 *self) +{ + int ret = 0; + if (self->ancestor_matcher == NULL) { + PyErr_SetString(PyExc_SystemError, "AncestorMatcher2 not initialised"); + ret = -1; + } + return ret; +} + +static void +AncestorMatcher2_dealloc(AncestorMatcher2* self) +{ + if (self->ancestor_matcher != NULL) { + ancestor_matcher2_free(self->ancestor_matcher); + PyMem_Free(self->ancestor_matcher); + self->ancestor_matcher = NULL; + } + Py_XDECREF(self->matcher_indexes); + Py_TYPE(self)->tp_free((PyObject*)self); +} + +static int +AncestorMatcher2_init(AncestorMatcher2 *self, PyObject *args, PyObject *kwds) +{ + int ret = -1; + int err; + int extended_checks = 0; + static char *kwlist[] = {"matcher_indexes", "recombination", + "mismatch", "precision", "extended_checks", NULL}; + MatcherIndexes *matcher_indexes = NULL; + PyObject *recombination = NULL; + PyObject *mismatch = NULL; + PyArrayObject *recombination_array = NULL; + PyArrayObject *mismatch_array = NULL; + npy_intp *shape; + unsigned int precision = 22; + int flags = 0; + + self->ancestor_matcher = NULL; + self->matcher_indexes = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O!OO|Ii", kwlist, + &MatcherIndexesType, &matcher_indexes, + &recombination, &mismatch, &precision, + &extended_checks)) { + goto out; + } + self->matcher_indexes = matcher_indexes; + Py_INCREF(self->matcher_indexes); + if (MatcherIndexes_check_state(self->matcher_indexes) != 0) { + goto out; + } + + recombination_array = (PyArrayObject *) PyArray_FromAny(recombination, + PyArray_DescrFromType(NPY_FLOAT64), 1, 1, + NPY_ARRAY_IN_ARRAY, NULL); + if (recombination_array == NULL) { + goto out; + } + shape = PyArray_DIMS(recombination_array); + if (shape[0] != (npy_intp) matcher_indexes->matcher_indexes->num_sites) { + PyErr_SetString(PyExc_ValueError, + "Size of recombination array must be num_sites"); + goto out; + } + mismatch_array = (PyArrayObject *) PyArray_FromAny(mismatch, + PyArray_DescrFromType(NPY_FLOAT64), 1, 1, + NPY_ARRAY_IN_ARRAY, NULL); + if (mismatch_array == NULL) { + goto out; + } + shape = PyArray_DIMS(mismatch_array); + if (shape[0] != (npy_intp) matcher_indexes->matcher_indexes->num_sites) { + PyErr_SetString(PyExc_ValueError, "Size of mismatch array must be num_sites"); + goto out; + } + + self->ancestor_matcher = PyMem_Malloc(sizeof(ancestor_matcher2_t)); + if (self->ancestor_matcher == NULL) { + PyErr_NoMemory(); + goto out; + } + if (extended_checks) { + flags = TSI_EXTENDED_CHECKS; + } + err = ancestor_matcher2_alloc(self->ancestor_matcher, + self->matcher_indexes->matcher_indexes, + PyArray_DATA(recombination_array), + PyArray_DATA(mismatch_array), + precision, flags); + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = 0; +out: + Py_XDECREF(recombination_array); + Py_XDECREF(mismatch_array); + return ret; +} + +static PyObject * +AncestorMatcher2_find_path(AncestorMatcher2 *self, PyObject *args, PyObject *kwds) +{ + int err; + PyObject *ret = NULL; + static char *kwlist[] = {"haplotype", "start", "end", "match", NULL}; + PyObject *haplotype = NULL; + PyArrayObject *haplotype_array = NULL; + PyObject *match = NULL; + PyArrayObject *match_array = NULL; + npy_intp *shape; + size_t num_edges; + int start, end; + PyArrayObject *left = NULL; + PyArrayObject *right = NULL; + PyArrayObject *parent = NULL; + npy_intp dims[1]; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OiiO!", kwlist, + &haplotype, &start, &end, &PyArray_Type, &match)) { + goto out; + } + haplotype_array = (PyArrayObject *) PyArray_FROM_OTF(haplotype, NPY_INT8, + NPY_ARRAY_IN_ARRAY); + if (haplotype_array == NULL) { + goto out; + } + if (PyArray_NDIM(haplotype_array) != 1) { + PyErr_SetString(PyExc_ValueError, "Dim != 1"); + goto out; + } + shape = PyArray_DIMS(haplotype_array); + if (shape[0] != (npy_intp) self->ancestor_matcher->num_sites) { + PyErr_SetString(PyExc_ValueError, "Incorrect size for input haplotype."); + goto out; + } + + match_array = (PyArrayObject *) PyArray_FROM_OTF(match, NPY_INT8, + NPY_ARRAY_INOUT_ARRAY); + if (match_array == NULL) { + goto out; + } + if (PyArray_NDIM(match_array) != 1) { + PyErr_SetString(PyExc_ValueError, "Dim != 1"); + goto out; + } + shape = PyArray_DIMS(match_array); + if (shape[0] != (npy_intp) self->ancestor_matcher->num_sites) { + PyErr_SetString(PyExc_ValueError, "input match wrong size"); + goto out; + } + dims[0] = self->ancestor_matcher->num_sites; + left = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_UINT32); + right = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_UINT32); + parent = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_INT32); + if (left == NULL || right == NULL || parent == NULL) { + goto out; + } + + Py_BEGIN_ALLOW_THREADS + err = ancestor_matcher2_find_path(self->ancestor_matcher, + (tsk_id_t) start, (tsk_id_t) end, (allele_t *) PyArray_DATA(haplotype_array), + (allele_t *) PyArray_DATA(match_array), + &num_edges, PyArray_DATA(left), PyArray_DATA(right), PyArray_DATA(parent)); + Py_END_ALLOW_THREADS + if (err != 0) { + handle_library_error(err); + goto out; + } + ret = Py_BuildValue("(kOOO)", (unsigned long) num_edges, left, right, parent); + if (ret == NULL) { + goto out; + } + left = NULL; + right = NULL; + parent = NULL; +out: + Py_XDECREF(haplotype_array); + Py_XDECREF(match_array); + Py_XDECREF(left); + Py_XDECREF(right); + Py_XDECREF(parent); + return ret; +} + +static PyObject * +AncestorMatcher2_get_traceback(AncestorMatcher2 *self, PyObject *args) +{ + PyObject *ret = NULL; + unsigned long site; + node_state_list_t *list; + PyObject *dict = NULL; + PyObject *key = NULL; + PyObject *value = NULL; + int j; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + if (!PyArg_ParseTuple(args, "k", &site)) { + goto out; + } + if (site >= self->ancestor_matcher->num_sites) { + PyErr_SetString(PyExc_ValueError, "site out of range"); + goto out; + } + dict = PyDict_New(); + if (dict == NULL) { + goto out; + } + list = &self->ancestor_matcher->traceback[site]; + for (j = 0; j < list->size; j++) { + key = Py_BuildValue("k", (unsigned long) list->node[j]); + value = Py_BuildValue("i", (int) list->recombination_required[j]); + if (key == NULL || value == NULL) { + goto out; + } + if (PyDict_SetItem(dict, key, value) != 0) { + goto out; + } + Py_DECREF(key); + key = NULL; + Py_DECREF(value); + value = NULL; + } + ret = dict; + dict = NULL; +out: + Py_XDECREF(key); + Py_XDECREF(value); + Py_XDECREF(dict); + return ret; +} + +static PyObject * +AncestorMatcher2_get_mean_traceback_size(AncestorMatcher2 *self, void *closure) +{ + PyObject *ret = NULL; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("d", ancestor_matcher2_get_mean_traceback_size( + self->ancestor_matcher)); +out: + return ret; +} + +static PyObject * +AncestorMatcher2_get_total_memory(AncestorMatcher2 *self, void *closure) +{ + PyObject *ret = NULL; + + if (AncestorMatcher2_check_state(self) != 0) { + goto out; + } + ret = Py_BuildValue("k", (unsigned long) + ancestor_matcher2_get_total_memory(self->ancestor_matcher)); +out: + return ret; +} + + +static PyMemberDef AncestorMatcher2_members[] = { + {NULL} /* Sentinel */ + +}; + +static PyGetSetDef AncestorMatcher2_getsetters[] = { + {"mean_traceback_size", (getter) AncestorMatcher2_get_mean_traceback_size, + NULL, "The mean size of the traceback per site."}, + {"total_memory", (getter) AncestorMatcher2_get_total_memory, + NULL, "The total amount of memory used by this matcher."}, + {NULL} /* Sentinel */ +}; + +static PyMethodDef AncestorMatcher2_methods[] = { + {"find_path", (PyCFunction) AncestorMatcher2_find_path, + METH_VARARGS|METH_KEYWORDS, + "Returns a best match path for the specified haplotype through the ancestors."}, + {"get_traceback", (PyCFunction) AncestorMatcher2_get_traceback, + METH_VARARGS, "Returns the traceback likelihood dictionary at the specified site."}, + {NULL} /* Sentinel */ +}; + +static PyTypeObject AncestorMatcher2Type = { + PyVarObject_HEAD_INIT(NULL, 0) + "_tsinfer.AncestorMatcher2", /* tp_name */ + sizeof(AncestorMatcher2), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)AncestorMatcher2_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_reserved */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT, /* tp_flags */ + "AncestorMatcher2 objects", /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + AncestorMatcher2_methods, /* tp_methods */ + AncestorMatcher2_members, /* tp_members */ + AncestorMatcher2_getsetters, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)AncestorMatcher2_init, /* tp_init */ +}; + + /*=================================================================== * Module level code. *=================================================================== @@ -1787,6 +2122,7 @@ init_tsinfer(void) } Py_INCREF(&AncestorBuilderType); PyModule_AddObject(module, "AncestorBuilder", (PyObject *) &AncestorBuilderType); + /* AncestorMatcher type */ AncestorMatcherType.tp_new = PyType_GenericNew; if (PyType_Ready(&AncestorMatcherType) < 0) { @@ -1794,6 +2130,7 @@ init_tsinfer(void) } Py_INCREF(&AncestorMatcherType); PyModule_AddObject(module, "AncestorMatcher", (PyObject *) &AncestorMatcherType); + /* TreeSequenceBuilder type */ TreeSequenceBuilderType.tp_new = PyType_GenericNew; if (PyType_Ready(&TreeSequenceBuilderType) < 0) { @@ -1802,7 +2139,6 @@ init_tsinfer(void) Py_INCREF(&TreeSequenceBuilderType); PyModule_AddObject(module, "TreeSequenceBuilder", (PyObject *) &TreeSequenceBuilderType); - /* MatcherIndexes type */ MatcherIndexesType.tp_new = PyType_GenericNew; if (PyType_Ready(&MatcherIndexesType) < 0) { @@ -1811,6 +2147,14 @@ init_tsinfer(void) Py_INCREF(&MatcherIndexesType); PyModule_AddObject(module, "MatcherIndexes", (PyObject *) &MatcherIndexesType); + /* AncestorMatcher2 type */ + AncestorMatcher2Type.tp_new = PyType_GenericNew; + if (PyType_Ready(&AncestorMatcher2Type) < 0) { + INITERROR; + } + Py_INCREF(&AncestorMatcher2Type); + PyModule_AddObject(module, "AncestorMatcher2", (PyObject *) &AncestorMatcher2Type); + TsinfLibraryError = PyErr_NewException("_tsinfer.LibraryError", NULL, NULL); Py_INCREF(TsinfLibraryError); PyModule_AddObject(module, "LibraryError", TsinfLibraryError); diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index 289caa92..ead1c1d0 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -72,20 +72,17 @@ class AncestorMatcher: def __init__( self, matcher_indexes, - # recombination=None, - # mismatch=None, - # precision=None, + recombination=None, + mismatch=None, + precision=None, extended_checks=False, ): self.matcher_indexes = matcher_indexes self.num_sites = matcher_indexes.num_sites self.num_nodes = matcher_indexes.num_nodes - self.recombination = np.zeros(self.num_sites) + 1e-9 - self.mismatch = np.zeros(self.num_sites) - # self.mismatch = mismatch - # self.recombination = recombination - # self.precision = precision - self.precision = 14 + self.mismatch = mismatch + self.recombination = recombination + self.precision = 22 self.extended_checks = extended_checks self.parent = None self.left_child = None @@ -513,11 +510,40 @@ def run_traceback(self, start, end, match): def run_match(ts, h): + h = h.astype(np.int8) assert len(h) == ts.num_sites + recombination = np.zeros(ts.num_sites) + 1e-9 + mismatch = np.zeros(ts.num_sites) + precision = 22 matcher_indexes = MatcherIndexes(ts.tables) - matcher = AncestorMatcher(matcher_indexes) - match = np.zeros(ts.num_sites, dtype=int) + matcher = AncestorMatcher( + matcher_indexes, + recombination=recombination, + mismatch=mismatch, + precision=precision, + ) + match = np.zeros(ts.num_sites, dtype=np.int8) left, right, parent = matcher.find_path(h, 0, ts.num_sites, match) + + tables = ts.dump_tables() + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + mi = _tsinfer.MatcherIndexes(ll_tables) + match_c = np.zeros(ts.num_sites, dtype=np.int8) + am = _tsinfer.AncestorMatcher2( + mi, recombination=recombination, mismatch=mismatch, precision=precision + ) + path_len, left_c, right_c, parent_c = am.find_path(h, 0, ts.num_sites, match_c) + left_c = left_c[:path_len] + right_c = right_c[:path_len] + parent_c = parent_c[:path_len] + + assert path_len == len(left) + np.testing.assert_array_equal(left, left_c) + np.testing.assert_array_equal(right, right_c) + np.testing.assert_array_equal(parent, parent_c) + np.testing.assert_array_equal(match, match_c) + return left, right, parent, match From abd8dd01071eac1ffdde72e56a106dad727f00eb Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 3 May 2023 23:05:05 +0100 Subject: [PATCH 24/42] Basic high-level infrastructure for matching --- tests/test_ls_hmm.py | 18 ++++++------------ tsinfer/__init__.py | 1 + tsinfer/matching.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 12 deletions(-) create mode 100644 tsinfer/matching.py diff --git a/tests/test_ls_hmm.py b/tests/test_ls_hmm.py index ead1c1d0..90037339 100644 --- a/tests/test_ls_hmm.py +++ b/tests/test_ls_hmm.py @@ -10,6 +10,7 @@ import tskit import _tsinfer +import tsinfer @dataclasses.dataclass @@ -525,10 +526,11 @@ def run_match(ts, h): match = np.zeros(ts.num_sites, dtype=np.int8) left, right, parent = matcher.find_path(h, 0, ts.num_sites, match) - tables = ts.dump_tables() - ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) - ll_tables.fromdict(tables.asdict()) - mi = _tsinfer.MatcherIndexes(ll_tables) + # tables = ts.dump_tables() + # ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + # ll_tables.fromdict(tables.asdict()) + # mi = _tsinfer.MatcherIndexes(ll_tables) + mi = tsinfer.MatcherIndexes(ts) match_c = np.zeros(ts.num_sites, dtype=np.int8) am = _tsinfer.AncestorMatcher2( mi, recombination=recombination, mismatch=mismatch, precision=precision @@ -701,11 +703,3 @@ def test_switch_each_sample(self): assert list(right) == [4, 3, 2, 1] assert list(parent) == [4, 3, 2, 1] np.testing.assert_array_equal(h, match) - - def test_matcher_indexes(self): - ts = self.ts() - tables = ts.dump_tables() - ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) - ll_tables.fromdict(tables.asdict()) - mi = _tsinfer.MatcherIndexes(ll_tables) - print(mi) diff --git a/tsinfer/__init__.py b/tsinfer/__init__.py index aa9d27ba..56dd631b 100644 --- a/tsinfer/__init__.py +++ b/tsinfer/__init__.py @@ -39,4 +39,5 @@ from .eval_util import * # NOQA from .exceptions import * # NOQA from .constants import * # NOQA +from .matching import MatcherIndexes # NOQA from .cli import get_cli_parser # NOQA diff --git a/tsinfer/matching.py b/tsinfer/matching.py new file mode 100644 index 00000000..0e82176f --- /dev/null +++ b/tsinfer/matching.py @@ -0,0 +1,31 @@ +# +# Copyright (C) 2018-2023 University of Oxford +# +# This file is part of tsinfer. +# +# tsinfer is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# tsinfer is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with tsinfer. If not, see . +# +import _tsinfer + + +class MatcherIndexes(_tsinfer.MatcherIndexes): + def __init__(self, ts): + # TODO make this polymorphic to accept tables as well + tables = ts.dump_tables() + ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) + ll_tables.fromdict(tables.asdict()) + super().__init__(ll_tables) + + +# TODO add the high-level classes fronting the other class From 52d20e69bfd0497af70a27939b2ab0bb05d98c82 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 4 May 2023 21:54:44 +0100 Subject: [PATCH 25/42] Rename file --- tests/{test_ls_hmm.py => test_matching.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{test_ls_hmm.py => test_matching.py} (100%) diff --git a/tests/test_ls_hmm.py b/tests/test_matching.py similarity index 100% rename from tests/test_ls_hmm.py rename to tests/test_matching.py From ea9bcd33560fb6afc4f8ba5a969f4c9cb710235a Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 4 May 2023 22:37:20 +0100 Subject: [PATCH 26/42] Refactor the Matcher infrastructure --- _tsinfermodule.c | 88 +++++++++++------------------------------- tests/test_matching.py | 86 ++++++++++++++++------------------------- tsinfer/__init__.py | 2 +- tsinfer/matching.py | 42 +++++++++++++++++++- 4 files changed, 97 insertions(+), 121 deletions(-) diff --git a/_tsinfermodule.c b/_tsinfermodule.c index 4be7012f..25ae80dd 100644 --- a/_tsinfermodule.c +++ b/_tsinfermodule.c @@ -1713,7 +1713,6 @@ static PyMethodDef MatcherIndexes_methods[] = { {NULL} /* Sentinel */ }; - static PyTypeObject MatcherIndexesType = { // clang-format off PyVarObject_HEAD_INIT(NULL, 0) @@ -1841,24 +1840,23 @@ AncestorMatcher2_find_path(AncestorMatcher2 *self, PyObject *args, PyObject *kwd { int err; PyObject *ret = NULL; - static char *kwlist[] = {"haplotype", "start", "end", "match", NULL}; + static char *kwlist[] = {"haplotype", "start", "end", NULL}; PyObject *haplotype = NULL; PyArrayObject *haplotype_array = NULL; - PyObject *match = NULL; - PyArrayObject *match_array = NULL; npy_intp *shape; size_t num_edges; int start, end; PyArrayObject *left = NULL; PyArrayObject *right = NULL; PyArrayObject *parent = NULL; + PyArrayObject *match = NULL; npy_intp dims[1]; if (AncestorMatcher2_check_state(self) != 0) { goto out; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OiiO!", kwlist, - &haplotype, &start, &end, &PyArray_Type, &match)) { + if (!PyArg_ParseTupleAndKeywords(args, kwds, "Oii", kwlist, + &haplotype, &start, &end)) { goto out; } haplotype_array = (PyArrayObject *) PyArray_FROM_OTF(haplotype, NPY_INT8, @@ -1876,48 +1874,35 @@ AncestorMatcher2_find_path(AncestorMatcher2 *self, PyObject *args, PyObject *kwd goto out; } - match_array = (PyArrayObject *) PyArray_FROM_OTF(match, NPY_INT8, - NPY_ARRAY_INOUT_ARRAY); - if (match_array == NULL) { - goto out; - } - if (PyArray_NDIM(match_array) != 1) { - PyErr_SetString(PyExc_ValueError, "Dim != 1"); - goto out; - } - shape = PyArray_DIMS(match_array); - if (shape[0] != (npy_intp) self->ancestor_matcher->num_sites) { - PyErr_SetString(PyExc_ValueError, "input match wrong size"); - goto out; - } dims[0] = self->ancestor_matcher->num_sites; left = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_UINT32); right = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_UINT32); parent = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_INT32); - if (left == NULL || right == NULL || parent == NULL) { + match = (PyArrayObject *) PyArray_SimpleNew(1, dims, NPY_INT8); + if (left == NULL || right == NULL || parent == NULL || match == NULL) { goto out; } Py_BEGIN_ALLOW_THREADS err = ancestor_matcher2_find_path(self->ancestor_matcher, (tsk_id_t) start, (tsk_id_t) end, (allele_t *) PyArray_DATA(haplotype_array), - (allele_t *) PyArray_DATA(match_array), + (allele_t *) PyArray_DATA(match), &num_edges, PyArray_DATA(left), PyArray_DATA(right), PyArray_DATA(parent)); Py_END_ALLOW_THREADS if (err != 0) { handle_library_error(err); goto out; } - ret = Py_BuildValue("(kOOO)", (unsigned long) num_edges, left, right, parent); + ret = Py_BuildValue("(kOOOO)", (unsigned long) num_edges, left, right, parent, match); if (ret == NULL) { goto out; } left = NULL; right = NULL; parent = NULL; + match = NULL; out: - Py_XDECREF(haplotype_array); - Py_XDECREF(match_array); + Py_XDECREF(match); Py_XDECREF(left); Py_XDECREF(right); Py_XDECREF(parent); @@ -2002,11 +1987,6 @@ AncestorMatcher2_get_total_memory(AncestorMatcher2 *self, void *closure) } -static PyMemberDef AncestorMatcher2_members[] = { - {NULL} /* Sentinel */ - -}; - static PyGetSetDef AncestorMatcher2_getsetters[] = { {"mean_traceback_size", (getter) AncestorMatcher2_get_mean_traceback_size, NULL, "The mean size of the traceback per site."}, @@ -2024,46 +2004,22 @@ static PyMethodDef AncestorMatcher2_methods[] = { {NULL} /* Sentinel */ }; + static PyTypeObject AncestorMatcher2Type = { + // clang-format off PyVarObject_HEAD_INIT(NULL, 0) - "_tsinfer.AncestorMatcher2", /* tp_name */ - sizeof(AncestorMatcher2), /* tp_basicsize */ - 0, /* tp_itemsize */ - (destructor)AncestorMatcher2_dealloc, /* tp_dealloc */ - 0, /* tp_print */ - 0, /* tp_getattr */ - 0, /* tp_setattr */ - 0, /* tp_reserved */ - 0, /* tp_repr */ - 0, /* tp_as_number */ - 0, /* tp_as_sequence */ - 0, /* tp_as_mapping */ - 0, /* tp_hash */ - 0, /* tp_call */ - 0, /* tp_str */ - 0, /* tp_getattro */ - 0, /* tp_setattro */ - 0, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ - "AncestorMatcher2 objects", /* tp_doc */ - 0, /* tp_traverse */ - 0, /* tp_clear */ - 0, /* tp_richcompare */ - 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ - 0, /* tp_iternext */ - AncestorMatcher2_methods, /* tp_methods */ - AncestorMatcher2_members, /* tp_members */ - AncestorMatcher2_getsetters, /* tp_getset */ - 0, /* tp_base */ - 0, /* tp_dict */ - 0, /* tp_descr_get */ - 0, /* tp_descr_set */ - 0, /* tp_dictoffset */ - (initproc)AncestorMatcher2_init, /* tp_init */ + .tp_name = "_tsinfer.AncestorMatcher2", + .tp_basicsize = sizeof(AncestorMatcher2), + .tp_dealloc = (destructor) AncestorMatcher2_dealloc, + .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, + .tp_doc = "AncestorMatcher2 objects", + .tp_methods = AncestorMatcher2_methods, + .tp_getset = AncestorMatcher2_getsetters, + .tp_init = (initproc) AncestorMatcher2_init, + .tp_new = PyType_GenericNew, + // clang-format on }; - /*=================================================================== * Module level code. *=================================================================== diff --git a/tests/test_matching.py b/tests/test_matching.py index 90037339..412aa9a1 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -286,7 +286,7 @@ def insert_edge(self, edge): def is_nonzero_root(self, u): return u != 0 and self.is_root(u) and self.left_child[u] == -1 - def find_path(self, h, start, end, match): + def find_path(self, h, start, end): Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index M = len(Il) @@ -419,9 +419,9 @@ def find_path(self, h, start, end, match): if k < M: right = min(right, Ir[k].right) - return self.run_traceback(start, end, match) + return self.run_traceback(start, end) - def run_traceback(self, start, end, match): + def run_traceback(self, start, end): Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index M = len(Il) @@ -434,7 +434,7 @@ def run_traceback(self, start, end, match): j = M - 1 k = M - 1 # Construct the matched haplotype - match[:] = 0 + match = np.zeros(self.num_sites, dtype=np.int8) match[:start] = tskit.MISSING_DATA match[end:] = tskit.MISSING_DATA # Reset the tree. @@ -507,7 +507,8 @@ def run_traceback(self, start, end, match): right[j] = e.right parent[j] = e.parent - return left, right, parent + path = tsinfer.matching.Path(left, right, parent) + return tsinfer.matching.Match(path, match) def run_match(ts, h): @@ -523,30 +524,17 @@ def run_match(ts, h): mismatch=mismatch, precision=precision, ) - match = np.zeros(ts.num_sites, dtype=np.int8) - left, right, parent = matcher.find_path(h, 0, ts.num_sites, match) + match_py = matcher.find_path(h, 0, ts.num_sites) - # tables = ts.dump_tables() - # ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) - # ll_tables.fromdict(tables.asdict()) - # mi = _tsinfer.MatcherIndexes(ll_tables) mi = tsinfer.MatcherIndexes(ts) - match_c = np.zeros(ts.num_sites, dtype=np.int8) - am = _tsinfer.AncestorMatcher2( + am = tsinfer.AncestorMatcher2( mi, recombination=recombination, mismatch=mismatch, precision=precision ) - path_len, left_c, right_c, parent_c = am.find_path(h, 0, ts.num_sites, match_c) - left_c = left_c[:path_len] - right_c = right_c[:path_len] - parent_c = parent_c[:path_len] + match_c = am.find_match(h, 0, ts.num_sites) - assert path_len == len(left) - np.testing.assert_array_equal(left, left_c) - np.testing.assert_array_equal(right, right_c) - np.testing.assert_array_equal(parent, parent_c) - np.testing.assert_array_equal(match, match_c) + match_py.assert_equals(match_c) - return left, right, parent, match + return match_py # TODO the tests on these two classes are the same right now, should @@ -611,22 +599,20 @@ def test_match_sample(self, j): ts = self.ts() h = np.zeros(4) h[j] = 1 - left, right, parent, match = run_match(ts, h) - assert list(left) == [0] - assert list(right) == [4] - assert list(parent) == [ts.samples()[j]] - np.testing.assert_array_equal(h, match) + m = run_match(ts, h) + assert list(m.path.left) == [0] + assert list(m.path.right) == [4] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) def test_switch_each_sample(self): ts = self.ts() - m = 4 - h = np.zeros(m) - h[:] = 1 - left, right, parent, match = run_match(ts, h) - assert list(left) == [3, 2, 1, 0] - assert list(right) == [4, 3, 2, 1] - assert list(parent) == [4, 3, 2, 1] - np.testing.assert_array_equal(h, match) + h = np.ones(4) + m = run_match(ts, h) + assert list(m.path.left) == [3, 2, 1, 0] + assert list(m.path.right) == [4, 3, 2, 1] + assert list(m.path.parent) == [4, 3, 2, 1] + np.testing.assert_array_equal(h, m.matched_haplotype) class TestMultiTreeExample: @@ -683,23 +669,19 @@ def ts(): @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): ts = self.ts() - ts.dump("multi_tree_example.trees") - m = 4 - h = np.zeros(m) + h = np.zeros(4) h[j] = 1 - left, right, parent, match = run_match(self.ts(), h) - assert list(left) == [0] - assert list(right) == [m] - assert list(parent) == [ts.samples()[j]] - np.testing.assert_array_equal(h, match) + m = run_match(self.ts(), h) + assert list(m.path.left) == [0] + assert list(m.path.right) == [4] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) def test_switch_each_sample(self): ts = self.ts() - m = 4 - h = np.zeros(m) - h[:] = 1 - left, right, parent, match = run_match(ts, h) - assert list(left) == [3, 2, 1, 0] - assert list(right) == [4, 3, 2, 1] - assert list(parent) == [4, 3, 2, 1] - np.testing.assert_array_equal(h, match) + h = np.ones(4) + m = run_match(ts, h) + assert list(m.path.left) == [3, 2, 1, 0] + assert list(m.path.right) == [4, 3, 2, 1] + assert list(m.path.parent) == [4, 3, 2, 1] + np.testing.assert_array_equal(h, m.matched_haplotype) diff --git a/tsinfer/__init__.py b/tsinfer/__init__.py index 56dd631b..e96f36c1 100644 --- a/tsinfer/__init__.py +++ b/tsinfer/__init__.py @@ -39,5 +39,5 @@ from .eval_util import * # NOQA from .exceptions import * # NOQA from .constants import * # NOQA -from .matching import MatcherIndexes # NOQA +from .matching import MatcherIndexes, AncestorMatcher2 # NOQA from .cli import get_cli_parser # NOQA diff --git a/tsinfer/matching.py b/tsinfer/matching.py index 0e82176f..48c1418e 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -1,5 +1,5 @@ # -# Copyright (C) 2018-2023 University of Oxford +# Copyright (C) 2023 University of Oxford # # This file is part of tsinfer. # @@ -16,6 +16,10 @@ # You should have received a copy of the GNU General Public License # along with tsinfer. If not, see . # +import dataclasses + +import numpy as np + import _tsinfer @@ -28,4 +32,38 @@ def __init__(self, ts): super().__init__(ll_tables) -# TODO add the high-level classes fronting the other class +@dataclasses.dataclass +class Path: + left: np.ndarray + right: np.ndarray + parent: np.ndarray + + def __len__(self): + return len(self.left) + + def assert_equals(self, other): + np.testing.assert_array_equal(self.left, other.left) + np.testing.assert_array_equal(self.right, other.right) + np.testing.assert_array_equal(self.parent, other.parent) + + +@dataclasses.dataclass +class Match: + path: Path + matched_haplotype: np.ndarray + + def assert_equals(self, other): + self.path.assert_equals(other.path) + np.testing.assert_array_equal(self.matched_haplotype, other.matched_haplotype) + + +class AncestorMatcher2(_tsinfer.AncestorMatcher2): + def find_match(self, h, left, right): + path_len, left, right, parent, matched_haplotype = self.find_path( + h, left, right + ) + + left = left[:path_len] + right = right[:path_len] + parent = parent[:path_len] + return Match(Path(left, right, parent), matched_haplotype) From 864d22746df1b678e17941cbd4cb4f5dbc59ac51 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 4 May 2023 22:50:08 +0100 Subject: [PATCH 27/42] Improve class infrastructure --- tests/test_matching.py | 27 +++++++++++++++++++-------- tsinfer/matching.py | 7 +++---- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 412aa9a1..1896380b 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -4,6 +4,7 @@ import collections import dataclasses import io +import pickle import numpy as np import pytest @@ -11,6 +12,7 @@ import _tsinfer import tsinfer +from tsinfer import matching @dataclasses.dataclass @@ -507,8 +509,8 @@ def run_traceback(self, start, end): right[j] = e.right parent[j] = e.parent - path = tsinfer.matching.Path(left, right, parent) - return tsinfer.matching.Match(path, match) + path = matching.Path(left[::-1], right[::-1], parent[::-1]) + return matching.Match(path, match) def run_match(ts, h): @@ -609,9 +611,9 @@ def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) m = run_match(ts, h) - assert list(m.path.left) == [3, 2, 1, 0] - assert list(m.path.right) == [4, 3, 2, 1] - assert list(m.path.parent) == [4, 3, 2, 1] + assert list(m.path.left) == [0, 1, 2, 3] + assert list(m.path.right) == [1, 2, 3, 4] + assert list(m.path.parent) == [1, 2, 3, 4] np.testing.assert_array_equal(h, m.matched_haplotype) @@ -681,7 +683,16 @@ def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) m = run_match(ts, h) - assert list(m.path.left) == [3, 2, 1, 0] - assert list(m.path.right) == [4, 3, 2, 1] - assert list(m.path.parent) == [4, 3, 2, 1] + assert list(m.path.left) == [0, 1, 2, 3] + assert list(m.path.right) == [1, 2, 3, 4] + assert list(m.path.parent) == [1, 2, 3, 4] np.testing.assert_array_equal(h, m.matched_haplotype) + + +class TestMatchClassUtils: + def test_pickle(self): + m1 = matching.Match( + matching.Path(np.array([0]), np.array([1]), np.array([0])), np.array([0]) + ) + m2 = pickle.loads(pickle.dumps(m1)) + m1.assert_equals(m2) diff --git a/tsinfer/matching.py b/tsinfer/matching.py index 48c1418e..3b5891b2 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -62,8 +62,7 @@ def find_match(self, h, left, right): path_len, left, right, parent, matched_haplotype = self.find_path( h, left, right ) - - left = left[:path_len] - right = right[:path_len] - parent = parent[:path_len] + left = left[:path_len][::-1] + right = right[:path_len][::-1] + parent = parent[:path_len][::-1] return Match(Path(left, right, parent), matched_haplotype) From 6aa12ef8b9ba086e4cbd1bcd578fd120d313ca9c Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Thu, 4 May 2023 23:09:30 +0100 Subject: [PATCH 28/42] Add vestigial root automatically --- tests/test_matching.py | 17 +++++++++++------ tsinfer/matching.py | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 1896380b..62aad250 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -43,7 +43,10 @@ class MatcherIndexes: The memory that can be shared between AncestorMatcher instances. """ - def __init__(self, tables): + def __init__(self, in_tables): + ts = add_vestigial_root(in_tables.tree_sequence()) + tables = ts.dump_tables() + self.num_nodes = len(tables.nodes) self.num_sites = len(tables.sites) @@ -509,6 +512,8 @@ def run_traceback(self, start, end): right[j] = e.right parent[j] = e.parent + # Convert the parent node IDs back to original values + parent -= 1 path = matching.Path(left[::-1], right[::-1], parent[::-1]) return matching.Match(path, match) @@ -527,13 +532,11 @@ def run_match(ts, h): precision=precision, ) match_py = matcher.find_path(h, 0, ts.num_sites) - mi = tsinfer.MatcherIndexes(ts) am = tsinfer.AncestorMatcher2( mi, recombination=recombination, mismatch=mismatch, precision=precision ) match_c = am.find_match(h, 0, ts.num_sites) - match_py.assert_equals(match_c) return match_py @@ -578,6 +581,7 @@ def example_binary(n, L): class TestSingleBalancedTreeExample: + # FIXME remove root # 4.00┊ 0 ┊ # ┊ ┃ ┊ # 3.00┊ 7 ┊ @@ -589,11 +593,12 @@ class TestSingleBalancedTreeExample: @staticmethod def ts(): - tables = example_binary(4, 4).dump_tables() + # tables = example_binary(4, 4).dump_tables() + tables = tskit.Tree.generate_balanced(4, span=4).tree_sequence.dump_tables() # Add a site for each sample with a single mutation above that sample. for j in range(4): tables.sites.add_row(j, "0") - tables.mutations.add_row(site=j, node=1 + j, derived_state="1") + tables.mutations.add_row(site=j, node=j, derived_state="1") return tables.tree_sequence() @pytest.mark.parametrize("j", [0, 1, 2, 3]) @@ -613,7 +618,7 @@ def test_switch_each_sample(self): m = run_match(ts, h) assert list(m.path.left) == [0, 1, 2, 3] assert list(m.path.right) == [1, 2, 3, 4] - assert list(m.path.parent) == [1, 2, 3, 4] + assert list(m.path.parent) == [0, 1, 2, 3] np.testing.assert_array_equal(h, m.matched_haplotype) diff --git a/tsinfer/matching.py b/tsinfer/matching.py index 3b5891b2..ec09e71a 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -23,9 +23,42 @@ import _tsinfer +def add_vestigial_root(ts): + """ + Adds the nodes and edges required by tsinfer to the specified tree sequence + and returns it. + """ + if not ts.discrete_genome: + raise ValueError("Only discrete genome coords supported") + + base_tables = ts.dump_tables() + tables = base_tables.copy() + tables.nodes.clear() + t = ts.max_root_time + tables.nodes.add_row(time=t + 1) + num_additonal_nodes = len(tables.nodes) + tables.mutations.node += num_additonal_nodes + tables.edges.child += num_additonal_nodes + tables.edges.parent += num_additonal_nodes + for node in base_tables.nodes: + tables.nodes.append(node) + for tree in ts.trees(): + root = tree.root + num_additonal_nodes + tables.edges.add_row( + tree.interval.left, tree.interval.right, parent=0, child=root + ) + tables.edges.squash() + # FIXME probably don't need to sort here most of the time, or at least we + # can just sort almost the end of the table. + tables.sort() + return tables.tree_sequence() + + class MatcherIndexes(_tsinfer.MatcherIndexes): def __init__(self, ts): # TODO make this polymorphic to accept tables as well + # This is very wasteful, but we can do better if it all basically works. + ts = add_vestigial_root(ts) tables = ts.dump_tables() ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) ll_tables.fromdict(tables.asdict()) @@ -65,4 +98,6 @@ def find_match(self, h, left, right): left = left[:path_len][::-1] right = right[:path_len][::-1] parent = parent[:path_len][::-1] + # We added a 0-root everywhere above, so convert node IDs back + parent -= 1 return Match(Path(left, right, parent), matched_haplotype) From 62c534ca4f11578b5e443f18c56c03a7b04a9220 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 5 May 2023 21:59:24 +0100 Subject: [PATCH 29/42] Fix up tests to remove hard-coded virtual root --- lib/tests/tests.c | 21 ++++++++++- tests/test_matching.py | 82 +++++++++++------------------------------- 2 files changed, 40 insertions(+), 63 deletions(-) diff --git a/lib/tests/tests.c b/lib/tests/tests.c index 0421123f..72ee75b5 100644 --- a/lib/tests/tests.c +++ b/lib/tests/tests.c @@ -41,9 +41,28 @@ char *_tmp_file_name; FILE *_devnull; -/* FIXME add drawings and descriptions of the trees here */ tsk_treeseq_t _single_tree_ex_ts; +/* 3.00┊ 0 ┊ */ +/* ┊ ┃ ┊ */ +/* 2.00┊ 7 ┊ */ +/* ┊ ┏━┻━┓ ┊ */ +/* 1.00┊ 5 6 ┊ */ +/* ┊ ┏┻┓ ┏┻┓ ┊ */ +/* 0.00┊ 1 2 3 4 ┊ */ +/* 0 4 */ tsk_treeseq_t _multi_tree_ex_ts; +/* 1.84┊ 0 ┊ 0 ┊ */ +/* ┊ ┃ ┊ ┃ ┊ */ +/* 0.84┊ 8 ┊ 8 ┊ */ +/* ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ */ +/* 0.42┊ ┃ ┃ ┊ 7 ┃ ┊ */ +/* ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ */ +/* 0.05┊ 6 ┃ ┊ ┃ ┃ ┃ ┊ */ +/* ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ */ +/* 0.04┊ ┃ 5 ┃ ┊ ┃ ┃ 5 ┊ */ +/* ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ */ +/* 0.00┊ 1 2 3 4 ┊ 1 4 2 3 ┊ */ +/* 0 2 4 */ static void dump_tree_sequence_builder( diff --git a/tests/test_matching.py b/tests/test_matching.py index 62aad250..02e2debf 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -44,7 +44,7 @@ class MatcherIndexes: """ def __init__(self, in_tables): - ts = add_vestigial_root(in_tables.tree_sequence()) + ts = matching.add_vestigial_root(in_tables.tree_sequence()) tables = ts.dump_tables() self.num_nodes = len(tables.nodes) @@ -546,54 +546,16 @@ def run_match(ts, h): # refactor. -def add_vestigial_root(ts): - """ - Adds the nodes and edges required by tsinfer to the specified tree sequence - and returns it. - """ - if not ts.discrete_genome: - raise ValueError("Only discrete genome coords supported") - - base_tables = ts.dump_tables() - tables = base_tables.copy() - tables.nodes.clear() - t = ts.max_root_time - tables.nodes.add_row(time=t + 1) - num_additonal_nodes = len(tables.nodes) - tables.mutations.node += num_additonal_nodes - tables.edges.child += num_additonal_nodes - tables.edges.parent += num_additonal_nodes - for node in base_tables.nodes: - tables.nodes.append(node) - for tree in ts.trees(): - root = tree.root + num_additonal_nodes - tables.edges.add_row( - tree.interval.left, tree.interval.right, parent=0, child=root - ) - tables.edges.squash() - tables.sort() - return tables.tree_sequence() - - -def example_binary(n, L): - tree = tskit.Tree.generate_balanced(n, span=L) - return add_vestigial_root(tree.tree_sequence) - - class TestSingleBalancedTreeExample: - # FIXME remove root - # 4.00┊ 0 ┊ - # ┊ ┃ ┊ - # 3.00┊ 7 ┊ + # 3.00┊ 6 ┊ # ┊ ┏━┻━┓ ┊ - # 2.00┊ 5 6 ┊ + # 2.00┊ 4 5 ┊ # ┊ ┏┻┓ ┏┻┓ ┊ - # 1.00┊ 1 2 3 4 ┊ + # 1.00┊ 0 1 2 3 ┊ # 0 4 @staticmethod def ts(): - # tables = example_binary(4, 4).dump_tables() tables = tskit.Tree.generate_balanced(4, span=4).tree_sequence.dump_tables() # Add a site for each sample with a single mutation above that sample. for j in range(4): @@ -623,23 +585,20 @@ def test_switch_each_sample(self): class TestMultiTreeExample: - # 1.84┊ 0 ┊ 0 ┊ - # ┊ ┃ ┊ ┃ ┊ - # 0.84┊ 8 ┊ 8 ┊ + # 0.84┊ 7 ┊ 7 ┊ # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ - # 0.42┊ ┃ ┃ ┊ 7 ┃ ┊ + # 0.42┊ ┃ ┃ ┊ 6 ┃ ┊ # ┊ ┃ ┃ ┊ ┏┻┓ ┃ ┊ - # 0.05┊ 6 ┃ ┊ ┃ ┃ ┃ ┊ + # 0.05┊ 5 ┃ ┊ ┃ ┃ ┃ ┊ # ┊ ┏━┻┓ ┃ ┊ ┃ ┃ ┃ ┊ - # 0.04┊ ┃ 5 ┃ ┊ ┃ ┃ 5 ┊ + # 0.04┊ ┃ 4 ┃ ┊ ┃ ┃ 4 ┊ # ┊ ┃ ┏┻┓ ┃ ┊ ┃ ┃ ┏┻┓ ┊ - # 0.00┊ 1 2 3 4 ┊ 1 4 2 3 ┊ + # 0.00┊ 0 1 2 3 ┊ 0 3 1 2 ┊ # 0 2 4 @staticmethod def ts(): nodes = """\ is_sample time - 0 1.838075 1 0.000000 1 0.000000 1 0.000000 @@ -651,17 +610,16 @@ def ts(): """ edges = """\ left right parent child - 0.000000 4.000000 5 2 - 0.000000 4.000000 5 3 - 0.000000 2.000000 6 1 - 0.000000 2.000000 6 5 - 2.000000 4.000000 7 1 + 0.000000 4.000000 4 1 + 0.000000 4.000000 4 2 + 0.000000 2.000000 5 0 + 0.000000 2.000000 5 4 + 2.000000 4.000000 6 0 + 2.000000 4.000000 6 3 + 0.000000 2.000000 7 3 2.000000 4.000000 7 4 - 0.000000 2.000000 8 4 - 2.000000 4.000000 8 5 - 0.000000 2.000000 8 6 - 2.000000 4.000000 8 7 - 0.000000 4.000000 0 8 + 0.000000 2.000000 7 5 + 2.000000 4.000000 7 6 """ ts = tskit.load_text( nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False @@ -670,7 +628,7 @@ def ts(): # Add a site for each sample with a single mutation above that sample. for j in range(4): tables.sites.add_row(j, "0") - tables.mutations.add_row(site=j, node=1 + j, derived_state="1") + tables.mutations.add_row(site=j, node=j, derived_state="1") return tables.tree_sequence() @pytest.mark.parametrize("j", [0, 1, 2, 3]) @@ -690,7 +648,7 @@ def test_switch_each_sample(self): m = run_match(ts, h) assert list(m.path.left) == [0, 1, 2, 3] assert list(m.path.right) == [1, 2, 3, 4] - assert list(m.path.parent) == [1, 2, 3, 4] + assert list(m.path.parent) == [0, 1, 2, 3] np.testing.assert_array_equal(h, m.matched_haplotype) From 5c86a8414d1837851254d41a89f878c65de29984 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 5 May 2023 22:48:02 +0100 Subject: [PATCH 30/42] Work on making matcher work with edges not on site values --- tests/test_matching.py | 87 ++++++++++++++++++++++++++++++++---------- 1 file changed, 67 insertions(+), 20 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 02e2debf..aa89bb37 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -9,6 +9,7 @@ import numpy as np import pytest import tskit +import msprime import _tsinfer import tsinfer @@ -47,6 +48,7 @@ def __init__(self, in_tables): ts = matching.add_vestigial_root(in_tables.tree_sequence()) tables = ts.dump_tables() + self.sequence_length = tables.sequence_length self.num_nodes = len(tables.nodes) self.num_sites = len(tables.sites) @@ -60,6 +62,9 @@ def __init__(self, in_tables): # TODO fixme self.num_alleles = np.zeros(self.num_sites, dtype=int) + 2 + self.sites_position = np.zeros(ts.num_sites + 1, dtype=np.uint32) + self.sites_position[:-1] = tables.sites.position + self.sites_position[-1] = tables.sequence_length self.mutations = collections.defaultdict(list) last_site = -1 for mutation in tables.mutations: @@ -294,6 +299,8 @@ def is_nonzero_root(self, u): def find_path(self, h, start, end): Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index + L = self.matcher_indexes.sequence_length + sites_position = self.matcher_indexes.sites_position M = len(Il) n = self.num_nodes m = self.num_sites @@ -314,11 +321,13 @@ def find_path(self, h, start, end): j = 0 k = 0 left = 0 + start_pos = sites_position[start] + end_pos = sites_position[end] pos = 0 - right = m - if j < M and start < Il[j].left: + right = L + if j < M and start_pos < Il[j].left: right = Il[j].left - while j < M and k < M and Il[j].left <= start: + while j < M and k < M and Il[j].left <= start_pos: while Ir[k].right == pos: self.remove_edge(Ir[k]) k += 1 @@ -326,7 +335,7 @@ def find_path(self, h, start, end): self.insert_edge(Il[j]) j += 1 left = pos - right = m + right = L if j < M: right = min(right, Il[j].left) if k < M: @@ -345,13 +354,17 @@ def find_path(self, h, start, end): self.likelihood_nodes.append(last_root) self.likelihood[last_root] = 1 + current_site = 0 + while sites_position[current_site] < left: + current_site += 1 + remove_start = k - while left < end: + while left < end_pos: # print("START OF TREE LOOP", left, right) # print("L:", {u: self.likelihood[u] for u in self.likelihood_nodes}) assert left < right - for site_index in range(remove_start, k): - edge = Ir[site_index] + for e in range(remove_start, k): + edge = Ir[e] for u in [edge.parent, edge.child]: if self.is_nonzero_root(u): self.likelihood[u] = NONZERO_ROOT @@ -373,8 +386,11 @@ def find_path(self, h, start, end): if self.extended_checks: self.check_likelihoods() - for site in range(max(left, start), min(right, end)): - self.update_site(site, h[site]) + + while left <= sites_position[current_site] < min(right, end_pos): + # print("update site", left, current_site, sites_position[current_site], right) + self.update_site(current_site, h[current_site]) + current_site += 1 remove_start = k while k < M and Ir[k].right == right: @@ -399,8 +415,8 @@ def find_path(self, h, start, end): self.likelihood[edge.child] = L_child self.likelihood_nodes.append(edge.child) # Clear the L cache - for site_index in range(remove_start, k): - edge = Ir[site_index] + for e in range(remove_start, k): + edge = Ir[e] u = edge.parent while L_cache[u] != -1: L_cache[u] = -1 @@ -429,6 +445,8 @@ def find_path(self, h, start, end): def run_traceback(self, start, end): Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index + L = self.matcher_indexes.sequence_length + sites_position = self.matcher_indexes.sites_position M = len(Il) u = self.max_likelihood_node[end - 1] output_edge = Edge(right=end, parent=u) @@ -438,6 +456,8 @@ def run_traceback(self, start, end): # Now go back through the trees. j = M - 1 k = M - 1 + start_pos = sites_position[start] + end_pos = sites_position[end] # Construct the matched haplotype match = np.zeros(self.num_sites, dtype=np.int8) match[:start] = tskit.MISSING_DATA @@ -449,8 +469,13 @@ def run_traceback(self, start, end): self.left_sib[:] = -1 self.right_sib[:] = -1 - pos = self.num_sites - while pos > start: + print("FIXME part way through updating this to use site_index.") + pos = L + site_index = self.num_sites -1 + while sites_position[site_index] >= end_pos: + site_index -= 1 + + while pos > start_pos: # print("Top of loop: pos = ", pos) while k >= 0 and Il[k].left == pos: edge = Il[k] @@ -469,7 +494,11 @@ def run_traceback(self, start, end): pos = left assert left < right - for site_index in range(min(right, end) - 1, max(left, start) - 1, -1): + print(left, right, site_index, sites_position[site_index]) + while sites_position[site_index] >= max(left, start): + + # print("FIXME") + # for site_index in range(min(right, end) - 1, max(left, start) - 1, -1): u = output_edge.parent self.set_allelic_state(site_index) v = u @@ -494,6 +523,7 @@ def run_traceback(self, start, end): # Reset the nodes in the recombination tree. for u in self.traceback[site_index].keys(): recombination_required[u] = -1 + site_index -= 1 output_edge.left = start self.mean_traceback_size = sum(len(t) for t in self.traceback) / self.num_sites @@ -532,12 +562,12 @@ def run_match(ts, h): precision=precision, ) match_py = matcher.find_path(h, 0, ts.num_sites) - mi = tsinfer.MatcherIndexes(ts) - am = tsinfer.AncestorMatcher2( - mi, recombination=recombination, mismatch=mismatch, precision=precision - ) - match_c = am.find_match(h, 0, ts.num_sites) - match_py.assert_equals(match_c) + # mi = tsinfer.MatcherIndexes(ts) + # am = tsinfer.AncestorMatcher2( + # mi, recombination=recombination, mismatch=mismatch, precision=precision + # ) + # match_c = am.find_match(h, 0, ts.num_sites) + # match_py.assert_equals(match_c) return match_py @@ -651,6 +681,23 @@ def test_switch_each_sample(self): assert list(m.path.parent) == [0, 1, 2, 3] np.testing.assert_array_equal(h, m.matched_haplotype) +class TestSimulationExamples: + def check_exact_sample_matches(self, ts): + H = ts.genotype_matrix().T + for u, h in zip(ts.samples(), H): + print(u, h) + m = run_match(ts, h) + print(m) + break + + def test_single_tree_binary_mutations(self): + ts = msprime.sim_ancestry(5, sequence_length=100, random_seed=2) + ts = msprime.sim_mutations(ts, rate=0.01, model=msprime.BinaryMutationModel(), + random_seed=2) + print(ts.tables.sites) + print(ts.tables.mutations) + self.check_exact_sample_matches(ts) + class TestMatchClassUtils: def test_pickle(self): From 5ca07108869ae0a3b323adeb8741ab156eb985f1 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 6 May 2023 22:14:45 +0100 Subject: [PATCH 31/42] Roughtly working version with edges on genome coords --- tests/test_matching.py | 73 ++++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 39 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index aa89bb37..16c8b94a 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -6,13 +6,12 @@ import io import pickle +import msprime import numpy as np import pytest import tskit -import msprime import _tsinfer -import tsinfer from tsinfer import matching @@ -388,7 +387,6 @@ def find_path(self, h, start, end): self.check_likelihoods() while left <= sites_position[current_site] < min(right, end_pos): - # print("update site", left, current_site, sites_position[current_site], right) self.update_site(current_site, h[current_site]) current_site += 1 @@ -469,12 +467,8 @@ def run_traceback(self, start, end): self.left_sib[:] = -1 self.right_sib[:] = -1 - print("FIXME part way through updating this to use site_index.") pos = L - site_index = self.num_sites -1 - while sites_position[site_index] >= end_pos: - site_index -= 1 - + site_index = self.num_sites - 1 while pos > start_pos: # print("Top of loop: pos = ", pos) while k >= 0 and Il[k].left == pos: @@ -494,36 +488,35 @@ def run_traceback(self, start, end): pos = left assert left < right - print(left, right, site_index, sites_position[site_index]) - while sites_position[site_index] >= max(left, start): - - # print("FIXME") - # for site_index in range(min(right, end) - 1, max(left, start) - 1, -1): - u = output_edge.parent - self.set_allelic_state(site_index) - v = u - while self.allelic_state[v] == -1: - v = self.parent[v] - match[site_index] = self.allelic_state[v] - self.unset_allelic_state(site_index) - - for u, recombine in self.traceback[site_index].items(): - # Mark the traceback nodes on the tree. - recombination_required[u] = recombine - # Now traverse up the tree from the current node. The first marked node - # we meet tells us whether we need to recombine. - u = output_edge.parent - while u != 0 and recombination_required[u] == -1: - u = self.parent[u] - if recombination_required[u] and site_index > start: - output_edge.left = site_index - u = self.max_likelihood_node[site_index - 1] - output_edge = Edge(right=site_index, parent=u) - output_edges.append(output_edge) - # Reset the nodes in the recombination tree. - for u in self.traceback[site_index].keys(): - recombination_required[u] = -1 + while left <= sites_position[site_index] < right: + if start_pos <= sites_position[site_index] < end_pos: + u = output_edge.parent + self.set_allelic_state(site_index) + v = u + while self.allelic_state[v] == -1: + v = self.parent[v] + match[site_index] = self.allelic_state[v] + self.unset_allelic_state(site_index) + + for u, recombine in self.traceback[site_index].items(): + # Mark the traceback nodes on the tree. + recombination_required[u] = recombine + # Now traverse up the tree from the current node. The first + # marked node we meet tells us whether we need to + # recombine. + u = output_edge.parent + while u != 0 and recombination_required[u] == -1: + u = self.parent[u] + if recombination_required[u] and site_index > start: + output_edge.left = site_index + u = self.max_likelihood_node[site_index - 1] + output_edge = Edge(right=site_index, parent=u) + output_edges.append(output_edge) + # Reset the nodes in the recombination tree. + for u in self.traceback[site_index].keys(): + recombination_required[u] = -1 site_index -= 1 + output_edge.left = start self.mean_traceback_size = sum(len(t) for t in self.traceback) / self.num_sites @@ -681,6 +674,7 @@ def test_switch_each_sample(self): assert list(m.path.parent) == [0, 1, 2, 3] np.testing.assert_array_equal(h, m.matched_haplotype) + class TestSimulationExamples: def check_exact_sample_matches(self, ts): H = ts.genotype_matrix().T @@ -692,8 +686,9 @@ def check_exact_sample_matches(self, ts): def test_single_tree_binary_mutations(self): ts = msprime.sim_ancestry(5, sequence_length=100, random_seed=2) - ts = msprime.sim_mutations(ts, rate=0.01, model=msprime.BinaryMutationModel(), - random_seed=2) + ts = msprime.sim_mutations( + ts, rate=0.01, model=msprime.BinaryMutationModel(), random_seed=2 + ) print(ts.tables.sites) print(ts.tables.mutations) self.check_exact_sample_matches(ts) From 9a7492b676e60831d45b2b74886de4223327aeae Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 6 May 2023 22:42:13 +0100 Subject: [PATCH 32/42] Python version looks like it's working --- tests/test_matching.py | 89 ++++++++++++++++++++++++++++++------------ 1 file changed, 64 insertions(+), 25 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 16c8b94a..bf6ec1b0 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -298,7 +298,7 @@ def is_nonzero_root(self, u): def find_path(self, h, start, end): Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index - L = self.matcher_indexes.sequence_length + sequence_length = self.matcher_indexes.sequence_length sites_position = self.matcher_indexes.sites_position M = len(Il) n = self.num_nodes @@ -323,7 +323,7 @@ def find_path(self, h, start, end): start_pos = sites_position[start] end_pos = sites_position[end] pos = 0 - right = L + right = sequence_length if j < M and start_pos < Il[j].left: right = Il[j].left while j < M and k < M and Il[j].left <= start_pos: @@ -334,7 +334,7 @@ def find_path(self, h, start, end): self.insert_edge(Il[j]) j += 1 left = pos - right = L + right = sequence_length if j < M: right = min(right, Il[j].left) if k < M: @@ -432,7 +432,7 @@ def find_path(self, h, start, end): if u != 0 and self.likelihood[u] == NONZERO_ROOT: self.likelihood[u] = 0 self.likelihood_nodes.append(u) - right = m + right = sequence_length if j < M: right = min(right, Il[j].left) if k < M: @@ -569,6 +569,23 @@ def run_match(ts, h): # refactor. +def add_unique_sample_mutations(ts): + """ + Adds a mutation for each of the samples at equally spaced locations + along the genome. + """ + tables = ts.dump_tables() + L = int(ts.sequence_length) + assert L % ts.num_samples == 0 + gap = L // ts.num_samples + x = 0 + for u in ts.samples(): + site = tables.sites.add_row(position=x, ancestral_state="0") + tables.mutations.add_row(site=site, derived_state="1", node=u) + x += gap + return tables.tree_sequence() + + class TestSingleBalancedTreeExample: # 3.00┊ 6 ┊ # ┊ ┏━┻━┓ ┊ @@ -579,12 +596,9 @@ class TestSingleBalancedTreeExample: @staticmethod def ts(): - tables = tskit.Tree.generate_balanced(4, span=4).tree_sequence.dump_tables() - # Add a site for each sample with a single mutation above that sample. - for j in range(4): - tables.sites.add_row(j, "0") - tables.mutations.add_row(site=j, node=j, derived_state="1") - return tables.tree_sequence() + return add_unique_sample_mutations( + tskit.Tree.generate_balanced(4, span=4).tree_sequence + ) @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): @@ -647,12 +661,7 @@ def ts(): ts = tskit.load_text( nodes=io.StringIO(nodes), edges=io.StringIO(edges), strict=False ) - tables = ts.dump_tables() - # Add a site for each sample with a single mutation above that sample. - for j in range(4): - tables.sites.add_row(j, "0") - tables.mutations.add_row(site=j, node=j, derived_state="1") - return tables.tree_sequence() + return add_unique_sample_mutations(ts) @pytest.mark.parametrize("j", [0, 1, 2, 3]) def test_match_sample(self, j): @@ -679,20 +688,50 @@ class TestSimulationExamples: def check_exact_sample_matches(self, ts): H = ts.genotype_matrix().T for u, h in zip(ts.samples(), H): - print(u, h) m = run_match(ts, h) - print(m) - break + np.testing.assert_array_equal(h, m.matched_haplotype) + assert list(m.path.left) == [0] + assert list(m.path.right) == [ts.num_sites] + assert list(m.path.parent) == [u] - def test_single_tree_binary_mutations(self): - ts = msprime.sim_ancestry(5, sequence_length=100, random_seed=2) - ts = msprime.sim_mutations( - ts, rate=0.01, model=msprime.BinaryMutationModel(), random_seed=2 + def check_switch_all_samples(self, ts): + h = np.ones(ts.num_sites, dtype=np.int8) + m = run_match(ts, h) + np.testing.assert_array_equal(h, m.matched_haplotype) + np.testing.assert_array_equal(m.path.left, np.arange(ts.num_sites)) + np.testing.assert_array_equal(m.path.right, 1 + np.arange(ts.num_sites)) + np.testing.assert_array_equal(m.path.parent, ts.samples()) + + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_single_tree_exact_match(self, n): + ts = msprime.sim_ancestry(n, sequence_length=100, random_seed=2) + ts = add_unique_sample_mutations(ts) + self.check_exact_sample_matches(ts) + + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_multiple_trees_exact_match(self, n): + ts = msprime.sim_ancestry( + n, sequence_length=20, recombination_rate=0.1, random_seed=2234 ) - print(ts.tables.sites) - print(ts.tables.mutations) + assert ts.num_trees > 1 + ts = add_unique_sample_mutations(ts) self.check_exact_sample_matches(ts) + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_single_tree_switch_all_samples(self, n): + ts = msprime.sim_ancestry(n, sequence_length=100, random_seed=2345) + ts = add_unique_sample_mutations(ts) + self.check_switch_all_samples(ts) + + @pytest.mark.parametrize("n", [1, 2, 5, 10]) + def test_multiple_trees_switch_all_sample(self, n): + ts = msprime.sim_ancestry( + n, sequence_length=20, recombination_rate=0.1, random_seed=12234 + ) + assert ts.num_trees > 1 + ts = add_unique_sample_mutations(ts) + self.check_switch_all_samples(ts) + class TestMatchClassUtils: def test_pickle(self): From 91e50fe113f660e09244eda324fc220c0b7ae761 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 6 May 2023 22:56:01 +0100 Subject: [PATCH 33/42] Add sites_position storage and coordinate_t type --- lib/ancestor_matcher.c | 15 +++++++++++++-- lib/tsinfer.h | 7 +++++-- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index 78ec3738..7e6ae8b9 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1038,7 +1038,6 @@ matcher_indexes_print_state(const matcher_indexes_t *self, FILE *out) fprintf(out, "\n"); } } - return 0; } @@ -1112,9 +1111,17 @@ matcher_indexes_copy_mutation_data( int ret = 0; tsk_size_t j; tsk_id_t site, last_site; + const double *restrict sites_position = tables->sites.position; const tsk_id_t *restrict mutations_site = tables->mutations.site; const tsk_id_t *restrict mutations_node = tables->mutations.node; const tsk_size_t total_mutations = tables->mutations.num_rows; + coordinate_t *restrict converted_position = self->sites.position; + + for (j = 0; j < self->num_sites; j++) { + /* TODO check for overflow */ + converted_position[j] = (coordinate_t) sites_position[j]; + } + converted_position[j] = (coordinate_t) tables->sequence_length; last_site = -1; for (j = 0; j < total_mutations; j++) { @@ -1156,8 +1163,11 @@ matcher_indexes_alloc( self->right_index_edges = malloc(self->num_edges * sizeof(*self->right_index_edges)); self->sites.mutations = malloc(self->num_sites * sizeof(*self->sites.mutations)); self->sites.num_alleles = malloc(self->num_sites * sizeof(*self->sites.num_alleles)); + self->sites.position + = malloc((self->num_sites + 1) * sizeof(*self->sites.mutations)); if (self->left_index_edges == NULL || self->right_index_edges == NULL - || self->sites.mutations == NULL) { + || self->sites.mutations == NULL || self->sites.position == NULL + || self->sites.num_alleles == NULL) { ret = TSI_ERR_NO_MEMORY; goto out; } @@ -1184,6 +1194,7 @@ matcher_indexes_free(matcher_indexes_t *self) tsk_safe_free(self->left_index_edges); tsk_safe_free(self->right_index_edges); tsk_safe_free(self->sites.mutations); + tsk_safe_free(self->sites.position); tsk_safe_free(self->sites.num_alleles); tsk_blkalloc_free(&self->allocator); return 0; diff --git a/lib/tsinfer.h b/lib/tsinfer.h index 7f7319ba..684f17ac 100644 --- a/lib/tsinfer.h +++ b/lib/tsinfer.h @@ -46,10 +46,12 @@ #define TSI_NODE_IS_PC_ANCESTOR ((tsk_flags_t)(1u << 16)) typedef int8_t allele_t; +/* TODO should probably change to uint32 when we have removed the old code.*/ +typedef tsk_id_t coordinate_t; typedef struct { - tsk_id_t left; - tsk_id_t right; + coordinate_t left; + coordinate_t right; tsk_id_t parent; tsk_id_t child; } edge_t; @@ -216,6 +218,7 @@ typedef struct { size_t num_mutations; size_t num_edges; struct { + coordinate_t *position; mutation_list_node_t **mutations; tsk_size_t *num_alleles; } sites; From b8aed5b239d55feb35b9b00b0c11eb734a62659d Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Mon, 8 May 2023 23:16:22 +0100 Subject: [PATCH 34/42] Matching working in C (looks like) --- lib/ancestor_matcher.c | 111 ++++++++++++++++++++++++----------------- tests/test_matching.py | 13 ++--- 2 files changed, 72 insertions(+), 52 deletions(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index 7e6ae8b9..183ec793 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1,5 +1,5 @@ /* -** Copyright (C) 2018-2020 University of Oxford +** Copyright (C) 2018-2023 University of Oxford ** ** This file is part of tsinfer. ** @@ -1118,7 +1118,7 @@ matcher_indexes_copy_mutation_data( coordinate_t *restrict converted_position = self->sites.position; for (j = 0; j < self->num_sites; j++) { - /* TODO check for overflow */ + /* TODO check for under/overflow */ converted_position[j] = (coordinate_t) sites_position[j]; } converted_position[j] = (coordinate_t) tables->sequence_length; @@ -1741,19 +1741,24 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i tsk_id_t *restrict path_parent) { int ret = 0; - tsk_id_t l; + tsk_id_t site; edge_t edge; tsk_id_t u, v, max_likelihood_node; - tsk_id_t left, right, pos; + coordinate_t left, right, pos, start_pos, end_pos; tsk_id_t *restrict parent = self->parent; allele_t *restrict allelic_state = self->allelic_state; int8_t *restrict recombination_required = self->recombination_required; const edge_t *restrict in = self->matcher_indexes->right_index_edges; const edge_t *restrict out = self->matcher_indexes->left_index_edges; + const coordinate_t *restrict sites_position = self->matcher_indexes->sites.position; + const coordinate_t sequence_length = sites_position[self->num_sites]; int_fast32_t in_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; int_fast32_t out_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; size_t path_length = 0; + start_pos = sites_position[start]; + end_pos = sites_position[end]; + /* Prepare for the traceback and get the memory ready for recording * the output edges. */ path_right[path_length] = end; @@ -1768,9 +1773,10 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i memset(parent, 0xff, self->num_nodes * sizeof(*parent)); memset( recombination_required, 0xff, self->num_nodes * sizeof(*recombination_required)); - pos = (tsk_id_t) self->num_sites; + pos = sequence_length; + site = (tsk_id_t) self->num_sites - 1; - while (pos > start) { + while (pos > start_pos) { while (out_index >= 0 && out[out_index].left == pos) { edge = out[out_index]; out_index--; @@ -1793,38 +1799,42 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i /* The tree is ready; perform the traceback at each site in this tree */ assert(left < right); - for (l = TSK_MIN(right, end) - 1; l >= (int) TSK_MAX(left, start); l--) { - ancestor_matcher2_set_allelic_state(self, l, allelic_state); - u = path_parent[path_length]; - v = u; - while (allelic_state[v] == TSK_NULL) { - v = parent[v]; - } - match[l] = allelic_state[v]; - ancestor_matcher2_unset_allelic_state(self, l, allelic_state); + for (; site >= 0 && left <= sites_position[site] && sites_position[site] < right; + site--) { + if (start_pos <= sites_position[site] && sites_position[site] < end_pos) { + + ancestor_matcher2_set_allelic_state(self, site, allelic_state); + u = path_parent[path_length]; + v = u; + while (allelic_state[v] == TSK_NULL) { + v = parent[v]; + } + match[site] = allelic_state[v]; + ancestor_matcher2_unset_allelic_state(self, site, allelic_state); - /* Mark the traceback nodes on the tree */ - ancestor_matcher2_set_recombination_required( - self, l, recombination_required); + /* Mark the traceback nodes on the tree */ + ancestor_matcher2_set_recombination_required( + self, site, recombination_required); - /* Traverse up the tree from the current node. The first marked node that we - * meed tells us whether we need to recombine */ - while (u != 0 && recombination_required[u] == -1) { - u = parent[u]; - assert(u != NULL_NODE); - } - if (recombination_required[u] && l > start) { - max_likelihood_node = self->max_likelihood_node[l - 1]; - assert(max_likelihood_node != NULL_NODE); - path_left[path_length] = l; - path_length++; - /* Start the next output edge */ - path_right[path_length] = l; - path_parent[path_length] = max_likelihood_node; + /* Traverse up the tree from the current node. The first marked node that + * we meed tells us whether we need to recombine */ + while (u != 0 && recombination_required[u] == -1) { + u = parent[u]; + assert(u != NULL_NODE); + } + if (recombination_required[u] && site > start) { + max_likelihood_node = self->max_likelihood_node[site - 1]; + assert(max_likelihood_node != NULL_NODE); + path_left[path_length] = site; + path_length++; + /* Start the next output edge */ + path_right[path_length] = site; + path_parent[path_length] = max_likelihood_node; + } + /* Unset the values in the tree for the next site. */ + ancestor_matcher2_unset_recombination_required( + self, site, recombination_required); } - /* Unset the values in the tree for the next site. */ - ancestor_matcher2_unset_recombination_required( - self, l, recombination_required); } } @@ -1855,26 +1865,29 @@ ancestor_matcher2_run_forwards_match( tsk_id_t *restrict right_child = self->right_child; tsk_id_t *restrict left_sib = self->left_sib; tsk_id_t *restrict right_sib = self->right_sib; - tsk_id_t pos, left, right; + coordinate_t pos, left, right; const edge_t *restrict in = self->matcher_indexes->left_index_edges; const edge_t *restrict out = self->matcher_indexes->right_index_edges; const int_fast32_t M = (tsk_id_t) self->matcher_indexes->num_edges; + const coordinate_t *restrict sites_position = self->matcher_indexes->sites.position; int_fast32_t in_index, out_index, l, remove_start; + const coordinate_t start_pos = sites_position[start]; + const coordinate_t end_pos = sites_position[end]; + const coordinate_t sequence_length = sites_position[self->num_sites]; /* Load the tree for start */ left = 0; pos = 0; in_index = 0; out_index = 0; - right = (tsk_id_t) self->num_sites; - if (in_index < M && start < in[in_index].left) { + right = sequence_length; + if (in_index < M && start_pos < in[in_index].left) { right = in[in_index].left; } - /* TODO there's probably quite a big gain to made here by seeking - * directly to the tree that we're interested in rather than just - * building the trees sequentially */ - while (in_index < M && out_index < M && in[in_index].left <= start) { + /* TODO don't add all edges trees but only insert edges that intersect + * with start_pos. Maybe a reasonable gain for short ancestral fragments */ + while (in_index < M && out_index < M && in[in_index].left <= start_pos) { while (out_index < M && out[out_index].right == pos) { remove_edge( out[out_index], parent, left_child, right_child, left_sib, right_sib); @@ -1886,7 +1899,7 @@ ancestor_matcher2_run_forwards_match( in_index++; } left = pos; - right = (tsk_id_t) self->num_sites; + right = sequence_length; if (in_index < M) { right = TSK_MIN(right, in[in_index].left); } @@ -1919,8 +1932,11 @@ ancestor_matcher2_run_forwards_match( self->likelihood_nodes[0] = last_root; self->num_likelihood_nodes = 1; + for (site = 0; sites_position[site] < left; site++) + ; + remove_start = out_index; - while (left < end) { + while (left < end_pos) { assert(left < right); /* Remove the likelihoods for any nonzero roots that have just left @@ -1962,12 +1978,15 @@ ancestor_matcher2_run_forwards_match( if (self->flags & TSI_EXTENDED_CHECKS) { ancestor_matcher2_check_state(self); } - for (site = TSK_MAX(left, start); site < TSK_MIN(right, end); site++) { + + while (left <= sites_position[site] + && sites_position[site] < TSK_MIN(right, end_pos)) { ret = ancestor_matcher2_update_site_state( self, site, haplotype[site], parent, L, L_cache); if (ret != 0) { goto out; } + site++; } /* Move on to the next tree */ @@ -2030,7 +2049,7 @@ ancestor_matcher2_run_forwards_match( self->num_likelihood_nodes++; } } - right = (tsk_id_t) self->num_sites; + right = sequence_length; if (in_index < M) { right = TSK_MIN(right, in[in_index].left); } diff --git a/tests/test_matching.py b/tests/test_matching.py index bf6ec1b0..cdf38eb8 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -12,6 +12,7 @@ import tskit import _tsinfer +import tsinfer from tsinfer import matching @@ -555,12 +556,12 @@ def run_match(ts, h): precision=precision, ) match_py = matcher.find_path(h, 0, ts.num_sites) - # mi = tsinfer.MatcherIndexes(ts) - # am = tsinfer.AncestorMatcher2( - # mi, recombination=recombination, mismatch=mismatch, precision=precision - # ) - # match_c = am.find_match(h, 0, ts.num_sites) - # match_py.assert_equals(match_c) + mi = tsinfer.MatcherIndexes(ts) + am = tsinfer.AncestorMatcher2( + mi, recombination=recombination, mismatch=mismatch, precision=precision + ) + match_c = am.find_match(h, 0, ts.num_sites) + match_py.assert_equals(match_c) return match_py From 2f694bb70cc9cd1ef61447f3c8da0397fbcf1ad3 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 10 May 2023 22:30:43 +0100 Subject: [PATCH 35/42] Infer start and end from haplotype --- tests/test_matching.py | 92 ++++++++++++++++++++++++++++++++++-------- tsinfer/matching.py | 23 ++++++----- 2 files changed, 88 insertions(+), 27 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index cdf38eb8..b6593964 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -296,7 +296,7 @@ def insert_edge(self, edge): def is_nonzero_root(self, u): return u != 0 and self.is_root(u) and self.left_child[u] == -1 - def find_path(self, h, start, end): + def find_path(self, h): Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index sequence_length = self.matcher_indexes.sequence_length @@ -317,6 +317,15 @@ def find_path(self, h, start, end): self.likelihood_nodes = [] L_cache = np.zeros_like(self.likelihood) - 1 + start = 0 + while start < m and h[start] == tskit.MISSING_DATA: + start += 1 + + end = m - 1 + while end >= 0 and h[end] == tskit.MISSING_DATA: + end -= 1 + end += 1 + # print("MATCH: start=", start, "end = ", end, "h = ", h) j = 0 k = 0 @@ -555,17 +564,27 @@ def run_match(ts, h): mismatch=mismatch, precision=precision, ) - match_py = matcher.find_path(h, 0, ts.num_sites) - mi = tsinfer.MatcherIndexes(ts) - am = tsinfer.AncestorMatcher2( - mi, recombination=recombination, mismatch=mismatch, precision=precision - ) - match_c = am.find_match(h, 0, ts.num_sites) - match_py.assert_equals(match_c) + match_py = matcher.find_path(h) + + # mi = tsinfer.MatcherIndexes(ts) + # am = tsinfer.AncestorMatcher2( + # mi, recombination=recombination, mismatch=mismatch, precision=precision + # ) + # match_c = am.find_match(h, 0, ts.num_sites) + # match_py.assert_equals(match_c) return match_py +class TestMatchClassUtils: + def test_pickle(self): + m1 = matching.Match( + matching.Path(np.array([0]), np.array([1]), np.array([0])), np.array([0]) + ) + m2 = pickle.loads(pickle.dumps(m1)) + m1.assert_equals(m2) + + # TODO the tests on these two classes are the same right now, should # refactor. @@ -612,6 +631,19 @@ def test_match_sample(self, j): assert list(m.path.parent) == [ts.samples()[j]] np.testing.assert_array_equal(h, m.matched_haplotype) + @pytest.mark.parametrize("j", [1, 2]) + def test_match_sample_missing_flanks(self, j): + ts = self.ts() + h = np.zeros(4) + h[0] = -1 + h[-1] = -1 + h[j] = 1 + m = run_match(ts, h) + assert list(m.path.left) == [1] + assert list(m.path.right) == [3] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) + def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) @@ -621,6 +653,17 @@ def test_switch_each_sample(self): assert list(m.path.parent) == [0, 1, 2, 3] np.testing.assert_array_equal(h, m.matched_haplotype) + def test_switch_each_sample_missing_flanks(self): + ts = self.ts() + h = np.ones(4) + h[0] = -1 + h[-1] = -1 + m = run_match(ts, h) + assert list(m.path.left) == [1, 2] + assert list(m.path.right) == [2, 3] + assert list(m.path.parent) == [1, 2] + np.testing.assert_array_equal(h, m.matched_haplotype) + class TestMultiTreeExample: # 0.84┊ 7 ┊ 7 ┊ @@ -675,6 +718,19 @@ def test_match_sample(self, j): assert list(m.path.parent) == [ts.samples()[j]] np.testing.assert_array_equal(h, m.matched_haplotype) + @pytest.mark.parametrize("j", [1, 2]) + def test_match_sample_missing_flanks(self, j): + ts = self.ts() + h = np.zeros(4) + h[0] = -1 + h[-1] = -1 + h[j] = 1 + m = run_match(ts, h) + assert list(m.path.left) == [1] + assert list(m.path.right) == [3] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) + def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) @@ -684,6 +740,17 @@ def test_switch_each_sample(self): assert list(m.path.parent) == [0, 1, 2, 3] np.testing.assert_array_equal(h, m.matched_haplotype) + def test_switch_each_sample_missing_flanks(self): + ts = self.ts() + h = np.ones(4) + h[0] = -1 + h[-1] = -1 + m = run_match(ts, h) + assert list(m.path.left) == [1, 2] + assert list(m.path.right) == [2, 3] + assert list(m.path.parent) == [1, 2] + np.testing.assert_array_equal(h, m.matched_haplotype) + class TestSimulationExamples: def check_exact_sample_matches(self, ts): @@ -732,12 +799,3 @@ def test_multiple_trees_switch_all_sample(self, n): assert ts.num_trees > 1 ts = add_unique_sample_mutations(ts) self.check_switch_all_samples(ts) - - -class TestMatchClassUtils: - def test_pickle(self): - m1 = matching.Match( - matching.Path(np.array([0]), np.array([1]), np.array([0])), np.array([0]) - ) - m2 = pickle.loads(pickle.dumps(m1)) - m1.assert_equals(m2) diff --git a/tsinfer/matching.py b/tsinfer/matching.py index ec09e71a..9b51a4c4 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -34,7 +34,9 @@ def add_vestigial_root(ts): base_tables = ts.dump_tables() tables = base_tables.copy() tables.nodes.clear() - t = ts.max_root_time + t = 0 + if ts.num_nodes > 0: + t = max(ts.nodes_time) tables.nodes.add_row(time=t + 1) num_additonal_nodes = len(tables.nodes) tables.mutations.node += num_additonal_nodes @@ -42,15 +44,16 @@ def add_vestigial_root(ts): tables.edges.parent += num_additonal_nodes for node in base_tables.nodes: tables.nodes.append(node) - for tree in ts.trees(): - root = tree.root + num_additonal_nodes - tables.edges.add_row( - tree.interval.left, tree.interval.right, parent=0, child=root - ) - tables.edges.squash() - # FIXME probably don't need to sort here most of the time, or at least we - # can just sort almost the end of the table. - tables.sort() + if ts.num_nodes > 0: + for tree in ts.trees(): + root = tree.root + num_additonal_nodes + tables.edges.add_row( + tree.interval.left, tree.interval.right, parent=0, child=root + ) + tables.edges.squash() + # FIXME probably don't need to sort here most of the time, or at least we + # can just sort almost the end of the table. + tables.sort() return tables.tree_sequence() From 1530fb2360f22e6bcc31bb37172418ab153465d8 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Wed, 10 May 2023 22:42:24 +0100 Subject: [PATCH 36/42] Rough implementation of flank skipping --- tests/test_matching.py | 25 ++++++------------------- tsinfer/matching.py | 22 ++++++++++++++++++---- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index b6593964..3a47a783 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -566,12 +566,12 @@ def run_match(ts, h): ) match_py = matcher.find_path(h) - # mi = tsinfer.MatcherIndexes(ts) - # am = tsinfer.AncestorMatcher2( - # mi, recombination=recombination, mismatch=mismatch, precision=precision - # ) - # match_c = am.find_match(h, 0, ts.num_sites) - # match_py.assert_equals(match_c) + mi = tsinfer.MatcherIndexes(ts) + am = tsinfer.AncestorMatcher2( + mi, recombination=recombination, mismatch=mismatch, precision=precision + ) + match_c = am.find_match(h) + match_py.assert_equals(match_c) return match_py @@ -718,19 +718,6 @@ def test_match_sample(self, j): assert list(m.path.parent) == [ts.samples()[j]] np.testing.assert_array_equal(h, m.matched_haplotype) - @pytest.mark.parametrize("j", [1, 2]) - def test_match_sample_missing_flanks(self, j): - ts = self.ts() - h = np.zeros(4) - h[0] = -1 - h[-1] = -1 - h[j] = 1 - m = run_match(ts, h) - assert list(m.path.left) == [1] - assert list(m.path.right) == [3] - assert list(m.path.parent) == [ts.samples()[j]] - np.testing.assert_array_equal(h, m.matched_haplotype) - def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) diff --git a/tsinfer/matching.py b/tsinfer/matching.py index 9b51a4c4..0bb94bbe 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -19,6 +19,7 @@ import dataclasses import numpy as np +import tskit import _tsinfer @@ -94,13 +95,26 @@ def assert_equals(self, other): class AncestorMatcher2(_tsinfer.AncestorMatcher2): - def find_match(self, h, left, right): - path_len, left, right, parent, matched_haplotype = self.find_path( - h, left, right - ) + def find_match(self, h): + # TODO compute these in C - taking a shortcut for now. + m = len(h) + start = 0 + while start < m and h[start] == tskit.MISSING_DATA: + start += 1 + if start == m: + raise ValueError("All missing data") + end = m - 1 + while end >= 0 and h[end] == tskit.MISSING_DATA: + end -= 1 + end += 1 + + path_len, left, right, parent, matched_haplotype = self.find_path(h, start, end) left = left[:path_len][::-1] right = right[:path_len][::-1] parent = parent[:path_len][::-1] # We added a 0-root everywhere above, so convert node IDs back parent -= 1 + # FIXME C code isn't setting match to missing as expected + matched_haplotype[:start] = tskit.MISSING_DATA + matched_haplotype[end:] = tskit.MISSING_DATA return Match(Path(left, right, parent), matched_haplotype) From 7a905e75a7a8fca8e54d3d6181bc4d094940db01 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 12 May 2023 22:24:15 +0100 Subject: [PATCH 37/42] Change python code to use coords in path --- tests/test_matching.py | 29 +++++++++++++++-------------- tsinfer/matching.py | 11 +++++++++-- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 3a47a783..3eeb4b7f 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -541,8 +541,8 @@ def run_traceback(self, start, end): # instance we need to pop the last edge off the list. Or, see why we're # generating it in the first place. assert e.left < e.right - left[j] = e.left - right[j] = e.right + left[j] = sites_position[e.left] + right[j] = sites_position[e.right] parent[j] = e.parent # Convert the parent node IDs back to original values @@ -612,12 +612,12 @@ class TestSingleBalancedTreeExample: # 2.00┊ 4 5 ┊ # ┊ ┏┻┓ ┏┻┓ ┊ # 1.00┊ 0 1 2 3 ┊ - # 0 4 + # 0 8 @staticmethod def ts(): return add_unique_sample_mutations( - tskit.Tree.generate_balanced(4, span=4).tree_sequence + tskit.Tree.generate_balanced(4, span=8).tree_sequence ) @pytest.mark.parametrize("j", [0, 1, 2, 3]) @@ -627,7 +627,7 @@ def test_match_sample(self, j): h[j] = 1 m = run_match(ts, h) assert list(m.path.left) == [0] - assert list(m.path.right) == [4] + assert list(m.path.right) == [ts.sequence_length] assert list(m.path.parent) == [ts.samples()[j]] np.testing.assert_array_equal(h, m.matched_haplotype) @@ -639,8 +639,8 @@ def test_match_sample_missing_flanks(self, j): h[-1] = -1 h[j] = 1 m = run_match(ts, h) - assert list(m.path.left) == [1] - assert list(m.path.right) == [3] + assert list(m.path.left) == [2] + assert list(m.path.right) == [6] assert list(m.path.parent) == [ts.samples()[j]] np.testing.assert_array_equal(h, m.matched_haplotype) @@ -648,8 +648,8 @@ def test_switch_each_sample(self): ts = self.ts() h = np.ones(4) m = run_match(ts, h) - assert list(m.path.left) == [0, 1, 2, 3] - assert list(m.path.right) == [1, 2, 3, 4] + assert list(m.path.left) == [0, 2, 4, 6] + assert list(m.path.right) == [2, 4, 6, 8] assert list(m.path.parent) == [0, 1, 2, 3] np.testing.assert_array_equal(h, m.matched_haplotype) @@ -659,8 +659,8 @@ def test_switch_each_sample_missing_flanks(self): h[0] = -1 h[-1] = -1 m = run_match(ts, h) - assert list(m.path.left) == [1, 2] - assert list(m.path.right) == [2, 3] + assert list(m.path.left) == [2, 4] + assert list(m.path.right) == [4, 6] assert list(m.path.parent) == [1, 2] np.testing.assert_array_equal(h, m.matched_haplotype) @@ -746,15 +746,16 @@ def check_exact_sample_matches(self, ts): m = run_match(ts, h) np.testing.assert_array_equal(h, m.matched_haplotype) assert list(m.path.left) == [0] - assert list(m.path.right) == [ts.num_sites] + assert list(m.path.right) == [ts.sequence_length] assert list(m.path.parent) == [u] def check_switch_all_samples(self, ts): h = np.ones(ts.num_sites, dtype=np.int8) m = run_match(ts, h) + X = np.append(ts.sites_position, [ts.sequence_length]) np.testing.assert_array_equal(h, m.matched_haplotype) - np.testing.assert_array_equal(m.path.left, np.arange(ts.num_sites)) - np.testing.assert_array_equal(m.path.right, 1 + np.arange(ts.num_sites)) + np.testing.assert_array_equal(m.path.left, X[:-1]) + np.testing.assert_array_equal(m.path.right, X[1:]) np.testing.assert_array_equal(m.path.parent, ts.samples()) @pytest.mark.parametrize("n", [1, 2, 5, 10]) diff --git a/tsinfer/matching.py b/tsinfer/matching.py index 0bb94bbe..ef814db9 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -39,13 +39,13 @@ def add_vestigial_root(ts): if ts.num_nodes > 0: t = max(ts.nodes_time) tables.nodes.add_row(time=t + 1) - num_additonal_nodes = len(tables.nodes) + num_additonal_nodes = 1 tables.mutations.node += num_additonal_nodes tables.edges.child += num_additonal_nodes tables.edges.parent += num_additonal_nodes for node in base_tables.nodes: tables.nodes.append(node) - if ts.num_nodes > 0: + if ts.num_edges > 0: for tree in ts.trees(): root = tree.root + num_additonal_nodes tables.edges.add_row( @@ -75,6 +75,9 @@ class Path: right: np.ndarray parent: np.ndarray + def __iter__(self): + yield from zip(self.left, self.right, self.parent) + def __len__(self): return len(self.left) @@ -98,6 +101,10 @@ class AncestorMatcher2(_tsinfer.AncestorMatcher2): def find_match(self, h): # TODO compute these in C - taking a shortcut for now. m = len(h) + if m == 0: + # FIXME hardcoding 0 for parent here + return Match(Path([0], [m], [0]), []) + start = 0 while start < m and h[start] == tskit.MISSING_DATA: start += 1 From 854f13b8c469c5af35bd81d50f6928916c91a06e Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Fri, 12 May 2023 22:42:15 +0100 Subject: [PATCH 38/42] Implement coordinate paths in C --- lib/ancestor_matcher.c | 10 +++++----- tests/test_matching.py | 41 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/lib/ancestor_matcher.c b/lib/ancestor_matcher.c index 183ec793..338c3401 100644 --- a/lib/ancestor_matcher.c +++ b/lib/ancestor_matcher.c @@ -1756,12 +1756,12 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i int_fast32_t out_index = (int_fast32_t) self->matcher_indexes->num_edges - 1; size_t path_length = 0; - start_pos = sites_position[start]; + start_pos = start == 0 ? 0 : sites_position[start]; end_pos = sites_position[end]; /* Prepare for the traceback and get the memory ready for recording * the output edges. */ - path_right[path_length] = end; + path_right[path_length] = end_pos; path_parent[path_length] = NULL_NODE; max_likelihood_node = self->max_likelihood_node[end - 1]; @@ -1825,10 +1825,10 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i if (recombination_required[u] && site > start) { max_likelihood_node = self->max_likelihood_node[site - 1]; assert(max_likelihood_node != NULL_NODE); - path_left[path_length] = site; + path_left[path_length] = sites_position[site]; path_length++; /* Start the next output edge */ - path_right[path_length] = site; + path_right[path_length] = sites_position[site]; path_parent[path_length] = max_likelihood_node; } /* Unset the values in the tree for the next site. */ @@ -1838,7 +1838,7 @@ ancestor_matcher2_run_traceback(ancestor_matcher2_t *self, tsk_id_t start, tsk_i } } - path_left[path_length] = start; + path_left[path_length] = start_pos; path_length++; assert(path_right[path_length - 1] != start); *path_length_out = path_length; diff --git a/tests/test_matching.py b/tests/test_matching.py index 3eeb4b7f..5e038bac 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -548,6 +548,8 @@ def run_traceback(self, start, end): # Convert the parent node IDs back to original values parent -= 1 path = matching.Path(left[::-1], right[::-1], parent[::-1]) + if start == 0 and path.left[0] == sites_position[0]: + path.left[0] = 0 return matching.Match(path, match) @@ -589,7 +591,7 @@ def test_pickle(self): # refactor. -def add_unique_sample_mutations(ts): +def add_unique_sample_mutations(ts, start=0): """ Adds a mutation for each of the samples at equally spaced locations along the genome. @@ -598,7 +600,7 @@ def add_unique_sample_mutations(ts): L = int(ts.sequence_length) assert L % ts.num_samples == 0 gap = L // ts.num_samples - x = 0 + x = start for u in ts.samples(): site = tables.sites.add_row(position=x, ancestral_state="0") tables.mutations.add_row(site=site, derived_state="1", node=u) @@ -665,6 +667,41 @@ def test_switch_each_sample_missing_flanks(self): np.testing.assert_array_equal(h, m.matched_haplotype) +class TestSingleBalancedTreeExampleNonZeroFirstSite: + # 3.00┊ 6 ┊ + # ┊ ┏━┻━┓ ┊ + # 2.00┊ 4 5 ┊ + # ┊ ┏┻┓ ┏┻┓ ┊ + # 1.00┊ 0 1 2 3 ┊ + # 0 8 + + @staticmethod + def ts(): + return add_unique_sample_mutations( + tskit.Tree.generate_balanced(4, span=8).tree_sequence, start=1 + ) + + @pytest.mark.parametrize("j", [0, 1, 2, 3]) + def test_match_sample(self, j): + ts = self.ts() + h = np.zeros(4) + h[j] = 1 + m = run_match(ts, h) + assert list(m.path.left) == [0] + assert list(m.path.right) == [ts.sequence_length] + assert list(m.path.parent) == [ts.samples()[j]] + np.testing.assert_array_equal(h, m.matched_haplotype) + + def test_switch_each_sample(self): + ts = self.ts() + h = np.ones(4) + m = run_match(ts, h) + assert list(m.path.left) == [0, 3, 5, 7] + assert list(m.path.right) == [3, 5, 7, 8] + assert list(m.path.parent) == [0, 1, 2, 3] + np.testing.assert_array_equal(h, m.matched_haplotype) + + class TestMultiTreeExample: # 0.84┊ 7 ┊ 7 ┊ # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ From 0e59cebe98e0187cfd5e77942527f68e88b61dff Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Sat, 13 May 2023 22:14:44 +0100 Subject: [PATCH 39/42] Implement some cludges to support initial zero site-paht --- tests/test_matching.py | 24 ++++++++++++++++++++---- tsinfer/matching.py | 30 ++++++++++++++++++++++-------- 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 5e038bac..87ec7266 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -296,7 +296,13 @@ def insert_edge(self, edge): def is_nonzero_root(self, u): return u != 0 and self.is_root(u) and self.left_child[u] == -1 + def zero_sites_path(self): + path = matching.Path([0], [self.matcher_indexes.sites_position[-1]], [0]) + return matching.Match(path, []) + def find_path(self, h): + if self.num_sites == 0: + return self.zero_sites_path() Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index sequence_length = self.matcher_indexes.sequence_length @@ -330,7 +336,7 @@ def find_path(self, h): j = 0 k = 0 left = 0 - start_pos = sites_position[start] + start_pos = 0 if start == 0 else sites_position[start] end_pos = sites_position[end] pos = 0 right = sequence_length @@ -464,7 +470,7 @@ def run_traceback(self, start, end): # Now go back through the trees. j = M - 1 k = M - 1 - start_pos = sites_position[start] + start_pos = 0 if start == 0 else sites_position[start] end_pos = sites_position[end] # Construct the matched haplotype match = np.zeros(self.num_sites, dtype=np.int8) @@ -554,7 +560,7 @@ def run_traceback(self, start, end): def run_match(ts, h): - h = h.astype(np.int8) + h = np.array(h).astype(np.int8) assert len(h) == ts.num_sites recombination = np.zeros(ts.num_sites) + 1e-9 mismatch = np.zeros(ts.num_sites) @@ -574,7 +580,6 @@ def run_match(ts, h): ) match_c = am.find_match(h) match_py.assert_equals(match_c) - return match_py @@ -702,6 +707,17 @@ def test_switch_each_sample(self): np.testing.assert_array_equal(h, m.matched_haplotype) +class TestZeroSites: + @pytest.mark.parametrize("L", [1, 2, 5]) + def test_one_node_ts(self, L): + tables = tskit.TableCollection(L) + tables.nodes.add_row(time=1) + m = run_match(tables.tree_sequence(), []) + assert list(m.path.left) == [0] + assert list(m.path.right) == [L] + assert list(m.path.parent) == [0] + + class TestMultiTreeExample: # 0.84┊ 7 ┊ 7 ┊ # ┊ ┏━┻━┓ ┊ ┏━┻━┓ ┊ diff --git a/tsinfer/matching.py b/tsinfer/matching.py index ef814db9..affad602 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -31,13 +31,13 @@ def add_vestigial_root(ts): """ if not ts.discrete_genome: raise ValueError("Only discrete genome coords supported") + if ts.num_nodes == 0: + raise ValueError("Emtpy trees not supported") base_tables = ts.dump_tables() tables = base_tables.copy() tables.nodes.clear() - t = 0 - if ts.num_nodes > 0: - t = max(ts.nodes_time) + t = max(ts.nodes_time) tables.nodes.add_row(time=t + 1) num_additonal_nodes = 1 tables.mutations.node += num_additonal_nodes @@ -66,6 +66,9 @@ def __init__(self, ts): tables = ts.dump_tables() ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) ll_tables.fromdict(tables.asdict()) + # TODO should really just reflect these from the low-level C values. + self.sequence_length = ts.sequence_length + self.num_sites = ts.num_sites super().__init__(ll_tables) @@ -98,18 +101,29 @@ def assert_equals(self, other): class AncestorMatcher2(_tsinfer.AncestorMatcher2): + def __init__(self, matcher_indexes, **kwargs): + super().__init__(matcher_indexes, **kwargs) + self.sequence_length = matcher_indexes.sequence_length + self.num_sites = matcher_indexes.num_sites + + def zero_sites_path(self): + left = np.array([0], dtype=np.uint32) + right = np.array([self.sequence_length], dtype=np.uint32) + parent = np.array([0], dtype=np.uint32) + return Match(Path(left, right, parent), []) + def find_match(self, h): + if self.num_sites == 0: + return self.zero_sites_path() + # TODO compute these in C - taking a shortcut for now. m = len(h) - if m == 0: - # FIXME hardcoding 0 for parent here - return Match(Path([0], [m], [0]), []) start = 0 while start < m and h[start] == tskit.MISSING_DATA: start += 1 - if start == m: - raise ValueError("All missing data") + # if start == m: + # raise ValueError("All missing data") end = m - 1 while end >= 0 and h[end] == tskit.MISSING_DATA: end -= 1 From ef34c12ef3ef0cb2766025919cc6cc1e331d1252 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Mon, 15 May 2023 17:05:15 +0100 Subject: [PATCH 40/42] Minor updates --- tests/test_matching.py | 13 ++++++++----- tsinfer/matching.py | 6 ++++-- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/test_matching.py b/tests/test_matching.py index 87ec7266..e4352c8c 100644 --- a/tests/test_matching.py +++ b/tests/test_matching.py @@ -298,7 +298,7 @@ def is_nonzero_root(self, u): def zero_sites_path(self): path = matching.Path([0], [self.matcher_indexes.sites_position[-1]], [0]) - return matching.Match(path, []) + return matching.Match(path, [], []) def find_path(self, h): if self.num_sites == 0: @@ -454,9 +454,9 @@ def find_path(self, h): if k < M: right = min(right, Ir[k].right) - return self.run_traceback(start, end) + return self.run_traceback(start, end, h) - def run_traceback(self, start, end): + def run_traceback(self, start, end, query_haplotype): Il = self.matcher_indexes.left_index Ir = self.matcher_indexes.right_index L = self.matcher_indexes.sequence_length @@ -556,7 +556,7 @@ def run_traceback(self, start, end): path = matching.Path(left[::-1], right[::-1], parent[::-1]) if start == 0 and path.left[0] == sites_position[0]: path.left[0] = 0 - return matching.Match(path, match) + return matching.Match(path, query_haplotype, match) def run_match(ts, h): @@ -586,7 +586,9 @@ def run_match(ts, h): class TestMatchClassUtils: def test_pickle(self): m1 = matching.Match( - matching.Path(np.array([0]), np.array([1]), np.array([0])), np.array([0]) + matching.Path(np.array([0]), np.array([1]), np.array([0])), + np.array([0]), + np.array([0]), ) m2 = pickle.loads(pickle.dumps(m1)) m1.assert_equals(m2) @@ -637,6 +639,7 @@ def test_match_sample(self, j): assert list(m.path.right) == [ts.sequence_length] assert list(m.path.parent) == [ts.samples()[j]] np.testing.assert_array_equal(h, m.matched_haplotype) + np.testing.assert_array_equal(h, m.query_haplotype) @pytest.mark.parametrize("j", [1, 2]) def test_match_sample_missing_flanks(self, j): diff --git a/tsinfer/matching.py b/tsinfer/matching.py index affad602..aeb3bfea 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -93,11 +93,13 @@ def assert_equals(self, other): @dataclasses.dataclass class Match: path: Path + query_haplotype: np.ndarray matched_haplotype: np.ndarray def assert_equals(self, other): self.path.assert_equals(other.path) np.testing.assert_array_equal(self.matched_haplotype, other.matched_haplotype) + np.testing.assert_array_equal(self.query_haplotype, other.query_haplotype) class AncestorMatcher2(_tsinfer.AncestorMatcher2): @@ -110,7 +112,7 @@ def zero_sites_path(self): left = np.array([0], dtype=np.uint32) right = np.array([self.sequence_length], dtype=np.uint32) parent = np.array([0], dtype=np.uint32) - return Match(Path(left, right, parent), []) + return Match(Path(left, right, parent), [], []) def find_match(self, h): if self.num_sites == 0: @@ -138,4 +140,4 @@ def find_match(self, h): # FIXME C code isn't setting match to missing as expected matched_haplotype[:start] = tskit.MISSING_DATA matched_haplotype[end:] = tskit.MISSING_DATA - return Match(Path(left, right, parent), matched_haplotype) + return Match(Path(left, right, parent), h, matched_haplotype) From 3626a2a0661318dd9278c22847f5258c307c49bc Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Mon, 15 May 2023 17:05:48 +0100 Subject: [PATCH 41/42] Sort-of working driver script --- tmp.py | 128 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 tmp.py diff --git a/tmp.py b/tmp.py new file mode 100644 index 00000000..b6321af0 --- /dev/null +++ b/tmp.py @@ -0,0 +1,128 @@ +import itertools + +import msprime +import numpy as np +import tskit + +import tsinfer + + +class Sequence: + def __init__(self, haplotype): + self.full_haplotype = haplotype + + +def run_matches(ts, positions, sequences): + match_indexes = tsinfer.MatcherIndexes(ts) + recombination = np.zeros(ts.num_sites) + 1e-9 + mismatch = np.zeros(ts.num_sites) + matcher = tsinfer.AncestorMatcher2( + match_indexes, recombination=recombination, mismatch=mismatch + ) + sites_index = np.searchsorted(positions, ts.sites_position) + assert np.all(positions[sites_index] == ts.sites_position) + sites_in_ts = np.zeros(len(positions), dtype=bool) + sites_in_ts[sites_index] = True + results = [] + for seq in sequences: + m = matcher.find_match(seq.full_haplotype[sites_in_ts]) + h = seq.full_haplotype.copy() + h[sites_in_ts] = 0 + focal_sites = np.where(h != 0)[0] + results.append((m, focal_sites)) + return results + + +def insert_matches(tables, time, all_positions, matches): + ts_sites_position = tables.sites.position + added_sites = {} + for m, new_sites in matches: + u = tables.nodes.add_row(time=time, flags=1) + for left, right, parent in m.path: + tables.edges.add_row(left, right, parent, u) + for site_index in new_sites: + if site_index not in added_sites: + s = tables.sites.add_row(all_positions[site_index], "0") + added_sites[site_index] = s + tables.mutations.add_row( + site=added_sites[site_index], node=u, derived_state="1" + ) + # print(tables) + # TODO check the matched haplotype for any mutations too. + # print(tables) + tables.sort() + ts = tables.tree_sequence() + return ts + + +def match_ancestors(ancestor_data): + tables = tskit.TableCollection(ancestor_data.sequence_length) + + all_positions = ancestor_data.sites_position[:] + + ancestors = ancestor_data.ancestors() + # Discard the "ultimate-ultimate ancestor" + next(ancestors) + ultimate_ancestor = next(ancestors) + assert np.all(ultimate_ancestor.full_haplotype == 0) + tables.nodes.add_row(time=ultimate_ancestor.time) + ts = tables.tree_sequence() + + # TODO We don't want to use the focal sites, so we need to keep track + # of when each site gets new variation, or at least keep track of all + # the sites that are entirely ancestral, so we only add sites into the + # ts as we see variation at them. + + for time, group in itertools.groupby(ancestors, key=lambda a: a.time): + # print("EPOCH", time) + group = list(group) + matches = run_matches(ts, all_positions, group) + ts = insert_matches(tables, time, all_positions, matches) + # print(ts.draw_text()) + return ts + + +def match_samples(ts, sample_data): + # print("SAMPLES") + all_positions = sample_data.sites_position[:] + sequences = [Sequence(h) for _, h in sample_data.haplotypes()] + matches = run_matches(ts, all_positions, sequences) + ts = insert_matches(ts.dump_tables(), 0, all_positions, matches) + tables = ts.dump_tables() + flags = tables.nodes.flags + flags[: -len(sequences)] = 0 + tables.nodes.flags = flags + return tables.tree_sequence() + + +if __name__ == "__main__": + ts = msprime.sim_ancestry( + 15, recombination_rate=0.001, sequence_length=10000, random_seed=32 + ) + ts_orig = msprime.sim_mutations(ts, rate=0.001, random_seed=41) + # print(ts_orig) + print(ts_orig.num_sites, ts_orig.num_mutations) + assert ts_orig.num_sites == ts_orig.num_mutations + + # with tsinfer.SampleData(sequence_length=7, path="tmp.samples") as sample_data: + # for _ in range(5): + # sample_data.add_individual(time=0, ploidy=1) + # sample_data.add_site(0, [0, 1, 0, 0, 0], ["A", "T"]) + # sample_data.add_site(1, [0, 0, 0, 1, 1], ["G", "C"]) + # sample_data.add_site(2, [0, 1, 1, 0, 0], ["C", "A"]) + # sample_data.add_site(3, [0, 1, 1, 0, 0], ["G", "C"]) + # sample_data.add_site(4, [0, 0, 0, 1, 1], ["A", "C"]) + # sample_data.add_site(5, [0, 1, 0, 0, 0], ["T", "G"]) + # sample_data.add_site(6, [1, 1, 1, 1, 0], ["T", "G"]) + sample_data = tsinfer.SampleData.from_tree_sequence(ts_orig) + # print(sample_data) + + ad = tsinfer.generate_ancestors(sample_data) + + ts = match_ancestors(ad) + # FIXME not quite working - something wrong with how we're handling singletons + ts = match_samples(ts, sample_data) + # print(ts.draw_text()) + # print(ts.genotype_matrix()) + # print(ts_orig.genotype_matrix()) + np.testing.assert_array_equal(ts.genotype_matrix(), ts_orig.genotype_matrix()) From f9070616b58221ab02f01738108e0ce19043ad35 Mon Sep 17 00:00:00 2001 From: Jerome Kelleher Date: Tue, 16 May 2023 22:53:20 +0100 Subject: [PATCH 42/42] Fiddle with some tricky issues --- tmp.py | 80 ++++++++++++++++++++++++++------------------- tsinfer/matching.py | 11 +++++++ 2 files changed, 58 insertions(+), 33 deletions(-) diff --git a/tmp.py b/tmp.py index b6321af0..14e4c992 100644 --- a/tmp.py +++ b/tmp.py @@ -37,7 +37,7 @@ def insert_matches(tables, time, all_positions, matches): ts_sites_position = tables.sites.position added_sites = {} for m, new_sites in matches: - u = tables.nodes.add_row(time=time, flags=1) + u = tables.nodes.add_row(time=time, flags=0) for left, right, parent in m.path: tables.edges.add_row(left, right, parent, u) for site_index in new_sites: @@ -65,7 +65,9 @@ def match_ancestors(ancestor_data): next(ancestors) ultimate_ancestor = next(ancestors) assert np.all(ultimate_ancestor.full_haplotype == 0) + tables.nodes.add_row(time=ultimate_ancestor.time + 1) tables.nodes.add_row(time=ultimate_ancestor.time) + tables.edges.add_row(0, tables.sequence_length, 0, 1) ts = tables.tree_sequence() # TODO We don't want to use the focal sites, so we need to keep track @@ -83,46 +85,58 @@ def match_ancestors(ancestor_data): def match_samples(ts, sample_data): - # print("SAMPLES") all_positions = sample_data.sites_position[:] sequences = [Sequence(h) for _, h in sample_data.haplotypes()] matches = run_matches(ts, all_positions, sequences) ts = insert_matches(ts.dump_tables(), 0, all_positions, matches) tables = ts.dump_tables() + # We can have sites that are monomorphic for the ancestral state. + missing_sites = set(all_positions) - set(ts.sites_position) + for pos in missing_sites: + tables.sites.add_row(pos, ancestral_state="0") + tables.sort() flags = tables.nodes.flags - flags[: -len(sequences)] = 0 + flags[-len(sequences) :] = 1 tables.nodes.flags = flags + print(tables) return tables.tree_sequence() if __name__ == "__main__": - ts = msprime.sim_ancestry( - 15, recombination_rate=0.001, sequence_length=10000, random_seed=32 - ) - ts_orig = msprime.sim_mutations(ts, rate=0.001, random_seed=41) - # print(ts_orig) - print(ts_orig.num_sites, ts_orig.num_mutations) - assert ts_orig.num_sites == ts_orig.num_mutations - - # with tsinfer.SampleData(sequence_length=7, path="tmp.samples") as sample_data: - # for _ in range(5): - # sample_data.add_individual(time=0, ploidy=1) - # sample_data.add_site(0, [0, 1, 0, 0, 0], ["A", "T"]) - # sample_data.add_site(1, [0, 0, 0, 1, 1], ["G", "C"]) - # sample_data.add_site(2, [0, 1, 1, 0, 0], ["C", "A"]) - # sample_data.add_site(3, [0, 1, 1, 0, 0], ["G", "C"]) - # sample_data.add_site(4, [0, 0, 0, 1, 1], ["A", "C"]) - # sample_data.add_site(5, [0, 1, 0, 0, 0], ["T", "G"]) - # sample_data.add_site(6, [1, 1, 1, 1, 0], ["T", "G"]) - sample_data = tsinfer.SampleData.from_tree_sequence(ts_orig) - # print(sample_data) - - ad = tsinfer.generate_ancestors(sample_data) - - ts = match_ancestors(ad) - # FIXME not quite working - something wrong with how we're handling singletons - ts = match_samples(ts, sample_data) - # print(ts.draw_text()) - # print(ts.genotype_matrix()) - # print(ts_orig.genotype_matrix()) - np.testing.assert_array_equal(ts.genotype_matrix(), ts_orig.genotype_matrix()) + for seed in range(1, 100): + ts = msprime.sim_ancestry( + 15, + population_size=1e4, # recombination_rate=1e-10, + sequence_length=1_000_000, + random_seed=seed, + ) + print(seed) + ts_orig = msprime.sim_mutations( + ts, rate=1e-8, random_seed=seed, model=msprime.BinaryMutationModel() + ) + print(ts_orig) + print(ts_orig.num_sites, ts_orig.num_mutations) + # assert ts_orig.num_sites == ts_orig.num_mutations + + # with tsinfer.SampleData(sequence_length=7, path="tmp.samples") as sample_data: + # for _ in range(5): + # sample_data.add_individual(time=0, ploidy=1) + # sample_data.add_site(0, [0, 1, 0, 0, 0], ["A", "T"]) + # sample_data.add_site(1, [0, 0, 0, 1, 1], ["G", "C"]) + # sample_data.add_site(2, [0, 1, 1, 0, 0], ["C", "A"]) + # sample_data.add_site(3, [0, 1, 1, 0, 0], ["G", "C"]) + # sample_data.add_site(4, [0, 0, 0, 1, 1], ["A", "C"]) + # sample_data.add_site(5, [0, 1, 0, 0, 0], ["T", "G"]) + # sample_data.add_site(6, [1, 1, 1, 1, 0], ["T", "G"]) + sample_data = tsinfer.SampleData.from_tree_sequence(ts_orig) + # print(sample_data) + + ad = tsinfer.generate_ancestors(sample_data) + + ts = match_ancestors(ad) + # print(sample_data) + ts = match_samples(ts, sample_data) + print(ts.draw_text()) + # print(ts.genotype_matrix()) + # print(ts_orig.genotype_matrix()) + np.testing.assert_array_equal(ts.genotype_matrix(), ts_orig.genotype_matrix()) diff --git a/tsinfer/matching.py b/tsinfer/matching.py index aeb3bfea..0fd6c5aa 100644 --- a/tsinfer/matching.py +++ b/tsinfer/matching.py @@ -47,6 +47,8 @@ def add_vestigial_root(ts): tables.nodes.append(node) if ts.num_edges > 0: for tree in ts.trees(): + # if tree.num_roots > 1: + # print(ts.draw_text()) root = tree.root + num_additonal_nodes tables.edges.add_row( tree.interval.left, tree.interval.right, parent=0, child=root @@ -62,8 +64,17 @@ class MatcherIndexes(_tsinfer.MatcherIndexes): def __init__(self, ts): # TODO make this polymorphic to accept tables as well # This is very wasteful, but we can do better if it all basically works. + print("FIXME!") + # This is turning out to be a bit problematic for actual tsinfer'd trees + # because we have to mark things as samples to define the roots, but then + # we get multiple roots incorrectly when we mark everything as a sample. + # It's not clear that doing this is helpful for tsinfer generated trees, + # but then when we turn it off the current generator script results in + # C-level assertion trips. Hmm. ts = add_vestigial_root(ts) + # print(ts.draw_text()) tables = ts.dump_tables() + # print(tables) ll_tables = _tsinfer.LightweightTableCollection(tables.sequence_length) ll_tables.fromdict(tables.asdict()) # TODO should really just reflect these from the low-level C values.