Skip to content

Commit e43e429

Browse files
created TensorFlow Image Classification web app to use custom Model (dotnet#511)
End-to-end app using TensorFlow model
1 parent 7b3c134 commit e43e429

File tree

66 files changed

+36286
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+36286
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Image Classification - AspNet core Web/service Sample
2+
3+
| ML.NET version | API type | Status | App Type | Data type | Scenario | ML Task | Algorithms |
4+
|----------------|-------------------|-------------------------------|-------------|-----------|---------------------|---------------------------|-----------------------------|
5+
| v1.1.0 | Dynamic API | up-to-date | Console app | Images and text labels | Images classification | TensorFlow model | DeepLearning model |
6+
7+
8+
## Problem
9+
The problem is how to run/score a TensorFlow model in a web app/service while using in-memory images.
10+
11+
## Solution:
12+
The model (`model.pb`) is trained using TensorFlow as disscussed in the blogpost [Run with ML.NET C# code a TensorFlow model exported from Azure Cognitive Services Custom Vision](https://devblogs.microsoft.com/cesardelatorre/run-with-ml-net-c-code-a-tensorflow-model-exported-from-azure-cognitive-services-custom-vision/).
13+
14+
see the below architecture that shows how to run/score TensorFlow model in ASP.NET Core Razor web app/service
15+
16+
![](docs/scenario-architecture.png)
17+
18+
19+
The difference between [getting started sample](https://github.com/dotnet/machinelearning-samples/tree/master/samples/csharp/getting-started/DeepLearning_ImageClassification_TensorFlow) and this end-to-end sample is that the images are loaded from **file** in getting started sample where as the images are loaded from **in-memory** in this end-to-end sample.
20+
21+
**Note:** this sample is trained using Custom images and it predicts the only specific images that are in [TestImages](./TestImages) Folder.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
2+
Microsoft Visual Studio Solution File, Format Version 12.00
3+
# Visual Studio Version 16
4+
VisualStudioVersion = 16.0.28803.202
5+
MinimumVisualStudioVersion = 10.0.40219.1
6+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowImageClassification", "TensorFlowImageClassification\TensorFlowImageClassification.csproj", "{3B0E6181-5C60-45D7-A722-A4AC13103DC5}"
7+
EndProject
8+
Global
9+
GlobalSection(SolutionConfigurationPlatforms) = preSolution
10+
Debug|Any CPU = Debug|Any CPU
11+
Release|Any CPU = Release|Any CPU
12+
EndGlobalSection
13+
GlobalSection(ProjectConfigurationPlatforms) = postSolution
14+
{3B0E6181-5C60-45D7-A722-A4AC13103DC5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
15+
{3B0E6181-5C60-45D7-A722-A4AC13103DC5}.Debug|Any CPU.Build.0 = Debug|Any CPU
16+
{3B0E6181-5C60-45D7-A722-A4AC13103DC5}.Release|Any CPU.ActiveCfg = Release|Any CPU
17+
{3B0E6181-5C60-45D7-A722-A4AC13103DC5}.Release|Any CPU.Build.0 = Release|Any CPU
18+
EndGlobalSection
19+
GlobalSection(SolutionProperties) = preSolution
20+
HideSolutionNode = FALSE
21+
EndGlobalSection
22+
GlobalSection(ExtensibilityGlobals) = postSolution
23+
SolutionGuid = {DE0DFEB0-7A11-4CA7-B471-F8D94F61ABF2}
24+
EndGlobalSection
25+
EndGlobal
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Drawing;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Threading.Tasks;
7+
using Microsoft.AspNetCore.Http;
8+
using Microsoft.AspNetCore.Mvc;
9+
using Microsoft.Extensions.Configuration;
10+
using Microsoft.Extensions.Logging;
11+
using Microsoft.Extensions.ML;
12+
using TensorFlowImageClassification.ImageHelpers;
13+
using TensorFlowImageClassification.ML.DataModels;
14+
15+
namespace TensorFlowImageClassification.Controllers
16+
{
17+
[Route("api/[controller]")]
18+
[ApiController]
19+
public class ImageClassificationController : ControllerBase
20+
{
21+
public IConfiguration Configuration { get; }
22+
private readonly PredictionEnginePool<ImageInputData, ImageLabelPredictions> _predictionEnginePool;
23+
private readonly ILogger<ImageClassificationController> _logger;
24+
private readonly string _labelsFilePath;
25+
26+
public ImageClassificationController(PredictionEnginePool<ImageInputData, ImageLabelPredictions> predictionEnginePool, IConfiguration configuration, ILogger<ImageClassificationController> logger) //When using DI/IoC
27+
{
28+
// Get the ML Model Engine injected, for scoring
29+
_predictionEnginePool = predictionEnginePool;
30+
31+
Configuration = configuration;
32+
_labelsFilePath = GetAbsolutePath(Configuration["MLModel:LabelsFilePath"]);
33+
34+
//Get other injected dependencies
35+
_logger = logger;
36+
}
37+
38+
[HttpPost]
39+
[ProducesResponseType(200)]
40+
[ProducesResponseType(400)]
41+
[Route("classifyimage")]
42+
public async Task<IActionResult> ClassifyImage(IFormFile imageFile)
43+
{
44+
if (imageFile.Length == 0)
45+
return BadRequest();
46+
47+
MemoryStream imageMemoryStream = new MemoryStream();
48+
await imageFile.CopyToAsync(imageMemoryStream);
49+
50+
//Check that the image is valid
51+
byte[] imageData = imageMemoryStream.ToArray();
52+
if (!imageData.IsValidImage())
53+
return StatusCode(StatusCodes.Status415UnsupportedMediaType);
54+
55+
//Convert to Image
56+
Image image = Image.FromStream(imageMemoryStream);
57+
58+
//Convert to Bitmap
59+
Bitmap bitmapImage = (Bitmap)image;
60+
61+
_logger.LogInformation($"Start processing image...");
62+
63+
//Measure execution time
64+
var watch = System.Diagnostics.Stopwatch.StartNew();
65+
66+
//Set the specific image data into the ImageInputData type used in the DataView
67+
ImageInputData imageInputData = new ImageInputData { Image = bitmapImage };
68+
69+
//Predict code for provided image
70+
ImageLabelPredictions imageLabelPredictions = _predictionEnginePool.Predict(imageInputData);
71+
72+
//Stop measuring time
73+
watch.Stop();
74+
var elapsedMs = watch.ElapsedMilliseconds;
75+
_logger.LogInformation($"Image processed in {elapsedMs} miliseconds");
76+
77+
//Predict the image's label (The one with highest probability)
78+
ImagePredictedLabelWithProbability imageBestLabelPrediction
79+
= FindBestLabelWithProbability(imageLabelPredictions, imageInputData);
80+
81+
return Ok(imageBestLabelPrediction);
82+
}
83+
84+
private ImagePredictedLabelWithProbability FindBestLabelWithProbability(ImageLabelPredictions imageLabelPredictions, ImageInputData imageInputData)
85+
{
86+
//Read TF model's labels (labels.txt) to classify the image across those labels
87+
var labels = ReadLabels(_labelsFilePath);
88+
89+
float[] probabilities = imageLabelPredictions.PredictedLabels;
90+
91+
//Set a single label as predicted or even none if probabilities were lower than 70%
92+
var imageBestLabelPrediction = new ImagePredictedLabelWithProbability()
93+
{
94+
ImageId = imageInputData.GetHashCode().ToString(), //This ID is not really needed, it could come from the application itself, etc.
95+
};
96+
97+
(imageBestLabelPrediction.PredictedLabel, imageBestLabelPrediction.Probability) = GetBestLabel(labels, probabilities);
98+
99+
return imageBestLabelPrediction;
100+
}
101+
102+
private (string, float) GetBestLabel(string[] labels, float[] probs)
103+
{
104+
var max = probs.Max();
105+
var index = probs.AsSpan().IndexOf(max);
106+
107+
if (max > 0.7)
108+
return (labels[index], max);
109+
else
110+
return ("None", max);
111+
}
112+
113+
private string[] ReadLabels(string labelsLocation)
114+
{
115+
return System.IO.File.ReadAllLines(labelsLocation);
116+
}
117+
118+
public static string GetAbsolutePath(string relativePath)
119+
{
120+
FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
121+
string assemblyFolderPath = _dataRoot.Directory.FullName;
122+
123+
string fullPath = Path.Combine(assemblyFolderPath, relativePath);
124+
return fullPath;
125+
}
126+
127+
// GET api/ImageClassification
128+
[HttpGet]
129+
public ActionResult<IEnumerable<string>> Get()
130+
{
131+
return new string[] { "ACK Heart beat 1", "ACK Heart beat 2" };
132+
}
133+
}
134+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
using System.Linq;
2+
using System.Text;
3+
4+
namespace TensorFlowImageClassification.ImageHelpers
5+
{
6+
public static class ImageValidationExtensions
7+
{
8+
public static bool IsValidImage(this byte[] image)
9+
{
10+
var imageFormat = GetImageFormat(image);
11+
return imageFormat == ImageFormat.jpeg ||
12+
imageFormat == ImageFormat.png;
13+
}
14+
15+
public enum ImageFormat
16+
{
17+
bmp,
18+
jpeg,
19+
gif,
20+
tiff,
21+
png,
22+
unknown
23+
}
24+
25+
public static ImageFormat GetImageFormat(byte[] bytes)
26+
{
27+
// see http://www.mikekunz.com/image_file_header.html
28+
var bmp = Encoding.ASCII.GetBytes("BM"); // BMP
29+
var gif = Encoding.ASCII.GetBytes("GIF"); // GIF
30+
var png = new byte[] { 137, 80, 78, 71 }; // PNG
31+
var tiff = new byte[] { 73, 73, 42 }; // TIFF
32+
var tiff2 = new byte[] { 77, 77, 42 }; // TIFF
33+
var jpeg = new byte[] { 255, 216, 255, 224 }; // jpeg
34+
var jpeg2 = new byte[] { 255, 216, 255, 225 }; // jpeg canon
35+
36+
if (bmp.SequenceEqual(bytes.Take(bmp.Length)))
37+
return ImageFormat.bmp;
38+
39+
if (gif.SequenceEqual(bytes.Take(gif.Length)))
40+
return ImageFormat.gif;
41+
42+
if (png.SequenceEqual(bytes.Take(png.Length)))
43+
return ImageFormat.png;
44+
45+
if (tiff.SequenceEqual(bytes.Take(tiff.Length)))
46+
return ImageFormat.tiff;
47+
48+
if (tiff2.SequenceEqual(bytes.Take(tiff2.Length)))
49+
return ImageFormat.tiff;
50+
51+
if (jpeg.SequenceEqual(bytes.Take(jpeg.Length)))
52+
return ImageFormat.jpeg;
53+
54+
if (jpeg2.SequenceEqual(bytes.Take(jpeg2.Length)))
55+
return ImageFormat.jpeg;
56+
57+
return ImageFormat.unknown;
58+
}
59+
}
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Microsoft.ML.Transforms.Image;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Drawing;
5+
using System.IO;
6+
using System.Linq;
7+
8+
namespace TensorFlowImageClassification.ML.DataModels
9+
{
10+
public class ImageInputData
11+
{
12+
[ImageType(227, 227)]
13+
public Bitmap Image { get; set; }
14+
}
15+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
3+
using Microsoft.ML.Data;
4+
5+
namespace TensorFlowImageClassification.ML.DataModels
6+
{
7+
public class ImageLabelPredictions
8+
{
9+
//TODO: Change to fixed output column name for TensorFlow model
10+
[ColumnName("loss")]
11+
public float[] PredictedLabels;
12+
}
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+

2+
namespace TensorFlowImageClassification.ML.DataModels
3+
{
4+
public class ImagePredictedLabelWithProbability
5+
{
6+
public string ImageId;
7+
8+
public string PredictedLabel;
9+
public float Probability { get; set; }
10+
11+
public long PredictionExecutionTime;
12+
}
13+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
001-Green-Meeting-Chair-Redmond
2+
002-High-Metal-Chair-Redmond
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using Microsoft.ML;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Threading.Tasks;
7+
using TensorFlowImageClassification.ML.DataModels;
8+
9+
namespace TensorFlowImageClassification.ML
10+
{
11+
public class TensorFlowModelConfigurator
12+
{
13+
private readonly MLContext _mlContext;
14+
private readonly ITransformer _mlModel;
15+
16+
public TensorFlowModelConfigurator(string tensorFlowModelFilePath)
17+
{
18+
_mlContext = new MLContext();
19+
20+
// Model creation and pipeline definition for images needs to run just once, so calling it from the constructor:
21+
_mlModel = SetupMlnetModel(tensorFlowModelFilePath);
22+
}
23+
24+
public struct ImageSettings
25+
{
26+
public const int imageHeight = 227;
27+
public const int imageWidth = 227;
28+
public const float mean = 117; //offsetImage
29+
public const bool channelsLast = true; //interleavePixelColors
30+
}
31+
32+
// For checking tensor names, you can open the TF model .pb file with tools like Netron: https://github.com/lutzroeder/netron
33+
public struct TensorFlowModelSettings
34+
{
35+
// input tensor name
36+
public const string inputTensorName = "Placeholder";
37+
38+
// output tensor name
39+
public const string outputTensorName = "loss";
40+
}
41+
42+
private ITransformer SetupMlnetModel(string tensorFlowModelFilePath)
43+
{
44+
var pipeline = _mlContext.Transforms.ResizeImages(outputColumnName: TensorFlowModelSettings.inputTensorName, imageWidth: ImageSettings.imageWidth, imageHeight: ImageSettings.imageHeight, inputColumnName: nameof(ImageInputData.Image))
45+
.Append(_mlContext.Transforms.ExtractPixels(outputColumnName: TensorFlowModelSettings.inputTensorName, interleavePixelColors: ImageSettings.channelsLast, offsetImage: ImageSettings.mean))
46+
.Append(_mlContext.Model.LoadTensorFlowModel(tensorFlowModelFilePath).
47+
ScoreTensorFlowModel(outputColumnNames: new[] { TensorFlowModelSettings.outputTensorName },
48+
inputColumnNames: new[] { TensorFlowModelSettings.inputTensorName }, addBatchDimensionInput: false));
49+
50+
ITransformer mlModel = pipeline.Fit(CreateEmptyDataView());
51+
52+
return mlModel;
53+
}
54+
private IDataView CreateEmptyDataView()
55+
{
56+
//Create empty DataView ot Images. We just need the schema to call fit()
57+
List<ImageInputData> list = new List<ImageInputData>();
58+
list.Add(new ImageInputData() { Image = new System.Drawing.Bitmap(ImageSettings.imageWidth, ImageSettings.imageHeight) }); //Test: Might not need to create the Bitmap.. = null; ?
59+
IEnumerable<ImageInputData> enumerableData = list;
60+
61+
var dv = _mlContext.Data.LoadFromEnumerable<ImageInputData>(list);
62+
return dv;
63+
}
64+
65+
public void SaveMLNetModel(string mlnetModelFilePath)
66+
{
67+
// Save/persist the model to a .ZIP file to be loaded by the PredictionEnginePool
68+
_mlContext.Model.Save(_mlModel, null, mlnetModelFilePath);
69+
}
70+
}
71+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@page
2+
@model ErrorModel
3+
@{
4+
ViewData["Title"] = "Error";
5+
}
6+
7+
<h1 class="text-danger">Error.</h1>
8+
<h2 class="text-danger">An error occurred while processing your request.</h2>
9+
10+
@if (Model.ShowRequestId)
11+
{
12+
<p>
13+
<strong>Request ID:</strong> <code>@Model.RequestId</code>
14+
</p>
15+
}
16+
17+
<h3>Development Mode</h3>
18+
<p>
19+
Swapping to the <strong>Development</strong> environment displays detailed information about the error that occurred.
20+
</p>
21+
<p>
22+
<strong>The Development environment shouldn't be enabled for deployed applications.</strong>
23+
It can result in displaying sensitive information from exceptions to end users.
24+
For local debugging, enable the <strong>Development</strong> environment by setting the <strong>ASPNETCORE_ENVIRONMENT</strong> environment variable to <strong>Development</strong>
25+
and restarting the app.
26+
</p>

0 commit comments

Comments
 (0)