Skip to content

Commit c376c81

Browse files
GitHubLabeler refactored to use ML.NET v0.7 and the common-code approach, plus additional global refactoring
1 parent e87afb3 commit c376c81

File tree

32 files changed

+557
-410
lines changed

32 files changed

+557
-410
lines changed

samples/csharp/common/ConsoleHelper.cs

+27-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,30 @@ public static void PrintRegressionMetrics(string name, RegressionEvaluator.Resul
4545
Console.WriteLine($"*************************************************");
4646
}
4747

48+
public static void PrintMulticlassClassificationFoldsAverageMetrics(
49+
string algorithmName,
50+
(MultiClassClassifierEvaluator.Result metrics,
51+
ITransformer model,
52+
IDataView scoredTestData)[] crossValResults
53+
)
54+
{
55+
var metricsInMultipleFolds = crossValResults.Select(r => r.metrics);
56+
57+
var microAccuracies = metricsInMultipleFolds.Select(m => m.AccuracyMicro);
58+
var macroAccuracies = metricsInMultipleFolds.Select(m => m.AccuracyMacro);
59+
var logLoss = metricsInMultipleFolds.Select(m => m.LogLoss);
60+
var logLossReduction = metricsInMultipleFolds.Select(m => m.LogLossReduction);
61+
62+
Console.WriteLine($"**************************************************************************");
63+
Console.WriteLine($"* Metrics for {algorithmName} Multi-class Classification model ");
64+
Console.WriteLine($"*-------------------------------------------------------------------------");
65+
Console.WriteLine($"* Average MicroAccuracy: {microAccuracies.Average():0.##}");
66+
Console.WriteLine($"* Average MacroAccuracy: {macroAccuracies.Average():0.##}");
67+
Console.WriteLine($"* Average LogLoss: {logLoss.Average():#.##}");
68+
Console.WriteLine($"* Average LogLossReduction: {logLossReduction.Average():#.##}");
69+
Console.WriteLine($"**************************************************************************");
70+
}
71+
4872
public static void PrintClusteringMetrics(string name, ClusteringEvaluator.Result metrics)
4973
{
5074
Console.WriteLine($"*************************************************");
@@ -58,7 +82,7 @@ public static void PrintClusteringMetrics(string name, ClusteringEvaluator.Resul
5882
public static List<TObservation> PeekDataViewInConsole<TObservation>(MLContext mlContext, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
5983
where TObservation : class, new()
6084
{
61-
string msg = string.Format("Showing {0} rows with all the columns", numberOfRows.ToString());
85+
string msg = string.Format("Peek data in DataView: Showing {0} rows with the columns specified by TObservation class", numberOfRows.ToString());
6286
ConsoleWriteHeader(msg);
6387

6488
//https://github.com/dotnet/machinelearning/blob/master/docs/code/MlNetCookBook.md#how-do-i-look-at-the-intermediate-data
@@ -85,9 +109,9 @@ public static List<TObservation> PeekDataViewInConsole<TObservation>(MLContext m
85109
return someRows;
86110
}
87111

88-
public static List<float[]> PeekFeaturesColumnDataInConsole(MLContext mlContext, string columnName, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
112+
public static List<float[]> PeekVectorColumnDataInConsole(MLContext mlContext, string columnName, IDataView dataView, IEstimator<ITransformer> pipeline, int numberOfRows = 4)
89113
{
90-
string msg = string.Format("Show {0} rows with just the '{1}' column", numberOfRows, columnName );
114+
string msg = string.Format("Peek data in DataView: : Show {0} rows with just the '{1}' column", numberOfRows, columnName );
91115
ConsoleWriteHeader(msg);
92116

93117
var transformedData = pipeline.Fit(dataView).Transform(dataView);

samples/csharp/common/ModelBuilder.cs

+32-7
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,38 @@
88

99
namespace Common
1010
{
11-
public class ModelBuilder<TObservation, TPrediction>
11+
public class ModelBuilder<TObservation, TPrediction>
1212
where TObservation : class
1313
where TPrediction : class, new()
1414
{
1515
private MLContext _mlcontext;
16-
private IEstimator<ITransformer> _trainingPipeline;
16+
public IEstimator<ITransformer> TrainingPipeline { get; private set; }
1717
public ITransformer TrainedModel { get; private set; }
1818

1919
public ModelBuilder(
2020
MLContext mlContext,
21-
IEstimator<ITransformer> dataProcessPipeline,
22-
IEstimator<ITransformer> trainer)
21+
IEstimator<ITransformer> dataProcessPipeline //, IEstimator<ITransformer> trainer
22+
)
2323
{
2424
_mlcontext = mlContext;
25-
_trainingPipeline = dataProcessPipeline.Append(trainer);
25+
TrainingPipeline = dataProcessPipeline;
26+
27+
//??? TrainingPipeline.Append(trainer);
28+
}
29+
30+
public void AddTrainer(IEstimator<ITransformer> trainer)
31+
{
32+
TrainingPipeline = TrainingPipeline.Append(trainer);
33+
}
34+
35+
public void AddEstimator(IEstimator<ITransformer> estimator)
36+
{
37+
TrainingPipeline = TrainingPipeline.Append(estimator);
2638
}
27-
39+
2840
public ITransformer Train(IDataView trainingData)
2941
{
30-
TrainedModel = _trainingPipeline.Fit(trainingData);
42+
TrainedModel = TrainingPipeline.Fit(trainingData);
3143
return TrainedModel;
3244
}
3345

@@ -39,6 +51,19 @@ public RegressionEvaluator.Result EvaluateRegressionModel(IDataView testData)
3951
return metrics;
4052
}
4153

54+
public (MultiClassClassifierEvaluator.Result metrics,
55+
ITransformer model,
56+
IDataView scoredTestData)[]
57+
CrossValidateAndEvaluateMulticlassClassificationModel(IDataView data, int numFolds = 5, string labelColumn = "Label", string stratificationColumn = null)
58+
{
59+
//CrossValidation happens actually before training, so no check.
60+
//...
61+
var context = new MulticlassClassificationContext(_mlcontext);
62+
63+
var crossValidationResults = context.CrossValidate(data, TrainingPipeline, numFolds, labelColumn, stratificationColumn);
64+
return crossValidationResults;
65+
}
66+
4267
public ClusteringEvaluator.Result EvaluateClusteringModel(IDataView dataView)
4368
{
4469
CheckTrained();

samples/csharp/common/ModelScorer.cs

+10-1
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,21 @@ public class ModelScorer<TObservation, TPrediction>
1818
public ITransformer TrainedModel { get; private set; }
1919
public PredictionFunction<TObservation, TPrediction> PredictionFunction;
2020

21-
public ModelScorer(MLContext mlContext)
21+
public ModelScorer(MLContext mlContext, ITransformer trainedModel = null)
2222
{
2323
_mlContext = mlContext;
24+
TrainedModel = trainedModel;
2425
}
2526

2627
public TPrediction PredictSingle(TObservation input)
2728
{
29+
CheckTrainedModelIsLoaded();
2830
return PredictionFunction.Predict(input);
2931
}
3032

3133
public IEnumerable<TPrediction> PredictBatch(IDataView inputDataView)
3234
{
35+
CheckTrainedModelIsLoaded();
3336
var predictions = TrainedModel.Transform(inputDataView);
3437
return predictions.AsEnumerable<TPrediction>(_mlContext, reuseRowObject: false);
3538
}
@@ -46,6 +49,12 @@ public ITransformer LoadModelFromZipFile(string modelPath)
4649

4750
return TrainedModel;
4851
}
52+
53+
private void CheckTrainedModelIsLoaded()
54+
{
55+
if (TrainedModel == null)
56+
throw new InvalidOperationException("Need to have a model before scoring. Call LoadModelFromZipFile(modelPath) first or provided a model through the constructor.");
57+
}
4958
}
5059

5160
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Microsoft.ML;
2+
using Microsoft.ML.Core.Data;
3+
using Microsoft.ML.Runtime.Data;
4+
using Microsoft.ML.Transforms;
5+
using System;
6+
using System.Collections.Generic;
7+
using System.Text;
8+
9+
namespace GitHubLabeler
10+
{
11+
class DataLoader
12+
{
13+
MLContext _mlContext;
14+
private TextLoader _loader;
15+
16+
public DataLoader(MLContext mlContext)
17+
{
18+
_mlContext = mlContext;
19+
20+
_loader = mlContext.Data.TextReader(new TextLoader.Arguments()
21+
{
22+
Separator = "tab",
23+
HasHeader = true,
24+
Column = new[]
25+
{
26+
new TextLoader.Column("ID", DataKind.Text, 0),
27+
new TextLoader.Column("Area", DataKind.Text, 1),
28+
new TextLoader.Column("Title", DataKind.Text, 2),
29+
new TextLoader.Column("Description", DataKind.Text, 3),
30+
}
31+
});
32+
}
33+
34+
public IDataView GetDataView(string filePath)
35+
{
36+
return _loader.Read(filePath);
37+
}
38+
}
39+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using Microsoft.ML;
2+
using Microsoft.ML.Core.Data;
3+
using Microsoft.ML.Runtime.Data;
4+
using Microsoft.ML.Transforms;
5+
using Microsoft.ML.Transforms.Categorical;
6+
using Microsoft.ML.Transforms.Text;
7+
using System;
8+
using System.Collections.Generic;
9+
using System.Text;
10+
11+
namespace GitHubLabeler
12+
{
13+
public class DataProcessor
14+
{
15+
public IEstimator<ITransformer> DataProcessPipeline { get; private set; }
16+
17+
public DataProcessor(MLContext mlContext)
18+
{
19+
// Configure data transformations in the Process pipeline
20+
21+
DataProcessPipeline = new ValueToKeyMappingEstimator(mlContext, "Area", "Label")
22+
.Append(new TextFeaturizingEstimator(mlContext, "Title", "TitleFeaturized"))
23+
.Append(new TextFeaturizingEstimator(mlContext, "Description", "DescriptionFeaturized"))
24+
.Append(new ColumnConcatenatingEstimator(mlContext, "Features", "TitleFeaturized", "DescriptionFeaturized"));
25+
}
26+
}
27+
}

samples/csharp/end-to-end-apps/github-labeler/GitHubLabeler/GitHubLabelerConsoleApp/GitHubIssue.cs renamed to samples/csharp/end-to-end-apps/MulticlassClassification-GitHubLabeler/GitHubLabeler/GitHubLabelerConsoleApp/DataStructures/GitHubIssue.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22

33
#pragma warning disable 649 // We don't care about unsused fields here, because they are mapped with the input file.
44

5-
namespace GitHubLabeler
5+
namespace GitHubLabeler.DataStructures
66
{
7+
//The only purpose of this class is for peek data after transforming it with the pipeline
78
internal class GitHubIssue
89
{
910
[Column(ordinal: "0")]

samples/csharp/end-to-end-apps/github-labeler/GitHubLabeler/GitHubLabelerConsoleApp/GitHubIssuePrediction.cs renamed to samples/csharp/end-to-end-apps/MulticlassClassification-GitHubLabeler/GitHubLabeler/GitHubLabelerConsoleApp/DataStructures/GitHubIssuePrediction.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
#pragma warning disable 649 // We don't care about unsused fields here, because they are mapped with the input file.
44

5-
namespace GitHubLabeler
5+
namespace GitHubLabeler.DataStructures
66
{
77
internal class GitHubIssuePrediction
88
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using Microsoft.ML.Runtime.Api;
2+
3+
#pragma warning disable 649 // We don't care about unsused fields here, because they are mapped with the input file.
4+
5+
namespace GitHubLabeler.DataStructures
6+
{
7+
internal class GitHubIssueTransformed
8+
{
9+
public string ID;
10+
public string Area;
11+
//public float[] Label; // -> Area dictionarized
12+
public string Title;
13+
//public float[] TitleFeaturized; // -> Title Featurized
14+
public string Description;
15+
//public float[] DescriptionFeaturized; // -> Description Featurized
16+
}
17+
}
18+
19+
20+
//public Scalar<bool> label { get; set; }
21+
//public Scalar<float> score { get; set; }

samples/csharp/end-to-end-apps/github-labeler/GitHubLabeler/GitHubLabelerConsoleApp/Labeler.cs renamed to samples/csharp/end-to-end-apps/MulticlassClassification-GitHubLabeler/GitHubLabeler/GitHubLabelerConsoleApp/Labeler.cs

+29-21
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,56 @@
77
using Octokit;
88
using System.IO;
99
using Microsoft.ML.Runtime.Data;
10+
using GitHubLabeler.DataStructures;
11+
using Common;
1012

1113
namespace GitHubLabeler
1214
{
15+
//This "Labeler" class could be used in a different End-User application (Web app, other console app, desktop app, etc.)
1316
internal class Labeler
1417
{
1518
private readonly GitHubClient _client;
1619
private readonly string _repoOwner;
1720
private readonly string _repoName;
1821
private readonly string _modelPath;
1922
private readonly MLContext _mlContext;
20-
private readonly ITransformer _loadedModel;
21-
private readonly PredictionFunction<GitHubIssue, GitHubIssuePrediction> _engine;
2223

23-
public Labeler(string modelPath, string repoOwner, string repoName, string accessToken)
24+
private readonly ModelScorer<GitHubIssue, GitHubIssuePrediction> _modelScorer;
25+
26+
public Labeler(string modelPath, string repoOwner = "", string repoName = "", string accessToken = "")
2427
{
2528
_modelPath = modelPath;
2629
_repoOwner = repoOwner;
2730
_repoName = repoName;
28-
29-
31+
3032
_mlContext = new MLContext(seed:1);
3133

32-
//Load model from .ZIP file
33-
using (var stream = new FileStream(_modelPath, System.IO.FileMode.Open, FileAccess.Read, FileShare.Read))
34+
//Load file model into ModelScorer
35+
_modelScorer = new ModelScorer<GitHubIssue, GitHubIssuePrediction>(_mlContext);
36+
_modelScorer.LoadModelFromZipFile(_modelPath);
37+
38+
//Configure Client to access a GitHub repo
39+
if (accessToken != string.Empty)
3440
{
35-
_loadedModel = TransformerChain.LoadFrom(_mlContext, stream);
41+
var productInformation = new ProductHeaderValue("MLGitHubLabeler");
42+
_client = new GitHubClient(productInformation)
43+
{
44+
Credentials = new Credentials(accessToken)
45+
};
3646
}
47+
}
3748

38-
// Create prediction engine
39-
_engine = _loadedModel.MakePredictionFunction<GitHubIssue, GitHubIssuePrediction>(_mlContext);
49+
public void TestPredictionForSingleIssue()
50+
{
51+
GitHubIssue singleIssue = new GitHubIssue() { ID = "Any-ID", Title = "Entity Framework crashes", Description = "When connecting to the database, EF is crashing" };
4052

41-
// Client to access GitHub
42-
var productInformation = new ProductHeaderValue("MLGitHubLabeler");
43-
_client = new GitHubClient(productInformation)
44-
{
45-
Credentials = new Credentials(accessToken)
46-
};
53+
//Predict label for single hard-coded issue
54+
var prediction = _modelScorer.PredictSingle(singleIssue);
55+
Console.WriteLine($"=============== Single Prediction - Result: {prediction.Area} ===============");
4756
}
4857

4958
// Label all issues that are not labeled yet
50-
public async Task LabelAllNewIssues()
59+
public async Task LabelAllNewIssuesInGitHubRepo()
5160
{
5261
var newIssues = await GetNewIssues();
5362
foreach (var issue in newIssues.Where(issue => !issue.Labels.Any()))
@@ -73,7 +82,7 @@ private async Task<IReadOnlyList<Issue>> GetNewIssues()
7382
.ToList();
7483
}
7584

76-
private string PredictLabel(Issue issue)
85+
private string PredictLabel(Octokit.Issue issue)
7786
{
7887
var corefxIssue = new GitHubIssue
7988
{
@@ -88,11 +97,10 @@ private string PredictLabel(Issue issue)
8897
}
8998

9099
public string Predict(GitHubIssue issue)
91-
{
92-
var prediction = _engine.Predict(issue);
100+
{
101+
var prediction = _modelScorer.PredictSingle(issue);
93102

94103
return prediction.Area;
95-
96104
}
97105

98106
private void ApplyLabel(Issue issue, string label)

0 commit comments

Comments
 (0)