From a39930e8c6cc7795d469ae0932d93d2a2633d6a0 Mon Sep 17 00:00:00 2001
From: Tom Wise <79859203+tomwisecodes@users.noreply.github.com>
Date: Thu, 5 Dec 2024 13:53:49 +0000
Subject: [PATCH] chore: regeneration
---
.../src/app/image-spike/[slug]/images.tsx | 155 ++++++++++++------
.../src/components/RegenerationForm.tsx | 96 +++++++++++
apps/nextjs/src/hooks/useImageSearch.ts | 67 ++++++++
packages/api/src/router/imageGen.ts | 135 ++++++++++++++-
4 files changed, 406 insertions(+), 47 deletions(-)
create mode 100644 apps/nextjs/src/components/RegenerationForm.tsx
diff --git a/apps/nextjs/src/app/image-spike/[slug]/images.tsx b/apps/nextjs/src/app/image-spike/[slug]/images.tsx
index a70b78c63..dfa3eaaf0 100644
--- a/apps/nextjs/src/app/image-spike/[slug]/images.tsx
+++ b/apps/nextjs/src/app/image-spike/[slug]/images.tsx
@@ -12,6 +12,7 @@ import Link from "next/link";
import type { ImageResponse } from "types/imageTypes";
import LoadingWheel from "@/components/LoadingWheel";
+import { RegenerationForm } from "@/components/RegenerationForm";
export type Cycle = {
title: string;
@@ -70,8 +71,51 @@ const ImagesPage = ({ pageData }) => {
},
]);
+ // In your ImagesPage component, add this state
+ const [isRegenerating, setIsRegenerating] = useState(false);
+
const { bestImage, findBestImage } = useBestImage({ pageData });
- const { fetchImages, availableSources } = useImageSearch({ pageData });
+ const { fetchImages, availableSources, regenerateImageWithAnalysis } =
+ useImageSearch({ pageData });
+
+ // Add this function to handle regeneration
+ const handleRegeneration = async (columnId: string, feedback: string) => {
+ const column = comparisonColumns.find((col) => col.id === columnId);
+ if (!column?.imageSearchBatch?.[0]) return;
+
+ const image = column.imageSearchBatch[0];
+
+ updateColumn(columnId, { isLoading: true });
+
+ if (!image.imageSource) {
+ throw new Error("Image source is required to regenerate image.");
+ }
+
+ try {
+ setIsRegenerating(true);
+ const regeneratedImage = await regenerateImageWithAnalysis(
+ image.url,
+ selectedImagePrompt,
+ feedback,
+ image.imageSource.toLowerCase().includes("Stability AI")
+ ? "stability"
+ : "openai",
+ );
+
+ setIsRegenerating(false);
+ updateColumn(columnId, {
+ imageSearchBatch: [regeneratedImage],
+ });
+ } catch (error) {
+ setIsRegenerating(false);
+ console.error("Error regenerating image:", error);
+ updateColumn(columnId, {
+ error: "Failed to regenerate image",
+ });
+ } finally {
+ updateColumn(columnId, { isLoading: false });
+ }
+ };
if (
!pageData?.lessonPlan?.cycle1?.explanation?.imagePrompt ||
@@ -196,8 +240,8 @@ const ImagesPage = ({ pageData }) => {
{promptConstructor(
prompt,
pageData.title,
- pageData.keyStage,
pageData.subject,
+ pageData.keyStage,
pageData.lessonPlan,
)}
@@ -348,58 +392,77 @@ const ImagesPage = ({ pageData }) => {
{column.imageSearchBatch && (
- {column.imageSearchBatch?.map((image) => (
-
-
-
-
-
-
- License:{" "}
- {image.license}
-
-
- Score:{" "}
- {image.appropriatenessScore}
-
-
- Prompt used:{" "}
- {image.imagePrompt}
-
-
- Reasoning:{" "}
- {image.appropriatenessReasoning}
-
- {image.photographer && (
+ {column.imageSearchBatch?.map((image) => {
+ console.log("iumage", image);
+ return (
+
+
+ {isRegenerating ? (
+
+ ) : (
+
+ )}
+
+ {(image.imageSource === "DAL-E" ||
+ image.imageSource?.includes(
+ "Stable Diffusion",
+ )) && (
+
+ handleRegeneration(column.id, feedback)
+ }
+ imageSource={image.imageSource}
+ />
+ )}
+
- By:{" "}
- {image.photographer}
+ License:{" "}
+ {image.license}
- )}
- {image.title && (
- Title:{" "}
- {image.title}
+ Score:{" "}
+ {image.appropriatenessScore}
+
+
+ Prompt used:{" "}
+ {image.imagePrompt}
- )}
-
-
Fetch: {image.timing.fetch.toFixed(2)}ms
- Validation: {image.timing.validation.toFixed(2)}ms
+ Reasoning:{" "}
+ {image.appropriatenessReasoning}
-
Total: {image.timing.total.toFixed(2)}ms
+ {image.photographer && (
+
+ By:{" "}
+ {image.photographer}
+
+ )}
+ {image.title && (
+
+ Title:{" "}
+ {image.title}
+
+ )}
+
+
Fetch: {image.timing.fetch.toFixed(2)}ms
+
+ Validation: {image.timing.validation.toFixed(2)}
+ ms
+
+
Total: {image.timing.total.toFixed(2)}ms
+
-
- ))}
+ );
+ })}
)}
diff --git a/apps/nextjs/src/components/RegenerationForm.tsx b/apps/nextjs/src/components/RegenerationForm.tsx
new file mode 100644
index 000000000..e3be99d28
--- /dev/null
+++ b/apps/nextjs/src/components/RegenerationForm.tsx
@@ -0,0 +1,96 @@
+import React, { useState } from "react";
+
+import { OakIcon } from "@oaknational/oak-components";
+import * as Dialog from "@radix-ui/react-dialog";
+
+import LoadingWheel from "@/components/LoadingWheel";
+
+import { Icon } from "./Icon";
+
+interface RegenerationFormProps {
+ onSubmit: (feedback: string) => Promise
;
+ imageSource: string;
+}
+
+export const RegenerationForm: React.FC = ({
+ onSubmit,
+ imageSource,
+}) => {
+ const [isOpen, setIsOpen] = useState(false);
+ const [feedback, setFeedback] = useState("");
+ const [isLoading, setIsLoading] = useState(false);
+
+ const handleSubmit = async (e: React.FormEvent) => {
+ e.preventDefault();
+ setIsLoading(true);
+ try {
+ await onSubmit(feedback);
+ setIsOpen(false);
+ setFeedback("");
+ } catch (error) {
+ console.error("Error regenerating image:", error);
+ } finally {
+ setIsLoading(false);
+ }
+ };
+
+ return (
+
+
+
+
+
+
+
+
+
+ Regenerate Image
+
+
+
+
+
+
+
+
+
+
+ );
+};
diff --git a/apps/nextjs/src/hooks/useImageSearch.ts b/apps/nextjs/src/hooks/useImageSearch.ts
index b414ccff3..31ed3c239 100644
--- a/apps/nextjs/src/hooks/useImageSearch.ts
+++ b/apps/nextjs/src/hooks/useImageSearch.ts
@@ -17,6 +17,9 @@ export const useImageSearch = ({ pageData }: ImageSearchHookProps) => {
const flickrMutation = trpc.imageSearch.getImagesFromFlickr.useMutation();
const unsplashMutation = trpc.imageSearch.getImagesFromUnsplash.useMutation();
const stableDiffusionCoreMutation = trpc.imageGen.stableDifCore.useMutation();
+ const analyzeAndRegenerateMutation =
+ trpc.imageGen.analyzeAndRegenerate.useMutation();
+
const daleMutation = trpc.imageGen.openAi.useMutation();
const validateImageMutation = trpc.imageGen.validateImage.useMutation();
const validateImagesInParallel =
@@ -34,6 +37,69 @@ export const useImageSearch = ({ pageData }: ImageSearchHookProps) => {
Cloudinary: cloudinaryMutation,
};
+ const regenerateImageWithAnalysis = async (
+ originalImageUrl: string,
+ originalPrompt: string,
+ feedback: string,
+ provider: "openai" | "stability",
+ ): Promise => {
+ try {
+ if (
+ !pageData?.title ||
+ !pageData?.keyStage ||
+ !pageData?.subject ||
+ !pageData?.lessonPlan
+ ) {
+ throw new Error("Missing required page data");
+ }
+
+ const startTime = performance.now();
+ const fetchStart = performance.now();
+
+ const response = await analyzeAndRegenerateMutation.mutateAsync({
+ originalImageUrl,
+ originalPrompt,
+ feedback,
+ lessonTitle: pageData.title,
+ subject: pageData.subject,
+ keyStage: pageData.keyStage,
+ lessonPlan: pageData.lessonPlan,
+ provider,
+ });
+
+ const fetchEnd = performance.now();
+
+ // Validate the regenerated image
+ const validationStart = performance.now();
+ const validationResult = await validateSingleImage(
+ response,
+ originalPrompt,
+ true,
+ );
+ const endTime = performance.now();
+ // generate random string for id
+ const id = Math.random().toString(36).substr(2, 9);
+
+ return {
+ id: id,
+ url: response,
+ license: provider === "openai" ? "OpenAI DALL-E 3" : "Stability AI",
+ imagePrompt: validationResult.metadata.promptUsed ?? originalPrompt,
+ appropriatenessScore: validationResult.metadata.appropriatenessScore,
+ appropriatenessReasoning: validationResult.metadata.validationReasoning,
+ imageSource: provider === "openai" ? "OpenAI DALL-E 3" : "Stability AI",
+ timing: {
+ total: endTime - startTime,
+ fetch: fetchEnd - fetchStart,
+ validation: endTime - validationStart,
+ },
+ };
+ } catch (error) {
+ console.error("Error regenerating image:", error);
+ throw error;
+ }
+ };
+
const validateSingleImage = async (
imageUrl: string,
prompt: string,
@@ -303,5 +369,6 @@ export const useImageSearch = ({ pageData }: ImageSearchHookProps) => {
return {
fetchImages,
availableSources: trpcMutations,
+ regenerateImageWithAnalysis,
};
};
diff --git a/packages/api/src/router/imageGen.ts b/packages/api/src/router/imageGen.ts
index b4c58e5a5..677c53d7d 100644
--- a/packages/api/src/router/imageGen.ts
+++ b/packages/api/src/router/imageGen.ts
@@ -20,6 +20,18 @@ const STABLE_DIF_API_KEY = process.env.STABLE_DIF_API_KEY;
const MAX_PARALLEL_CHECKS = 5;
// Types
+
+interface RefinementInput {
+ originalImageUrl: string;
+ originalPrompt: string;
+ feedback: string;
+ lessonTitle: string;
+ subject: string;
+ keyStage: string;
+ lessonPlan: LessonPlan;
+ provider: "openai" | "stability";
+}
+
interface ValidatedImage {
id: string;
url: string;
@@ -428,7 +440,7 @@ function findTheRelevantCycle({
throw new Error("Cycle not found");
}
}
-// Router Definition
+
export const imageGen = router({
customPipelineWithReasoning: protectedProcedure
.input(
@@ -889,4 +901,125 @@ export const imageGen = router({
throw error;
}
}),
+ analyzeAndRegenerate: protectedProcedure
+ .input(
+ z.object({
+ originalImageUrl: z.string(),
+ originalPrompt: z.string(),
+ feedback: z.string(),
+ lessonTitle: z.string(),
+ subject: z.string(),
+ keyStage: z.string(),
+ lessonPlan: imageLessonPlan,
+ provider: z.enum(["openai", "stability"]),
+ }),
+ )
+ .mutation(async ({ input }) => {
+ try {
+ const cycleInfo = findTheRelevantCycle({
+ lessonPlan: input.lessonPlan,
+ searchExpression: input.originalPrompt,
+ });
+
+ // First, analyze the original image with GPT-4 Vision
+ const openai = new OpenAI({
+ apiKey: process.env.OPENAI_API_KEY,
+ baseURL: process.env.HELICONE_EU_HOST,
+ });
+
+ // Get image analysis and refined prompt
+ const analysisResponse = await openai.chat.completions.create({
+ model: "gpt-4o",
+ messages: [
+ {
+ role: "user",
+ content: [
+ {
+ type: "text",
+ text: `Analyze this image that was generated for an educational context.
+ Original prompt: ${input.originalPrompt}
+ User feedback: ${input.feedback}
+
+ Based on the image and feedback, provide a detailed prompt that would generate an improved version.
+ Focus on specific visual elements that need to change.`,
+ },
+ {
+ type: "image_url",
+ image_url: {
+ url: input.originalImageUrl,
+ },
+ },
+ ],
+ },
+ ],
+ max_tokens: 500,
+ });
+
+ const refinedPrompt = analysisResponse.choices[0]?.message?.content;
+ if (!refinedPrompt) {
+ throw new Error("Failed to generate refined prompt");
+ }
+
+ console.log("refinedPrompt", refinedPrompt);
+
+ // pause operations for 10s
+
+ // Construct the final prompt
+ const finalPrompt = `${promptConstructor(
+ input.originalPrompt,
+ input.lessonTitle,
+ input.subject,
+ input.keyStage,
+ cycleInfo,
+ )}\n\nRefined requirements based on previous version: ${refinedPrompt}`;
+
+ if (input.provider === "openai") {
+ // Generate new image with DALL-E
+ const response = await openai.images.generate({
+ model: "dall-e-3",
+ prompt: finalPrompt,
+ n: 1,
+ size: "1024x1024",
+ });
+
+ if (!response?.data?.[0]?.url) {
+ throw new Error("Failed to generate image with feedback");
+ }
+
+ return response.data[0].url;
+ } else if (input.provider === "stability") {
+ // For Stability AI, we can use their img2img endpoint
+ const formData = new FormData();
+ formData.append(
+ "init_image",
+ await fetch(input.originalImageUrl).then((r) => r.blob()),
+ );
+ formData.append("prompt", finalPrompt);
+ formData.append("image_strength", "0.35"); // How much to preserve from original
+
+ const response = await fetch(
+ "https://api.stability.ai/v1/generation/stable-diffusion-xl-1024-v1-0/image-to-image",
+ {
+ method: "POST",
+ headers: {
+ Authorization: `Bearer ${STABLE_DIF_API_KEY}`,
+ },
+ body: formData,
+ },
+ );
+
+ if (!response.ok) {
+ throw new Error(`Stability API error: ${response.statusText}`);
+ }
+
+ const imageBuffer = await response.arrayBuffer();
+ return `data:image/png;base64,${Buffer.from(imageBuffer).toString("base64")}`;
+ }
+
+ throw new Error("Invalid provider specified");
+ } catch (error) {
+ console.error("[AnalyzeAndRegenerate] Error:", error);
+ throw error;
+ }
+ }),
});