-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathensure_multiArray_to_image.py
42 lines (31 loc) · 1.33 KB
/
ensure_multiArray_to_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def get_nn(spec):
if spec.WhichOneof("Type") == "neuralNetwork":
return spec.neuralNetwork
elif spec.WhichOneof("Type") == "neuralNetworkClassifier":
return spec.neuralNetworkClassifier
elif spec.WhichOneof("Type") == "neuralNetworkRegressor":
return spec.neuralNetworkRegressor
else:
raise ValueError("MLModel does not have a neural network")
def convert_multiarray_to_image(feature, is_bgr=False):
import coremltools.proto.FeatureTypes_pb2 as ft
if feature.type.WhichOneof("Type") != "multiArrayType":
raise ValueError("%s is not a multiarray type" % feature.name)
shape = tuple(feature.type.multiArrayType.shape)
channels = None
if len(shape) == 2:
channels = 1
height, width = shape
elif len(shape) == 3:
channels, height, width = shape
if channels != 1 and channels != 3:
raise ValueError("Shape {} not supported for image type".format(shape))
if channels == 1:
feature.type.imageType.colorSpace = ft.ImageFeatureType.GRAYSCALE
elif channels == 3:
if is_bgr:
feature.type.imageType.colorSpace = ft.ImageFeatureType.BGR
else:
feature.type.imageType.colorSpace = ft.ImageFeatureType.RGB
feature.type.imageType.width = width
feature.type.imageType.height = height