Skip to content

Commit 1f38698

Browse files
zekeSBrandeishanouticelinaWauplin
authored
Run Flux LoRAs on Replicate (#1395)
This PR updates the `ReplicateTextToImageTask` method to support running arbitrary LoRAs on Replicate. How it works: 1. Detect LoRAs by their tags before inference. 2. Use https://replicate.com/black-forest-labs/flux-dev-lora for inference 3. Use the requested LoRA as the `weights_url` input to the flux-dev-lora model ## Companion PR - replicate/huggingface-model-mappings#9 ## Testing Remove the `skip` from `describe.skip("InferenceClient" ...` in InferenceClient.spec.ts Then: ``` $ pnpx vitest run --config vitest.config.mts -t "Replicate" ``` --------- Co-authored-by: Simon Brandeis <[email protected]> Co-authored-by: célina <[email protected]> Co-authored-by: Lucain <[email protected]>
1 parent 23d556f commit 1f38698

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

packages/inference/src/providers/replicate.ts

+15-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import { isUrl } from "../lib/isUrl";
1919
import type { BodyParams, HeaderParams, UrlParams } from "../types";
2020
import { omit } from "../utils/omit";
2121
import { TaskProviderHelper, type TextToImageTaskHelper, type TextToVideoTaskHelper } from "./providerHelper";
22-
2322
export interface ReplicateOutput {
2423
output?: string | string[];
2524
}
@@ -63,6 +62,21 @@ abstract class ReplicateTask extends TaskProviderHelper {
6362
}
6463

6564
export class ReplicateTextToImageTask extends ReplicateTask implements TextToImageTaskHelper {
65+
override preparePayload(params: BodyParams): Record<string, unknown> {
66+
return {
67+
input: {
68+
...omit(params.args, ["inputs", "parameters"]),
69+
...(params.args.parameters as Record<string, unknown>),
70+
prompt: params.args.inputs,
71+
lora_weights:
72+
params.mapping?.adapter === "lora" && params.mapping.adapterWeightsPath
73+
? `https://huggingface.co/${params.mapping.hfModelId}`
74+
: undefined,
75+
},
76+
version: params.model.includes(":") ? params.model.split(":")[1] : undefined,
77+
};
78+
}
79+
6680
override async getResponse(
6781
res: ReplicateOutput | Blob,
6882
url?: string,

packages/inference/test/InferenceClient.spec.ts

+11
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,17 @@ describe.skip("InferenceClient", () => {
11601160
expect(res).toBeInstanceOf(Blob);
11611161
});
11621162

1163+
// Runs black-forest-labs/flux-dev-lora under the hood
1164+
// with fofr/flux-80s-cyberpunk as the LoRA weights
1165+
it("textToImage - all Flux LoRAs", async () => {
1166+
const res = await client.textToImage({
1167+
model: "fofr/flux-80s-cyberpunk",
1168+
provider: "replicate",
1169+
inputs: "style of 80s cyberpunk, a portrait photo",
1170+
});
1171+
expect(res).toBeInstanceOf(Blob);
1172+
});
1173+
11631174
it("textToImage canonical - stabilityai/stable-diffusion-3.5-large-turbo", async () => {
11641175
const res = await client.textToImage({
11651176
model: "stabilityai/stable-diffusion-3.5-large-turbo",

0 commit comments

Comments
 (0)