forked from zhoubolei/TRN-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
giphy reactions and siamese net flows
- Loading branch information
EC2 Default User
committed
Jan 28, 2019
1 parent
57ad19a
commit 6ec1523
Showing
15 changed files
with
3,740 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import random | ||
import os | ||
import numpy as np | ||
import torch | ||
|
||
from dataset import TSNDataSet | ||
from vidaug import augmentors as va # pip3 install git+https://github.com/okankop/vidaug --user | ||
|
||
|
||
class RGB2Gray(object): | ||
def __call__(self, clip): | ||
return [x.convert('L').convert('RGB') for x in clip] | ||
|
||
|
||
def augmentation(prob=0.5, N=2, random_order=True): | ||
sometimes = lambda aug: va.Sometimes(prob, aug) # Used to apply augmentor with 50% probability | ||
return va.Sequential([ | ||
va.SomeOf( | ||
[ | ||
sometimes(va.GaussianBlur(sigma=3.0)), | ||
sometimes(va.ElasticTransformation(alpha=3.5, sigma=0.25)), | ||
sometimes(va.PiecewiseAffineTransform(displacement=5, displacement_kernel=1, displacement_magnification=1)), | ||
sometimes(va.RandomRotate(degrees=10)), | ||
sometimes(va.RandomResize(0.5)), | ||
sometimes(va.RandomTranslate(x=20, y=20)), | ||
sometimes(va.RandomShear(x=0.2, y=0.2)), | ||
sometimes(va.InvertColor()), | ||
sometimes(va.Add(100)), | ||
sometimes(va.Multiply(1.2)), | ||
sometimes(va.Pepper()), | ||
sometimes(va.Salt()), | ||
sometimes(va.HorizontalFlip()), | ||
sometimes(va.TemporalElasticTransformation()), | ||
sometimes(RGB2Gray()) | ||
], | ||
N=N, | ||
random_order=random_order | ||
)]) | ||
|
||
|
||
aug = augmentation() | ||
|
||
|
||
class SiameseDataset(TSNDataSet): | ||
def __getitem__(self, _): | ||
path, data, label, _ = self.get() | ||
should_get_same_class = random.randint(0, 1) | ||
if bool(should_get_same_class): | ||
other_index = random.choice(self.label2videos[label]) | ||
other_path, other_data, other_label, _ = self.get(other_index) | ||
else: | ||
# TODO: fix this dirty hack | ||
other_labels = random.sample(self.label2videos.keys(), 2) | ||
other_label = next(x for x in other_labels if x != label) | ||
other_index = random.choice(self.label2videos[other_label]) | ||
other_path, other_data, other_label, _ = self.get(other_index) | ||
return data, other_data, torch.Tensor([float(label == other_label)]) | ||
|
||
def get(self, index=None, apply_aug=True): | ||
if index is None: | ||
index = random.choice(range(len(self.video_list))) | ||
record = self.video_list[index] | ||
|
||
if not self.test_mode: | ||
indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record) | ||
else: | ||
indices = self._get_test_indices(record) | ||
|
||
images = list() | ||
for seg_ind in indices: | ||
p = int(seg_ind) | ||
for i in range(self.new_length): | ||
seg_imgs = self._load_image(record.path, p) | ||
images.extend(seg_imgs) | ||
if p < record.num_frames: | ||
p += 1 | ||
|
||
if apply_aug: | ||
images = aug(images) | ||
process_data = self.transform(images) | ||
return record.path, process_data, record.label, index | ||
|
||
def __len__(self): | ||
return len(self.video_list) * 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
import pandas as pd | ||
import moviepy.editor as mov_editor | ||
import logging | ||
from time import sleep | ||
from multiprocessing import Pool | ||
from PIL import Image | ||
from itertools import islice | ||
|
||
logging.basicConfig( | ||
format='%(asctime)s:%(levelname)s: %(message)s', | ||
level=logging.INFO | ||
) | ||
|
||
# OUTPUT_PATH = '/home/ec2-user/mnt/giphy_dataset' | ||
OUTPUT_PATH = '/home/ec2-user/gifs' | ||
|
||
# gifs2cat = pd.read_csv('/home/ec2-user/gifs2cat.csv') | ||
# gifs2cat = gifs2cat.groupby('gif_id')['category'].apply(list) | ||
# gifs2cat = pd.read_csv('/home/ec2-user/reactions.csv') | ||
import pickle | ||
from math import ceil | ||
|
||
payload = pickle.load(open('/home/ec2-user/siamese_dataset.pkl', 'rb')) | ||
gifs = [] | ||
for x in payload: | ||
gifs.extend(list(x)) | ||
|
||
def evenly_spaced_sampling(array, n): | ||
"""Choose `n` evenly spaced elements from `array` sequence""" | ||
length = len(array) | ||
|
||
if n == 0 or length == 0: | ||
return [] | ||
elif n == length: | ||
return array | ||
elif n < length: | ||
return [array[ceil(i * length / n)] for i in range(n)] | ||
elif n > length: | ||
result = [] | ||
for _ in range(ceil(n / length)): | ||
result.extend(array) | ||
return result[:n] | ||
|
||
def get_gif_mov_url(id, ext='mp4'): | ||
return f'https://media.giphy.com/media/{id}/giphy.{ext}' | ||
|
||
def get_gif(id): | ||
gif = None | ||
for i in range(5): | ||
try: | ||
gif = mov_editor.VideoFileClip(get_gif_mov_url(id)) | ||
break | ||
except Exception as ex: | ||
sleep(1) | ||
logging.info(f'Error: {type(ex)}:{ex}. {i} times. With extension mp4.') | ||
else: | ||
for i in range(5): | ||
try: | ||
gif = mov_editor.VideoFileClip(get_gif_mov_url(id, ext='gif')) | ||
break | ||
except Exception as ex: | ||
sleep(1) | ||
logging.info(f'Error: {type(ex)}:{ex}. {i} times. With extension gif.') | ||
return gif | ||
|
||
def save_gif(id): | ||
print(OUTPUT_PATH, id) | ||
directory = os.path.join(OUTPUT_PATH, id) | ||
if os.path.isdir(directory): | ||
return None | ||
gif = get_gif(id) | ||
if gif is not None: | ||
os.mkdir(directory) | ||
frames = list(gif.iter_frames()) | ||
for i, x in enumerate(evenly_spaced_sampling(frames, 50)): | ||
Image.fromarray(x).save(os.path.join(directory, f'{i}.jpg')) | ||
del gif | ||
|
||
print('Total gifs:', len(gifs)) | ||
with Pool(processes=16) as executor: | ||
executor.map(save_gif, gifs) | ||
# executor.map(save_gif, gifs2cat.index) | ||
#executor.map(save_gif, gifs2cat.gif_id) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.