forked from StrawberryCheesecake/image_recog
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_preprocessing.py
62 lines (40 loc) · 2.11 KB
/
image_preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import random
from six.moves import xrange # pylint: disable=redefined-builtin
def preprocess_image(image, height, width):
"""
Pre-process an image tensor for training or evaluation
Input:
image: 3D Tensor [height, width, channels] contains the image
height (int) : image expected width
width (int) : image expected height
is_training (boolean) : whether its for training or not, different preprocessing for evaluation
Output:
Also a 3D tensor that does shit and shit
"""
image = tf.image.resize_images(image,[height,width])
reshaped_image = tf.cast(image, tf.float32)
reshaped_image = tf.reshape(reshaped_image, [height, width, 3])
print("Reached data augmentation.")
# Data augmentation- we apply several kinds of distortions
# some at a given probability, to create more permutations and variance in our data set
# Randomly crop a [height, width] section of the image
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
# Randomly flip the image horizontally
distorted_image = tf.image.random_flip_left_right(distorted_image)
# Randomly brighten the image at a given probability
distorted_image = tf.image.random_brightness(distorted_image,
max_delta = 63) if random.random() >= 0.4 else distorted_image
# Randomly saturate the photo's contrast
# Ideally we'd do a PCA analysis of in all 3 of the images's pixel depths
# and adjust the contrast based on some kind of calculated offset.
# This can be added in when we refactor for segmentation and object detection
distorted_image = tf.image.random_contrast(distorted_image,
lower = 0.2, upper = 1.8) if random.random() >= 0.4 else distorted_image
print("Reached data pre-processing.")
# Data Pre-processing- subtract off the mean and divide by the variance of the pixels
# to center the data on the origin and then normalize it
float_image = tf.image.per_image_standardization(distorted_image)
# Set the shapes of the tensors
float_image.set_shape([height, width, 3])
return float_image