Skip to content

feat(ui): flux redux canvas #7752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions invokeai/app/invocations/flux_redux.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
)
from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField
from invokeai.app.services.model_records.model_records_base import ModelRecordChanges
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.redux.flux_redux_model import FluxReduxModel
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelType
from invokeai.backend.model_manager.starter_models import siglip
from invokeai.backend.sig_lip.sig_lip_pipeline import SigLipPipeline
from invokeai.backend.util.devices import TorchDevice

Expand All @@ -35,16 +38,12 @@ class FluxReduxOutput(BaseInvocationOutput):
)


SIGLIP_STARTER_MODEL_NAME = "SigLIP - google/siglip-so400m-patch14-384"
FLUX_REDUX_STARTER_MODEL_NAME = "FLUX Redux"


@invocation(
"flux_redux",
title="FLUX Redux",
tags=["ip_adapter", "control"],
category="ip_adapter",
version="1.0.0",
version="2.0.0",
classification=Classification.Prototype,
)
class FluxReduxInvocation(BaseInvocation):
Expand All @@ -61,11 +60,6 @@ class FluxReduxInvocation(BaseInvocation):
title="FLUX Redux Model",
ui_type=UIType.FluxReduxModel,
)
siglip_model: ModelIdentifierField = InputField(
description="The SigLIP model to use.",
title="SigLIP Model",
ui_type=UIType.SigLipModel,
)

def invoke(self, context: InvocationContext) -> FluxReduxOutput:
image = context.images.get_pil(self.image.image_name, "RGB")
Expand All @@ -80,7 +74,8 @@ def invoke(self, context: InvocationContext) -> FluxReduxOutput:

@torch.no_grad()
def _siglip_encode(self, context: InvocationContext, image: Image.Image) -> torch.Tensor:
with context.models.load(self.siglip_model).model_on_device() as (_, siglip_pipeline):
siglip_model_config = self._get_siglip_model(context)
with context.models.load(siglip_model_config.key).model_on_device() as (_, siglip_pipeline):
assert isinstance(siglip_pipeline, SigLipPipeline)
return siglip_pipeline.encode_image(
x=image, device=TorchDevice.choose_torch_device(), dtype=TorchDevice.choose_torch_dtype()
Expand All @@ -93,3 +88,32 @@ def _flux_redux_encode(self, context: InvocationContext, encoded_x: torch.Tensor
dtype = next(flux_redux.parameters()).dtype
encoded_x = encoded_x.to(dtype=dtype)
return flux_redux(encoded_x)

def _get_siglip_model(self, context: InvocationContext) -> AnyModelConfig:
siglip_models = context.models.search_by_attrs(name=siglip.name, base=BaseModelType.Any, type=ModelType.SigLIP)

if not len(siglip_models) > 0:
context.logger.warning(
f"The SigLIP model required by FLUX Redux ({siglip.name}) is not installed. Downloading and installing now. This may take a while."
)

# TODO(psyche): Can the probe reliably determine the type of the model? Just hardcoding it bc I don't want to experiment now
config_overrides = ModelRecordChanges(name=siglip.name, type=ModelType.SigLIP)

# Queue the job
job = context._services.model_manager.install.heuristic_import(siglip.source, config=config_overrides)

# Wait for up to 10 minutes - model is ~3.5GB
context._services.model_manager.install.wait_for_job(job, timeout=600)

siglip_models = context.models.search_by_attrs(
name=siglip.name,
base=BaseModelType.Any,
type=ModelType.SigLIP,
)

if len(siglip_models) == 0:
context.logger.error("Error while fetching SigLIP for FLUX Redux")
assert len(siglip_models) == 1

return siglip_models[0]
4 changes: 2 additions & 2 deletions invokeai/backend/model_manager/starter_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ class StarterModelBundles(BaseModel):
source="black-forest-labs/FLUX.1-Redux-dev::flux1-redux-dev.safetensors",
description="FLUX Redux model (for image variation).",
type=ModelType.FluxRedux,
dependencies=[siglip],
)
# endregion

Expand Down Expand Up @@ -717,7 +718,6 @@ class StarterModelBundles(BaseModel):
scribble_sdxl,
tile_sdxl,
swinir,
flux_redux,
]

flux_bundle: list[StarterModel] = [
Expand All @@ -730,7 +730,7 @@ class StarterModelBundles(BaseModel):
ip_adapter_flux,
flux_canny_control_lora,
flux_depth_control_lora,
siglip,
flux_redux,
]

STARTER_BUNDLES: dict[str, list[StarterModel]] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import type { AnyModelConfig } from 'services/api/types';
import {
isCLIPEmbedModelConfig,
isControlLayerModelConfig,
isFluxReduxModelConfig,
isFluxVAEModelConfig,
isIPAdapterModelConfig,
isLoRAModelConfig,
Expand Down Expand Up @@ -77,6 +78,7 @@ export const addModelsLoadedListener = (startAppListening: AppStartListening) =>
handleT5EncoderModels(models, state, dispatch, log);
handleCLIPEmbedModels(models, state, dispatch, log);
handleFLUXVAEModels(models, state, dispatch, log);
handleFLUXReduxModels(models, state, dispatch, log);
},
});
};
Expand Down Expand Up @@ -209,6 +211,10 @@ const handleControlAdapterModels: ModelHandler = (models, state, dispatch, log)
const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
const ipaModels = models.filter(isIPAdapterModelConfig);
selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'ip_adapter') {
return;
}

const selectedIPAdapterModel = entity.ipAdapter.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
Expand All @@ -224,6 +230,10 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {

selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'ip_adapter') {
return;
}

