Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui, worker): add invocation progress events to model loading #7286

Merged
merged 13 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, CLIPTextModel)
assert isinstance(tokenizer, CLIPTokenizer)
compel = Compel(
Expand Down Expand Up @@ -191,6 +192,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
ti_manager,
),
):
context.util.signal_progress("Building conditioning")
assert isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(tokenizer, CLIPTokenizer)

Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/create_denoise_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def invoke(self, context: InvocationContext) -> DenoiseMaskOutput:
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
# TODO:
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(vae_info, self.fp32, self.tiled, masked_image.clone())

masked_latents_name = context.tensors.save(tensor=masked_latents)
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/create_gradient_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def invoke(self, context: InvocationContext) -> GradientMaskOutput:
image_tensor = image_tensor.unsqueeze(0)
img_mask = tv_resize(mask, image_tensor.shape[-2:], T.InterpolationMode.BILINEAR, antialias=False)
masked_image = image_tensor * torch.where(img_mask < 0.5, 0.0, 1.0)
context.util.signal_progress("Running VAE encoder")
masked_latents = ImageToLatentsInvocation.vae_encode(
vae_info, self.fp32, self.tiled, masked_image.clone()
)
Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/invocations/flux_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _t5_encode(self, context: InvocationContext) -> torch.Tensor:

t5_encoder = HFEncoder(t5_text_encoder, t5_tokenizer, False, self.t5_max_seq_len)

context.util.signal_progress("Running T5 encoder")
prompt_embeds = t5_encoder(prompt)

assert isinstance(prompt_embeds, torch.Tensor)
Expand Down Expand Up @@ -111,6 +112,7 @@ def _clip_encode(self, context: InvocationContext) -> torch.Tensor:

clip_encoder = HFEncoder(clip_text_encoder, clip_tokenizer, True, 77)

context.util.signal_progress("Running CLIP encoder")
pooled_prompt_embeds = clip_encoder(prompt)

assert isinstance(pooled_prompt_embeds, torch.Tensor)
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/flux_vae_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _vae_decode(self, vae_info: LoadedModel, latents: torch.Tensor) -> Image.Ima
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)
vae_info = context.models.load(self.vae.vae)
context.util.signal_progress("Running VAE")
image = self._vae_decode(vae_info=vae_info, latents=latents)

TorchDevice.empty_cache()
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/flux_vae_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

context.util.signal_progress("Running VAE")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)

latents = latents.to("cpu")
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/image_to_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

context.util.signal_progress("Running VAE encoder")
latents = self.vae_encode(
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
)
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE decoder")
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
latents = latents.to(vae.device)
if self.fp32:
Expand Down
1 change: 1 addition & 0 deletions invokeai/app/invocations/sd3_latents_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
vae_info = context.models.load(self.vae.vae)
assert isinstance(vae_info.model, (AutoencoderKL))
with SeamlessExt.static_patch_model(vae_info.model, self.vae.seamless_axes), vae_info as vae:
context.util.signal_progress("Running VAE")
assert isinstance(vae, (AutoencoderKL))
latents = latents.to(vae.device)

Expand Down
2 changes: 2 additions & 0 deletions invokeai/app/invocations/sd3_text_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _t5_encode(self, context: InvocationContext, max_seq_len: int) -> torch.Tens
t5_text_encoder_info as t5_text_encoder,
t5_tokenizer_info as t5_tokenizer,
):
context.util.signal_progress("Running T5 encoder")
assert isinstance(t5_text_encoder, T5EncoderModel)
assert isinstance(t5_tokenizer, (T5Tokenizer, T5TokenizerFast))

Expand Down Expand Up @@ -137,6 +138,7 @@ def _clip_encode(
clip_tokenizer_info as clip_tokenizer,
ExitStack() as exit_stack,
):
context.util.signal_progress("Running CLIP encoder")
assert isinstance(clip_text_encoder, (CLIPTextModel, CLIPTextModelWithProjection))
assert isinstance(clip_tokenizer, CLIPTokenizer)

Expand Down
33 changes: 28 additions & 5 deletions invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ def error(self, message: str) -> None:


class ImagesInterface(InvocationContextInterface):
def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util

