Skip to content

Commit 0edabdb

Browse files
authored
Fix LTR rescorer with model alias (#126273) (#126655)
1 parent 0442be3 commit 0edabdb

File tree

5 files changed

+13
-329
lines changed

5 files changed

+13
-329
lines changed

docs/changelog/126273.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 126273
2+
summary: Fix LTR rescorer with model alias
3+
area: Ranking
4+
type: bug
5+
issues: []

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/loadingservice/ModelLoadingService.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,7 @@ public ModelLoadingService(
239239
this.licenseState = licenseState;
240240
}
241241

242-
// for testing
243-
String getModelId(String modelIdOrAlias) {
242+
public String getModelId(String modelIdOrAlias) {
244243
return modelAliasToId.getOrDefault(modelIdOrAlias, modelIdOrAlias);
245244
}
246245

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankService.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public void loadLocalModel(String modelId, ActionListener<LocalModel> listener)
101101
*/
102102
public void loadLearningToRankConfig(String modelId, Map<String, Object> params, ActionListener<LearningToRankConfig> listener) {
103103
trainedModelProvider.getTrainedModel(
104-
modelId,
104+
modelLoadingService.getModelId(modelId),
105105
GetTrainedModelsAction.Includes.all(),
106106
null,
107107
ActionListener.wrap(trainedModelConfig -> {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankServiceTests.java

+6-1
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,14 @@
4242
import static org.hamcrest.Matchers.hasKey;
4343
import static org.hamcrest.Matchers.hasSize;
4444
import static org.mockito.ArgumentMatchers.any;
45+
import static org.mockito.ArgumentMatchers.anyString;
4546
import static org.mockito.ArgumentMatchers.argThat;
4647
import static org.mockito.ArgumentMatchers.eq;
4748
import static org.mockito.ArgumentMatchers.isA;
4849
import static org.mockito.Mockito.doAnswer;
4950
import static org.mockito.Mockito.mock;
5051
import static org.mockito.Mockito.verify;
52+
import static org.mockito.Mockito.when;
5153

5254
public class LearningToRankServiceTests extends ESTestCase {
5355
public static final String GOOD_MODEL = "inference-entity-id";
@@ -185,7 +187,10 @@ protected NamedXContentRegistry xContentRegistry() {
185187
}
186188

187189
private ModelLoadingService mockModelLoadingService() {
188-
return mock(ModelLoadingService.class);
190+
ModelLoadingService modelLoadingService = mock(ModelLoadingService.class);
191+
when(modelLoadingService.getModelId(anyString())).thenAnswer(i -> i.getArgument(0));
192+
193+
return modelLoadingService;
189194
}
190195

191196
@SuppressWarnings("unchecked")

x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/ml/learning_to_rank_rescorer.yml

-325
This file was deleted.

0 commit comments

Comments
 (0)