Skip to content

Commit

Permalink
Use GPU instead of DML to identify DML devices in SqueezeNet model (#236
Browse files Browse the repository at this point in the history
)
  • Loading branch information
PatriceVignola authored Apr 25, 2022
1 parent 5f959ba commit 7735d8e
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions TensorFlow/TF2/squeezenet/squeezenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

gpu_available = (len(tf.config.list_physical_devices('GPU')) > 0)
dml_available = (len(tf.config.list_physical_devices('DML')) > 0)

tf.random.set_seed(1234)

Expand All @@ -20,7 +19,7 @@
parser.add_argument('--batch_size', type=int, default=32, help='Number of images per batch fed through network')
parser.add_argument('--num_epochs', type=int, default=100, help='Number of passes through training data before stopping')
parser.add_argument("--log_device_placement", action="store_true", help="Print the operator device placement on the pre-optimized graph")
parser.add_argument('--device', type=str, default='CPU:0' if not dml_available else 'DML:0', help='Specify manually to use non-DML GPU device eg. GPU:0')
parser.add_argument('--device', type=str, default='CPU:0' if not gpu_available else 'GPU:0', help='Specify manually to use non-DML GPU device eg. CPU:0')
parser.add_argument("--inter_op_threads", default=0, type=int, help="Max number of threads for the runtime to use for kernel scheduling")
parser.add_argument('--cifar10', action='store_true', help='Train with CIFAR-10 dataset')
parser.add_argument('--tb_profile', action='store_true', help='Performance profiling using TensorBoard')
Expand Down Expand Up @@ -109,8 +108,8 @@ def main():
tf.debugging.set_log_device_placement(True)
if args.inter_op_threads > 0:
tf.config.threading.set_inter_op_parallelism_threads(args.inter_op_threads)
if args.device[:3] != 'DML':
tf.config.set_visible_devices([], 'DML')
if args.device[:3] != 'GPU':
tf.config.set_visible_devices([], 'GPU')

batch_size = args.batch_size
num_epochs = args.num_epochs
Expand Down

0 comments on commit 7735d8e

Please sign in to comment.