Skip to content

Commit f4e8758

Browse files
Marcnchristensen
authored andcommitted
Fix small bugs related to nan values in vector passed to pdf (#256)
* Enable replacing InCondition and ForbiddenRelation constraints * Allow nan values in CategoricalHP _pdf function
1 parent 33ac083 commit f4e8758

File tree

4 files changed

+65
-12
lines changed

4 files changed

+65
-12
lines changed

ConfigSpace/configuration_space.pyx

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,12 +1526,12 @@ class ConfigurationSpace(collections.abc.Mapping):
15261526
new_child = new_configspace[child_name]
15271527
new_parent = new_configspace[parent_name]
15281528

1529-
if hasattr(condition, 'value'):
1530-
condition_arg = getattr(condition, 'value')
1531-
substituted_condition = condition_type(child=new_child, parent=new_parent, value=condition_arg)
1532-
elif hasattr(condition, 'values'):
1529+
if hasattr(condition, 'values'):
15331530
condition_arg = getattr(condition, 'values')
15341531
substituted_condition = condition_type(child=new_child, parent=new_parent, values=condition_arg)
1532+
elif hasattr(condition, 'value'):
1533+
condition_arg = getattr(condition, 'value')
1534+
substituted_condition = condition_type(child=new_child, parent=new_parent, value=condition_arg)
15351535
else:
15361536
raise AttributeError(f'Did not find the expected attribute in condition {type(condition)}.')
15371537

@@ -1573,15 +1573,24 @@ class ConfigurationSpace(collections.abc.Mapping):
15731573
hyperparameter_name = getattr(forbidden.hyperparameter, 'name')
15741574
new_hyperparameter = new_configspace[hyperparameter_name]
15751575

1576-
if hasattr(forbidden, 'value'):
1577-
forbidden_arg = getattr(forbidden, 'value')
1578-
substituted_forbidden = forbidden_type(hyperparameter=new_hyperparameter, value=forbidden_arg)
1579-
elif hasattr(forbidden, 'values'):
1576+
if hasattr(forbidden, 'values'):
15801577
forbidden_arg = getattr(forbidden, 'values')
15811578
substituted_forbidden = forbidden_type(hyperparameter=new_hyperparameter, values=forbidden_arg)
1579+
elif hasattr(forbidden, 'value'):
1580+
forbidden_arg = getattr(forbidden, 'value')
1581+
substituted_forbidden = forbidden_type(hyperparameter=new_hyperparameter, value=forbidden_arg)
15821582
else:
15831583
raise AttributeError(f'Did not find the expected attribute in forbidden {type(forbidden)}.')
15841584

1585+
new_forbiddens.append(substituted_forbidden)
1586+
elif isinstance(forbidden, ForbiddenRelation):
1587+
forbidden_type = type(forbidden)
1588+
left_name = getattr(forbidden.left, 'name')
1589+
left_hyperparameter = new_configspace[left_name]
1590+
right_name = getattr(forbidden.right, 'name')
1591+
right_hyperparameter = new_configspace[right_name]
1592+
1593+
substituted_forbidden = forbidden_type(left=left_hyperparameter, right=right_hyperparameter)
15851594
new_forbiddens.append(substituted_forbidden)
15861595
else:
15871596
raise TypeError(f'Did not expect the supplied forbidden type {type(forbidden)}.')

ConfigSpace/hyperparameters.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,13 @@ cdef class CategoricalHyperparameter(Hyperparameter):
26172617
Probability density values of the input vector
26182618
"""
26192619
probs = np.array(self.probabilities)
2620+
nan = np.isnan(vector)
2621+
if np.any(nan):
2622+
# Temporarily pick any valid index to use `vector` as an index for `probs`
2623+
vector[nan] = 0
26202624
res = np.array(probs[vector.astype(int)])
2625+
if np.any(nan):
2626+
res[nan] = 0
26212627
if res.ndim == 0:
26222628
return res.reshape(-1)
26232629
return res

test/test_configuration_space.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
BetaIntegerHyperparameter,
4747
OrdinalHyperparameter)
4848
from ConfigSpace.exceptions import ForbiddenValueError
49-
from ConfigSpace.forbidden import ForbiddenEqualsRelation
49+
from ConfigSpace.forbidden import ForbiddenEqualsRelation, ForbiddenLessThanRelation
5050

5151

5252
def byteify(input):
@@ -919,6 +919,34 @@ def test_substitute_hyperparameters_in_conditions(self):
919919
self.assertEqual(new_conditions[0], test_conditions[0])
920920
self.assertEqual(new_conditions[1], test_conditions[1])
921921

922+
def test_substitute_hyperparameters_in_inconditions(self):
923+
cs1 = ConfigurationSpace()
924+
a = UniformIntegerHyperparameter('a', lower=0, upper=10)
925+
b = UniformFloatHyperparameter('b', lower=1., upper=8., log=False)
926+
cs1.add_hyperparameters([a, b])
927+
928+
cond = InCondition(b, a, [1, 2, 3, 4])
929+
cs1.add_conditions([cond])
930+
931+
cs2 = ConfigurationSpace()
932+
sub_a = UniformIntegerHyperparameter('a', lower=0, upper=10)
933+
sub_b = UniformFloatHyperparameter('b', lower=1., upper=8., log=False)
934+
cs2.add_hyperparameters([sub_a, sub_b])
935+
new_conditions = cs1.substitute_hyperparameters_in_conditions(cs1.get_conditions(), cs2)
936+
937+
test_cond = InCondition(b, a, [1, 2, 3, 4])
938+
cs2.add_conditions([test_cond])
939+
test_conditions = cs2.get_conditions()
940+
941+
self.assertEqual(new_conditions[0], test_conditions[0])
942+
self.assertIsNot(new_conditions[0], test_conditions[0])
943+
944+
self.assertEqual(new_conditions[0].get_parents(), test_conditions[0].get_parents())
945+
self.assertIsNot(new_conditions[0].get_parents(), test_conditions[0].get_parents())
946+
947+
self.assertEqual(new_conditions[0].get_children(), test_conditions[0].get_children())
948+
self.assertIsNot(new_conditions[0].get_children(), test_conditions[0].get_children())
949+
922950
def test_substitute_hyperparameters_in_forbiddens(self):
923951
cs1 = ConfigurationSpace()
924952
orig_hp1 = CategoricalHyperparameter("input1", [0, 1])
@@ -930,7 +958,8 @@ def test_substitute_hyperparameters_in_forbiddens(self):
930958
forb_2 = ForbiddenEqualsClause(orig_hp2, 1)
931959
forb_3 = ForbiddenEqualsClause(orig_hp3, 10)
932960
forb_4 = ForbiddenAndConjunction(forb_1, forb_2)
933-
cs1.add_forbidden_clauses([forb_3, forb_4])
961+
forb_5 = ForbiddenLessThanRelation(orig_hp1, orig_hp2)
962+
cs1.add_forbidden_clauses([forb_3, forb_4, forb_5])
934963

935964
cs2 = ConfigurationSpace()
936965
sub_hp1 = CategoricalHyperparameter("input1", [0, 1, 2])
@@ -944,9 +973,11 @@ def test_substitute_hyperparameters_in_forbiddens(self):
944973
test_forb_2 = ForbiddenEqualsClause(sub_hp2, 1)
945974
test_forb_3 = ForbiddenEqualsClause(sub_hp3, 10)
946975
test_forb_4 = ForbiddenAndConjunction(test_forb_1, test_forb_2)
947-
cs2.add_forbidden_clauses([test_forb_3, test_forb_4])
976+
test_forb_5 = ForbiddenLessThanRelation(sub_hp1, sub_hp2)
977+
cs2.add_forbidden_clauses([test_forb_3, test_forb_4, test_forb_5])
948978
test_forbiddens = cs2.get_forbiddens()
949979

980+
self.assertEqual(new_forbiddens[2], test_forbiddens[2])
950981
self.assertEqual(new_forbiddens[1], test_forbiddens[1])
951982
self.assertEqual(new_forbiddens[0], test_forbiddens[0])
952983

test/test_hyperparameters.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1946,6 +1946,7 @@ def test_categorical__pdf(self):
19461946
point_1 = np.array([0])
19471947
point_2 = np.array([1])
19481948
array_1 = np.array([1, 0, 2])
1949+
nan = np.array([0, np.nan])
19491950
self.assertEqual(c1._pdf(point_1)[0], 0.4)
19501951
self.assertEqual(c1._pdf(point_2)[0], 0.2)
19511952
self.assertAlmostEqual(c2._pdf(point_1)[0], 0.7142857142857143)
@@ -1957,14 +1958,20 @@ def test_categorical__pdf(self):
19571958
for res, exp_res in zip(array_results, expected_results):
19581959
self.assertEqual(res, exp_res)
19591960

1961+
nan_results = c1._pdf(nan)
1962+
expected_results = np.array([0.4, 0])
1963+
self.assertEqual(nan_results.shape, expected_results.shape)
1964+
for res, exp_res in zip(nan_results, expected_results):
1965+
self.assertEqual(res, exp_res)
1966+
19601967
# pdf must take a numpy array
19611968
with self.assertRaises(TypeError):
19621969
c1._pdf(0.2)
19631970
with self.assertRaises(TypeError):
19641971
c1._pdf('pdf')
19651972
with self.assertRaises(TypeError):
19661973
c1._pdf('one')
1967-
with self.assertRaises(ValueError):
1974+
with self.assertRaises(TypeError):
19681975
c1._pdf(np.array(['zero']))
19691976

19701977
def test_categorical_get_max_density(self):

0 commit comments

Comments
 (0)