Skip to content

Commit b08b27b

Browse files
committed
Multilabel classifications scripts: adding TODO
1 parent e1d9e89 commit b08b27b

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

estimator_multiclass.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,16 @@ def model_fn(features, mode):
8989
tf.train.init_from_checkpoint(checkpoint_file,
9090
{v.name.split(':')[0]: v for v in variables_to_restore})
9191

92-
#Defining losses and regulization ops:
92+
#Defining losses and regulization ops:
9393
with tf.name_scope("loss_op"):
9494
loss = tf.losses.sigmoid_cross_entropy(multi_class_labels = features['image/class/id'], logits = logits)
9595
total_loss = tf.losses.get_total_loss() #obtain the regularization losses as well
9696
#FIXME: Replace classifier function (sigmoid / softmax)
97-
print(features['image/class/id'])
97+
98+
#TODO: Add a func to transform logit tensor to a label-like tensor
99+
# If value[][class_id]<0.5 then value[][class_id] = 0. else value[][class_id]= 1.
100+
#It is necessary for a multilabel classification problem
101+
98102
if mode != tf.estimator.ModeKeys.PREDICT:
99103
metrics = {
100104
'Accuracy': tf.metrics.accuracy(features['image/class/id'], logits, name="acc_op"),

0 commit comments

Comments
 (0)