Skip to content

Commit 655928f

Browse files
Merge pull request #14 from christianversloot/ensure-numpy
Fix #8
2 parents 8004180 + ae1c5d7 commit 655928f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

extra_keras_datasets/iris.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def load_data(path="iris.npz", test_split=0.2):
8080
testing_data = samples[:num_test_samples]
8181

8282
# Split into inputs and targets
83-
input_train = [i[0:4] for i in training_data]
84-
input_test = [i[0:4] for i in testing_data]
85-
target_train = [i[4] for i in training_data]
86-
target_test = [i[4] for i in testing_data]
83+
input_train = np.array([i[0:4] for i in training_data])
84+
input_test = np.array([i[0:4] for i in testing_data])
85+
target_train = np.array([i[4] for i in training_data])
86+
target_test = np.array([i[4] for i in testing_data])
8787

8888
# Warn about citation
8989
warn_citation()

0 commit comments

Comments
 (0)