Skip to content

Commit

Permalink
Merge pull request #394 from dorianjanezic/main
Browse files Browse the repository at this point in the history
feat: video generation plugin
  • Loading branch information
sirkitree authored Nov 19, 2024
2 parents 87a367e + 591b971 commit af4e1fd
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 0 deletions.
125 changes: 125 additions & 0 deletions packages/core/src/tests/videoGeneration.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import { IAgentRuntime, Memory, State } from "@ai16z/eliza";
import { videoGenerationPlugin } from "../index";

// Mock the fetch function
global.fetch = jest.fn();

// Mock the fs module
jest.mock('fs', () => ({
writeFileSync: jest.fn(),
existsSync: jest.fn(),
mkdirSync: jest.fn(),
}));

describe('Video Generation Plugin', () => {
let mockRuntime: IAgentRuntime;
let mockCallback: jest.Mock;

beforeEach(() => {
// Reset mocks
jest.clearAllMocks();

// Setup mock runtime
mockRuntime = {
getSetting: jest.fn().mockReturnValue('mock-api-key'),
agentId: 'mock-agent-id',
composeState: jest.fn().mockResolvedValue({}),
} as unknown as IAgentRuntime;

mockCallback = jest.fn();

// Setup fetch mock for successful response
(global.fetch as jest.Mock).mockImplementation(() =>
Promise.resolve({
ok: true,
json: () => Promise.resolve({
id: 'mock-generation-id',
status: 'completed',
assets: {
video: 'https://example.com/video.mp4'
}
}),
text: () => Promise.resolve(''),
})
);
});

it('should validate when API key is present', async () => {
const mockMessage = {} as Memory;
const result = await videoGenerationPlugin.actions[0].validate(mockRuntime, mockMessage);
expect(result).toBe(true);
expect(mockRuntime.getSetting).toHaveBeenCalledWith('LUMA_API_KEY');
});

it('should handle video generation request', async () => {
const mockMessage = {
content: {
text: 'Generate a video of a sunset'
}
} as Memory;
const mockState = {} as State;

await videoGenerationPlugin.actions[0].handler(
mockRuntime,
mockMessage,
mockState,
{},
mockCallback
);

// Check initial callback
expect(mockCallback).toHaveBeenCalledWith(
expect.objectContaining({
text: expect.stringContaining('I\'ll generate a video based on your prompt')
})
);

// Check final callback with video
expect(mockCallback).toHaveBeenCalledWith(
expect.objectContaining({
text: 'Here\'s your generated video!',
attachments: expect.arrayContaining([
expect.objectContaining({
source: 'videoGeneration'
})
])
}),
expect.arrayContaining([expect.stringMatching(/generated_video_.*\.mp4/)])
);
});

it('should handle API errors gracefully', async () => {
// Mock API error
(global.fetch as jest.Mock).mockImplementationOnce(() =>
Promise.resolve({
ok: false,
status: 500,
statusText: 'Internal Server Error',
text: () => Promise.resolve('API Error'),
})
);

const mockMessage = {
content: {
text: 'Generate a video of a sunset'
}
} as Memory;
const mockState = {} as State;

await videoGenerationPlugin.actions[0].handler(
mockRuntime,
mockMessage,
mockState,
{},
mockCallback
);

// Check error callback
expect(mockCallback).toHaveBeenCalledWith(
expect.objectContaining({
text: expect.stringContaining('Video generation failed'),
error: true
})
);
});
});
18 changes: 18 additions & 0 deletions packages/plugin-video-generation/package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{
"name": "@ai16z/plugin-video-generation",
"version": "0.0.1",
"main": "dist/index.js",
"type": "module",
"types": "dist/index.d.ts",
"dependencies": {
"@ai16z/eliza": "workspace:*",
"tsup": "^8.3.5"
},
"scripts": {
"build": "tsup --format esm --dts",
"dev": "tsup --watch"
},
"peerDependencies": {
"whatwg-url": "7.1.0"
}
}
4 changes: 4 additions & 0 deletions packages/plugin-video-generation/src/constants.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export const LUMA_CONSTANTS = {
API_URL: 'https://api.lumalabs.ai/dream-machine/v1/generations',
API_KEY_SETTING: "LUMA_API_KEY" // The setting name to fetch from runtime
};
221 changes: 221 additions & 0 deletions packages/plugin-video-generation/src/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import { elizaLogger } from "@ai16z/eliza/src/logger.ts";
import {
Action,
HandlerCallback,
IAgentRuntime,
Memory,
Plugin,
State,
} from "@ai16z/eliza/src/types.ts";
import fs from "fs";
import { LUMA_CONSTANTS } from './constants';

