Skip to content

Commit

Permalink
Fix the inconsistency in affine_transform between numpy and jax (#926)
Browse files Browse the repository at this point in the history
* Fix the inconsistency of `map_coordinates`

* Cast float64 to float32 in unit test

* Bypass test with interpolation=nearest

* Cast to float32 in moments test

* Increase atol and rtol in `nn.moments`
  • Loading branch information
james77777778 authored Sep 20, 2023
1 parent 1f7d869 commit 9d341ae
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 55 deletions.
15 changes: 7 additions & 8 deletions keras_core/backend/jax/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ def resize(
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


Expand All @@ -80,11 +80,10 @@ def affine_transform(
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}"
)

transform = convert_to_tensor(transform)
Expand Down
22 changes: 10 additions & 12 deletions keras_core/backend/numpy/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def resize(
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
"constant",
"nearest",
"wrap",
"mirror",
"reflect",
}


Expand All @@ -79,11 +79,10 @@ def affine_transform(
f"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: "
f"interpolation={interpolation}"
)
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES.keys():
if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:
raise ValueError(
"Invalid value for argument `fill_mode`. Expected of one "
f"{set(AFFINE_TRANSFORM_FILL_MODES.keys())}. "
f"Received: fill_mode={fill_mode}"
f"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}"
)

transform = convert_to_tensor(transform)
Expand Down Expand Up @@ -153,13 +152,12 @@ def affine_transform(
# apply affine transformation
affined = np.stack(
[
scipy.ndimage.map_coordinates(
map_coordinates(
image[i],
coordinates[i],
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
cval=fill_value,
prefilter=False,
fill_mode=fill_mode,
fill_value=fill_value,
)
for i in range(batch_size)
],
Expand Down
44 changes: 22 additions & 22 deletions keras_core/ops/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ def test_map_coordinates(self):
"nearest": 0,
"bilinear": 1,
}
AFFINE_TRANSFORM_FILL_MODES = {
"constant": "grid-constant",
"nearest": "nearest",
"wrap": "grid-wrap",
"mirror": "mirror",
"reflect": "reflect",
}


def _compute_affine_transform_coordinates(image, transform):
Expand Down Expand Up @@ -121,7 +114,7 @@ def _compute_affine_transform_coordinates(image, transform):
return coordinates


def _map_coordinates(
def _fixed_map_coordinates(
input, coordinates, order, fill_mode="constant", fill_value=0.0
):
# SciPy's implementation of map_coordinates handles boundaries incorrectly,
Expand Down Expand Up @@ -266,13 +259,22 @@ def test_affine_transform(self, interpolation, fill_mode, data_format):
"affine_transform with fill_mode=wrap is inconsistent with"
"scipy"
)
# TODO: `nearest` interpolation and `nearest` fill_mode in torch and jax
# causes random index shifting, resulting in significant differences in
# output which leads to failure
if backend.backend() in ("torch", "jax") and interpolation == "nearest":
self.skipTest(
f"In {backend.backend()} backend, "
f"interpolation={interpolation} causes index shifting and "
"leads test failure"
)

# Unbatched case
if data_format == "channels_first":
x = np.random.random((3, 50, 50)) * 255
x = np.random.random((3, 50, 50)).astype("float32") * 255
else:
x = np.random.random((50, 50, 3)) * 255
transform = np.random.random(size=(6))
x = np.random.random((50, 50, 3)).astype("float32") * 255
transform = np.random.random(size=(6)).astype("float32")
transform = np.pad(transform, (0, 2)) # makes c0, c1 always 0
out = kimage.affine_transform(
x,
Expand All @@ -284,24 +286,23 @@ def test_affine_transform(self, interpolation, fill_mode, data_format):
if data_format == "channels_first":
x = np.transpose(x, (1, 2, 0))
coordinates = _compute_affine_transform_coordinates(x, transform)
ref_out = scipy.ndimage.map_coordinates(
ref_out = _fixed_map_coordinates(
x,
coordinates,
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
prefilter=False,
fill_mode=fill_mode,
)
if data_format == "channels_first":
ref_out = np.transpose(ref_out, (2, 0, 1))
self.assertEqual(tuple(out.shape), tuple(ref_out.shape))
self.assertAllClose(ref_out, out, atol=0.3)
self.assertAllClose(ref_out, out, atol=1e-3, rtol=1e-3)

# Batched case
if data_format == "channels_first":
x = np.random.random((2, 3, 50, 50)) * 255
x = np.random.random((2, 3, 50, 50)).astype("float32") * 255
else:
x = np.random.random((2, 50, 50, 3)) * 255
transform = np.random.random(size=(2, 6))
x = np.random.random((2, 50, 50, 3)).astype("float32") * 255
transform = np.random.random(size=(2, 6)).astype("float32")
transform = np.pad(transform, [(0, 0), (0, 2)]) # makes c0, c1 always 0
out = kimage.affine_transform(
x,
Expand All @@ -315,12 +316,11 @@ def test_affine_transform(self, interpolation, fill_mode, data_format):
coordinates = _compute_affine_transform_coordinates(x, transform)
ref_out = np.stack(
[
scipy.ndimage.map_coordinates(
_fixed_map_coordinates(
x[i],
coordinates[i],
order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],
mode=AFFINE_TRANSFORM_FILL_MODES[fill_mode],
prefilter=False,
fill_mode=fill_mode,
)
for i in range(x.shape[0])
],
Expand Down Expand Up @@ -420,6 +420,6 @@ def test_map_coordinates(self, shape, dtype, order, fill_mode):
for size in input_shape
]
output = kimage.map_coordinates(input, coordinates, order, fill_mode)
expected = _map_coordinates(input, coordinates, order, fill_mode)
expected = _fixed_map_coordinates(input, coordinates, order, fill_mode)

self.assertAllClose(output, expected)
30 changes: 17 additions & 13 deletions keras_core/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,26 +1183,30 @@ def test_moments(self):
# Test 1D moments
x = np.array([0, 1, 2, 3, 4, 100, -200]).astype(np.float32)
mean, variance = knn.moments(x, axes=[0])
self.assertAllClose(mean, np.mean(x))
self.assertAllClose(variance, np.var(x))
self.assertAllClose(mean, np.mean(x), atol=1e-5, rtol=1e-5)
self.assertAllClose(variance, np.var(x), atol=1e-5, rtol=1e-5)

# Test batch statistics for 4D moments (batch, height, width, channels)
x = np.random.uniform(size=(2, 28, 28, 3))
x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)
mean, variance = knn.moments(x, axes=[0])
self.assertAllClose(mean, np.mean(x, axis=0))
self.assertAllClose(variance, np.var(x, axis=0))
self.assertAllClose(mean, np.mean(x, axis=0), atol=1e-5, rtol=1e-5)
self.assertAllClose(variance, np.var(x, axis=0), atol=1e-5, rtol=1e-5)

# Test global statistics for 4D moments (batch, height, width, channels)
x = np.random.uniform(size=(2, 28, 28, 3))
x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)
mean, variance = knn.moments(x, axes=[0, 1, 2])
self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2)))
self.assertAllClose(variance, np.var(x, axis=(0, 1, 2)))
expected_mean = np.mean(x, axis=(0, 1, 2))
expected_variance = np.var(x, axis=(0, 1, 2))
self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)
self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)

# Test keepdims
x = np.random.uniform(size=(2, 28, 28, 3))
x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)
mean, variance = knn.moments(x, axes=[0, 1, 2], keepdims=True)
self.assertAllClose(mean, np.mean(x, axis=(0, 1, 2), keepdims=True))
self.assertAllClose(variance, np.var(x, axis=(0, 1, 2), keepdims=True))
expected_mean = np.mean(x, axis=(0, 1, 2), keepdims=True)
expected_variance = np.var(x, axis=(0, 1, 2), keepdims=True)
self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)
self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)

# Test float16 which causes overflow
x = np.array(
Expand All @@ -1214,5 +1218,5 @@ def test_moments(self):
# the output variance is clipped to the max value of np.float16 because
# it is overflowed
expected_variance = np.finfo(np.float16).max
self.assertAllClose(mean, expected_mean)
self.assertAllClose(variance, expected_variance)
self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)
self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)

0 comments on commit 9d341ae

Please sign in to comment.