Skip to content

Commit

Permalink
giphy reactions and siamese net flows
Browse files Browse the repository at this point in the history
  • Loading branch information
EC2 Default User committed Jan 28, 2019
1 parent 57ad19a commit 6ec1523
Show file tree
Hide file tree
Showing 15 changed files with 3,740 additions and 8 deletions.
5 changes: 5 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os.path
import numpy as np
from numpy.random import randint
from collections import defaultdict

class VideoRecord(object):
def __init__(self, row):
Expand Down Expand Up @@ -43,6 +44,10 @@ def __init__(self, root_path, list_file,
self.new_length += 1# Diff needs one more image to calculate diff

self._parse_list()

self.label2videos = defaultdict(list)
for i, x in enumerate(self.video_list):
self.label2videos[x.label].append(i)

def _load_image(self, directory, idx):
if self.modality == 'RGB' or self.modality == 'RGBDiff':
Expand Down
84 changes: 84 additions & 0 deletions dataset_siamese.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import random
import os
import numpy as np
import torch

from dataset import TSNDataSet
from vidaug import augmentors as va # pip3 install git+https://github.com/okankop/vidaug --user


class RGB2Gray(object):
def __call__(self, clip):
return [x.convert('L').convert('RGB') for x in clip]


def augmentation(prob=0.5, N=2, random_order=True):
sometimes = lambda aug: va.Sometimes(prob, aug) # Used to apply augmentor with 50% probability
return va.Sequential([
va.SomeOf(
[
sometimes(va.GaussianBlur(sigma=3.0)),
sometimes(va.ElasticTransformation(alpha=3.5, sigma=0.25)),
sometimes(va.PiecewiseAffineTransform(displacement=5, displacement_kernel=1, displacement_magnification=1)),
sometimes(va.RandomRotate(degrees=10)),
sometimes(va.RandomResize(0.5)),
sometimes(va.RandomTranslate(x=20, y=20)),
sometimes(va.RandomShear(x=0.2, y=0.2)),
sometimes(va.InvertColor()),
sometimes(va.Add(100)),
sometimes(va.Multiply(1.2)),
sometimes(va.Pepper()),
sometimes(va.Salt()),
sometimes(va.HorizontalFlip()),
sometimes(va.TemporalElasticTransformation()),
sometimes(RGB2Gray())
],
N=N,
random_order=random_order
)])


aug = augmentation()


class SiameseDataset(TSNDataSet):
def __getitem__(self, _):
path, data, label, _ = self.get()
should_get_same_class = random.randint(0, 1)
if bool(should_get_same_class):
other_index = random.choice(self.label2videos[label])
other_path, other_data, other_label, _ = self.get(other_index)
else:
# TODO: fix this dirty hack
other_labels = random.sample(self.label2videos.keys(), 2)
other_label = next(x for x in other_labels if x != label)
other_index = random.choice(self.label2videos[other_label])
other_path, other_data, other_label, _ = self.get(other_index)
return data, other_data, torch.Tensor([float(label == other_label)])

def get(self, index=None, apply_aug=True):
if index is None:
index = random.choice(range(len(self.video_list)))
record = self.video_list[index]

if not self.test_mode:
indices = self._sample_indices(record) if self.random_shift else self._get_val_indices(record)
else:
indices = self._get_test_indices(record)

images = list()
for seg_ind in indices:
p = int(seg_ind)
for i in range(self.new_length):
seg_imgs = self._load_image(record.path, p)
images.extend(seg_imgs)
if p < record.num_frames:
p += 1

if apply_aug:
images = aug(images)
process_data = self.transform(images)
return record.path, process_data, record.label, index

def __len__(self):
return len(self.video_list) * 4
6 changes: 4 additions & 2 deletions datasets_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import torchvision.datasets as datasets


ROOT_DATASET = '/home/ec2-user/mnt/giphy_dataset'
# ROOT_DATASET = '/home/ec2-user/mnt/giphy_dataset'
ROOT_DATASET = '/home/ec2-user/gifs'


def return_custom(modality):
filename_categories = 'category.txt'
if modality == 'RGB':
root_data = '/home/ec2-user/mnt/giphy_dataset'
# root_data = '/home/ec2-user/mnt/giphy_dataset'
root_data = '/home/ec2-user/gifs'
filename_imglist_train = 'train_videofolder.txt'
filename_imglist_val = 'val_videofolder.txt'
prefix = '{}.jpg'
Expand Down
84 changes: 84 additions & 0 deletions grab_gifs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os
import pandas as pd
import moviepy.editor as mov_editor
import logging
from time import sleep
from multiprocessing import Pool
from PIL import Image
from itertools import islice

logging.basicConfig(
format='%(asctime)s:%(levelname)s: %(message)s',
level=logging.INFO
)

# OUTPUT_PATH = '/home/ec2-user/mnt/giphy_dataset'
OUTPUT_PATH = '/home/ec2-user/gifs'

# gifs2cat = pd.read_csv('/home/ec2-user/gifs2cat.csv')
# gifs2cat = gifs2cat.groupby('gif_id')['category'].apply(list)
# gifs2cat = pd.read_csv('/home/ec2-user/reactions.csv')
import pickle
from math import ceil

payload = pickle.load(open('/home/ec2-user/siamese_dataset.pkl', 'rb'))
gifs = []
for x in payload:
gifs.extend(list(x))

def evenly_spaced_sampling(array, n):
"""Choose `n` evenly spaced elements from `array` sequence"""
length = len(array)

if n == 0 or length == 0:
return []
elif n == length:
return array
elif n < length:
return [array[ceil(i * length / n)] for i in range(n)]
elif n > length:
result = []
for _ in range(ceil(n / length)):
result.extend(array)
return result[:n]

def get_gif_mov_url(id, ext='mp4'):
return f'https://media.giphy.com/media/{id}/giphy.{ext}'

def get_gif(id):
gif = None
for i in range(5):
try:
gif = mov_editor.VideoFileClip(get_gif_mov_url(id))
break
except Exception as ex:
sleep(1)
logging.info(f'Error: {type(ex)}:{ex}. {i} times. With extension mp4.')
else:
for i in range(5):
try:
gif = mov_editor.VideoFileClip(get_gif_mov_url(id, ext='gif'))
break
except Exception as ex:
sleep(1)
logging.info(f'Error: {type(ex)}:{ex}. {i} times. With extension gif.')
return gif

def save_gif(id):
print(OUTPUT_PATH, id)
directory = os.path.join(OUTPUT_PATH, id)
if os.path.isdir(directory):
return None
gif = get_gif(id)
if gif is not None:
os.mkdir(directory)
frames = list(gif.iter_frames())
for i, x in enumerate(evenly_spaced_sampling(frames, 50)):
Image.fromarray(x).save(os.path.join(directory, f'{i}.jpg'))
del gif

print('Total gifs:', len(gifs))
with Pool(processes=16) as executor:
executor.map(save_gif, gifs)
# executor.map(save_gif, gifs2cat.index)
#executor.map(save_gif, gifs2cat.gif_id)
4 changes: 2 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def main():
dropout=args.dropout,
img_feature_dim=args.img_feature_dim,
partial_bn=not args.no_partialbn)
base_model.train(False)
for p in base_model.parameters():
_, cnn =list(base_model.named_children())[0]
for p in cnn.parameters():
p.requires_grad = False

crop_size = base_model.crop_size
Expand Down
Loading

0 comments on commit 6ec1523

Please sign in to comment.