const generateVideo = async (prompt: string, runtime: IAgentRuntime) => {
const API_KEY = runtime.getSetting(LUMA_CONSTANTS.API_KEY_SETTING);

try {
elizaLogger.log("Starting video generation with prompt:", prompt);

const response = await fetch(LUMA_CONSTANTS.API_URL, {
method: 'POST',
headers: {
'Authorization': `Bearer ${API_KEY}`,
'accept': 'application/json',
'Content-Type': 'application/json'
},
body: JSON.stringify({ prompt })
});

if (!response.ok) {
const errorText = await response.text();
elizaLogger.error("Luma API error:", {
status: response.status,
statusText: response.statusText,
error: errorText
});
throw new Error(`Luma API error: ${response.statusText} - ${errorText}`);
}

const data = await response.json();
elizaLogger.log("Generation request successful, received response:", data);

// Poll for completion
let status = data.status;
let videoUrl = null;
const generationId = data.id;

while (status !== 'completed' && status !== 'failed') {
await new Promise(resolve => setTimeout(resolve, 5000)); // Wait 5 seconds

const statusResponse = await fetch(`${LUMA_CONSTANTS.API_URL}/${generationId}`, {
method: 'GET',
headers: {
'Authorization': `Bearer ${API_KEY}`,
'accept': 'application/json'
}
});

if (!statusResponse.ok) {
const errorText = await statusResponse.text();
elizaLogger.error("Status check error:", {
status: statusResponse.status,
statusText: statusResponse.statusText,
error: errorText
});
throw new Error('Failed to check generation status: ' + errorText);
}

const statusData = await statusResponse.json();
elizaLogger.log("Status check response:", statusData);

status = statusData.state;
if (status === 'completed') {
videoUrl = statusData.assets?.video;
}
}

if (status === 'failed') {
throw new Error('Video generation failed');
}

if (!videoUrl) {
throw new Error('No video URL in completed response');
}

return {
success: true,
data: videoUrl
};
} catch (error) {
elizaLogger.error("Video generation error:", error);
return {
success: false,
error: error.message || 'Unknown error occurred'
};
}
}

const videoGeneration: Action = {
name: "GENERATE_VIDEO",
similes: [
"VIDEO_GENERATION",
"VIDEO_GEN",
"CREATE_VIDEO",
"MAKE_VIDEO",
"RENDER_VIDEO",
"ANIMATE",
"CREATE_ANIMATION",
"VIDEO_CREATE",
"VIDEO_MAKE"
],
description: "Generate a video based on a text prompt",
validate: async (runtime: IAgentRuntime, message: Memory) => {
elizaLogger.log("Validating video generation action");
const lumaApiKey = runtime.getSetting("LUMA_API_KEY");
elizaLogger.log("LUMA_API_KEY present:", !!lumaApiKey);
return !!lumaApiKey;
},
handler: async (
runtime: IAgentRuntime,
message: Memory,
state: State,
options: any,
callback: HandlerCallback
) => {
elizaLogger.log("Video generation request:", message);

// Clean up the prompt by removing mentions and commands
let videoPrompt = message.content.text
.replace(/<@\d+>/g, '') // Remove mentions
.replace(/generate video|create video|make video|render video/gi, '') // Remove commands
.trim();

if (!videoPrompt || videoPrompt.length < 5) {
callback({
text: "Could you please provide more details about what kind of video you'd like me to generate? For example: 'Generate a video of a sunset on a beach' or 'Create a video of a futuristic city'",
});
return;
}

elizaLogger.log("Video prompt:", videoPrompt);

callback({
text: `I'll generate a video based on your prompt: "${videoPrompt}". This might take a few minutes...`,
});

try {
const result = await generateVideo(videoPrompt, runtime);

if (result.success && result.data) {
// Download the video file
const response = await fetch(result.data);
const arrayBuffer = await response.arrayBuffer();
const videoFileName = `content_cache/generated_video_${Date.now()}.mp4`;

// Save video file
fs.writeFileSync(videoFileName, Buffer.from(arrayBuffer));

callback({
text: "Here's your generated video!",
attachments: [
{
id: crypto.randomUUID(),
url: result.data,
title: "Generated Video",
source: "videoGeneration",
description: videoPrompt,
text: videoPrompt,
},
],
}, [videoFileName]); // Add the video file to the attachments
} else {
callback({
text: `Video generation failed: ${result.error}`,
error: true
});
}
} catch (error) {
elizaLogger.error(`Failed to generate video. Error: ${error}`);
callback({
text: `Video generation failed: ${error.message}`,
error: true
});
}
},
examples: [
[
{
user: "{{user1}}",
content: { text: "Generate a video of a cat playing piano" },
},
{
user: "{{agentName}}",
content: {
text: "I'll create a video of a cat playing piano for you",
action: "GENERATE_VIDEO"
},
}
],
[
{
user: "{{user1}}",
content: { text: "Can you make a video of a sunset at the beach?" },
},
{
user: "{{agentName}}",
content: {
text: "I'll generate a beautiful beach sunset video for you",
action: "GENERATE_VIDEO"
},
}
]
]
} as Action;

export const videoGenerationPlugin: Plugin = {
name: "videoGeneration",
description: "Generate videos using Luma AI",
actions: [videoGeneration],
evaluators: [],
providers: [],
};
15 changes: 15 additions & 0 deletions packages/plugin-video-generation/tsconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"extends": "../../tsconfig.json",
"compilerOptions": {
"outDir": "dist",
"rootDir": ".",
"module": "ESNext",
"moduleResolution": "Bundler",
"types": [
"node"
]
},
"include": [
"src"
]
}
Loading

0 comments on commit af4e1fd

Please sign in to comment.