Skip to content

Commit 50e40a1

Browse files
authored
Update metrics.py
using sklearn to calculate confusion matrix
1 parent 801fb20 commit 50e40a1

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

ptsemseg/metrics.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,15 @@
22
# https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
33

44
import numpy as np
5-
5+
from sklearn.metrics import confusion_matrix
66

77
class runningScore(object):
88
def __init__(self, n_classes):
99
self.n_classes = n_classes
1010
self.confusion_matrix = np.zeros((n_classes, n_classes))
1111

12-
def _fast_hist(self, label_true, label_pred, n_class):
13-
mask = (label_true >= 0) & (label_true < n_class)
14-
hist = np.bincount(
15-
n_class * label_true[mask].astype(int) + label_pred[mask], minlength=n_class ** 2
16-
).reshape(n_class, n_class)
17-
return hist
18-
1912
def update(self, label_trues, label_preds):
20-
for lt, lp in zip(label_trues, label_preds):
21-
self.confusion_matrix += self._fast_hist(lt.flatten(), lp.flatten(), self.n_classes)
13+
self.confusion_matrix += confusion_matrix(label_trues.flatten(), label_preds.flatten(), list(range(self.n_classes)))
2214

2315
def get_scores(self):
2416
"""Returns accuracy score evaluation result.

0 commit comments

Comments
 (0)