Skip to content

Commit 3e569c8

Browse files
Mary Hippmaryhipp
Mary Hipp
authored andcommitted
feat(ui): add fields for CLIP embed models and Flux VAE models in workflows
1 parent 16825ee commit 3e569c8

File tree

10 files changed

+268
-5
lines changed

10 files changed

+268
-5
lines changed

invokeai/frontend/web/src/features/nodes/components/flow/nodes/Invocation/fields/InputFieldRenderer.tsx

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import {
66
isBoardFieldInputTemplate,
77
isBooleanFieldInputInstance,
88
isBooleanFieldInputTemplate,
9+
isCLIPEmbedModelFieldInputInstance,
10+
isCLIPEmbedModelFieldInputTemplate,
911
isColorFieldInputInstance,
1012
isColorFieldInputTemplate,
1113
isControlNetModelFieldInputInstance,
@@ -16,6 +18,8 @@ import {
1618
isFloatFieldInputTemplate,
1719
isFluxMainModelFieldInputInstance,
1820
isFluxMainModelFieldInputTemplate,
21+
isFluxVAEModelFieldInputInstance,
22+
isFluxVAEModelFieldInputTemplate,
1923
isImageFieldInputInstance,
2024
isImageFieldInputTemplate,
2125
isIntegerFieldInputInstance,
@@ -49,10 +53,12 @@ import { memo } from 'react';
4953

5054
import BoardFieldInputComponent from './inputs/BoardFieldInputComponent';
5155
import BooleanFieldInputComponent from './inputs/BooleanFieldInputComponent';
56+
import CLIPEmbedModelFieldInputComponent from './inputs/CLIPEmbedModelFieldInputComponent';
5257
import ColorFieldInputComponent from './inputs/ColorFieldInputComponent';
5358
import ControlNetModelFieldInputComponent from './inputs/ControlNetModelFieldInputComponent';
5459
import EnumFieldInputComponent from './inputs/EnumFieldInputComponent';
5560
import FluxMainModelFieldInputComponent from './inputs/FluxMainModelFieldInputComponent';
61+
import FluxVAEModelFieldInputComponent from './inputs/FluxVAEModelFieldInputComponent';
5662
import ImageFieldInputComponent from './inputs/ImageFieldInputComponent';
5763
import IPAdapterModelFieldInputComponent from './inputs/IPAdapterModelFieldInputComponent';
5864
import LoRAModelFieldInputComponent from './inputs/LoRAModelFieldInputComponent';
@@ -122,6 +128,13 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
122128
if (isT5EncoderModelFieldInputInstance(fieldInstance) && isT5EncoderModelFieldInputTemplate(fieldTemplate)) {
123129
return <T5EncoderModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
124130
}
131+
if (isCLIPEmbedModelFieldInputInstance(fieldInstance) && isCLIPEmbedModelFieldInputTemplate(fieldTemplate)) {
132+
return <CLIPEmbedModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
133+
}
134+
135+
if (isFluxVAEModelFieldInputInstance(fieldInstance) && isFluxVAEModelFieldInputTemplate(fieldTemplate)) {
136+
return <FluxVAEModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
137+
}
125138

126139
if (isLoRAModelFieldInputInstance(fieldInstance) && isLoRAModelFieldInputTemplate(fieldTemplate)) {
127140
return <LoRAModelFieldInputComponent nodeId={nodeId} field={fieldInstance} fieldTemplate={fieldTemplate} />;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
2+
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
3+
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
4+
import { fieldCLIPEmbedValueChanged } from 'features/nodes/store/nodesSlice';
5+
import type { CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate } from 'features/nodes/types/field';
6+
import { memo, useCallback } from 'react';
7+
import { useTranslation } from 'react-i18next';
8+
import { useClipEmbedModels } from 'services/api/hooks/modelsByType';
9+
import type { ClipEmbedModelConfig } from 'services/api/types';
10+
11+
import type { FieldComponentProps } from './types';
12+
13+
type Props = FieldComponentProps<CLIPEmbedModelFieldInputInstance, CLIPEmbedModelFieldInputTemplate>;
14+
15+
const CLIPEmbedModelFieldInputComponent = (props: Props) => {
16+
const { nodeId, field } = props;
17+
const { t } = useTranslation();
18+
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
19+
const dispatch = useAppDispatch();
20+
const [modelConfigs, { isLoading }] = useClipEmbedModels();
21+
const _onChange = useCallback(
22+
(value: ClipEmbedModelConfig | null) => {
23+
if (!value) {
24+
return;
25+
}
26+
dispatch(
27+
fieldCLIPEmbedValueChanged({
28+
nodeId,
29+
fieldName: field.name,
30+
value,
31+
})
32+
);
33+
},
34+
[dispatch, field.name, nodeId]
35+
);
36+
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
37+
modelConfigs,
38+
onChange: _onChange,
39+
isLoading,
40+
selectedModel: field.value,
41+
});
42+
43+
return (
44+
<Flex w="full" alignItems="center" gap={2}>
45+
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
46+
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
47+
<Combobox
48+
value={value}
49+
placeholder={placeholder}
50+
options={options}
51+
onChange={onChange}
52+
noOptionsMessage={noOptionsMessage}
53+
/>
54+
</FormControl>
55+
</Tooltip>
56+
</Flex>
57+
);
58+
};
59+
60+
export default memo(CLIPEmbedModelFieldInputComponent);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
2+
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
3+
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
4+
import { fieldFluxVAEModelValueChanged } from 'features/nodes/store/nodesSlice';
5+
import type { FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate } from 'features/nodes/types/field';
6+
import { memo, useCallback } from 'react';
7+
import { useTranslation } from 'react-i18next';
8+
import { useFluxVAEModels } from 'services/api/hooks/modelsByType';
9+
import type { VAEModelConfig } from 'services/api/types';
10+
11+
import type { FieldComponentProps } from './types';
12+
13+
type Props = FieldComponentProps<FluxVAEModelFieldInputInstance, FluxVAEModelFieldInputTemplate>;
14+
15+
const FluxVAEModelFieldInputComponent = (props: Props) => {
16+
const { nodeId, field } = props;
17+
const { t } = useTranslation();
18+
const disabledTabs = useAppSelector((s) => s.config.disabledTabs);
19+
const dispatch = useAppDispatch();
20+
const [modelConfigs, { isLoading }] = useFluxVAEModels();
21+
const _onChange = useCallback(
22+
(value: VAEModelConfig | null) => {
23+
if (!value) {
24+
return;
25+
}
26+
dispatch(
27+
fieldFluxVAEModelValueChanged({
28+
nodeId,
29+
fieldName: field.name,
30+
value,
31+
})
32+
);
33+
},
34+
[dispatch, field.name, nodeId]
35+
);
36+
const { options, value, onChange, placeholder, noOptionsMessage } = useGroupedModelCombobox({
37+
modelConfigs,
38+
onChange: _onChange,
39+
isLoading,
40+
selectedModel: field.value,
41+
});
42+
43+
return (
44+
<Flex w="full" alignItems="center" gap={2}>
45+
<Tooltip label={!disabledTabs.includes('models') && t('modelManager.starterModelsInModelManager')}>
46+
<FormControl className="nowheel nodrag" isDisabled={!options.length} isInvalid={!value}>
47+
<Combobox
48+
value={value}
49+
placeholder={placeholder}
50+
options={options}
51+
onChange={onChange}
52+
noOptionsMessage={noOptionsMessage}
53+
/>
54+
</FormControl>
55+
</Tooltip>
56+
</Flex>
57+
);
58+
};
59+
60+
export default memo(FluxVAEModelFieldInputComponent);

invokeai/frontend/web/src/features/nodes/store/nodesSlice.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ import { SHARED_NODE_PROPERTIES } from 'features/nodes/types/constants';
66
import type {
77
BoardFieldValue,
88
BooleanFieldValue,
9+
CLIPEmbedModelFieldValue,
910
ColorFieldValue,
1011
ControlNetModelFieldValue,
1112
EnumFieldValue,
1213
FieldValue,
1314
FloatFieldValue,
15+
FluxVAEModelFieldValue,
1416
ImageFieldValue,
1517
IntegerFieldValue,
1618
IPAdapterModelFieldValue,
@@ -29,10 +31,12 @@ import type {
2931
import {
3032
zBoardFieldValue,
3133
zBooleanFieldValue,
34+
zCLIPEmbedModelFieldValue,
3235
zColorFieldValue,
3336
zControlNetModelFieldValue,
3437
zEnumFieldValue,
3538
zFloatFieldValue,
39+
zFluxVAEModelFieldValue,
3640
zImageFieldValue,
3741
zIntegerFieldValue,
3842
zIPAdapterModelFieldValue,
@@ -346,6 +350,12 @@ export const nodesSlice = createSlice({
346350
fieldT5EncoderValueChanged: (state, action: FieldValueAction<T5EncoderModelFieldValue>) => {
347351
fieldValueReducer(state, action, zT5EncoderModelFieldValue);
348352
},
353+
fieldCLIPEmbedValueChanged: (state, action: FieldValueAction<CLIPEmbedModelFieldValue>) => {
354+
fieldValueReducer(state, action, zCLIPEmbedModelFieldValue);
355+
},
356+
fieldFluxVAEModelValueChanged: (state, action: FieldValueAction<FluxVAEModelFieldValue>) => {
357+
fieldValueReducer(state, action, zFluxVAEModelFieldValue);
358+
},
349359
fieldEnumModelValueChanged: (state, action: FieldValueAction<EnumFieldValue>) => {
350360
fieldValueReducer(state, action, zEnumFieldValue);
351361
},
@@ -408,6 +418,8 @@ export const {
408418
fieldStringValueChanged,
409419
fieldVaeModelValueChanged,
410420
fieldT5EncoderValueChanged,
421+
fieldCLIPEmbedValueChanged,
422+
fieldFluxVAEModelValueChanged,
411423
nodeEditorReset,
412424
nodeIsIntermediateChanged,
413425
nodeIsOpenChanged,
@@ -521,6 +533,8 @@ export const isAnyNodeOrEdgeMutation = isAnyOf(
521533
fieldStringValueChanged,
522534
fieldVaeModelValueChanged,
523535
fieldT5EncoderValueChanged,
536+
fieldCLIPEmbedValueChanged,
537+
fieldFluxVAEModelValueChanged,
524538
nodesChanged,
525539
nodeIsIntermediateChanged,
526540
nodeIsOpenChanged,

invokeai/frontend/web/src/features/nodes/types/field.ts

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,14 @@ const zT5EncoderModelFieldType = zFieldTypeBase.extend({
151151
name: z.literal('T5EncoderModelField'),
152152
originalType: zStatelessFieldType.optional(),
153153
});
154+
const zCLIPEmbedModelFieldType = zFieldTypeBase.extend({
155+
name: z.literal('CLIPEmbedModelField'),
156+
originalType: zStatelessFieldType.optional(),
157+
});
158+
const zFluxVAEModelFieldType = zFieldTypeBase.extend({
159+
name: z.literal('FluxVAEModelField'),
160+
originalType: zStatelessFieldType.optional(),
161+
});
154162
const zSchedulerFieldType = zFieldTypeBase.extend({
155163
name: z.literal('SchedulerField'),
156164
originalType: zStatelessFieldType.optional(),
@@ -175,6 +183,8 @@ const zStatefulFieldType = z.union([
175183
zT2IAdapterModelFieldType,
176184
zSpandrelImageToImageModelFieldType,
177185
zT5EncoderModelFieldType,
186+
zCLIPEmbedModelFieldType,
187+
zFluxVAEModelFieldType,
178188
zColorFieldType,
179189
zSchedulerFieldType,
180190
]);
@@ -667,7 +677,53 @@ export const isT5EncoderModelFieldInputInstance = (val: unknown): val is T5Encod
667677
export const isT5EncoderModelFieldInputTemplate = (val: unknown): val is T5EncoderModelFieldInputTemplate =>
668678
zT5EncoderModelFieldInputTemplate.safeParse(val).success;
669679

670-
// #endregio
680+
// #endregion
681+
682+
// #region FluxVAEModelField
683+
684+
export const zFluxVAEModelFieldValue = zModelIdentifierField.optional();
685+
const zFluxVAEModelFieldInputInstance = zFieldInputInstanceBase.extend({
686+
value: zFluxVAEModelFieldValue,
687+
});
688+
const zFluxVAEModelFieldInputTemplate = zFieldInputTemplateBase.extend({
689+
type: zFluxVAEModelFieldType,
690+
originalType: zFieldType.optional(),
691+
default: zFluxVAEModelFieldValue,
692+
});
693+
694+
export type FluxVAEModelFieldValue = z.infer<typeof zFluxVAEModelFieldValue>;
695+
696+
export type FluxVAEModelFieldInputInstance = z.infer<typeof zFluxVAEModelFieldInputInstance>;
697+
export type FluxVAEModelFieldInputTemplate = z.infer<typeof zFluxVAEModelFieldInputTemplate>;
698+
export const isFluxVAEModelFieldInputInstance = (val: unknown): val is FluxVAEModelFieldInputInstance =>
699+
zFluxVAEModelFieldInputInstance.safeParse(val).success;
700+
export const isFluxVAEModelFieldInputTemplate = (val: unknown): val is FluxVAEModelFieldInputTemplate =>
701+
zFluxVAEModelFieldInputTemplate.safeParse(val).success;
702+
703+
// #endregion
704+
705+
// #region CLIPEmbedModelField
706+
707+
export const zCLIPEmbedModelFieldValue = zModelIdentifierField.optional();
708+
const zCLIPEmbedModelFieldInputInstance = zFieldInputInstanceBase.extend({
709+
value: zCLIPEmbedModelFieldValue,
710+
});
711+
const zCLIPEmbedModelFieldInputTemplate = zFieldInputTemplateBase.extend({
712+
type: zCLIPEmbedModelFieldType,
713+
originalType: zFieldType.optional(),
714+
default: zCLIPEmbedModelFieldValue,
715+
});
716+
717+
export type CLIPEmbedModelFieldValue = z.infer<typeof zCLIPEmbedModelFieldValue>;
718+
719+
export type CLIPEmbedModelFieldInputInstance = z.infer<typeof zCLIPEmbedModelFieldInputInstance>;
720+
export type CLIPEmbedModelFieldInputTemplate = z.infer<typeof zCLIPEmbedModelFieldInputTemplate>;
721+
export const isCLIPEmbedModelFieldInputInstance = (val: unknown): val is CLIPEmbedModelFieldInputInstance =>
722+
zCLIPEmbedModelFieldInputInstance.safeParse(val).success;
723+
export const isCLIPEmbedModelFieldInputTemplate = (val: unknown): val is CLIPEmbedModelFieldInputTemplate =>
724+
zCLIPEmbedModelFieldInputTemplate.safeParse(val).success;
725+
726+
// #endregion
671727

672728
// #region SchedulerField
673729

@@ -758,6 +814,8 @@ export const zStatefulFieldValue = z.union([
758814
zT2IAdapterModelFieldValue,
759815
zSpandrelImageToImageModelFieldValue,
760816
zT5EncoderModelFieldValue,
817+
zFluxVAEModelFieldValue,
818+
zCLIPEmbedModelFieldValue,
761819
zColorFieldValue,
762820
zSchedulerFieldValue,
763821
]);
@@ -788,6 +846,8 @@ const zStatefulFieldInputInstance = z.union([
788846
zT2IAdapterModelFieldInputInstance,
789847
zSpandrelImageToImageModelFieldInputInstance,
790848
zT5EncoderModelFieldInputInstance,
849+
zFluxVAEModelFieldInputInstance,
850+
zCLIPEmbedModelFieldInputInstance,
791851
zColorFieldInputInstance,
792852
zSchedulerFieldInputInstance,
793853
]);
@@ -819,6 +879,8 @@ const zStatefulFieldInputTemplate = z.union([
819879
zT2IAdapterModelFieldInputTemplate,
820880
zSpandrelImageToImageModelFieldInputTemplate,
821881
zT5EncoderModelFieldInputTemplate,
882+
zFluxVAEModelFieldInputTemplate,
883+
zCLIPEmbedModelFieldInputTemplate,
822884
zColorFieldInputTemplate,
823885
zSchedulerFieldInputTemplate,
824886
zStatelessFieldInputTemplate,

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputInstance.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ const FIELD_VALUE_FALLBACK_MAP: Record<StatefulFieldType['name'], FieldValue> =
2323
VAEModelField: undefined,
2424
ControlNetModelField: undefined,
2525
T5EncoderModelField: undefined,
26+
FluxVAEModelField: undefined,
27+
CLIPEmbedModelField: undefined,
2628
};
2729

2830
export const buildFieldInputInstance = (id: string, template: FieldInputTemplate): FieldInputInstance => {

invokeai/frontend/web/src/features/nodes/util/schema/buildFieldInputTemplate.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ import { FieldParseError } from 'features/nodes/types/error';
22
import type {
33
BoardFieldInputTemplate,
44
BooleanFieldInputTemplate,
5+
CLIPEmbedModelFieldInputTemplate,
56
ColorFieldInputTemplate,
67
ControlNetModelFieldInputTemplate,
78
EnumFieldInputTemplate,
89
FieldInputTemplate,
910
FieldType,
1011
FloatFieldInputTemplate,
1112
FluxMainModelFieldInputTemplate,
13+
FluxVAEModelFieldInputTemplate,
1214
ImageFieldInputTemplate,
1315
IntegerFieldInputTemplate,
1416
IPAdapterModelFieldInputTemplate,
@@ -238,6 +240,34 @@ const buildT5EncoderModelFieldInputTemplate: FieldInputTemplateBuilder<T5Encoder
238240
return template;
239241
};
240242

243+
const buildCLIPEmbedModelFieldInputTemplate: FieldInputTemplateBuilder<CLIPEmbedModelFieldInputTemplate> = ({
244+
schemaObject,
245+
baseField,
246+
fieldType,
247+
}) => {
248+
const template: CLIPEmbedModelFieldInputTemplate = {
249+
...baseField,
250+
type: fieldType,
251+
default: schemaObject.default ?? undefined,
252+
};
253+
254+
return template;
255+
};
256+
257+
const buildFluxVAEModelFieldInputTemplate: FieldInputTemplateBuilder<FluxVAEModelFieldInputTemplate> = ({
258+
schemaObject,
259+
baseField,
260+
fieldType,
261+
}) => {
262+
const template: FluxVAEModelFieldInputTemplate = {
263+
...baseField,
264+
type: fieldType,
265+
default: schemaObject.default ?? undefined,
266+
};
267+
268+
return template;
269+
};
270+
241271
const buildLoRAModelFieldInputTemplate: FieldInputTemplateBuilder<LoRAModelFieldInputTemplate> = ({
242272
schemaObject,
243273
baseField,
@@ -423,6 +453,8 @@ export const TEMPLATE_BUILDER_MAP: Record<StatefulFieldType['name'], FieldInputT
423453
SpandrelImageToImageModelField: buildSpandrelImageToImageModelFieldInputTemplate,
424454
VAEModelField: buildVAEModelFieldInputTemplate,
425455
T5EncoderModelField: buildT5EncoderModelFieldInputTemplate,
456+
CLIPEmbedModelField: buildCLIPEmbedModelFieldInputTemplate,
457+
FluxVAEModelField: buildFluxVAEModelFieldInputTemplate,
426458
} as const;
427459

428460
export const buildFieldInputTemplate = (

0 commit comments

Comments
 (0)