4
4
"""
5
5
6
6
import collections
7
- import math
8
7
import queue
9
8
import threading
10
9
11
- # TODO remove this dependency. It's not used for anything important.
12
- import attr
13
-
14
10
import numpy as np
15
11
import tqdm
16
12
import humanize
@@ -189,8 +185,9 @@ def __init__(
189
185
190
186
def initialise (self ):
191
187
# This is slow, so we should figure out a way to report progress on it.
192
- self .logger .info ("Initialising ancestor builder for {} samples and {} sites" .
193
- format (self .num_samples , self .num_sites ))
188
+ self .logger .info (
189
+ "Initialising ancestor builder for {} samples and {} sites" .format (
190
+ self .num_samples , self .num_sites ))
194
191
self .ancestor_builder = self .ancestor_builder_class (self .samples , self .positions )
195
192
self .num_ancestors = self .ancestor_builder .num_ancestors
196
193
self .tree_sequence_builder = self .tree_sequence_builder_class (
@@ -333,16 +330,12 @@ def match_worker(thread_index):
333
330
for j in range (self .num_threads ):
334
331
match_threads [j ].join ()
335
332
336
-
337
333
def __process_ancestors_single_threaded (self ):
338
334
a = np .zeros (self .num_sites , dtype = np .int8 )
339
335
matcher = self .ancestor_matcher_class (
340
336
self .tree_sequence_builder , self .recombination_rate )
341
337
results = ResultBuffer ()
342
338
343
- # TODO remove this stuff for ancestor ID maps. It's just here for
344
- # debugging.
345
- ancestor_id = 1
346
339
for epoch in range (self .num_epochs ):
347
340
time = self .epoch_time [epoch ]
348
341
ancestor_focal_sites = self .epoch_ancestors [epoch ]
@@ -391,8 +384,9 @@ def worker(thread_index):
391
384
work_queue .task_done ()
392
385
if num_matches > 0 :
393
386
mean_traceback_size /= num_matches
394
- self .logger .info ("Thread {} done: mean_tb_size={:.2f}; total_edges={}" .format (
395
- thread_index , mean_traceback_size , results [thread_index ].num_edges ))
387
+ self .logger .info (
388
+ "Thread {} done: mean_tb_size={:.2f}; total_edges={}" .format (
389
+ thread_index , mean_traceback_size , results [thread_index ].num_edges ))
396
390
work_queue .task_done ()
397
391
398
392
threads = [
@@ -467,7 +461,7 @@ def get_tree_sequence(self):
467
461
mutations .set_columns (
468
462
site = site , node = node , derived_state = derived_state ,
469
463
derived_state_length = np .ones (tsb .num_mutations , dtype = np .uint32 ),
470
- parent = parent );
464
+ parent = parent )
471
465
msprime .sort_tables (nodes , edges , sites = sites , mutations = mutations )
472
466
return msprime .load_tables (
473
467
nodes = nodes , edges = edges , sites = sites , mutations = mutations )
@@ -482,14 +476,14 @@ def ancestors(self):
482
476
ancestor_id = 1
483
477
for age , ancestor_focal_sites in frequency_classes :
484
478
for focal_sites in ancestor_focal_sites :
485
- builder .make_ancestor (focal_sites , A [ancestor_id ,:])
479
+ builder .make_ancestor (focal_sites , A [ancestor_id , :])
486
480
ancestor_id += 1
487
481
return A
488
482
489
483
490
- def infer (samples , positions , sequence_length , recombination_rate ,
491
- error_rate = 0 , method = "C" ,
492
- num_threads = 1 , progress = False , log_level = "WARNING" ,
484
+ def infer (
485
+ samples , positions , sequence_length , recombination_rate , error_rate = 0 ,
486
+ method = "C" , num_threads = 1 , progress = False , log_level = "WARNING" ,
493
487
resolve_shared_recombinations = False , resolve_polytomies = False ):
494
488
# Primary entry point.
495
489
manager = InferenceManager (
@@ -517,18 +511,19 @@ def infer(samples, positions, sequence_length, recombination_rate,
517
511
###############################################################
518
512
519
513
520
- @attr .s
521
514
class Edge (object ):
522
- left = attr .ib (default = None )
523
- right = attr .ib (default = None )
524
- parent = attr .ib (default = None )
525
- child = attr .ib (default = None )
515
+
516
+ def __init__ (self , left = None , right = None , parent = None , child = None ):
517
+ self .left = left
518
+ self .right = right
519
+ self .parent = parent
520
+ self .child = child
526
521
527
522
528
- @attr .s
529
523
class Site (object ):
530
- id = attr .ib (default = None )
531
- frequency = attr .ib (default = None )
524
+ def __init__ (self , id , frequency ):
525
+ self .id = id
526
+ self .frequency = frequency
532
527
533
528
534
529
class AncestorBuilder (object ):
@@ -613,7 +608,6 @@ def make_ancestor(self, focal_sites, a):
613
608
return a
614
609
615
610
616
-
617
611
def edge_group_equal (edges , group1 , group2 ):
618
612
"""
619
613
Returns true if the specified subsets of the list of edges are considered
@@ -637,7 +631,6 @@ def edge_group_equal(edges, group1, group2):
637
631
return ret
638
632
639
633
640
-
641
634
class TreeSequenceBuilder (object ):
642
635
643
636
def __init__ (
@@ -675,6 +668,7 @@ def print_state(self):
675
668
print ("num_nodes = " , self .num_nodes )
676
669
nodes = msprime .NodeTable ()
677
670
flags = np .zeros (self .num_nodes , dtype = np .uint32 )
671
+ time = np .zeros (self .num_nodes , dtype = np .float64 )
678
672
self .dump_nodes (flags = flags , time = time )
679
673
nodes .set_columns (flags = flags , time = time )
680
674
print ("nodes = " )
@@ -802,7 +796,8 @@ def _replace_recombinations(self):
802
796
if not match_found [j ]:
803
797
for k in range (j + 1 , len (groups )):
804
798
# Compare this group to the others.
805
- if not match_found [k ] and edge_group_equal (active , groups [j ], groups [k ]):
799
+ if not match_found [k ] and edge_group_equal (
800
+ active , groups [j ], groups [k ]):
806
801
matches .append (k )
807
802
match_found [k ] = True
808
803
if len (matches ) > 0 :
@@ -825,8 +820,10 @@ def _replace_recombinations(self):
825
820
for group_index in group_index_list :
826
821
start , end = groups [group_index ]
827
822
left_set .add (tuple ([active [j ].left for j in range (start , end )]))
828
- right_set .add (tuple ([active [j ].right for j in range (start , end )]))
829
- parent_set .add (tuple ([active [j ].parent for j in range (start , end )]))
823
+ right_set .add (
824
+ tuple ([active [j ].right for j in range (start , end )]))
825
+ parent_set .add (
826
+ tuple ([active [j ].parent for j in range (start , end )]))
830
827
children = set (active [j ].child for j in range (start , end ))
831
828
assert len (children ) == 1
832
829
for j in range (start , end - 1 ):
@@ -895,7 +892,6 @@ def _replace_recombinations(self):
895
892
# for e in self.edges:
896
893
# print("\t", e)
897
894
898
-
899
895
def insert_polytomy_ancestor (self , edges ):
900
896
"""
901
897
Insert a new ancestor for the specified edges and update the parents
@@ -933,12 +929,12 @@ def _resolve_polytomies(self):
933
929
active [j - 1 ].parent != active [j ].parent )
934
930
if condition :
935
931
size = j - group_start
936
- if size > 1 and size != parent_count [active [j - 1 ].parent ]:
932
+ if size > 1 and size != parent_count [active [j - 1 ].parent ]:
937
933
groups .append ((group_start , j ))
938
934
group_start = j
939
935
j = len (active )
940
936
size = j - group_start
941
- if size > 1 and size != parent_count [active [j - 1 ].parent ]:
937
+ if size > 1 and size != parent_count [active [j - 1 ].parent ]:
942
938
groups .append ((group_start , j ))
943
939
944
940
for start , end in groups :
@@ -988,7 +984,6 @@ def update(self, num_nodes, time, left, right, parent, child, site, node):
988
984
# print("AFTER UPDATE")
989
985
# self.print_state()
990
986
991
-
992
987
def dump_nodes (self , flags , time ):
993
988
time [:] = self .time [:self .num_nodes ]
994
989
flags [:] = self .flags
@@ -1106,7 +1101,8 @@ def find_path(self, h):
1106
1101
if state == 0 :
1107
1102
assert len (L ) > 0
1108
1103
traceback [site ] = dict (L )
1109
- if site > 0 and (site - 1 ) not in self .tree_sequence_builder .mutations :
1104
+ if site > 0 and (site - 1 ) \
1105
+ not in self .tree_sequence_builder .mutations :
1110
1106
assert traceback [site ] == traceback [site - 1 ]
1111
1107
continue
1112
1108
mutation_node = self .tree_sequence_builder .mutations [site ]
@@ -1167,8 +1163,6 @@ def find_path(self, h):
1167
1163
# for l in range(self.num_sites):
1168
1164
# print("\t", l, traceback[l])
1169
1165
1170
-
1171
-
1172
1166
u = self .get_max_likelihood_node (L )
1173
1167
output_edge = Edge (right = m , parent = u )
1174
1168
output_edges = [output_edge ]
@@ -1220,5 +1214,3 @@ def find_path(self, h):
1220
1214
right [j ] = e .right
1221
1215
parent [j ] = e .parent
1222
1216
return (left , right , parent ), np .array (mismatches , dtype = np .uint32 )
1223
-
1224
-
0 commit comments