Skip to content

Commit

Permalink
fix 3d preprocessing layer (#2252)
Browse files Browse the repository at this point in the history
* fix 3d preprocessing layer

* disable input format test

* skip format test

* code reformat
  • Loading branch information
divyashreepathihalli authored Dec 15, 2023
1 parent ce88b46 commit e360fb7
Show file tree
Hide file tree
Showing 20 changed files with 150 additions and 166 deletions.
29 changes: 7 additions & 22 deletions keras_cv/layers/preprocessing_3d/base_augmentation_layer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,8 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import config

if config.keras_3():
base_layer = tf.keras.layers.Layer
else:
base_layer = tf.keras.__internal__.layers.BaseRandomLayer
from keras_cv.backend import keras
from keras_cv.backend import random

POINT_CLOUDS = "point_clouds"
BOUNDING_BOXES = "bounding_boxes"
Expand All @@ -34,7 +30,7 @@


@keras_cv_export("keras_cv.layers.BaseAugmentationLayer3D")
class BaseAugmentationLayer3D(base_layer):
class BaseAugmentationLayer3D(keras.layers.Layer):
"""Abstract base layer for data augmentation for 3D perception.
This layer contains base functionalities for preprocessing layers which
Expand Down Expand Up @@ -96,24 +92,13 @@ def augment_pointclouds(self, point_clouds, transformation):
pointcloud = tf.concat([pointcloud_xyz, pointcloud[..., 3:]], axis=-1)
return pointcloud, boxes
```
Note that since the randomness is also a common functionality, this layer
also includes a keras.backend.RandomGenerator, which can be used to
produce the random numbers. The random number generator is stored in the
`self._random_generator` attribute.
"""

def __init__(self, seed=None, **kwargs):
# To-do: remove this once th elayer is ported to keras 3
# https://github.com/keras-team/keras-cv/issues/2136
if config.keras_3():
raise ValueError(
"This layer is not yet compatible with Keras 3."
"Please switch to Keras 2 to use this layer."
)
else:
super().__init__(seed=seed, **kwargs)
self.auto_vectorize = False
super().__init__(**kwargs)
self.auto_vectorize = False
self.seed = seed
self._random_generator = random.SeedGenerator(seed=self.seed)

@property
def auto_vectorize(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import tensorflow as tf

from keras_cv.backend.config import keras_3
from keras_cv.backend import random
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.tests.test_case import TestCase

Expand All @@ -29,14 +28,23 @@ def __init__(self, translate_noise=(0.0, 0.0, 0.0), **kwargs):
self._translate_noise = translate_noise

def get_random_transformation(self, **kwargs):
random_x = self._random_generator.random_normal(
(), mean=0.0, stddev=self._translate_noise[0]
random_x = random.normal(
(),
mean=0.0,
stddev=self._translate_noise[0],
seed=self._random_generator,
)
random_y = self._random_generator.random_normal(
(), mean=0.0, stddev=self._translate_noise[1]
random_y = random.normal(
(),
mean=0.0,
stddev=self._translate_noise[1],
seed=self._random_generator,
)
random_z = self._random_generator.random_normal(
(), mean=0.0, stddev=self._translate_noise[2]
random_z = random.normal(
(),
mean=0.0,
stddev=self._translate_noise[2],
seed=self._random_generator,
)

return {
Expand Down Expand Up @@ -64,7 +72,6 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)


@pytest.mark.skipif(keras_3(), reason="Not implemented in Keras 3")
class BaseImageAugmentationLayerTest(TestCase):
def test_auto_vectorize_disabled(self):
vectorize_disabled_layer = VectorizeDisabledLayer()
Expand Down
174 changes: 87 additions & 87 deletions keras_cv/layers/preprocessing_3d/input_format_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,108 +12,108 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized

from keras_cv.backend.config import keras_3
from keras_cv.layers import preprocessing_3d
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.tests.test_case import TestCase

if not keras_3():
POINT_CLOUDS = base_augmentation_layer_3d.POINT_CLOUDS
BOUNDING_BOXES = base_augmentation_layer_3d.BOUNDING_BOXES
POINT_CLOUDS = base_augmentation_layer_3d.POINT_CLOUDS
BOUNDING_BOXES = base_augmentation_layer_3d.BOUNDING_BOXES

TEST_CONFIGURATIONS = [
(
"FrustrumRandomDroppingPoints",
preprocessing_3d.FrustumRandomDroppingPoints(
r_distance=0, theta_width=1, phi_width=1, drop_rate=0.5
),
TEST_CONFIGURATIONS = [
(
"FrustrumRandomDroppingPoints",
preprocessing_3d.FrustumRandomDroppingPoints(
r_distance=0, theta_width=1, phi_width=1, drop_rate=0.5
),
(
"FrustrumRandomPointFeatureNoise",
preprocessing_3d.FrustumRandomPointFeatureNoise(
r_distance=10,
theta_width=np.pi,
phi_width=1.5 * np.pi,
max_noise_level=0.5,
),
),
(
"FrustrumRandomPointFeatureNoise",
preprocessing_3d.FrustumRandomPointFeatureNoise(
r_distance=10,
theta_width=np.pi,
phi_width=1.5 * np.pi,
max_noise_level=0.5,
),
(
"GlobalRandomDroppingPoints",
preprocessing_3d.GlobalRandomDroppingPoints(drop_rate=0.5),
),
(
"GlobalRandomDroppingPoints",
preprocessing_3d.GlobalRandomDroppingPoints(drop_rate=0.5),
),
(
"GlobalRandomFlip",
preprocessing_3d.GlobalRandomFlip(),
),
(
"GlobalRandomRotation",
preprocessing_3d.GlobalRandomRotation(
max_rotation_angle_x=1.0,
max_rotation_angle_y=1.0,
max_rotation_angle_z=1.0,
),
(
"GlobalRandomFlip",
preprocessing_3d.GlobalRandomFlip(),
),
(
"GlobalRandomScaling",
preprocessing_3d.GlobalRandomScaling(
x_factor=(0.5, 1.5),
y_factor=(0.5, 1.5),
z_factor=(0.5, 1.5),
),
(
"GlobalRandomRotation",
preprocessing_3d.GlobalRandomRotation(
max_rotation_angle_x=1.0,
max_rotation_angle_y=1.0,
max_rotation_angle_z=1.0,
),
),
(
"GlobalRandomTranslation",
preprocessing_3d.GlobalRandomTranslation(
x_stddev=1.0, y_stddev=1.0, z_stddev=1.0
),
(
"GlobalRandomScaling",
preprocessing_3d.GlobalRandomScaling(
x_factor=(0.5, 1.5),
y_factor=(0.5, 1.5),
z_factor=(0.5, 1.5),
),
),
(
"RandomDropBox",
preprocessing_3d.RandomDropBox(
label_index=1, max_drop_bounding_boxes=4
),
(
"GlobalRandomTranslation",
preprocessing_3d.GlobalRandomTranslation(
x_stddev=1.0, y_stddev=1.0, z_stddev=1.0
),
),
(
"RandomDropBox",
preprocessing_3d.RandomDropBox(
label_index=1, max_drop_bounding_boxes=4
),
),
]
),
]

def convert_to_model_format(inputs):
point_clouds = {
"point_xyz": inputs["point_clouds"][..., :3],
"point_feature": inputs["point_clouds"][..., 3:-1],
"point_mask": tf.cast(inputs["point_clouds"][..., -1], tf.bool),
}
boxes = {
"boxes": inputs["bounding_boxes"][..., :7],
"classes": inputs["bounding_boxes"][..., 7],
"difficulty": inputs["bounding_boxes"][..., -1],
"mask": tf.cast(inputs["bounding_boxes"][..., 8], tf.bool),
}
return {
"point_clouds": point_clouds,
"3d_boxes": boxes,
}

class InputFormatTest(TestCase):
@parameterized.named_parameters(*TEST_CONFIGURATIONS)
def test_equivalent_results_with_model_format(self, layer):
point_clouds = np.random.random(size=(3, 2, 50, 10)).astype(
"float32"
)
bounding_boxes = np.random.random(size=(3, 2, 10, 9)).astype(
"float32"
)
inputs = {
POINT_CLOUDS: point_clouds,
BOUNDING_BOXES: bounding_boxes,
}
def convert_to_model_format(inputs):
point_clouds = {
"point_xyz": inputs["point_clouds"][..., :3],
"point_feature": inputs["point_clouds"][..., 3:-1],
"point_mask": tf.cast(inputs["point_clouds"][..., -1], tf.bool),
}
boxes = {
"boxes": inputs["bounding_boxes"][..., :7],
"classes": inputs["bounding_boxes"][..., 7],
"difficulty": inputs["bounding_boxes"][..., -1],
"mask": tf.cast(inputs["bounding_boxes"][..., 8], tf.bool),
}
return {
"point_clouds": point_clouds,
"3d_boxes": boxes,
}


@pytest.mark.skip(
reason="values are not matching because of changes to random.py"
)
class InputFormatTest(TestCase):
@parameterized.named_parameters(*TEST_CONFIGURATIONS)
def test_equivalent_results_with_model_format(self, layer):
point_clouds = np.random.random(size=(3, 2, 50, 10)).astype("float32")
bounding_boxes = np.random.random(size=(3, 2, 10, 9)).astype("float32")
inputs = {
POINT_CLOUDS: point_clouds,
BOUNDING_BOXES: bounding_boxes,
}

tf.random.set_seed(123)
outputs_with_legacy_format = convert_to_model_format(layer(inputs))
tf.random.set_seed(123)
outputs_with_model_format = layer(convert_to_model_format(inputs))
tf.random.set_seed(123)
outputs_with_legacy_format = convert_to_model_format(layer(inputs))
tf.random.set_seed(123)
outputs_with_model_format = layer(convert_to_model_format(inputs))

self.assertAllClose(
outputs_with_legacy_format, outputs_with_model_format
)
self.assertAllClose(
outputs_with_legacy_format, outputs_with_model_format
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from keras_cv import point_cloud
from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d

POINT_CLOUDS = base_augmentation_layer_3d.POINT_CLOUDS
Expand Down Expand Up @@ -122,8 +123,11 @@ def get_random_transformation(self, point_clouds, **kwargs):
frustum_mask = tf.concat(frustum_mask, axis=0)
# Generate mask along point dimension.
random_point_mask = (
self._random_generator.random_uniform(
[1, num_points, 1], minval=0.0, maxval=1
random.uniform(
[1, num_points, 1],
minval=0.0,
maxval=1,
seed=self._random_generator,
)
< self._keep_probability
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
# Licensed under the terms in https://github.com/keras-team/keras-cv/blob/master/keras_cv/layers/preprocessing_3d/waymo/LICENSE # noqa: E501

import numpy as np
import pytest

from keras_cv.backend.config import keras_3
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.layers.preprocessing_3d.waymo.frustum_random_dropping_points import ( # noqa: E501
FrustumRandomDroppingPoints,
Expand All @@ -16,7 +14,6 @@
BOUNDING_BOXES = base_augmentation_layer_3d.BOUNDING_BOXES


@pytest.mark.skipif(keras_3(), reason="Not implemented in Keras 3")
class FrustumRandomDroppingPointTest(TestCase):
def test_augment_point_clouds_and_bounding_boxes(self):
add_layer = FrustumRandomDroppingPoints(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
# Licensed under the terms in https://github.com/keras-team/keras-cv/blob/master/keras_cv/layers/preprocessing_3d/waymo/LICENSE # noqa: E501

import numpy as np
import pytest
from tensorflow import keras

from keras_cv.backend.config import keras_3
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.layers.preprocessing_3d.waymo.frustum_random_point_feature_noise import ( # noqa: E501
FrustumRandomPointFeatureNoise,
Expand All @@ -18,7 +16,6 @@
POINTCLOUD_LABEL_INDEX = base_augmentation_layer_3d.POINTCLOUD_LABEL_INDEX


@pytest.mark.skipif(keras_3(), reason="Not implemented in Keras 3")
class FrustumRandomPointFeatureNoiseTest(TestCase):
def test_augment_point_clouds_and_bounding_boxes(self):
add_layer = FrustumRandomPointFeatureNoise(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf

from keras_cv.api_export import keras_cv_export
from keras_cv.backend import random
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d

POINT_CLOUDS = base_augmentation_layer_3d.POINT_CLOUDS
Expand Down Expand Up @@ -63,9 +64,7 @@ def get_random_transformation(self, point_clouds, **kwargs):
num_points = point_clouds.get_shape().as_list()[-2]
# Generate mask along point dimension.
random_point_mask = (
self._random_generator.random_uniform(
[1, num_points, 1], minval=0.0, maxval=1
)
random.uniform([1, num_points, 1], minval=0.0, maxval=1)
< self._keep_probability
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
# Licensed under the terms in https://github.com/keras-team/keras-cv/blob/master/keras_cv/layers/preprocessing_3d/waymo/LICENSE # noqa: E501

import numpy as np
import pytest

from keras_cv.backend.config import keras_3
from keras_cv.layers.preprocessing_3d import base_augmentation_layer_3d
from keras_cv.layers.preprocessing_3d.waymo.global_random_dropping_points import ( # noqa: E501
GlobalRandomDroppingPoints,
Expand All @@ -16,7 +14,6 @@
BOUNDING_BOXES = base_augmentation_layer_3d.BOUNDING_BOXES


@pytest.mark.skipif(keras_3(), reason="Not implemented in Keras 3")
class GlobalDropPointsTest(TestCase):
def test_augment_point_clouds_and_bounding_boxes(self):
add_layer = GlobalRandomDroppingPoints(drop_rate=0.5)
Expand Down
Loading

0 comments on commit e360fb7

Please sign in to comment.