Skip to content

Commit

Permalink
AI Featured Image: support backend prompts (#37668)
Browse files Browse the repository at this point in the history
* Export generic image generation method by parameters

* Change generation call to a generic one so the backend can decide which one to use

* Introduce activeModel cost and use it to infer the model name

* changelog
  • Loading branch information
lhkowalski authored Jun 3, 2024
1 parent b6ef3c3 commit 5383ad7
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Significance: patch
Type: changed

AI Featured Image: export generic image generation request function.
2 changes: 1 addition & 1 deletion projects/js-packages/ai-client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"private": false,
"name": "@automattic/jetpack-ai-client",
"version": "0.14.2",
"version": "0.14.3-alpha",
"description": "A JS client for consuming Jetpack AI services",
"homepage": "https://github.com/Automattic/jetpack/tree/HEAD/projects/js-packages/ai-client/#readme",
"bugs": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ const getStableDiffusionImageGenerationPrompt = async (
};

const useImageGenerator = () => {
const executeImageGeneration = async function ( parameters: {
[ key: string ]: string;
} ): Promise< ImageGenerationResponse > {
const executeImageGeneration = async function (
parameters: object
): Promise< ImageGenerationResponse > {
let token = '';

try {
Expand Down Expand Up @@ -259,6 +259,7 @@ const useImageGenerator = () => {
return {
generateImage,
generateImageWithStableDiffusion,
generateImageWithParameters: executeImageGeneration,
};
};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Significance: minor
Type: other

AI Featured Image: let the backend decide the model for the image generation.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import { Icon, external } from '@wordpress/icons';
import './style.scss';
import UpgradePrompt from '../../../../blocks/ai-assistant/components/upgrade-prompt';
import useAiFeature from '../../../../blocks/ai-assistant/hooks/use-ai-feature';
import { getFeatureAvailability } from '../../../../blocks/ai-assistant/lib/utils/get-feature-availability';
import { PLAN_TYPE_UNLIMITED, usePlanType } from '../../../../shared/use-plan-type';
import usePostContent from '../../hooks/use-post-content';
import useSaveToMediaLibrary from '../../hooks/use-save-to-media-library';
Expand All @@ -34,17 +33,8 @@ const FEATURED_IMAGE_UPGRADE_PROMPT_PLACEMENT = 'ai-image-generator';
const FEATURED_IMAGE_FEATURE_NAME = 'featured-post-image';
export const FEATURED_IMAGE_PLACEMENT_MEDIA_SOURCE_DROPDOWN = 'media-source-dropdown';

/**
* Control experimental image generation for the featured image.
*/
const AI_ASSISTANT_EXPERIMENTAL_IMAGE_GENERATION_SUPPORT =
'ai-assistant-experimental-image-generation-support';
const isAiAssistantExperimentalImageGenerationSupportEnabled = getFeatureAvailability(
AI_ASSISTANT_EXPERIMENTAL_IMAGE_GENERATION_SUPPORT
);
const IMAGE_GENERATION_MODEL = isAiAssistantExperimentalImageGenerationSupportEnabled
? 'stable-diffusion'
: 'dall-e-3';
const IMAGE_GENERATION_MODEL_STABLE_DIFFUSION = 'stable-diffusion';
const IMAGE_GENERATION_MODEL_DALL_E_3 = 'dall-e-3';
/**
* Determine the site type for tracking purposes.
*
Expand Down Expand Up @@ -86,7 +76,7 @@ export default function FeaturedImage( {
const triggeredAutoGeneration = useRef( false );

const { enableComplementaryArea } = useDispatch( 'core/interface' );
const { generateImage, generateImageWithStableDiffusion } = useImageGenerator();
const { generateImageWithParameters } = useImageGenerator();
const { saveToMediaLibrary } = useSaveToMediaLibrary();
const { tracks } = useAnalytics();
const { recordEvent } = tracks;
Expand All @@ -104,9 +94,11 @@ export default function FeaturedImage( {
costs,
} = useAiFeature();
const planType = usePlanType( currentTier );
const featuredImageCost = isAiAssistantExperimentalImageGenerationSupportEnabled
? costs?.[ FEATURED_IMAGE_FEATURE_NAME ]?.stableDiffusion ?? 1
: costs?.[ FEATURED_IMAGE_FEATURE_NAME ]?.image;
const featuredImageCost = costs?.[ FEATURED_IMAGE_FEATURE_NAME ]?.activeModel ?? 10;
const featuredImageActiveModel =
featuredImageCost === costs?.[ FEATURED_IMAGE_FEATURE_NAME ]?.stableDiffusion
? IMAGE_GENERATION_MODEL_STABLE_DIFFUSION
: IMAGE_GENERATION_MODEL_DALL_E_3;
const isUnlimited = planType === PLAN_TYPE_UNLIMITED;
const requestsBalance = requestsLimit - requestsCount;
const notEnoughRequests = requestsBalance < featuredImageCost;
Expand Down Expand Up @@ -150,12 +142,12 @@ export default function FeaturedImage( {
recordEvent( 'jetpack_ai_featured_image_generation_error', {
placement,
error: data.error?.message,
model: IMAGE_GENERATION_MODEL,
model: featuredImageActiveModel,
site_type: SITE_TYPE,
} );
}
},
[ placement, recordEvent ]
[ placement, recordEvent, featuredImageActiveModel ]
);

const handlePreviousImage = useCallback( () => {
Expand Down Expand Up @@ -203,19 +195,25 @@ export default function FeaturedImage( {
return;
}

/** Decide between standard or experimental generation */
const generateImagePromise = isAiAssistantExperimentalImageGenerationSupportEnabled
? generateImageWithStableDiffusion( {
feature: FEATURED_IMAGE_FEATURE_NAME,
postContent,
userPrompt,
} )
: generateImage( {
feature: FEATURED_IMAGE_FEATURE_NAME,
postContent,
responseFormat: 'b64_json',
userPrompt,
} );
/**
* Make a generic call to backend and let it decide the model.
*/
const generateImagePromise = generateImageWithParameters( {
feature: FEATURED_IMAGE_FEATURE_NAME,
size: '1792x1024', // the size, when the generation happens with DALL-E-3
responseFormat: 'b64_json', // the response format, when the generation happens with DALL-E-3
style: 'photographic', // the style of the image, when the generation happens with Stable Diffusion
messages: [
{
role: 'jetpack-ai',
context: {
type: 'featured-image-generation',
request: userPrompt ? userPrompt : null,
content: postContent,
},
},
],
} );

generateImagePromise
.then( result => {
Expand All @@ -240,8 +238,7 @@ export default function FeaturedImage( {
}, [
notEnoughRequests,
updateImages,
generateImage,
generateImageWithStableDiffusion,
generateImageWithParameters,
postContent,
userPrompt,
updateRequestsCount,
Expand All @@ -261,36 +258,42 @@ export default function FeaturedImage( {
// track the generate image event
recordEvent( 'jetpack_ai_featured_image_generation_generate_image', {
placement,
model: IMAGE_GENERATION_MODEL,
model: featuredImageActiveModel,
site_type: SITE_TYPE,
} );

toggleFeaturedImageModal();
processImageGeneration();
}, [ toggleFeaturedImageModal, processImageGeneration, recordEvent, placement ] );
}, [
toggleFeaturedImageModal,
processImageGeneration,
recordEvent,
placement,
featuredImageActiveModel,
] );

const handleRegenerate = useCallback( () => {
// track the regenerate image event
recordEvent( 'jetpack_ai_featured_image_generation_generate_another_image', {
placement,
model: IMAGE_GENERATION_MODEL,
model: featuredImageActiveModel,
site_type: SITE_TYPE,
} );

processImageGeneration();
setCurrent( crrt => crrt + 1 );
}, [ processImageGeneration, recordEvent, placement ] );
}, [ processImageGeneration, recordEvent, placement, featuredImageActiveModel ] );

const handleTryAgain = useCallback( () => {
// track the try again event
recordEvent( 'jetpack_ai_featured_image_generation_try_again', {
placement,
model: IMAGE_GENERATION_MODEL,
model: featuredImageActiveModel,
site_type: SITE_TYPE,
} );

processImageGeneration();
}, [ processImageGeneration, recordEvent, placement ] );
}, [ processImageGeneration, recordEvent, placement, featuredImageActiveModel ] );

const handleUserPromptChange = useCallback(
( e: React.ChangeEvent< HTMLTextAreaElement > ) => {
Expand All @@ -309,7 +312,7 @@ export default function FeaturedImage( {
// track the accept/use image event
recordEvent( 'jetpack_ai_featured_image_generation_use_image', {
placement,
model: IMAGE_GENERATION_MODEL,
model: featuredImageActiveModel,
site_type: SITE_TYPE,
} );

Expand Down Expand Up @@ -355,6 +358,7 @@ export default function FeaturedImage( {
triggerComplementaryArea,
handleModalClose,
placement,
featuredImageActiveModel,
] );

/**
Expand Down

0 comments on commit 5383ad7

Please sign in to comment.