-
Notifications
You must be signed in to change notification settings - Fork 7
/
models.py
32 lines (25 loc) · 1.37 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch.nn as nn
from tha2.poser.modes.mode_20 import load_face_morpher, load_face_rotater, load_combiner
class TalkingAnimeLight(nn.Module):
def __init__(self):
super(TalkingAnimeLight, self).__init__()
self.face_morpher = load_face_morpher('pretrained/face_morpher.pt')
self.two_algo_face_rotator = load_face_rotater('pretrained/two_algo_face_rotator.pt')
self.combiner = load_combiner('pretrained/combiner.pt')
def forward(self, image, mouth_eye_vector, pose_vector):
x = image.clone()
mouth_eye_morp_image = self.face_morpher(image[:, :, 32:224, 32:224], mouth_eye_vector)
x[:, :, 32:224, 32:224] = mouth_eye_morp_image
rotate_image = self.two_algo_face_rotator(x, pose_vector)[:2]
output_image = self.combiner(rotate_image[0], rotate_image[1], pose_vector)
return output_image
class TalkingAnime(nn.Module):
def __init__(self):
super(TalkingAnime, self).__init__()
def forward(self, image, mouth_eye_vector, pose_vector):
x = image.clone()
mouth_eye_morp_image = self.face_morpher(image[:, :, 32:224, 32:224], mouth_eye_vector)
x[:, :, 32:224, 32:224] = mouth_eye_morp_image
rotate_image = self.two_algo_face_rotator(x, pose_vector)[:2]
output_image = self.combiner(rotate_image[0], rotate_image[1], pose_vector)
return output_image