def save(
self,
image: Image,
Expand All @@ -186,6 +190,8 @@ def save(
The saved image DTO.
"""

self._util.signal_progress("Saving image")

# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata:
Expand Down Expand Up @@ -336,6 +342,10 @@ def load(self, name: str) -> ConditioningFieldData:
class ModelsInterface(InvocationContextInterface):
"""Common API for loading, downloading and managing models."""

def __init__(self, services: InvocationServices, data: InvocationContextData, util: "UtilInterface") -> None:
super().__init__(services, data)
self._util = util

def exists(self, identifier: Union[str, "ModelIdentifierField"]) -> bool:
"""Check if a model exists.

Expand Down Expand Up @@ -368,11 +378,15 @@ def load(

if isinstance(identifier, str):
model = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.load.load_model(model, submodel_type)
else:
_submodel_type = submodel_type or identifier.submodel_type
submodel_type = submodel_type or identifier.submodel_type
model = self._services.model_manager.store.get_model(identifier.key)
return self._services.model_manager.load.load_model(model, _submodel_type)

message = f"Loading model {model.name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(model, submodel_type)

def load_by_attrs(
self, name: str, base: BaseModelType, type: ModelType, submodel_type: Optional[SubModelType] = None
Expand All @@ -397,6 +411,10 @@ def load_by_attrs(
if len(configs) > 1:
raise ValueError(f"More than one model found with name {name}, base {base}, and type {type}")

message = f"Loading model {name}"
if submodel_type:
message += f" ({submodel_type.value})"
self._util.signal_progress(message)
return self._services.model_manager.load.load_model(configs[0], submodel_type)

def get_config(self, identifier: Union[str, "ModelIdentifierField"]) -> AnyModelConfig:
Expand Down Expand Up @@ -467,6 +485,7 @@ def download_and_cache_model(
Returns:
Path to the downloaded model
"""
self._util.signal_progress(f"Downloading model {source}")
return self._services.model_manager.install.download_and_cache_model(source=source)

def load_local_model(
Expand All @@ -489,6 +508,8 @@ def load_local_model(
Returns:
A LoadedModelWithoutConfig object.
"""

self._util.signal_progress(f"Loading model {model_path.name}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)

def load_remote_model(
Expand All @@ -514,6 +535,8 @@ def load_remote_model(
A LoadedModelWithoutConfig object.
"""
model_path = self._services.model_manager.install.download_and_cache_model(source=str(source))

self._util.signal_progress(f"Loading model {source}")
return self._services.model_manager.load.load_model_from_path(model_path=model_path, loader=loader)


Expand Down Expand Up @@ -707,12 +730,12 @@ def build_invocation_context(
"""

logger = LoggerInterface(services=services, data=data)
images = ImagesInterface(services=services, data=data)
tensors = TensorsInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data)
config = ConfigInterface(services=services, data=data)
util = UtilInterface(services=services, data=data, is_canceled=is_canceled)
conditioning = ConditioningInterface(services=services, data=data)
models = ModelsInterface(services=services, data=data, util=util)
images = ImagesInterface(services=services, data=data, util=util)
boards = BoardsInterface(services=services, data=data)

ctx = InvocationContext(
Expand Down
4 changes: 3 additions & 1 deletion invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,8 @@
"placeholderSelectAModel": "Select a model",
"reset": "Reset",
"none": "None",
"new": "New"
"new": "New",
"generating": "Generating"
},
"hrf": {
"hrf": "High Resolution Fix",
Expand Down Expand Up @@ -1139,6 +1140,7 @@
"resetWebUI": "Reset Web UI",
"resetWebUIDesc1": "Resetting the web UI only resets the browser's local cache of your images and remembered settings. It does not delete any images from disk.",
"resetWebUIDesc2": "If images aren't showing up in the gallery or something else isn't working, please try resetting before submitting an issue on GitHub.",
"showDetailedInvocationProgress": "Show Progress Details",
"showProgressInViewer": "Show Progress Images in Viewer",
"ui": "User Interface",
"clearIntermediatesDisabled": "Queue must be empty to clear intermediates",
Expand Down
3 changes: 2 additions & 1 deletion invokeai/frontend/web/src/app/types/invokeai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ export type AppFeature =
| 'invocationCache'
| 'bulkDownload'
| 'starterModels'
| 'hfToken';
| 'hfToken'
| 'invocationProgressAlert';

/**
* A disable-able Stable Diffusion feature
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import { Alert, AlertDescription, AlertIcon, AlertTitle } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { useAppSelector } from 'app/store/storeHooks';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { selectSystemShouldShowInvocationProgressDetail } from 'features/system/store/systemSlice';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { $invocationProgressMessage } from 'services/events/stores';

const CanvasAlertsInvocationProgressContent = memo(() => {
const { t } = useTranslation();
const invocationProgressMessage = useStore($invocationProgressMessage);

if (!invocationProgressMessage) {
return null;
}

return (
<Alert status="loading" borderRadius="base" fontSize="sm" shadow="md" w="fit-content">
<AlertIcon />
<AlertTitle>{t('common.generating')}</AlertTitle>
<AlertDescription>{invocationProgressMessage}</AlertDescription>
</Alert>
);
});
CanvasAlertsInvocationProgressContent.displayName = 'CanvasAlertsInvocationProgressContent';

export const CanvasAlertsInvocationProgress = memo(() => {
const isProgressMessageAlertEnabled = useFeatureStatus('invocationProgressAlert');
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);

// The alert is disabled at the system level
if (!isProgressMessageAlertEnabled) {
return null;
}

// The alert is disabled at the user level
if (!shouldShowInvocationProgressDetail) {
return null;
}

return <CanvasAlertsInvocationProgressContent />;
});

CanvasAlertsInvocationProgress.displayName = 'CanvasAlertsInvocationProgress';
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import { GatedImageViewer } from 'features/gallery/components/ImageViewer/ImageV
import { memo, useCallback, useRef } from 'react';
import { PiDotsThreeOutlineVerticalFill } from 'react-icons/pi';

import { CanvasAlertsInvocationProgress } from './CanvasAlerts/CanvasAlertsInvocationProgress';

const MenuContent = () => {
return (
<CanvasManagerProviderGate>
Expand Down Expand Up @@ -84,6 +86,7 @@ export const CanvasMainPanelContent = memo(() => {
<CanvasAlertsSelectedEntityStatus />
<CanvasAlertsPreserveMask />
<CanvasAlertsSendingToGallery />
<CanvasAlertsInvocationProgress />
</Flex>
<Flex position="absolute" top={1} insetInlineEnd={1}>
<Menu>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { Box, Flex } from '@invoke-ai/ui-library';
import { useStore } from '@nanostores/react';
import { skipToken } from '@reduxjs/toolkit/query';
import { useAppSelector } from 'app/store/storeHooks';
import { CanvasAlertsInvocationProgress } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsInvocationProgress';
import { CanvasAlertsSendingToCanvas } from 'features/controlLayers/components/CanvasAlerts/CanvasAlertsSendingTo';
import { DndImage } from 'features/dnd/DndImage';
import ImageMetadataViewer from 'features/gallery/components/ImageMetadataViewer/ImageMetadataViewer';
Expand Down Expand Up @@ -48,9 +49,18 @@ const CurrentImagePreview = () => {
position="relative"
>
<ImageContent imageDTO={imageDTO} />
<Box position="absolute" top={0} insetInlineStart={0}>
<Flex
flexDir="column"
gap={2}
position="absolute"
top={0}
insetInlineStart={0}
pointerEvents="none"
alignItems="flex-start"
>
<CanvasAlertsSendingToCanvas />
</Box>
<CanvasAlertsInvocationProgress />
</Flex>
{shouldShowImageDetails && imageDTO && (
<Box position="absolute" opacity={0.8} top={0} width="full" height="full" borderRadius="base">
<ImageMetadataViewer image={imageDTO} />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ export const ImageViewer = memo(({ closeButton }: Props) => {
right={0}
bottom={0}
left={0}
rowGap={2}
alignItems="center"
justifyContent="center"
>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,20 @@ import { SettingsDeveloperLogLevel } from 'features/system/components/SettingsMo
import { SettingsDeveloperLogNamespaces } from 'features/system/components/SettingsModal/SettingsDeveloperLogNamespaces';
import { useClearIntermediates } from 'features/system/components/SettingsModal/useClearIntermediates';
import { StickyScrollable } from 'features/system/components/StickyScrollable';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import {
selectSystemShouldAntialiasProgressImage,
selectSystemShouldConfirmOnDelete,
selectSystemShouldConfirmOnNewSession,
selectSystemShouldEnableInformationalPopovers,
selectSystemShouldEnableModelDescriptions,
selectSystemShouldShowInvocationProgressDetail,
selectSystemShouldUseNSFWChecker,
selectSystemShouldUseWatermarker,
setShouldConfirmOnDelete,
setShouldEnableInformationalPopovers,
setShouldEnableModelDescriptions,
setShouldShowInvocationProgressDetail,
shouldAntialiasProgressImageChanged,
shouldConfirmOnNewSessionToggled,
shouldUseNSFWCheckerChanged,
Expand Down Expand Up @@ -103,6 +106,8 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
const shouldEnableInformationalPopovers = useAppSelector(selectSystemShouldEnableInformationalPopovers);
const shouldEnableModelDescriptions = useAppSelector(selectSystemShouldEnableModelDescriptions);
const shouldConfirmOnNewSession = useAppSelector(selectSystemShouldConfirmOnNewSession);
const shouldShowInvocationProgressDetail = useAppSelector(selectSystemShouldShowInvocationProgressDetail);
const isInvocationProgressAlertEnabled = useFeatureStatus('invocationProgressAlert');
const onToggleConfirmOnNewSession = useCallback(() => {
dispatch(shouldConfirmOnNewSessionToggled());
}, [dispatch]);
Expand Down Expand Up @@ -170,6 +175,13 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
[dispatch]
);

const handleChangeShouldShowInvocationProgressDetail = useCallback(
(e: ChangeEvent<HTMLInputElement>) => {
dispatch(setShouldShowInvocationProgressDetail(e.target.checked));
},
[dispatch]
);

return (
<>
{cloneElement(children, {
Expand Down Expand Up @@ -221,6 +233,15 @@ const SettingsModal = ({ config = defaultConfig, children }: SettingsModalProps)
onChange={handleChangeShouldAntialiasProgressImage}
/>
</FormControl>
{isInvocationProgressAlertEnabled && (
<FormControl>
<FormLabel>{t('settings.showDetailedInvocationProgress')}</FormLabel>
<Switch
isChecked={shouldShowInvocationProgressDetail}
onChange={handleChangeShouldShowInvocationProgressDetail}
/>
</FormControl>
)}
<FormControl>
<InformationalPopover feature="noiseUseCPU" inPortal={false}>
<FormLabel>{t('parameters.useCpuNoise')}</FormLabel>
Expand Down
Loading
Loading