Skip to content

feat: add hugging face text to image integration #1162

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion models/spring-ai-huggingface/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
<version>3.0.46</version>
<executions>
<execution>
<id>generate-chat-api</id>
<goals>
<goal>generate</goal>
</goals>
Expand All @@ -101,7 +102,7 @@
<language>java</language>
<library>resttemplate</library>
<apiPackage>org.springframework.ai.huggingface.api</apiPackage>
<modelPackage>org.springframework.ai.huggingface.model</modelPackage>
<modelPackage>org.springframework.ai.huggingface.model.chat</modelPackage>
<invokerPackage>org.springframework.ai.huggingface.invoker</invokerPackage>
<generateApiTests>false</generateApiTests>
<generateModelTests>false</generateModelTests>
Expand All @@ -113,6 +114,30 @@
</configOptions>
</configuration>
</execution>
<execution>
<id>generate-imagegen-api</id>
<goals>
<goal>generate</goal>
</goals>
<configuration>
<inputSpec>${project.basedir}/src/main/resources/openapi-imagegen.json</inputSpec>
<language>java</language>
<library>resttemplate</library>
<apiPackage>org.springframework.ai.huggingface.api</apiPackage>
<modelPackage>org.springframework.ai.huggingface.model.imagegen</modelPackage>
<invokerPackage>org.springframework.ai.huggingface.invoker</invokerPackage>
<generateApiTests>false</generateApiTests>
<generateModelTests>false</generateModelTests>
<!-- use custom codegen-template to avoid accept-header selection in generated inference api -->
<templateDirectory>src/main/resources/swagger-codegen/templates/Java</templateDirectory>
<configOptions>
<sourceFolder>src/main/java</sourceFolder>
<dateLibrary>java8</dateLibrary>
<!-- jackson secret sauce!! -->
<notNullJacksonAnnotation>true</notNullJacksonAnnotation>
</configOptions>
</configuration>
</execution>
</executions>
</plugin>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

package org.springframework.ai.huggingface;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

Expand All @@ -32,10 +28,16 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.huggingface.api.TextGenerationInferenceApi;
import org.springframework.ai.huggingface.invoker.ApiClient;
import org.springframework.ai.huggingface.model.AllOfGenerateResponseDetails;
import org.springframework.ai.huggingface.model.CompatGenerateRequest;
import org.springframework.ai.huggingface.model.GenerateParameters;
import org.springframework.ai.huggingface.model.GenerateResponse;

import org.springframework.ai.huggingface.model.chat.*;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.ai.huggingface.model.chat.AllOfGenerateResponseDetails;
import org.springframework.ai.huggingface.model.chat.GenerateParameters;
import org.springframework.ai.huggingface.model.chat.GenerateRequest;
import org.springframework.ai.huggingface.model.chat.GenerateResponse;

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like your code style settings don't align with those of the project. None of these lines should change

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not related to code style, I have changed the buildpath in openapi codegen configuration to improve the package structure of generated code. Is this an undesired change ?
However I used the default code styles of my IDE, changed it now to settings mentioned here: https://github.com/spring-projects/spring-framework/wiki/IntelliJ-IDEA-Editor-Settings

/**
* An implementation of {@link ChatModel} that interfaces with HuggingFace Inference
Expand Down Expand Up @@ -112,6 +114,10 @@ public ChatResponse call(Prompt prompt) {
return new ChatResponse(generations);
}

public Info info() {
return this.textGenApi.getModelInfo();
}

/**
* Gets the maximum number of new tokens to be generated.
* @return The maximum number of new tokens.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package org.springframework.ai.huggingface;

import org.springframework.ai.huggingface.api.ImageGenerationInferenceApi;
import org.springframework.ai.huggingface.text2image.HuggingfaceImageOptions;
import org.springframework.ai.huggingface.invoker.ApiClient;
import org.springframework.ai.huggingface.model.imagegen.GenerateParameters;
import org.springframework.ai.huggingface.model.imagegen.GenerateRequest;
import org.springframework.ai.image.*;

import java.util.Base64;
import java.util.List;

/**
* An implementation of {@link ImageModel} that interfaces with HuggingFace Inference
* Endpoints for text-to-image generation.
*
* @author Denis Lobo
*/
public class HuggingfaceImageModel implements ImageModel {

private final String APPLICATION_JSON = "application/json";

/**
* Token required for authenticating with the HuggingFace Inference API.
*/
private final String apiToken;

/**
* Client for making API calls.
*/
private ApiClient apiClient = new ApiClient();

private ImageGenerationInferenceApi imageGenApi = new ImageGenerationInferenceApi();

/**
* Constructs a new HuggingfaceImageModel with the specified API token and base path.
* @param apiToken The API token for HuggingFace.
* @param basePath The base path for API requests.
*/
public HuggingfaceImageModel(final String apiToken, String basePath) {
this.apiToken = apiToken;
this.apiClient.setBasePath(basePath);
this.apiClient.addDefaultHeader("Authorization", "Bearer " + this.apiToken);
this.imageGenApi.setApiClient(this.apiClient);
}

@Override
public ImageResponse call(ImagePrompt prompt) {
final GenerateParameters generateParameters = createGenerateParameters(prompt.getOptions());
final GenerateRequest generateRequest = createGenerateRequest(prompt.getInstructions(), generateParameters);

// hf text-to-image endpoints return only a single image in default mode
final String base64Encoded = generateImage(generateRequest, prompt);
final Image image = new Image(null, base64Encoded);
final ImageGeneration imageGeneration = new ImageGeneration(image);
return new ImageResponse(List.of(imageGeneration), new ImageResponseMetadata());
}

private String generateImage(GenerateRequest generateRequest, ImagePrompt prompt) {
final String responseFormat = prompt.getOptions().getResponseFormat();
final HuggingfaceImageOptions options = (HuggingfaceImageOptions) prompt.getOptions();
switch (responseFormat) {
case "base64" -> {
return new String(this.imageGenApi.generate(generateRequest, APPLICATION_JSON));
}
case "bytes" -> {
byte[] bytes = this.imageGenApi.generate(generateRequest, options.getResponseMimeType());
return Base64.getEncoder().encodeToString(bytes);
}
default -> {
throw new UnsupportedOperationException(String
.format("Unsupported response format: %s, should be 'base64' or 'bytes'", responseFormat));
}
}
}

private GenerateRequest createGenerateRequest(List<ImageMessage> promptInstructs,
GenerateParameters generateParameters) {
final GenerateRequest request = new GenerateRequest();
final List<String> instructions = promptInstructs.stream().map(ImageMessage::getText).toList();

request.setParameters(generateParameters);
request.setInputs(instructions);
return request;
}

private GenerateParameters createGenerateParameters(ImageOptions options) {
final GenerateParameters params = new GenerateParameters();
params.setWidth(options.getWidth());
params.setHeight(options.getHeight());
params.setNumImagesPerPrompt(options.getN());

if (options instanceof HuggingfaceImageOptions hfImageOptions) {
params.setClipSkip(hfImageOptions.getClipSkip());
params.setGuidanceScale(hfImageOptions.getGuidanceScale());
params.setNumInferenceSteps(hfImageOptions.getNumInferenceSteps());
params.setNegativePrompt(List.of(hfImageOptions.getNegativePrompt()));
}
return params;
}

}
Loading