Skip to content

Commit 2b0d07f

Browse files
authored
Training Update (#145)
* split requirement file * add packages * update sample image name * update check condition * update scripts as per new OD model * update metadata * flake update * flake update
1 parent 7782a10 commit 2b0d07f

File tree

4 files changed

+15
-11
lines changed

4 files changed

+15
-11
lines changed

tests/training/test_sample_training_response.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,17 @@ def test_metadata():
3939
assert r.status_code == 200
4040

4141
metadata = r.json()
42-
assert metadata['id'] == 'ssd_mobilenet_v1_coco_2017_11_17-tf-mobilenet'
43-
assert metadata['name'] == 'ssd_mobilenet_v1_coco_2017_11_17 TensorFlow Model'
44-
assert metadata['description'] == 'ssd_mobilenet_v1_coco_2017_11_17 TensorFlow model trained on MobileNet'
42+
assert metadata['id'] == 'object-detector-ssd_mobilenet_v1'
43+
assert metadata['name'] == 'ssd_mobilenet_v1 TensorFlow Object Detector Model'
44+
assert metadata['description'] == 'ssd_mobilenet_v1 TensorFlow object detector model'
4545
assert metadata['type'] == 'Object Detection'
4646
assert metadata['source'] == 'https://developer.ibm.com/exchanges/models/all/max-object-detector/'
4747
assert metadata['license'] == 'ApacheV2'
4848

4949

5050
def test_predict():
5151
model_endpoint = 'http://localhost:5000/model/predict'
52-
file_path = 'samples/data_1.jpg'
52+
file_path = 'samples/a-pen-i-am.jpg'
5353

5454
with open(file_path, 'rb') as file:
5555
file_form = {'image': (file_path, file, 'image/jpeg')}
@@ -59,11 +59,9 @@ def test_predict():
5959
response = r.json()
6060

6161
assert response['status'] == 'ok'
62-
63-
# Teddy Bear
64-
# assert response['predictions'][0]['label_id'] == '88'
65-
assert response['predictions'][0]['label'] == 'toy'
66-
# assert response['predictions'][0]['probability'] > 0.95
62+
# Training run uses fewer samples and epochs to train. Results have randomness due to this.
63+
# Checking for all the new classes used in the sample data.
64+
assert response['predictions'][0]['label'] in ('toy', 'pen')
6765

6866

6967
if __name__ == '__main__':

training/training_code/train-max-model.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ echo "Training data is stored in $DATA_DIR"
2626
# The WML stores work files in the $RESULT_DIR.
2727
echo "Training work files and results will be stored in $RESULT_DIR"
2828

29-
echo "Installing prerequisite packages ..."
29+
# Cython is needed for pycococtools installation. Splitting the installation.
30+
echo "Installing prerequisite packages - 1..."
31+
pip install --user --no-deps -r training_prerequirements.txt
32+
echo "Installing prerequisite packages - 2..."
3033
pip install --user --no-deps -r training_requirements.txt
3134

3235
# ---------------------------------------------------------------
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Cython
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
Cython
21
contextlib2
32
pycocotools==2.0.0
43
coremltools==2.0
@@ -7,3 +6,6 @@ tensorflowjs==0.8.0
76
tensorflow-hub==0.3.0
87
h5py==2.8.0
98
numpy==1.17.5
9+
matplotlib
10+
cycler
11+
kiwisolver

0 commit comments

Comments
 (0)