Skip to content

Commit 7424d71

Browse files
Cleaned up code and unneccessary attrs dependency.
1 parent c9a0ef7 commit 7424d71

File tree

2 files changed

+31
-1308
lines changed

2 files changed

+31
-1308
lines changed

tsinfer/inference.py

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,9 @@
44
"""
55

66
import collections
7-
import math
87
import queue
98
import threading
109

11-
# TODO remove this dependency. It's not used for anything important.
12-
import attr
13-
1410
import numpy as np
1511
import tqdm
1612
import humanize
@@ -189,8 +185,9 @@ def __init__(
189185

190186
def initialise(self):
191187
# This is slow, so we should figure out a way to report progress on it.
192-
self.logger.info("Initialising ancestor builder for {} samples and {} sites".
193-
format(self.num_samples, self.num_sites))
188+
self.logger.info(
189+
"Initialising ancestor builder for {} samples and {} sites".format(
190+
self.num_samples, self.num_sites))
194191
self.ancestor_builder = self.ancestor_builder_class(self.samples, self.positions)
195192
self.num_ancestors = self.ancestor_builder.num_ancestors
196193
self.tree_sequence_builder = self.tree_sequence_builder_class(
@@ -333,16 +330,12 @@ def match_worker(thread_index):
333330
for j in range(self.num_threads):
334331
match_threads[j].join()
335332

336-
337333
def __process_ancestors_single_threaded(self):
338334
a = np.zeros(self.num_sites, dtype=np.int8)
339335
matcher = self.ancestor_matcher_class(
340336
self.tree_sequence_builder, self.recombination_rate)
341337
results = ResultBuffer()
342338

343-
# TODO remove this stuff for ancestor ID maps. It's just here for
344-
# debugging.
345-
ancestor_id = 1
346339
for epoch in range(self.num_epochs):
347340
time = self.epoch_time[epoch]
348341
ancestor_focal_sites = self.epoch_ancestors[epoch]
@@ -391,8 +384,9 @@ def worker(thread_index):
391384
work_queue.task_done()
392385
if num_matches > 0:
393386
mean_traceback_size /= num_matches
394-
self.logger.info("Thread {} done: mean_tb_size={:.2f}; total_edges={}".format(
395-
thread_index, mean_traceback_size, results[thread_index].num_edges))
387+
self.logger.info(
388+
"Thread {} done: mean_tb_size={:.2f}; total_edges={}".format(
389+
thread_index, mean_traceback_size, results[thread_index].num_edges))
396390
work_queue.task_done()
397391

398392
threads = [
@@ -467,7 +461,7 @@ def get_tree_sequence(self):
467461
mutations.set_columns(
468462
site=site, node=node, derived_state=derived_state,
469463
derived_state_length=np.ones(tsb.num_mutations, dtype=np.uint32),
470-
parent=parent);
464+
parent=parent)
471465
msprime.sort_tables(nodes, edges, sites=sites, mutations=mutations)
472466
return msprime.load_tables(
473467
nodes=nodes, edges=edges, sites=sites, mutations=mutations)
@@ -482,14 +476,14 @@ def ancestors(self):
482476
ancestor_id = 1
483477
for age, ancestor_focal_sites in frequency_classes:
484478
for focal_sites in ancestor_focal_sites:
485-
builder.make_ancestor(focal_sites, A[ancestor_id,:])
479+
builder.make_ancestor(focal_sites, A[ancestor_id, :])
486480
ancestor_id += 1
487481
return A
488482

489483

490-
def infer(samples, positions, sequence_length, recombination_rate,
491-
error_rate=0, method="C",
492-
num_threads=1, progress=False, log_level="WARNING",
484+
def infer(
485+
samples, positions, sequence_length, recombination_rate, error_rate=0,
486+
method="C", num_threads=1, progress=False, log_level="WARNING",
493487
resolve_shared_recombinations=False, resolve_polytomies=False):
494488
# Primary entry point.
495489
manager = InferenceManager(
@@ -517,18 +511,19 @@ def infer(samples, positions, sequence_length, recombination_rate,
517511
###############################################################
518512

519513

520-
@attr.s
521514
class Edge(object):
522-
left = attr.ib(default=None)
523-
right = attr.ib(default=None)
524-
parent = attr.ib(default=None)
525-
child = attr.ib(default=None)
515+
516+
def __init__(self, left=None, right=None, parent=None, child=None):
517+
self.left = left
518+
self.right = right
519+
self.parent = parent
520+
self.child = child
526521

527522

528-
@attr.s
529523
class Site(object):
530-
id = attr.ib(default=None)
531-
frequency = attr.ib(default=None)
524+
def __init__(self, id, frequency):
525+
self.id = id
526+
self.frequency = frequency
532527

533528

534529
class AncestorBuilder(object):
@@ -613,7 +608,6 @@ def make_ancestor(self, focal_sites, a):
613608
return a
614609

615610

616-
617611
def edge_group_equal(edges, group1, group2):
618612
"""
619613
Returns true if the specified subsets of the list of edges are considered
@@ -637,7 +631,6 @@ def edge_group_equal(edges, group1, group2):
637631
return ret
638632

