Skip to content

Commit 89ffb0c

Browse files
author
lobo
committed
consolidate implementation with recent changes in upstream
1 parent fee40bd commit 89ffb0c

File tree

3 files changed

+46
-12
lines changed

3 files changed

+46
-12
lines changed

models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceChatModel.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
import org.springframework.ai.chat.prompt.Prompt;
3232
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
3333
import org.springframework.ai.huggingface.invoker.ApiClient;
34-
import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
35-
import org.springframework.ai.huggingface.model.GenerateParameters;
36-
import org.springframework.ai.huggingface.model.GenerateRequest;
37-
import org.springframework.ai.huggingface.model.GenerateResponse;
34+
import org.springframework.ai.huggingface.model.chat.AllOfGenerateResponseDetails;
35+
import org.springframework.ai.huggingface.model.chat.GenerateParameters;
36+
import org.springframework.ai.huggingface.model.chat.GenerateRequest;
37+
import org.springframework.ai.huggingface.model.chat.GenerateResponse;
3838

3939
/**
4040
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference

models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/HuggingfaceImageModel.java

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
*/
1919
public class HuggingfaceImageModel implements ImageModel {
2020

21+
private final String APPLICATION_JSON = "application/json";
22+
2123
/**
2224
* Token required for authenticating with the HuggingFace Inference API.
2325
*/
@@ -47,24 +49,28 @@ public ImageResponse call(ImagePrompt prompt) {
4749
final GenerateParameters generateParameters = createGenerateParameters(prompt.getOptions());
4850
final GenerateRequest generateRequest = createGenerateRequest(prompt.getInstructions(), generateParameters);
4951

50-
// huggingface eps with text-to-image models return only single image in default
51-
// mode
52+
// hf text-to-image endpoints return only a single image in default mode
5253
final String base64Encoded = generateImage(generateRequest, prompt);
5354
final Image image = new Image(null, base64Encoded);
5455
final ImageGeneration imageGeneration = new ImageGeneration(image);
5556
return new ImageResponse(List.of(imageGeneration), new ImageResponseMetadata());
5657
}
5758

5859
private String generateImage(GenerateRequest generateRequest, ImagePrompt prompt) {
59-
final String mimeType = prompt.getOptions().getResponseFormat();
60-
switch (mimeType) {
61-
case "application/json" -> {
62-
return new String(this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat()));
60+
final String responseFormat = prompt.getOptions().getResponseFormat();
61+
final HuggingfaceImageOptions options = (HuggingfaceImageOptions) prompt.getOptions();
62+
switch (responseFormat) {
63+
case "base64" -> {
64+
return new String(this.imageGenApi.generate(generateRequest, APPLICATION_JSON));
6365
}
64-
default -> {
65-
byte[] bytes = this.imageGenApi.generate(generateRequest, prompt.getOptions().getResponseFormat());
66+
case "bytes" -> {
67+
byte[] bytes = this.imageGenApi.generate(generateRequest, options.getResponseMimeType());
6668
return Base64.getEncoder().encodeToString(bytes);
6769
}
70+
default -> {
71+
throw new UnsupportedOperationException(String
72+
.format("Unsupported response format: %s, should be 'base64' or 'bytes'", responseFormat));
73+
}
6874
}
6975
}
7076

models/spring-ai-huggingface/src/main/java/org/springframework/ai/huggingface/text2image/HuggingfaceImageOptions.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,19 @@ public class HuggingfaceImageOptions implements ImageOptions {
1212

1313
private Integer height;
1414

15+
/**
16+
* should be one of 'base64' or 'bytes'
17+
*/
1518
private String responseFormat;
1619

20+
private String style;
21+
22+
/**
23+
* considered only if responseFormat = 'bytes' should be one of 'image/png',
24+
* 'image/jpg', 'image/tiff' etc.
25+
*/
26+
private String responseMimeType;
27+
1728
private String negativePrompt;
1829

1930
private Float sigmaItems;
@@ -71,6 +82,23 @@ public void setResponseFormat(String responseFormat) {
7182
this.responseFormat = responseFormat;
7283
}
7384

85+
@Override
86+
public String getStyle() {
87+
return style;
88+
}
89+
90+
public void setStyle(String style) {
91+
this.style = style;
92+
}
93+
94+
public String getResponseMimeType() {
95+
return responseMimeType;
96+
}
97+
98+
public void setResponseMimeType(String responseMimeType) {
99+
this.responseMimeType = responseMimeType;
100+
}
101+
74102
public String getNegativePrompt() {
75103
return negativePrompt;
76104
}

0 commit comments

Comments
 (0)