-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Updated project folder structure and readme. Updated script files for…
… faster image generation
- Loading branch information
Showing
14 changed files
with
651 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from keras.layers import Conv2D, BatchNormalization,MaxPooling2D, Flatten, Dense, Dropout | ||
from keras.models import Sequential | ||
from keras import backend as K | ||
|
||
class ConvolutionalNeuralNet: | ||
@staticmethod | ||
def build(width, height, depth, classes): | ||
#initlialize model along with input shape to be | ||
#"channels last" and channels dimension itself | ||
model = Sequential() | ||
inputShape = (width, height, depth) | ||
chanDim = -1 | ||
|
||
#if we are using "channels first" update the | ||
#input shape and channels dimension | ||
if K.image_data_format() == "channels_first": | ||
inputShape = (depth, height, width) | ||
chanDim = 1 | ||
|
||
#CONV -> RELU -> POOL | ||
model.add(Conv2D(32, (3,3), padding='same', activation='relu', input_shape=inputShape)) | ||
model.add(BatchNormalization(axis=chanDim)) | ||
model.add(MaxPooling2D(pool_size=(3,3))) | ||
model.add(Dropout(0.25)) | ||
|
||
#(CONV -> RELU) * 2 -> POOL | ||
model.add(Conv2D(64, (3,3), padding='same', activation='relu')) | ||
model.add(BatchNormalization(axis=chanDim)) | ||
model.add(Conv2D(64, (3,3), padding='same', activation='relu')) | ||
model.add(BatchNormalization(axis=chanDim)) | ||
model.add(MaxPooling2D(pool_size=(2,2))) | ||
model.add(Dropout(0.25)) | ||
|
||
#(CONV -> RELU) * 2 -> POOL | ||
model.add(Conv2D(128, (3,3), padding='same', activation='relu')) | ||
model.add(BatchNormalization(axis=chanDim)) | ||
model.add(Conv2D(128, (3,3), padding='same', activation='relu')) | ||
model.add(BatchNormalization(axis=chanDim)) | ||
model.add(MaxPooling2D(pool_size=(2,2))) | ||
model.add(Dropout(0.25)) | ||
|
||
#first and only set of Fully Connected layers (FC -> RELU) | ||
model.add(Flatten()) | ||
model.add(Dense(1024, activation='relu')) | ||
model.add(BatchNormalization()) | ||
model.add(Dropout(0.5)) | ||
|
||
#softmax classifier | ||
model.add(Dense(2, activation='softmax')) | ||
|
||
#return the constructed network architeture | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
from ../scripts/constant import IMAGE_DIMS, BATCH_SIZE, INIT_LR, EPOCHS, IMAGE_PATH | ||
from keras.preprocessing.image import ImageDataGenerator | ||
from sklearn.model_selection import train_test_split | ||
from cnn import ConvolutionalNeuralNet | ||
from keras.models import Sequential | ||
from keras.optimizers import Adam | ||
#import matplotlib.pyplot as plt | ||
import numpy as np | ||
import cv2 | ||
import os | ||
|
||
input_shape = IMAGE_DIMS | ||
|
||
ad_directory = os.path.join(IMAGE_PATH, 'ad') | ||
notad_directory = os.path.join(IMAGE_PATH, 'notad') | ||
|
||
ad_images = [os.path.join(ad_directory, x) for x in os.listdir(ad_directory) if 'DS_Store' not in x][:MAX_IMAGES] | ||
notad_images = [os.path.join(notad_directory, x) for x in os.listdir(notad_directory) if 'DS_Store' not in x][:MAX_IMAGES] | ||
all_images = ad_images + notad_images | ||
|
||
ad_labels = [1 for x in ad_images] | ||
notad_labels = [0 for x in notad_images] | ||
all_labels = ad_labels + notad_labels | ||
|
||
#Loading images into arrays | ||
print('[INFO] Loading images...') | ||
all_image_data = [cv2.resize(cv2.imread(x, 0), (input_shape[0], input_shape[1])) for x in all_images] | ||
random_permutation = np.random.permutation(len(all_image_data)) | ||
#Turn images and labels into a numpy array and randomise the indexes | ||
x,y = (np.array(all_image_data, dtype='float')/255)[random_permutation], np.array(all_labels)[random_permutation] | ||
|
||
x_train, x_test, y_train, y_test = train_test_split(x,y, test_size=0.27, random_state=37) | ||
|
||
#Reshape array into [input_shape e.g. (num of iamges,100,100,1) | ||
x_train = np.reshape(x_train, (x_train.shape[0], input_shape[0], input_shape[1], input_shape[2])) | ||
x_test = np.reshape(x_test, (x_test.shape[0], input_shape[0], input_shape[1], 1)) | ||
|
||
#plt.imshow(x_train[0], cmap='gray') | ||
#plt.show() | ||
|
||
aug = ImageDataGenerator(rotation_range=0, width_shift_range=0.1, | ||
height_shift_range=0.1, shear_range=0.2, zoom_range=0.2, | ||
horizontal_flip=True, fill_mode='nearest') | ||
|
||
model = ConvolutionalNeuralNet.build(width=input_shape[0], height=input_shape[1], | ||
depth=input_shape[2], classes=2) | ||
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS) | ||
|
||
#Compile the model | ||
print('[INFO] Compiling the model...') | ||
model.compile(loss="sparse_categorical_crossentropy", | ||
optimizer=opt, | ||
metrics=["accuracy"]) | ||
|
||
#Train the model | ||
print('[INFO] Training the network...') | ||
model.fit_generator(aug.flow(x_train, y_train, batch_size=BATCH_SIZE), | ||
validation_data=(x_test, y_test), | ||
steps_per_epoch=len(x_train)//BS, | ||
epochs=EPOCHS, verbose=1) | ||
|
||
#Save the model locally | ||
print('[INFO] Saving model...') | ||
model.save("cnn.model") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
from ../scripts/constants import TMP_PATH | ||
from ../scripts/constants import TMP_PATH | ||
from multiprocessing.dummy import Pool as ThreadPool | ||
from urllib.parse import urlparse | ||
import subprocess | ||
import requests | ||
import pyprind | ||
import urllib | ||
import os | ||
import re | ||
|
||
class M3U8: | ||
""" | ||
M3U8 Downloader using Multiprocessing to speed up download | ||
""" | ||
|
||
M3U8_TAG = "#EXT-X-I-FRAME-STREAM-INF:" | ||
RES_REGEX = "([0-9]{0,5})x([0-9]{0,5})" | ||
URI_REGEX = 'URI="(.*?)"' | ||
NUM_THREAD = 1000 | ||
|
||
# Init function initializes all variables | ||
def __init__(self, m3u8_playlist_page,video_name): | ||
self.playlist_page = m3u8_playlist_page | ||
self.m3u8_url = '' | ||
self.m3u8_body = '' | ||
self.video_name = video_name | ||
self.progress_bar = None | ||
self.download_percentage = 0 | ||
parsed_playlist_path = os.path.basename(urlparse(m3u8_playlist_page).path) | ||
self.basepath = m3u8_playlist_page.split(parsed_playlist_path)[0] | ||
|
||
# Function which chooses the playlist e.g. the quality of the stream and saves the URL into self.m3u8_url | ||
def choose_playlist(self, index): | ||
req = requests.get(self.playlist_page) | ||
split_response = req.text.split('\n') | ||
m3u8_lines = [] | ||
# Split the m3u8 of any unecessary tags | ||
for line in split_response: | ||
if self.M3U8_TAG in line: | ||
m3u8_lines.append(line.split(self.M3U8_TAG)[1]) | ||
# Sort the m3u8 playlists in orser of resolution | ||
m3u8_lines = sorted(m3u8_lines, key=lambda x: int(re.search(re.compile(self.RES_REGEX), x).group(1)), reverse=True) | ||
print(f"[INFO] You have chosen the stream with resolution: {re.search(re.compile(self.RES_REGEX),m3u8_lines[index]).group(0)}") | ||
print(f"[INFO] The URL is: {re.search(re.compile(self.URI_REGEX),m3u8_lines[index]).group(1)}") | ||
print(f"[INFO] The Video Title is: {m3u8_lines[index]}") | ||
# Select the m3u8 playlist and extract the url from the line | ||
self.m3u8_url = self.basepath+re.search(re.compile(self.URI_REGEX), m3u8_lines[index]).group(1) | ||
|
||
# Function which geoes through the chosen URL and saves the mp4 file URL's in the variable self.m3u8_body | ||
def get_m3u8_body(self): | ||
url_identifier = os.path.basename(self.m3u8_url).split('.')[0] | ||
req = requests.get(self.m3u8_url) | ||
regex = url_identifier + '.*?.ts' | ||
results = re.findall(re.compile(regex), req.text) | ||
# Skip every 2 because there are two urls per video section | ||
results = [self.basepath + x for x in results][::2] | ||
self.m3u8_body = results | ||
self.progress_bar = pyprind.ProgBar(len(results)) | ||
print(f'[INFO] Number of .ts files: {len(results)}') | ||
|
||
# Function which downloads a video from given URL and outputs the current % downloaded | ||
def download_url(self, url, index): | ||
self.download_percentage += 1 | ||
self.progress_bar.update() | ||
data = urllib.request.urlopen(url).read() | ||
with open(f'{TMP_PATH}/{index}.ts', 'wb') as f: | ||
f.write(data) | ||
|
||
# Main Download function which initiates a pool of workers to download the ts stream by multiprocessing | ||
def download(self): | ||
print("[STATUS] Downloading...") | ||
pool = ThreadPool(14) | ||
# Using starmap as it accepts a tuple for the function arguments | ||
results = pool.starmap(self.download_url, zip(self.m3u8_body, range(len(self.m3u8_body)))) | ||
pool.close() | ||
pool.join() | ||
print('[STATUS] Joining files into single video') | ||
self.cleanup_writefile() | ||
|
||
# Cleanup tmp folder and turn into one video | ||
def cleanup_writefile(self): | ||
pattern = re.compile('([0-9]*?).ts') | ||
files = [x for x in os.listdir('{TMP_PATH}/') if 'DS_Store' not in x] | ||
# Sort list of files into the video number | ||
files = sorted(files, key=lambda x: int(re.search(pattern, x).group(1))) | ||
files = [os.path.join(TMP_PATH, x) for x in files] | ||
|
||
previous_data = b'' | ||
with open(f'{VIDEO_PATH}/{self.video_name.strip()}.ts', 'wb') as video_file: | ||
for f in files: | ||
with open(f, 'rb') as sub_file: | ||
data = sub_file.read() | ||
if data != previous_data: | ||
video_file.write(data) | ||
previous_data = data | ||
os.remove(f) |
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from keras.preprocessing.image import img_to_array | ||
from keras.models import load_model | ||
from constants import IMAGE_DIMS | ||
import numpy as np | ||
import argparse | ||
import pickle | ||
import cv2 | ||
import os | ||
|
||
ap = argparse.ArgumentParser() | ||
ap.add_argument("-i", "--image", required=True) | ||
ap.add_argument("-m", "--model", required=True) | ||
args = vars(ap.parse_args()) | ||
|
||
#load the image | ||
image = cv2.imread(args["image"], 0) | ||
output = image.copy() | ||
|
||
#pre-process image for classification | ||
#[::-1] because otherwise it reshaped with | ||
#(input_shape[1], input_shape[0]) as new dim | ||
image = cv2.resize(image,INPUT_DIMS[::-1]) | ||
image = image.astype('float') / 255.0 | ||
image = img_to_array(image) | ||
#change shape to an array e.g. (72,72,1) to (1,72,72,1) | ||
image = np.expand_dims(image, axis=0) | ||
|
||
#load model | ||
model = load_model(args["model"]) | ||
|
||
#classify the input image | ||
print('[INFO] Classifying image...') | ||
prb = model.predict(image)[0] | ||
idx = np.argmax(prb) | ||
print(idx) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import os | ||
|
||
# Video and image constants | ||
VIDEO_EXT = '.mp4' | ||
ASPECT_RATIO = 1080/1920 | ||
IMAGE_DIMS = (128, int(128*ASPECT_RATIO)) | ||
|
||
# Path constants | ||
WORKING_DIR = os.path.abspath(os.path.join(os.getcwd(), '../')) | ||
TMP_PATH = os.path.join(WORKING_DIR, 'tmp') | ||
IMAGE_PATH = os.path.join(WORKING_DIR, 'images') | ||
VIDEO_PATH = os.path.join(WORKING_DIR, 'videos') | ||
VIDEO_CUT_PATH = os.path.join(VIDEO_PATH, 'cuts') | ||
VIDEO_FINAL_PATH = os.path.join(VIDEO_PATH, 'final_videos') | ||
|
||
# Learning constants | ||
MAX_IMAGES = 1000 | ||
BATCH_SIZE = 32 | ||
INIT_LR = 1e-3 | ||
EPOCHS = 50 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from moviepy.video.io.VideoFileClip import VideoFileClip | ||
from constants import VIDEO_PATH, VIDEO_CUT_PATH | ||
import subprocess | ||
import argparse | ||
import random | ||
import cv2 | ||
import sys | ||
import os | ||
|
||
ap = argparse.ArgumentParser() | ||
ap.add_argument('-s', '--start_time', required=True) | ||
ap.add_argument('-d', '--duration', required=True) | ||
ap.add_argument('-f', '--filename', required=True) | ||
ap.add_argument('-g', '--get_images', type=bool) | ||
ap.add_argument('-c', '--count', type=float) | ||
ap.add_argument('-a', '--ad') | ||
args = vars(ap.parse_args()) | ||
|
||
print('*' * 32) | ||
print('\n\tArguments: \n\t Start Time (Hours:Minutes:Seconds)\n\t End Time (Same) \n\t Cut Images (Bool) \n\t Image Cut Percentage (float) \n\t Ad Section (Bool)\n') | ||
print('*' * 32) | ||
|
||
filename, start_time, end_time = args['filename'], args['start_time'], args['duration'] | ||
# Start/end time is in the format '14.30' in minutes and subclip takes it in seconds so multiply mins by 60 | ||
|
||
# If you want to generate images is the section being cut an ad section or not ads | ||
image_extra_path = '' | ||
get_images = False | ||
every_second = 0 | ||
if len(args) > 3: | ||
get_images = args['get_images'] | ||
every_second = args['count'] | ||
image_extra_path = 'ad' if args['ad'].strip() == 'True' else 'notad' | ||
|
||
print(f"* Start Time: {start_time} , End Time: {end_time} *\n") | ||
video_name = f'{os.path.basename(filename)[:-3].replace(" ", "")}_{str(start_time).replace(":", "-")}_{str(end_time).replace(":","-")}_{image_extra_path}' | ||
new_video_path = os.path.join(VIDEO_CUT_PATH, video_name + '.mp4') | ||
|
||
#new_clip.write_videofile(new_video_path, audio_codec='aac') | ||
|
||
subprocess.call(['ffmpeg', '-ss', f'{start_time}','-i', f'{filename}','-to',f'{end_time}', | ||
'-c','copy', f'{VIDEO_CUT_PATH}/{video_name}.mp4']) | ||
|
||
if get_images: | ||
video_f = VideoFileClip(new_video_path) | ||
total_frame_count = video_f.duration | ||
# Calculate how many images i need to cut (total seconds/images to cut per seconds) | ||
image_cut_count = int(total_frame_count/every_second) | ||
print(f'[STATUS] Cutting {image_cut_count} Images') | ||
for image in range(image_cut_count): | ||
print(f'\r [STATUS] {image_cut_count-image} Images Left', end='') | ||
video_f.save_frame(f'images/{image_extra_path}/{os.path.basename(filename.strip())[0:3]}{random.randrange(1,1000000)}.jpg', t=image*every_second) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from urllib.parse import urlparse | ||
import subprocess | ||
import requests | ||
import argparse | ||
import m3u8_downloader | ||
import json | ||
import os | ||
import re | ||
|
||
URL = "https://overwatchleague.com/en-gb/videos" | ||
|
||
headers = { | ||
'authority': "overwatchleague.com", | ||
'cache-control': "max-age=0,no-cache", | ||
'upgrade-insecure-requests': "1", | ||
'user-agent': "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_14_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/72.0.3626.119 Safari/537.36", | ||
'dnt': "1", | ||
'accept': "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8", | ||
'referer': "https://overwatchleague.com/en-gb/schedule", | ||
'accept-encoding': "gzip, deflate, br", | ||
'accept-language': "en-GB,en-US;q=0.9,en;q=0.8", | ||
'cookie': "locale=en_GB; optimizelyEndUserId=oeu1550218223756r0.2747767037296329; _ga=GA1.2.1054216460.1550218224; _cb_ls=1; showSpoilers=true; _cb=BRYHJDCpz3rqBQ0dRl; __cfduid=ddf4186d24f9805e516c775a2cbdcab541550771447; session=jCRjuNS_4sTtazUOZi_-ug._hY9Xk1qwzVG1poL79tuhXOSBaJLUaEli4oeyknMvz3fxCcUh_oeL5IuI2DG5w4wf3IBYxbRWhC8uoXA_u3UXQ.1551982450964.86400000.jpokG5wnTF9qWh9s9wtjMO195w9Xvjon93dbbsvZKJs; _gid=GA1.2.1606585728.1551982452; _chartbeat2=.1550218225376.1551982462841.0000111001000001.BN9zODC8Ei_YKosLBCfQrCwBTXfPr.1; _cb_svref=https%3A%2F%2Foverwatchleague.com%2Fen-gb%2Fschedule", | ||
'Postman-Token': "14b6702a-d5f6-45b5-bd28-59fe7d21ac82" | ||
} | ||
|
||
req = requests.get(URL, headers=headers) | ||
video_tags = re.compile('class="Card-link(.*?)"(.*?)data-mlg-embed="(.*?)"') | ||
#get video url | ||
top_video_urls = re.findall(video_tags, req.text) | ||
|
||
#look over each video and get the title to choose from | ||
data_title = re.compile('data-title="(.*?)"') | ||
videos = [] | ||
for link in top_video_urls: | ||
title_search = re.search(data_title, link[1].strip()) | ||
title = title_search.group(1).strip() | ||
|
||
#there are preview's interviews etc. jut want full matches | ||
if "Full Match" in title: | ||
url = link[-1] | ||
split_title = title.split("|") | ||
videos.append((split_title[len(split_title)-1], url)) | ||
|
||
current_video=1 | ||
for video in videos: | ||
print(f"Video {current_video}: {video[0]}") | ||
current_video+=1 | ||
|
||
chosen_video = int(input("Choose a video: ")) | ||
chosen_video_url = videos[chosen_video-1][1] | ||
print(f"[INFO] Chosen Video: {videos[chosen_video-1][0]}") | ||
|
||
#find m3u8 variable in javascript streamUrl: https://something.m3u8 | ||
req = requests.get(chosen_video_url) | ||
variable = re.compile('streamUrl":"(.*?)(.m3u8)') | ||
url = variable.search(req.text).group(1) + variable.search(req.text).group(2) | ||
|
||
#parse url to get basepath for once i get the m3u8 url as it only has relative path | ||
parsed_url = urlparse(url) | ||
file_path = parsed_url.path | ||
basepath = 'http://' + parsed_url.netloc + file_path.split(os.path.basename(file_path))[0] | ||
#add http: beacuse the url is //m3u8urlhere | ||
general_m3u8 = 'http:' + url | ||
print("[INFO] Top-Level M3U8 URL: " + general_m3u8) | ||
|
||
downloader = m3u8_downloader.M3U8(general_m3u8, videos[chosen_video-1][0]) | ||
downloader.choose_playlist(0) | ||
downloader.get_m3u8_body() | ||
downloader.download() |
Oops, something went wrong.