-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #394 from dorianjanezic/main
feat: video generation plugin
- Loading branch information
Showing
6 changed files
with
402 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import { IAgentRuntime, Memory, State } from "@ai16z/eliza"; | ||
import { videoGenerationPlugin } from "../index"; | ||
|
||
// Mock the fetch function | ||
global.fetch = jest.fn(); | ||
|
||
// Mock the fs module | ||
jest.mock('fs', () => ({ | ||
writeFileSync: jest.fn(), | ||
existsSync: jest.fn(), | ||
mkdirSync: jest.fn(), | ||
})); | ||
|
||
describe('Video Generation Plugin', () => { | ||
let mockRuntime: IAgentRuntime; | ||
let mockCallback: jest.Mock; | ||
|
||
beforeEach(() => { | ||
// Reset mocks | ||
jest.clearAllMocks(); | ||
|
||
// Setup mock runtime | ||
mockRuntime = { | ||
getSetting: jest.fn().mockReturnValue('mock-api-key'), | ||
agentId: 'mock-agent-id', | ||
composeState: jest.fn().mockResolvedValue({}), | ||
} as unknown as IAgentRuntime; | ||
|
||
mockCallback = jest.fn(); | ||
|
||
// Setup fetch mock for successful response | ||
(global.fetch as jest.Mock).mockImplementation(() => | ||
Promise.resolve({ | ||
ok: true, | ||
json: () => Promise.resolve({ | ||
id: 'mock-generation-id', | ||
status: 'completed', | ||
assets: { | ||
video: 'https://example.com/video.mp4' | ||
} | ||
}), | ||
text: () => Promise.resolve(''), | ||
}) | ||
); | ||
}); | ||
|
||
it('should validate when API key is present', async () => { | ||
const mockMessage = {} as Memory; | ||
const result = await videoGenerationPlugin.actions[0].validate(mockRuntime, mockMessage); | ||
expect(result).toBe(true); | ||
expect(mockRuntime.getSetting).toHaveBeenCalledWith('LUMA_API_KEY'); | ||
}); | ||
|
||
it('should handle video generation request', async () => { | ||
const mockMessage = { | ||
content: { | ||
text: 'Generate a video of a sunset' | ||
} | ||
} as Memory; | ||
const mockState = {} as State; | ||
|
||
await videoGenerationPlugin.actions[0].handler( | ||
mockRuntime, | ||
mockMessage, | ||
mockState, | ||
{}, | ||
mockCallback | ||
); | ||
|
||
// Check initial callback | ||
expect(mockCallback).toHaveBeenCalledWith( | ||
expect.objectContaining({ | ||
text: expect.stringContaining('I\'ll generate a video based on your prompt') | ||
}) | ||
); | ||
|
||
// Check final callback with video | ||
expect(mockCallback).toHaveBeenCalledWith( | ||
expect.objectContaining({ | ||
text: 'Here\'s your generated video!', | ||
attachments: expect.arrayContaining([ | ||
expect.objectContaining({ | ||
source: 'videoGeneration' | ||
}) | ||
]) | ||
}), | ||
expect.arrayContaining([expect.stringMatching(/generated_video_.*\.mp4/)]) | ||
); | ||
}); | ||
|
||
it('should handle API errors gracefully', async () => { | ||
// Mock API error | ||
(global.fetch as jest.Mock).mockImplementationOnce(() => | ||
Promise.resolve({ | ||
ok: false, | ||
status: 500, | ||
statusText: 'Internal Server Error', | ||
text: () => Promise.resolve('API Error'), | ||
}) | ||
); | ||
|
||
const mockMessage = { | ||
content: { | ||
text: 'Generate a video of a sunset' | ||
} | ||
} as Memory; | ||
const mockState = {} as State; | ||
|
||
await videoGenerationPlugin.actions[0].handler( | ||
mockRuntime, | ||
mockMessage, | ||
mockState, | ||
{}, | ||
mockCallback | ||
); | ||
|
||
// Check error callback | ||
expect(mockCallback).toHaveBeenCalledWith( | ||
expect.objectContaining({ | ||
text: expect.stringContaining('Video generation failed'), | ||
error: true | ||
}) | ||
); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
"name": "@ai16z/plugin-video-generation", | ||
"version": "0.0.1", | ||
"main": "dist/index.js", | ||
"type": "module", | ||
"types": "dist/index.d.ts", | ||
"dependencies": { | ||
"@ai16z/eliza": "workspace:*", | ||
"tsup": "^8.3.5" | ||
}, | ||
"scripts": { | ||
"build": "tsup --format esm --dts", | ||
"dev": "tsup --watch" | ||
}, | ||
"peerDependencies": { | ||
"whatwg-url": "7.1.0" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
export const LUMA_CONSTANTS = { | ||
API_URL: 'https://api.lumalabs.ai/dream-machine/v1/generations', | ||
API_KEY_SETTING: "LUMA_API_KEY" // The setting name to fetch from runtime | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,221 @@ | ||
import { elizaLogger } from "@ai16z/eliza/src/logger.ts"; | ||
import { | ||
Action, | ||
HandlerCallback, | ||
IAgentRuntime, | ||
Memory, | ||
Plugin, | ||
State, | ||
} from "@ai16z/eliza/src/types.ts"; | ||
import fs from "fs"; | ||
import { LUMA_CONSTANTS } from './constants'; | ||
|
||
const generateVideo = async (prompt: string, runtime: IAgentRuntime) => { | ||
const API_KEY = runtime.getSetting(LUMA_CONSTANTS.API_KEY_SETTING); | ||
|
||
try { | ||
elizaLogger.log("Starting video generation with prompt:", prompt); | ||
|
||
const response = await fetch(LUMA_CONSTANTS.API_URL, { | ||
method: 'POST', | ||
headers: { | ||
'Authorization': `Bearer ${API_KEY}`, | ||
'accept': 'application/json', | ||
'Content-Type': 'application/json' | ||
}, | ||
body: JSON.stringify({ prompt }) | ||
}); | ||
|
||
if (!response.ok) { | ||
const errorText = await response.text(); | ||
elizaLogger.error("Luma API error:", { | ||
status: response.status, | ||
statusText: response.statusText, | ||
error: errorText | ||
}); | ||
throw new Error(`Luma API error: ${response.statusText} - ${errorText}`); | ||
} | ||
|
||
const data = await response.json(); | ||
elizaLogger.log("Generation request successful, received response:", data); | ||
|
||
// Poll for completion | ||
let status = data.status; | ||
let videoUrl = null; | ||
const generationId = data.id; | ||
|
||
while (status !== 'completed' && status !== 'failed') { | ||
await new Promise(resolve => setTimeout(resolve, 5000)); // Wait 5 seconds | ||
|
||
const statusResponse = await fetch(`${LUMA_CONSTANTS.API_URL}/${generationId}`, { | ||
method: 'GET', | ||
headers: { | ||
'Authorization': `Bearer ${API_KEY}`, | ||
'accept': 'application/json' | ||
} | ||
}); | ||
|
||
if (!statusResponse.ok) { | ||
const errorText = await statusResponse.text(); | ||
elizaLogger.error("Status check error:", { | ||
status: statusResponse.status, | ||
statusText: statusResponse.statusText, | ||
error: errorText | ||
}); | ||
throw new Error('Failed to check generation status: ' + errorText); | ||
} | ||
|
||
const statusData = await statusResponse.json(); | ||
elizaLogger.log("Status check response:", statusData); | ||
|
||
status = statusData.state; | ||
if (status === 'completed') { | ||
videoUrl = statusData.assets?.video; | ||
} | ||
} | ||
|
||
if (status === 'failed') { | ||
throw new Error('Video generation failed'); | ||
} | ||
|
||
if (!videoUrl) { | ||
throw new Error('No video URL in completed response'); | ||
} | ||
|
||
return { | ||
success: true, | ||
data: videoUrl | ||
}; | ||
} catch (error) { | ||
elizaLogger.error("Video generation error:", error); | ||
return { | ||
success: false, | ||
error: error.message || 'Unknown error occurred' | ||
}; | ||
} | ||
} | ||
|
||
const videoGeneration: Action = { | ||
name: "GENERATE_VIDEO", | ||
similes: [ | ||
"VIDEO_GENERATION", | ||
"VIDEO_GEN", | ||
"CREATE_VIDEO", | ||
"MAKE_VIDEO", | ||
"RENDER_VIDEO", | ||
"ANIMATE", | ||
"CREATE_ANIMATION", | ||
"VIDEO_CREATE", | ||
"VIDEO_MAKE" | ||
], | ||
description: "Generate a video based on a text prompt", | ||
validate: async (runtime: IAgentRuntime, message: Memory) => { | ||
elizaLogger.log("Validating video generation action"); | ||
const lumaApiKey = runtime.getSetting("LUMA_API_KEY"); | ||
elizaLogger.log("LUMA_API_KEY present:", !!lumaApiKey); | ||
return !!lumaApiKey; | ||
}, | ||
handler: async ( | ||
runtime: IAgentRuntime, | ||
message: Memory, | ||
state: State, | ||
options: any, | ||
callback: HandlerCallback | ||
) => { | ||
elizaLogger.log("Video generation request:", message); | ||
|
||
// Clean up the prompt by removing mentions and commands | ||
let videoPrompt = message.content.text | ||
.replace(/<@\d+>/g, '') // Remove mentions | ||
.replace(/generate video|create video|make video|render video/gi, '') // Remove commands | ||
.trim(); | ||
|
||
if (!videoPrompt || videoPrompt.length < 5) { | ||
callback({ | ||
text: "Could you please provide more details about what kind of video you'd like me to generate? For example: 'Generate a video of a sunset on a beach' or 'Create a video of a futuristic city'", | ||
}); | ||
return; | ||
} | ||
|
||
elizaLogger.log("Video prompt:", videoPrompt); | ||
|
||
callback({ | ||
text: `I'll generate a video based on your prompt: "${videoPrompt}". This might take a few minutes...`, | ||
}); | ||
|
||
try { | ||
const result = await generateVideo(videoPrompt, runtime); | ||
|
||
if (result.success && result.data) { | ||
// Download the video file | ||
const response = await fetch(result.data); | ||
const arrayBuffer = await response.arrayBuffer(); | ||
const videoFileName = `content_cache/generated_video_${Date.now()}.mp4`; | ||
|
||
// Save video file | ||
fs.writeFileSync(videoFileName, Buffer.from(arrayBuffer)); | ||
|
||
callback({ | ||
text: "Here's your generated video!", | ||
attachments: [ | ||
{ | ||
id: crypto.randomUUID(), | ||
url: result.data, | ||
title: "Generated Video", | ||
source: "videoGeneration", | ||
description: videoPrompt, | ||
text: videoPrompt, | ||
}, | ||
], | ||
}, [videoFileName]); // Add the video file to the attachments | ||
} else { | ||
callback({ | ||
text: `Video generation failed: ${result.error}`, | ||
error: true | ||
}); | ||
} | ||
} catch (error) { | ||
elizaLogger.error(`Failed to generate video. Error: ${error}`); | ||
callback({ | ||
text: `Video generation failed: ${error.message}`, | ||
error: true | ||
}); | ||
} | ||
}, | ||
examples: [ | ||
[ | ||
{ | ||
user: "{{user1}}", | ||
content: { text: "Generate a video of a cat playing piano" }, | ||
}, | ||
{ | ||
user: "{{agentName}}", | ||
content: { | ||
text: "I'll create a video of a cat playing piano for you", | ||
action: "GENERATE_VIDEO" | ||
}, | ||
} | ||
], | ||
[ | ||
{ | ||
user: "{{user1}}", | ||
content: { text: "Can you make a video of a sunset at the beach?" }, | ||
}, | ||
{ | ||
user: "{{agentName}}", | ||
content: { | ||
text: "I'll generate a beautiful beach sunset video for you", | ||
action: "GENERATE_VIDEO" | ||
}, | ||
} | ||
] | ||
] | ||
} as Action; | ||
|
||
export const videoGenerationPlugin: Plugin = { | ||
name: "videoGeneration", | ||
description: "Generate videos using Luma AI", | ||
actions: [videoGeneration], | ||
evaluators: [], | ||
providers: [], | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"extends": "../../tsconfig.json", | ||
"compilerOptions": { | ||
"outDir": "dist", | ||
"rootDir": ".", | ||
"module": "ESNext", | ||
"moduleResolution": "Bundler", | ||
"types": [ | ||
"node" | ||
] | ||
}, | ||
"include": [ | ||
"src" | ||
] | ||
} |
Oops, something went wrong.