-
Notifications
You must be signed in to change notification settings - Fork 7k
/
image.py
255 lines (205 loc) · 9.18 KB
/
image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
from enum import Enum
from warnings import warn
import torch
from ..extension import _load_library
from ..utils import _log_api_usage_once
try:
_load_library("image")
except (ImportError, OSError) as e:
warn(f"Failed to load image Python extension: {e}")
class ImageReadMode(Enum):
"""
Support for various modes while reading images.
Use ``ImageReadMode.UNCHANGED`` for loading the image as-is,
``ImageReadMode.GRAY`` for converting to grayscale,
``ImageReadMode.GRAY_ALPHA`` for grayscale with transparency,
``ImageReadMode.RGB`` for RGB and ``ImageReadMode.RGB_ALPHA`` for
RGB with transparency.
"""
UNCHANGED = 0
GRAY = 1
GRAY_ALPHA = 2
RGB = 3
RGB_ALPHA = 4
def read_file(path: str) -> torch.Tensor:
"""
Reads and outputs the bytes contents of a file as a uint8 Tensor
with one dimension.
Args:
path (str): the path to the file to be read
Returns:
data (Tensor)
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_file)
data = torch.ops.image.read_file(path)
return data
def write_file(filename: str, data: torch.Tensor) -> None:
"""
Writes the contents of a uint8 tensor with one dimension to a
file.
Args:
filename (str): the path to the file to be written
data (Tensor): the contents to be written to the output file
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_file)
torch.ops.image.write_file(filename, data)
def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the PNG image.
mode (ImageReadMode): the read mode used for optionally
converting the image. Default: ``ImageReadMode.UNCHANGED``.
See `ImageReadMode` class for more information on various
available modes.
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_png)
output = torch.ops.image.decode_png(input, mode.value, False)
return output
def encode_png(input: torch.Tensor, compression_level: int = 6) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
of its corresponding PNG file.
Args:
input (Tensor[channels, image_height, image_width]): int8 image tensor of
``c`` channels, where ``c`` must 3 or 1.
compression_level (int): Compression factor for the resulting file, it must be a number
between 0 and 9. Default: 6
Returns:
Tensor[1]: A one dimensional int8 tensor that contains the raw bytes of the
PNG file.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(encode_png)
output = torch.ops.image.encode_png(input, compression_level)
return output
def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
"""
Takes an input tensor in CHW layout (or HW in the case of grayscale images)
and saves it in a PNG file.
Args:
input (Tensor[channels, image_height, image_width]): int8 image tensor of
``c`` channels, where ``c`` must be 1 or 3.
filename (str): Path to save the image.
compression_level (int): Compression factor for the resulting file, it must be a number
between 0 and 9. Default: 6
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_png)
output = encode_png(input, compression_level)
write_file(filename, output)
def decode_jpeg(
input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, device: str = "cpu"
) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 between 0 and 255.
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image. This tensor must be on CPU,
regardless of the ``device`` parameter.
mode (ImageReadMode): the read mode used for optionally
converting the image. Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes.
device (str or torch.device): The device on which the decoded image will
be stored. If a cuda device is specified, the image will be decoded
with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
supported for CUDA version >= 10.1
.. warning::
There is a memory leak in the nvjpeg library for CUDA versions < 11.6.
Make sure to rely on CUDA 11.6 or above before using ``device="cuda"``.
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_jpeg)
device = torch.device(device)
if device.type == "cuda":
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
return output
def encode_jpeg(input: torch.Tensor, quality: int = 75) -> torch.Tensor:
"""
Takes an input tensor in CHW layout and returns a buffer with the contents
of its corresponding JPEG file.
Args:
input (Tensor[channels, image_height, image_width])): int8 image tensor of
``c`` channels, where ``c`` must be 1 or 3.
quality (int): Quality of the resulting JPEG file, it must be a number between
1 and 100. Default: 75
Returns:
output (Tensor[1]): A one dimensional int8 tensor that contains the raw bytes of the
JPEG file.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(encode_jpeg)
if quality < 1 or quality > 100:
raise ValueError("Image quality should be a positive number between 1 and 100")
output = torch.ops.image.encode_jpeg(input, quality)
return output
def write_jpeg(input: torch.Tensor, filename: str, quality: int = 75):
"""
Takes an input tensor in CHW layout and saves it in a JPEG file.
Args:
input (Tensor[channels, image_height, image_width]): int8 image tensor of ``c``
channels, where ``c`` must be 1 or 3.
filename (str): Path to save the image.
quality (int): Quality of the resulting JPEG file, it must be a number
between 1 and 100. Default: 75
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(write_jpeg)
output = encode_jpeg(input, quality)
write_file(filename, output)
def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Detects whether an image is a JPEG or PNG and performs the appropriate
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
PNG or JPEG image.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes.
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_image)
output = torch.ops.image.decode_image(input, mode.value)
return output
def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].
Args:
path (str): path of the JPEG or PNG image.
mode (ImageReadMode): the read mode used for optionally converting the image.
Default: ``ImageReadMode.UNCHANGED``.
See ``ImageReadMode`` class for more information on various
available modes.
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(read_image)
data = read_file(path)
return decode_image(data, mode)
def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
data = read_file(path)
return torch.ops.image.decode_png(data, mode.value, True)