Skip to content

Commit

Permalink
Fix test_dali_tf_dataset_mnist_eager.py and test_dali_tf_dataset_mnis…
Browse files Browse the repository at this point in the history
…t_graph.py tests (#3987)

- adds a proper import for skip_for_incompatible_tf in test_dali_tf_dataset_mnist_eager.py
  test
- adds missing import of with_setup to test_dali_tf_dataset_mnist_graph.py

Signed-off-by: Janusz Lisiecki <[email protected]>
  • Loading branch information
JanuszL authored Jun 15, 2022
1 parent e687379 commit 6a2eb23
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
9 changes: 5 additions & 4 deletions dali/test/python/test_dali_tf_dataset_mnist_eager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@
from nose.tools import with_setup

import test_dali_tf_dataset_mnist as mnist
from test_utils_tensorflow import skip_for_incompatible_tf
from nose_utils import raises

tf.compat.v1.enable_eager_execution()
Expand All @@ -33,7 +34,7 @@ def test_keras_single_cpu():
mnist.run_keras_single_device('cpu', 0)


@with_setup(mnist.skip_for_incompatible_tf)
@with_setup(skip_for_incompatible_tf)
@raises(Exception, "TF device and DALI device mismatch")
def test_keras_wrong_placement_gpu():
with tf.device('cpu:0'):
Expand All @@ -46,7 +47,7 @@ def test_keras_wrong_placement_gpu():
steps_per_epoch=mnist.ITERATIONS)


@with_setup(mnist.skip_for_incompatible_tf)
@with_setup(skip_for_incompatible_tf)
@raises(Exception, "TF device and DALI device mismatch")
def test_keras_wrong_placement_cpu():
with tf.device('gpu:0'):
Expand All @@ -59,7 +60,7 @@ def test_keras_wrong_placement_cpu():
steps_per_epoch=mnist.ITERATIONS)


@with_setup(mnist.skip_for_incompatible_tf)
@with_setup(skip_for_incompatible_tf)
def test_keras_multi_gpu_mirrored_strategy():
strategy = tf.distribute.MirroredStrategy(devices=mnist.available_gpus())

Expand Down
3 changes: 2 additions & 1 deletion dali/test/python/test_dali_tf_dataset_mnist_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,6 +14,7 @@

from test_dali_tf_dataset_mnist import *
from nose_utils import raises
from nose import with_setup

tf.compat.v1.disable_eager_execution()

Expand Down

0 comments on commit 6a2eb23

Please sign in to comment.