639633

640-
641634
class TreeSequenceBuilder(object):
642635

643636
def __init__(
@@ -675,6 +668,7 @@ def print_state(self):
675668
print("num_nodes = ", self.num_nodes)
676669
nodes = msprime.NodeTable()
677670
flags = np.zeros(self.num_nodes, dtype=np.uint32)
671+
time = np.zeros(self.num_nodes, dtype=np.float64)
678672
self.dump_nodes(flags=flags, time=time)
679673
nodes.set_columns(flags=flags, time=time)
680674
print("nodes = ")
@@ -802,7 +796,8 @@ def _replace_recombinations(self):
802796
if not match_found[j]:
803797
for k in range(j + 1, len(groups)):
804798
# Compare this group to the others.
805-
if not match_found[k] and edge_group_equal(active, groups[j], groups[k]):
799+
if not match_found[k] and edge_group_equal(
800+
active, groups[j], groups[k]):
806801
matches.append(k)
807802
match_found[k] = True
808803
if len(matches) > 0:
@@ -825,8 +820,10 @@ def _replace_recombinations(self):
825820
for group_index in group_index_list:
826821
start, end = groups[group_index]
827822
left_set.add(tuple([active[j].left for j in range(start, end)]))
828-
right_set.add(tuple([active[j].right for j in range(start, end)]))
829-
parent_set.add(tuple([active[j].parent for j in range(start, end)]))
823+
right_set.add(
824+
tuple([active[j].right for j in range(start, end)]))
825+
parent_set.add(
826+
tuple([active[j].parent for j in range(start, end)]))
830827
children = set(active[j].child for j in range(start, end))
831828
assert len(children) == 1
832829
for j in range(start, end - 1):
@@ -895,7 +892,6 @@ def _replace_recombinations(self):
895892
# for e in self.edges:
896893
# print("\t", e)
897894

898-
899895
def insert_polytomy_ancestor(self, edges):
900896
"""
901897
Insert a new ancestor for the specified edges and update the parents
@@ -933,12 +929,12 @@ def _resolve_polytomies(self):
933929
active[j - 1].parent != active[j].parent)
934930
if condition:
935931
size = j - group_start
936-
if size > 1 and size != parent_count[active[j -1].parent]:
932+
if size > 1 and size != parent_count[active[j - 1].parent]:
937933
groups.append((group_start, j))
938934
group_start = j
939935
j = len(active)
940936
size = j - group_start
941-
if size > 1 and size != parent_count[active[j -1].parent]:
937+
if size > 1 and size != parent_count[active[j - 1].parent]:
942938
groups.append((group_start, j))
943939

944940
for start, end in groups:
@@ -988,7 +984,6 @@ def update(self, num_nodes, time, left, right, parent, child, site, node):
988984
# print("AFTER UPDATE")
989985
# self.print_state()
990986

991-
992987
def dump_nodes(self, flags, time):
993988
time[:] = self.time[:self.num_nodes]
994989
flags[:] = self.flags
@@ -1106,7 +1101,8 @@ def find_path(self, h):
11061101
if state == 0:
11071102
assert len(L) > 0
11081103
traceback[site] = dict(L)
1109-
if site > 0 and (site - 1) not in self.tree_sequence_builder.mutations:
1104+
if site > 0 and (site - 1) \
1105+
not in self.tree_sequence_builder.mutations:
11101106
assert traceback[site] == traceback[site - 1]
11111107
continue
11121108
mutation_node = self.tree_sequence_builder.mutations[site]
@@ -1167,8 +1163,6 @@ def find_path(self, h):
11671163
# for l in range(self.num_sites):
11681164
# print("\t", l, traceback[l])
11691165

1170-
1171-
11721166
u = self.get_max_likelihood_node(L)
11731167
output_edge = Edge(right=m, parent=u)
11741168
output_edges = [output_edge]
@@ -1220,5 +1214,3 @@ def find_path(self, h):
12201214
right[j] = e.right
12211215
parent[j] = e.parent
12221216
return (left, right, parent), np.array(mismatches, dtype=np.uint32)
1223-
1224-

0 commit comments

Comments
 (0)