Skip to content

Commit feabb65

Browse files
Add low-level Python-C support for arbirary derived state
1 parent 57594f4 commit feabb65

File tree

2 files changed

+48
-8
lines changed

2 files changed

+48
-8
lines changed

_tsinfermodule.c

+30-5
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,17 @@ uint64_PyArray_converter(PyObject *in, PyObject **out)
6060
return NPY_SUCCEED;
6161
}
6262

63+
static int
64+
int8_PyArray_converter(PyObject *in, PyObject **out)
65+
{
66+
PyObject *ret = PyArray_FROMANY(in, NPY_INT8, 1, 1, NPY_ARRAY_IN_ARRAY);
67+
if (ret == NULL) {
68+
return NPY_FAIL;
69+
}
70+
*out = ret;
71+
return NPY_SUCCEED;
72+
}
73+
6374
/*===================================================================
6475
* AncestorBuilder
6576
*===================================================================
@@ -429,30 +440,43 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
429440
{
430441
int ret = -1;
431442
int err;
432-
static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", NULL};
443+
static char *kwlist[] = {"num_alleles", "max_nodes", "max_edges", "derived_state",
444+
NULL};
433445
PyArrayObject *num_alleles = NULL;
446+
PyArrayObject *derived_state = NULL;
447+
int8_t *derived_state_data = NULL;
434448
unsigned long max_nodes = 1024;
435449
unsigned long max_edges = 1024;
436450
unsigned long num_sites;
437451
npy_intp *shape;
438452
int flags = 0;
439453

440454
self->tree_sequence_builder = NULL;
441-
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kk", kwlist,
455+
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O&|kkO&", kwlist,
442456
uint64_PyArray_converter, &num_alleles,
443-
&max_nodes, &max_edges)) {
457+
&max_nodes, &max_edges,
458+
int8_PyArray_converter, &derived_state)) {
444459
goto out;
445460
}
446461
shape = PyArray_DIMS(num_alleles);
447462
num_sites = shape[0];
448-
463+
if (derived_state != NULL) {
464+
shape = PyArray_DIMS(derived_state);
465+
if (shape[0] != (npy_intp) num_sites) {
466+
PyErr_SetString(PyExc_ValueError, "derived state array wrong size");
467+
goto out;
468+
}
469+
derived_state_data = PyArray_DATA(derived_state);
470+
}
449471
self->tree_sequence_builder = PyMem_Malloc(sizeof(tree_sequence_builder_t));
450472
if (self->tree_sequence_builder == NULL) {
451473
PyErr_NoMemory();
452474
goto out;
453475
}
454476
err = tree_sequence_builder_alloc(self->tree_sequence_builder,
455-
num_sites, PyArray_DATA(num_alleles),
477+
num_sites,
478+
PyArray_DATA(num_alleles),
479+
derived_state_data,
456480
max_nodes, max_edges, flags);
457481
if (err != 0) {
458482
handle_library_error(err);
@@ -461,6 +485,7 @@ TreeSequenceBuilder_init(TreeSequenceBuilder *self, PyObject *args, PyObject *kw
461485
ret = 0;
462486
out:
463487
Py_XDECREF(num_alleles);
488+
Py_XDECREF(derived_state);
464489
return ret;
465490
}
466491

tests/test_low_level.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,31 @@ class TestTreeSequenceBuilder:
8888
def test_init(self):
8989
with pytest.raises(TypeError):
9090
_tsinfer.TreeSequenceBuilder()
91-
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
92-
with pytest.raises(ValueError):
93-
_tsinfer.TreeSequenceBuilder(bad_array)
9491

9592
for bad_type in [None, "sdf", {}]:
9693
with pytest.raises(TypeError):
9794
_tsinfer.TreeSequenceBuilder([2], max_nodes=bad_type)
9895
with pytest.raises(TypeError):
9996
_tsinfer.TreeSequenceBuilder([2], max_edges=bad_type)
10097

98+
def test_bad_num_alleles(self):
99+
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
100+
with pytest.raises(ValueError):
101+
_tsinfer.TreeSequenceBuilder(bad_array)
102+
with pytest.raises(_tsinfer.LibraryError, match="number of alleles"):
103+
_tsinfer.TreeSequenceBuilder([1000])
104+
105+
def test_bad_derived_state(self):
106+
for bad_array in [None, "serf", [[], []], ["asdf"], {}]:
107+
with pytest.raises(ValueError):
108+
_tsinfer.TreeSequenceBuilder([2], derived_state=bad_array)
109+
with pytest.raises(_tsinfer.LibraryError, match="Bad derived state"):
110+
for bad_derived_state in [-1, 2, 100]:
111+
_tsinfer.TreeSequenceBuilder([2], derived_state=[bad_derived_state])
112+
113+
with pytest.raises(ValueError, match="derived state array wrong size"):
114+
_tsinfer.TreeSequenceBuilder([2, 2, 2], derived_state=[])
115+
101116

102117
class TestAncestorBuilder:
103118
"""

0 commit comments

Comments
 (0)