From 7735d8ef1a103a142614f941c018e79cf44d5f7b Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 25 Apr 2022 12:43:20 -0700 Subject: [PATCH] Use GPU instead of DML to identify DML devices in SqueezeNet model (#236) --- TensorFlow/TF2/squeezenet/squeezenet.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/TensorFlow/TF2/squeezenet/squeezenet.py b/TensorFlow/TF2/squeezenet/squeezenet.py index 9193cd42..dd273ef9 100644 --- a/TensorFlow/TF2/squeezenet/squeezenet.py +++ b/TensorFlow/TF2/squeezenet/squeezenet.py @@ -8,7 +8,6 @@ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' gpu_available = (len(tf.config.list_physical_devices('GPU')) > 0) -dml_available = (len(tf.config.list_physical_devices('DML')) > 0) tf.random.set_seed(1234) @@ -20,7 +19,7 @@ parser.add_argument('--batch_size', type=int, default=32, help='Number of images per batch fed through network') parser.add_argument('--num_epochs', type=int, default=100, help='Number of passes through training data before stopping') parser.add_argument("--log_device_placement", action="store_true", help="Print the operator device placement on the pre-optimized graph") -parser.add_argument('--device', type=str, default='CPU:0' if not dml_available else 'DML:0', help='Specify manually to use non-DML GPU device eg. GPU:0') +parser.add_argument('--device', type=str, default='CPU:0' if not gpu_available else 'GPU:0', help='Specify manually to use non-DML GPU device eg. CPU:0') parser.add_argument("--inter_op_threads", default=0, type=int, help="Max number of threads for the runtime to use for kernel scheduling") parser.add_argument('--cifar10', action='store_true', help='Train with CIFAR-10 dataset') parser.add_argument('--tb_profile', action='store_true', help='Performance profiling using TensorBoard') @@ -109,8 +108,8 @@ def main(): tf.debugging.set_log_device_placement(True) if args.inter_op_threads > 0: tf.config.threading.set_inter_op_parallelism_threads(args.inter_op_threads) - if args.device[:3] != 'DML': - tf.config.set_visible_devices([], 'DML') + if args.device[:3] != 'GPU': + tf.config.set_visible_devices([], 'GPU') batch_size = args.batch_size num_epochs = args.num_epochs