const selectedIPAdapterModel = ipAdapter.model;
// `null` is a valid IP adapter model - no need to do anything.
if (!selectedIPAdapterModel) {
Expand All @@ -241,6 +251,49 @@ const handleIPAdapterModels: ModelHandler = (models, state, dispatch, log) => {
});
};

const handleFLUXReduxModels: ModelHandler = (models, state, dispatch, log) => {
const fluxReduxModels = models.filter(isFluxReduxModelConfig);

selectCanvasSlice(state).referenceImages.entities.forEach((entity) => {
if (entity.ipAdapter.type !== 'flux_redux') {
return;
}
const selectedFLUXReduxModel = entity.ipAdapter.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
}
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
if (isModelAvailable) {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(referenceImageIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), modelConfig: null }));
});

selectCanvasSlice(state).regionalGuidance.entities.forEach((entity) => {
entity.referenceImages.forEach(({ id: referenceImageId, ipAdapter }) => {
if (ipAdapter.type !== 'flux_redux') {
return;
}

const selectedFLUXReduxModel = ipAdapter.model;
// `null` is a valid FLUX Redux model - no need to do anything.
if (!selectedFLUXReduxModel) {
return;
}
const isModelAvailable = fluxReduxModels.some((m) => m.key === selectedFLUXReduxModel.key);
if (isModelAvailable) {
return;
}
log.debug({ selectedFLUXReduxModel }, 'Selected FLUX Redux model is not available, clearing');
dispatch(
rgIPAdapterModelChanged({ entityIdentifier: getEntityIdentifier(entity), referenceImageId, modelConfig: null })
);
});
});
};

const handlePostProcessingModel: ModelHandler = (models, state, dispatch, log) => {
const selectedPostProcessingModel = state.upscale.postProcessingModel;
const allSpandrelModels = models.filter(isSpandrelImageToImageModelConfig);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
useAddRegionalGuidance,
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
Expand All @@ -22,7 +22,6 @@ export const CanvasAddEntityButtons = memo(() => {
const addControlLayer = useAddControlLayer();
const addGlobalReferenceImage = useAddGlobalReferenceImage();
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);

return (
Expand Down Expand Up @@ -75,7 +74,7 @@ export const CanvasAddEntityButtons = memo(() => {
justifyContent="flex-start"
leftIcon={<PiPlusBold />}
onClick={addRegionalReferenceImage}
isDisabled={isFLUX || isSD3}
isDisabled={isSD3}
>
{t('controlLayers.regionalReferenceImage')}
</Button>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
useAddRegionalReferenceImage,
} from 'features/controlLayers/hooks/addLayerHooks';
import { useCanvasIsBusy } from 'features/controlLayers/hooks/useCanvasIsBusy';
import { selectIsFLUX, selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { selectIsSD3 } from 'features/controlLayers/store/paramsSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';
Expand All @@ -23,7 +23,6 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
const addRegionalReferenceImage = useAddRegionalReferenceImage();
const addRasterLayer = useAddRasterLayer();
const addControlLayer = useAddControlLayer();
const isFLUX = useAppSelector(selectIsFLUX);
const isSD3 = useAppSelector(selectIsSD3);

return (
Expand Down Expand Up @@ -52,7 +51,7 @@ export const EntityListGlobalActionBarAddLayerMenu = memo(() => {
<MenuItem icon={<PiPlusBold />} onClick={addRegionalGuidance} isDisabled={isSD3}>
{t('controlLayers.regionalGuidance')}
</MenuItem>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isFLUX || isSD3}>
<MenuItem icon={<PiPlusBold />} onClick={addRegionalReferenceImage} isDisabled={isSD3}>
{t('controlLayers.regionalReferenceImage')}
</MenuItem>
</MenuGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import type { ComboboxOnChange } from '@invoke-ai/ui-library';
import { Combobox, FormControl } from '@invoke-ai/ui-library';
import { useAppSelector } from 'app/store/storeHooks';
import { selectIsFLUX } from 'features/controlLayers/store/paramsSlice';
import type { CLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { isCLIPVisionModelV2 } from 'features/controlLayers/store/types';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { assert } from 'tsafe';

// at this time, ViT-L is the only supported clip model for FLUX IP adapter
const FLUX_CLIP_VISION = 'ViT-L';

const CLIP_VISION_OPTIONS = [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
{ label: FLUX_CLIP_VISION, value: FLUX_CLIP_VISION },
];

type Props = {
model: CLIPVisionModelV2;
onChange: (clipVisionModel: CLIPVisionModelV2) => void;
};

export const CLIPVisionModel = memo(({ model, onChange }: Props) => {
const { t } = useTranslation();

const _onChangeCLIPVisionModel = useCallback<ComboboxOnChange>(
(v) => {
assert(isCLIPVisionModelV2(v?.value));
onChange(v.value);
},
[onChange]
);

const isFLUX = useAppSelector(selectIsFLUX);

const clipVisionOptions = useMemo(() => {
return CLIP_VISION_OPTIONS.map((option) => ({
...option,
isDisabled: isFLUX && option.value !== FLUX_CLIP_VISION,
}));
}, [isFLUX]);

const clipVisionModelValue = useMemo(() => {
return CLIP_VISION_OPTIONS.find((o) => o.value === model);
}, [model]);

return (
<FormControl width="max-content" minWidth={28}>
<Combobox
options={clipVisionOptions}
placeholder={t('common.placeholderSelectAModel')}
value={clipVisionModelValue}
onChange={_onChangeCLIPVisionModel}
/>
</FormControl>
);
});

CLIPVisionModel.displayName = 'CLIPVisionModel';
Loading
Loading