@@ -60,6 +60,17 @@ uint64_PyArray_converter(PyObject *in, PyObject **out)
60
60
return NPY_SUCCEED ;
61
61
}
62
62
63
+ static int
64
+ int8_PyArray_converter (PyObject * in , PyObject * * out )
65
+ {
66
+ PyObject * ret = PyArray_FROMANY (in , NPY_INT8 , 1 , 1 , NPY_ARRAY_IN_ARRAY );
67
+ if (ret == NULL ) {
68
+ return NPY_FAIL ;
69
+ }
70
+ * out = ret ;
71
+ return NPY_SUCCEED ;
72
+ }
73
+
63
74
/*===================================================================
64
75
* AncestorBuilder
65
76
*===================================================================
@@ -429,30 +440,43 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
429
440
{
430
441
int ret = -1 ;
431
442
int err ;
432
- static char * kwlist [] = {"num_alleles" , "max_nodes" , "max_edges" , NULL };
443
+ static char * kwlist [] = {"num_alleles" , "max_nodes" , "max_edges" , "derived_state" ,
444
+ NULL };
433
445
PyArrayObject * num_alleles = NULL ;
446
+ PyArrayObject * derived_state = NULL ;
447
+ int8_t * derived_state_data = NULL ;
434
448
unsigned long max_nodes = 1024 ;
435
449
unsigned long max_edges = 1024 ;
436
450
unsigned long num_sites ;
437
451
npy_intp * shape ;
438
452
int flags = 0 ;
439
453
440
454
self -> tree_sequence_builder = NULL ;
441
- if (!PyArg_ParseTupleAndKeywords (args , kwds , "O&|kk " , kwlist ,
455
+ if (!PyArg_ParseTupleAndKeywords (args , kwds , "O&|kkO& " , kwlist ,
442
456
uint64_PyArray_converter , & num_alleles ,
443
- & max_nodes , & max_edges )) {
457
+ & max_nodes , & max_edges ,
458
+ int8_PyArray_converter , & derived_state )) {
444
459
goto out ;
445
460
}
446
461
shape = PyArray_DIMS (num_alleles );
447
462
num_sites = shape [0 ];
448
-
463
+ if (derived_state != NULL ) {
464
+ shape = PyArray_DIMS (derived_state );
465
+ if (shape [0 ] != (npy_intp ) num_sites ) {
466
+ PyErr_SetString (PyExc_ValueError , "derived state array wrong size" );
467
+ goto out ;
468
+ }
469
+ derived_state_data = PyArray_DATA (derived_state );
470
+ }
449
471
self -> tree_sequence_builder = PyMem_Malloc (sizeof (tree_sequence_builder_t ));
450
472
if (self -> tree_sequence_builder == NULL ) {
451
473
PyErr_NoMemory ();
452
474
goto out ;
453
475
}
454
476
err = tree_sequence_builder_alloc (self -> tree_sequence_builder ,
455
- num_sites , PyArray_DATA (num_alleles ),
477
+ num_sites ,
478
+ PyArray_DATA (num_alleles ),
479
+ derived_state_data ,
456
480
max_nodes , max_edges , flags );
457
481
if (err != 0 ) {
458
482
handle_library_error (err );
@@ -461,6 +485,7 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
461
485
ret = 0 ;
462
486
out :
463
487
Py_XDECREF (num_alleles );
488
+ Py_XDECREF (derived_state );
464
489
return ret ;
465
490
}
466
491
0 commit comments