diff --git a/dali/kernels/context.h b/dali/kernels/context.h index 1ad57649049..5eb2448b83d 100644 --- a/dali/kernels/context.h +++ b/dali/kernels/context.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ #include #include #include +#include "dali/core/access_order.h" #include "dali/core/tensor_view.h" #include "dali/core/mm/memory_resource.h" #include "dali/core/mm/memory_kind.h" @@ -34,7 +35,7 @@ struct Context {}; template <> struct Context { - cudaStream_t stream = 0; + cudaStream_t stream = AccessOrder::null_stream(); }; class Scratchpad; diff --git a/dali/kernels/dynamic_scratchpad.h b/dali/kernels/dynamic_scratchpad.h index 109271e6a00..9a1d5128a51 100644 --- a/dali/kernels/dynamic_scratchpad.h +++ b/dali/kernels/dynamic_scratchpad.h @@ -60,6 +60,8 @@ class DynamicScratchpadImplT { AccessOrder dealloc_order = {}) { static_assert(!std::is_same::value, "Cannot use a stream-ordered resource for plain host memory"); + if (!dealloc_order.has_value()) + dealloc_order = alloc_order; adapter() = { rsrc, alloc_order, dealloc_order }; set_upstream_resource(&adapter()); } @@ -135,41 +137,71 @@ class DynamicScratchpad initial_sizes_ = initial_sizes; for (auto &s : initial_sizes_) { if (s == 0) - s = 4096; + s = 0x10000; // 64k } if (!pinned_dealloc_order.has_value()) pinned_dealloc_order = device_order; if (!managed_dealloc_order.has_value()) managed_dealloc_order = device_order; + device_order_ = device_order; + pinned_dealloc_order_ = pinned_dealloc_order; + managed_dealloc_order_ = managed_dealloc_order; + } + + virtual void *Alloc(mm::memory_kind_id kind_id, size_t bytes, size_t alignment) { + void *ret = nullptr; + TYPE_SWITCH(kind_id, mm::kind2id, Kind, + (mm::memory_kind::host, + mm::memory_kind::pinned, + mm::memory_kind::device, + mm::memory_kind::managed), + (ret = AllocImpl(bytes, alignment)), + (assert(!"Incorrect memory kind id");)); + return ret; + } + + template + struct type_tag {}; + + void InitResource(type_tag) { set_upstream_resource(mm::GetDefaultResource()); + } + void InitResource(type_tag) { set_upstream_resource( mm::GetDefaultResource(), AccessOrder::host(), - pinned_dealloc_order); + pinned_dealloc_order_); + } + void InitResource(type_tag) { set_upstream_resource( mm::GetDefaultResource(), - device_order); + device_order_); + } + void InitResource(type_tag) { set_upstream_resource( mm::GetDefaultResource(), AccessOrder::host(), - managed_dealloc_order); + managed_dealloc_order_); } - virtual void *Alloc(mm::memory_kind_id kind_id, size_t bytes, size_t alignment) { - void *ret = nullptr; - TYPE_SWITCH(kind_id, mm::kind2id, Kind, - (mm::memory_kind::host, - mm::memory_kind::pinned, - mm::memory_kind::device, - mm::memory_kind::managed), - (ret = resource().allocate(bytes, alignment)), - (assert(!"Incorrect memory kind id");)); - return ret; + template + void *AllocImpl(size_t bytes, size_t alignment) { + if (bytes == 0) + return nullptr; // do not initialize the resource in case of 0-sized allocation + + auto &r = resource(); + if (!r.get_upstream()) { + InitResource(type_tag()); + assert(r.get_upstream() != nullptr); + } + return r.allocate(bytes, alignment); } + + AccessOrder device_order_, pinned_dealloc_order_, managed_dealloc_order_; }; } // namespace kernels diff --git a/dali/kernels/imgproc/convolution/convolution_gpu_test.cu b/dali/kernels/imgproc/convolution/convolution_gpu_test.cu index 3a2791a58e2..4ffb44177c5 100644 --- a/dali/kernels/imgproc/convolution/convolution_gpu_test.cu +++ b/dali/kernels/imgproc/convolution/convolution_gpu_test.cu @@ -158,6 +158,7 @@ struct ConvolutionGpuKernelTest : public ::testing::Test { void RunTest() { KernelContext ctx_cpu, ctx_gpu; + ctx_gpu.gpu.stream = 0; KernelCpu kernel_cpu; KernelGpu kernel_gpu; diff --git a/dali/kernels/imgproc/convolution/laplacian_gpu_test.cu b/dali/kernels/imgproc/convolution/laplacian_gpu_test.cu index 94949b1227e..cd61eaa9e88 100644 --- a/dali/kernels/imgproc/convolution/laplacian_gpu_test.cu +++ b/dali/kernels/imgproc/convolution/laplacian_gpu_test.cu @@ -161,6 +161,7 @@ struct LaplacianGpuTest : public ::testing::Test { void RunTest() { KernelContext ctx_cpu = {}, ctx_gpu = {}; + ctx_gpu.gpu.stream = 0; KernelCpu kernel_cpu; KernelGpu kernel_gpu; int nsamples = in_.shape.size(); diff --git a/dali/kernels/imgproc/flip_gpu_test.cu b/dali/kernels/imgproc/flip_gpu_test.cu index 723e465082d..7dc3f18c2d4 100644 --- a/dali/kernels/imgproc/flip_gpu_test.cu +++ b/dali/kernels/imgproc/flip_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -51,6 +51,7 @@ class FlipGpuTest: public testing::TestWithParam> TEST_P(FlipGpuTest, ImplTest) { KernelContext ctx; + ctx.gpu.stream = 0; FlipGPU kernel; auto in_view = ttl_in_.gpu(nullptr); ttl_in_.invalidate_cpu(); @@ -75,6 +76,7 @@ TEST_P(FlipGpuTest, ImplTest) { TEST_P(FlipGpuTest, KernelTest) { KernelContext ctx; + ctx.gpu.stream = 0; FlipGPU kernel; auto in_view = ttl_in_.gpu(nullptr); ttl_in_.invalidate_cpu(); diff --git a/dali/kernels/imgproc/pointwise/linear_transformation_gpu_test.cu b/dali/kernels/imgproc/pointwise/linear_transformation_gpu_test.cu index 9a9a60e7165..3eab77db4e3 100644 --- a/dali/kernels/imgproc/pointwise/linear_transformation_gpu_test.cu +++ b/dali/kernels/imgproc/pointwise/linear_transformation_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -126,6 +126,7 @@ TYPED_TEST(LinearTransformationGpuTest, check_kernel) { TYPED_TEST(LinearTransformationGpuTest, setup_test) { TheKernel kernel; KernelContext ctx; + ctx.gpu.stream = 0; InListGPU in(this->input_device_, this->in_shapes_); auto reqs = kernel.Setup(ctx, in, make_cspan(this->vmat_), make_cspan(this->vvec_)); ASSERT_EQ(this->out_shapes_.size(), static_cast(reqs.output_shapes[0].num_samples())) @@ -140,6 +141,7 @@ TYPED_TEST(LinearTransformationGpuTest, setup_test) { TYPED_TEST(LinearTransformationGpuTest, setup_test_with_roi) { TheKernel kernel; KernelContext ctx; + ctx.gpu.stream = 0; InListGPU in(this->input_device_, this->in_shapes_); auto reqs = kernel.Setup(ctx, in, make_cspan(this->vmat_), make_cspan(this->vvec_), make_cspan(this->rois_)); @@ -150,20 +152,21 @@ TYPED_TEST(LinearTransformationGpuTest, setup_test_with_roi) { TYPED_TEST(LinearTransformationGpuTest, run_test) { TheKernel kernel; - KernelContext c; + KernelContext ctx; + ctx.gpu.stream = 0; InListGPU in(this->input_device_, this->in_shapes_); - auto reqs = kernel.Setup(c, in, make_cspan(this->vmat_), make_cspan(this->vvec_)); + auto reqs = kernel.Setup(ctx, in, make_cspan(this->vmat_), make_cspan(this->vvec_)); ScratchpadAllocator sa; sa.Reserve(reqs.scratch_sizes); auto scratchpad = sa.GetScratchpad(); - c.scratchpad = &scratchpad; + ctx.scratchpad = &scratchpad; OutListGPU out( this->output_, reqs.output_shapes[0].template to_static()); - kernel.Run(c, out, in, make_cspan(this->vmat_), make_cspan(this->vvec_)); + kernel.Run(ctx, out, in, make_cspan(this->vmat_), make_cspan(this->vvec_)); CUDA_CALL(cudaDeviceSynchronize()); auto res = copy(out[0]); @@ -175,22 +178,24 @@ TYPED_TEST(LinearTransformationGpuTest, run_test) { TYPED_TEST(LinearTransformationGpuTest, run_test_with_roi) { TheKernel kernel; - KernelContext c; + KernelContext ctx; + ctx.gpu.stream = 0; InListGPU in(this->input_device_, this->in_shapes_); - auto reqs = kernel.Setup(c, in, + auto reqs = kernel.Setup(ctx, in, make_cspan(this->vmat_), make_cspan(this->vvec_), make_cspan(this->rois_)); ScratchpadAllocator sa; sa.Reserve(reqs.scratch_sizes); auto scratchpad = sa.GetScratchpad(); - c.scratchpad = &scratchpad; + ctx.scratchpad = &scratchpad; OutListGPU out( this->output_, reqs.output_shapes[0].template to_static()); - kernel.Run(c, out, in, make_cspan(this->vmat_), make_cspan(this->vvec_), make_cspan(this->rois_)); + kernel.Run(ctx, out, in, + make_cspan(this->vmat_), make_cspan(this->vvec_), make_cspan(this->rois_)); CUDA_CALL(cudaDeviceSynchronize()); auto res = copy(out[0]); diff --git a/dali/kernels/imgproc/pointwise/multiply_add_gpu_test.cu b/dali/kernels/imgproc/pointwise/multiply_add_gpu_test.cu index e0a57af9da2..b23e21c20d0 100644 --- a/dali/kernels/imgproc/pointwise/multiply_add_gpu_test.cu +++ b/dali/kernels/imgproc/pointwise/multiply_add_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -130,6 +130,7 @@ TYPED_TEST(MultiplyAddGpuTest, check_kernel) { TYPED_TEST(MultiplyAddGpuTest, setup_test) { TheKernel kernel; KernelContext ctx; + ctx.gpu.stream = 0; InListGPU in(this->input_device_, this->shapes_); auto reqs = kernel.Setup(ctx, in, this->addends_, this->multipliers_); ASSERT_EQ(this->shapes_.size(), static_cast(reqs.output_shapes[0].num_samples())) @@ -143,18 +144,19 @@ TYPED_TEST(MultiplyAddGpuTest, setup_test) { TYPED_TEST(MultiplyAddGpuTest, run_test) { TheKernel kernel; - KernelContext c; + KernelContext ctx; + ctx.gpu.stream = 0; InListGPU in(this->input_device_, this->shapes_); OutListGPU out(this->output_, TensorListShape(this->shapes_)); - auto reqs = kernel.Setup(c, in, this->addends_, this->multipliers_); + auto reqs = kernel.Setup(ctx, in, this->addends_, this->multipliers_); ScratchpadAllocator sa; sa.Reserve(reqs.scratch_sizes); auto scratchpad = sa.GetScratchpad(); - c.scratchpad = &scratchpad; - kernel.Run(c, out, in, this->addends_, this->multipliers_); + ctx.scratchpad = &scratchpad; + kernel.Run(ctx, out, in, this->addends_, this->multipliers_); CUDA_CALL(cudaDeviceSynchronize()); auto res = copy(out[0]); diff --git a/dali/kernels/kernel_manager.h b/dali/kernels/kernel_manager.h index 637252e9f1a..75dcb3619db 100644 --- a/dali/kernels/kernel_manager.h +++ b/dali/kernels/kernel_manager.h @@ -22,6 +22,7 @@ #include "dali/kernels/scratch.h" #include "dali/kernels/context.h" #include "dali/kernels/kernel_req.h" +#include "dali/kernels/dynamic_scratchpad.h" #include "dali/core/small_vector.h" #include "dali/core/mm/memory_kind.h" @@ -218,9 +219,15 @@ class DLL_PUBLIC KernelManager { */ template void Run(int thread_idx, int instance_idx, KernelContext &context, OutInArgs &&...out_in_args) { - assert(static_cast(thread_idx) < scratchpads.size()); - auto &sa = GetScratchpadAllocator(thread_idx); - Run(sa, instance_idx, context, std::forward(out_in_args)...); + assert(instance_idx >= 0 && + static_cast(instance_idx) < NumInstances() && + "Kernel instance index (instance_idx) out of range"); + auto &inst = instances[instance_idx]; + DynamicScratchpad scratchpad({}, AccessOrder(context.gpu.stream)); + auto *old_scratchpad = context.scratchpad; + context.scratchpad = &scratchpad; + inst.get().Run(context, std::forward(out_in_args)...); + context.scratchpad = old_scratchpad; } /** diff --git a/dali/kernels/math/transform_points_test.cu b/dali/kernels/math/transform_points_test.cu index f8346d727ff..84941014fff 100644 --- a/dali/kernels/math/transform_points_test.cu +++ b/dali/kernels/math/transform_points_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -70,6 +70,7 @@ struct TransformPointsTest : ::testing::Test { const auto *in_points = reinterpret_cast *>(in_tensor.data); const auto *out_points = reinterpret_cast *>(out_tensor.data); KernelContext ctx; + ctx.gpu.stream = 0; auto &req = kmgr_.Setup(0, ctx, in_tensor.shape); ASSERT_EQ(req.output_shapes[0][0], out_tensor.shape); kmgr_.Run(0, 0, ctx, out_tensor, in_tensor, M, T); @@ -103,6 +104,7 @@ struct TransformPointsTest : ::testing::Test { kmgr_.Resize(1, 1); KernelContext ctx; + ctx.gpu.stream = 0; auto &req = kmgr_.Setup(0, ctx, in_gpu.shape); ASSERT_EQ(req.output_shapes[0], out_gpu.shape); kmgr_.Run(0, 0, ctx, out_gpu, in_gpu, make_span(M), make_span(T)); diff --git a/dali/kernels/normalize/normalize_gpu_test.cu b/dali/kernels/normalize/normalize_gpu_test.cu index 9ab36d339e3..a1e3d1ffbe7 100644 --- a/dali/kernels/normalize/normalize_gpu_test.cu +++ b/dali/kernels/normalize/normalize_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -246,6 +246,7 @@ class NormalizeImplGPUTest> : public ::testing::Test { void RunTest() { kmgr_.Resize(1, 1); KernelContext ctx; + ctx.gpu.stream = 0; for (int iter = 0; iter < 3; iter++) { auto req = kmgr_.Setup(0, ctx, data_shape_, param_shape_, use_scalar_base_, use_scalar_scale_, scale_is_stddev_); @@ -276,6 +277,7 @@ class NormalizeImplGPUTest> : public ::testing::Test { void RunPerf() { kmgr_.Resize(1, 1); KernelContext ctx; + ctx.gpu.stream = 0; auto req = kmgr_.Setup(0, ctx, data_shape_, param_shape_, use_scalar_base_, use_scalar_scale_, scale_is_stddev_); ASSERT_EQ(req.output_shapes.size(), 1u); diff --git a/dali/kernels/reduce/reduce_gpu_test.h b/dali/kernels/reduce/reduce_gpu_test.h index c0ac7f110e6..59167dcc15e 100644 --- a/dali/kernels/reduce/reduce_gpu_test.h +++ b/dali/kernels/reduce/reduce_gpu_test.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ struct ReductionKernelTest { Args &&...args) { in.reshape(in_shape); ref.reshape(ref_out_shape); + ctx.gpu.stream = 0; auto req = kernel.Setup(ctx, in_shape, axes, keep_dims, batch, std::forward(args)...); ASSERT_EQ(req.output_shapes.size(), 1), req; ASSERT_EQ(req.output_shapes[0], ref_out_shape), req; @@ -66,6 +67,7 @@ struct ReductionKernelTest { template void Run(Args &&...args) { auto scratchpad = sa.GetScratchpad(); + ctx.gpu.stream = 0; ctx.scratchpad = &scratchpad; kernel.Run(ctx, out.gpu(stream()), in.gpu(stream()), std::forward(args)...); } diff --git a/dali/kernels/signal/fft/fft_postprocess_test.cu b/dali/kernels/signal/fft/fft_postprocess_test.cu index a027906cce5..eb8eb0a8523 100644 --- a/dali/kernels/signal/fft/fft_postprocess_test.cu +++ b/dali/kernels/signal/fft/fft_postprocess_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -108,6 +108,7 @@ class FFTPostprocessTest> : public ::testin ToFreqMajorSpectrum tr; KernelContext ctx; + ctx.gpu.stream = 0; ScratchpadAllocator sa; KernelRequirements req = tr.Setup(ctx, in_shape); ASSERT_EQ(req.output_shapes.size(), 1u); @@ -174,6 +175,7 @@ class FFTPostprocessTest> : public ::testin ConvertTimeMajorSpectrum tr; KernelContext ctx; + ctx.gpu.stream = 0; tr.Setup(ctx, in_shape); tr.Run(ctx, out_gpu, in.gpu()); CUDA_CALL(cudaGetLastError()); diff --git a/dali/kernels/signal/fft/stft_gpu_test.cu b/dali/kernels/signal/fft/stft_gpu_test.cu index 1695299b5d9..b32c44ee46c 100644 --- a/dali/kernels/signal/fft/stft_gpu_test.cu +++ b/dali/kernels/signal/fft/stft_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -160,6 +160,7 @@ TEST(StftGPU, Setup) { args.time_major_layout = time_major; KernelContext ctx; + ctx.gpu.stream = 0; KernelRequirements req = stft.Setup(ctx, lengths, args); ASSERT_EQ(req.output_shapes.size(), 1u); auto &o_shape = req.output_shapes[0]; @@ -218,6 +219,7 @@ class StftGPUTest> TestTensorList out; KernelContext ctx; + ctx.gpu.stream = 0; KernelRequirements req = stft.Setup(ctx, in_shape, args); auto stream = ctx.gpu.stream; ASSERT_EQ(req.output_shapes.size(), 1u); diff --git a/dali/kernels/signal/window/extract_windows_gpu_test.cu b/dali/kernels/signal/window/extract_windows_gpu_test.cu index 58c82d88643..4ef8221371a 100644 --- a/dali/kernels/signal/window/extract_windows_gpu_test.cu +++ b/dali/kernels/signal/window/extract_windows_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -130,6 +130,7 @@ void TestBatchedExtract( int in_win_start = (out_win_len_actual - args.window_length) / 2; KernelContext ctx; + ctx.gpu.stream = 0; auto in_gpu = in_list.gpu(0); diff --git a/dali/kernels/slice/slice_flip_normalize_permute_pad_gpu_test.cu b/dali/kernels/slice/slice_flip_normalize_permute_pad_gpu_test.cu index cb66d204bf9..bdfd4e9aa74 100644 --- a/dali/kernels/slice/slice_flip_normalize_permute_pad_gpu_test.cu +++ b/dali/kernels/slice/slice_flip_normalize_permute_pad_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ class SliceFlipNormalizePermutePadGpuTest : public SliceFlipNormalizePermutePadT void Run() override { KernelContext ctx; + ctx.gpu.stream = 0; TestTensorList test_data; this->PrepareData(test_data); diff --git a/dali/kernels/slice/slice_gpu_test.cu b/dali/kernels/slice/slice_gpu_test.cu index 21a556d5c8b..ce9791582e8 100644 --- a/dali/kernels/slice/slice_gpu_test.cu +++ b/dali/kernels/slice/slice_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -33,6 +33,7 @@ class SliceGPUTest : public SliceTest { void Run() override { KernelContext ctx; + ctx.gpu.stream = 0; TestTensorList test_data; this->PrepareData(test_data); diff --git a/dali/kernels/test/kernel_poc_test.cu b/dali/kernels/test/kernel_poc_test.cu index 8a737288cd3..4d816356c4c 100644 --- a/dali/kernels/test/kernel_poc_test.cu +++ b/dali/kernels/test/kernel_poc_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -67,6 +67,10 @@ struct MADKernelGPU { template class KernelPoC_GPU : public ::testing::Test, public KernelPoCFixture { + public: + KernelPoC_GPU() { + this->ctx.gpu.stream = 0; + } }; diff --git a/dali/kernels/test/kernel_poc_test.h b/dali/kernels/test/kernel_poc_test.h index ba1ced0d72f..0e1dbaffa5d 100644 --- a/dali/kernels/test/kernel_poc_test.h +++ b/dali/kernels/test/kernel_poc_test.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2019, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ struct KernelPoCFixture : Base { ASSERT_NO_FATAL_FAILURE(Verify()); } - private: + protected: KernelContext ctx; Kernel kernel; TestTensorList tl1; diff --git a/dali/kernels/test/manager_test.cc b/dali/kernels/test/manager_test.cc index 50765364397..c0cfca1ba5d 100644 --- a/dali/kernels/test/manager_test.cc +++ b/dali/kernels/test/manager_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -113,6 +113,7 @@ TEST(KernelManager, CreateOrGet_Get) { in.shape = {{ { 10, 10, 1 }, { 20, 20, 3 } }}; out.shape = {{ { 10, 10, 1 }, { 20, 20, 3 } }}; KernelContext ctx; + ctx.gpu.stream = 0; mgr.Setup(0, ctx, in, 100, 1.25f); mgr.Run(0, 0, ctx, out, in, 100, 1.25f); } @@ -125,6 +126,7 @@ TEST(KernelManager, TemplateResize) { in.shape = {{ { 10, 10, 1 }, { 20, 20, 3 } }}; out.shape = {{ { 10, 10, 1 }, { 20, 20, 3 } }}; KernelContext ctx; + ctx.gpu.stream = 0; mgr.Setup(0, ctx, in, 100, 1.25f); mgr.Run(0, 0, ctx, out, in, 100, 1.25f); } diff --git a/dali/kernels/test/warp_test/warp_gpu_test.cu b/dali/kernels/test/warp_test/warp_gpu_test.cu index 36093be01fa..725fbebf0ad 100644 --- a/dali/kernels/test/warp_test/warp_gpu_test.cu +++ b/dali/kernels/test/warp_test/warp_gpu_test.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -69,6 +69,7 @@ void WarpGPU_Affine_Transpose(bool force_variable) { auto mapping_gpu = mm::alloc_raw_unique(1); TensorShape<2> out_shape = { img_tensor.shape[1], img_tensor.shape[0] }; KernelContext ctx = {}; + ctx.gpu.stream = 0; auto out_shapes_hw = make_span<1>(&out_shape); auto mappings = make_tensor_gpu<1>(mapping_gpu.get(), { 1 }); copy(mappings, make_tensor_cpu<1>(&mapping_cpu, { 1 })); @@ -169,6 +170,7 @@ TEST(WarpGPU, Affine_RotateScale_Single) { auto mapping_gpu = mm::alloc_raw_unique(1); TensorShape<2> out_shape = { img_tensor.shape[0] * scale, img_tensor.shape[1] * scale }; KernelContext ctx = {}; + ctx.gpu.stream = 0; auto out_shapes_hw = make_span<1>(&out_shape); auto mappings = make_tensor_gpu<1>(mapping_gpu.get(), { 1 }); copy(mappings, make_tensor_cpu<1>(&mapping_cpu, { 1 })); @@ -237,6 +239,7 @@ TEST(WarpGPU, Affine_RotateScale_Uniform) { auto mapping_gpu = mm::alloc_raw_unique(samples); TensorShape<2> out_shape = { img_tensor.shape[0] * scale, img_tensor.shape[1] * scale }; KernelContext ctx = {}; + ctx.gpu.stream = 0; std::vector> out_shapes_hw(samples); for (int i = 0; i < samples; i++) out_shapes_hw[i] = out_shape; diff --git a/dali/kernels/transpose/transpose_gpu_test.cc b/dali/kernels/transpose/transpose_gpu_test.cc index fdca54b2b12..e5ae015f43a 100644 --- a/dali/kernels/transpose/transpose_gpu_test.cc +++ b/dali/kernels/transpose/transpose_gpu_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +// Copyright (c) 2020, 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -68,6 +68,7 @@ TEST(TransposeGPU, Test4DAll) { UniformRandomFill(in_cpu, rng, 0, 1000); KernelContext ctx; + ctx.gpu.stream = 0; auto req = transpose.Setup(ctx, shape, make_span(perm), sizeof(int)); auto out_shape = req.output_shapes[0]; ASSERT_EQ(out_shape.num_elements(), shape.num_elements()); @@ -122,6 +123,7 @@ void RunPerfTest(RNG &rng, const TensorListShape<> &shape, span perm) UniformRandomFill(in_cpu, rng, 0, 100); KernelContext ctx; + ctx.gpu.stream = 0; auto req = transpose.Setup(ctx, shape, perm, sizeof(T)); auto out_shape = req.output_shapes[0]; ASSERT_EQ(out_shape.num_elements(), shape.num_elements()); diff --git a/dali/operators/generic/join.cc b/dali/operators/generic/join.cc index 2c71a1601ad..fd435b8e1ee 100644 --- a/dali/operators/generic/join.cc +++ b/dali/operators/generic/join.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -178,7 +178,8 @@ void TensorJoin::RunImpl(workspace_t &ws) { TensorListShape<> shape; if (new_axis) shape = out.shape(); - out.Copy(ws.template Input(copy_idx_), ws.has_stream() ? ws.stream() : 0); + out.Copy(ws.template Input(copy_idx_), ws.has_stream() ? ws.stream() + : AccessOrder::host()); if (new_axis) out.Resize(shape); out.SetLayout(output_layout_); diff --git a/dali/operators/generic/slice/slice_base.cu b/dali/operators/generic/slice/slice_base.cu index 33a85848767..ef1af2d3c49 100644 --- a/dali/operators/generic/slice/slice_base.cu +++ b/dali/operators/generic/slice/slice_base.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ bool SliceBaseGpu::SetupImpl(std::vector(1, 1); auto in_view = view(input); kernels::KernelContext ctx; + ctx.gpu.stream = ws.stream(); auto req = kmgr_.Setup(0, ctx, in_view, args_); output_desc[0].shape = req.output_shapes[0]; return true; diff --git a/dali/operators/image/remap/warp.h b/dali/operators/image/remap/warp.h index e0c3f6a8aeb..1681038c8dc 100644 --- a/dali/operators/image/remap/warp.h +++ b/dali/operators/image/remap/warp.h @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -111,7 +111,7 @@ class WarpOpImpl : public OpImplInterface { kernels::KernelContext GetContext(const Workspace &ws) { kernels::KernelContext context; - context.gpu.stream = ws.has_stream() ? ws.stream() : 0; + context.gpu.stream = ws.has_stream() ? ws.stream() : AccessOrder::null_stream(); return context; } diff --git a/dali/operators/image/resize/resize_base.cc b/dali/operators/image/resize/resize_base.cc index 91f418c2d3a..282416c6768 100644 --- a/dali/operators/image/resize/resize_base.cc +++ b/dali/operators/image/resize/resize_base.cc @@ -121,8 +121,6 @@ void ResizeBase::InitializeGPU(int minibatch_size, size_t temp_buffe minibatch_size_ = minibatch_size; } kmgr_.Resize(1, 0); - kmgr_.SetMemoryHint(temp_buffer_hint); - kmgr_.GetScratchpadAllocator(0).Reserve(temp_buffer_hint); } template diff --git a/dali/operators/image/resize/resize_op_impl_gpu.h b/dali/operators/image/resize/resize_op_impl_gpu.h index 5d3b8efa8f6..55ebd43f05c 100644 --- a/dali/operators/image/resize/resize_op_impl_gpu.h +++ b/dali/operators/image/resize/resize_op_impl_gpu.h @@ -79,7 +79,6 @@ class ResizeOpImplGPU : public ResizeBase::Impl { kernels::KernelRequirements &req = kmgr_.Setup(mb_idx, ctx, mb.input, param_slice); mb.out_shape = req.output_shapes[0].to_static(); } - kmgr_.ReserveMaxScratchpad(0); } void RunResize(DeviceWorkspace &ws, diff --git a/dali/operators/signal/decibel/to_decibels_op_gpu.cu b/dali/operators/signal/decibel/to_decibels_op_gpu.cu index d873d720949..4ce7cc3d112 100644 --- a/dali/operators/signal/decibel/to_decibels_op_gpu.cu +++ b/dali/operators/signal/decibel/to_decibels_op_gpu.cu @@ -1,4 +1,4 @@ -// Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -63,6 +63,7 @@ bool ToDecibelsImpl::SetupImpl(std::vector &output_desc, auto type = type2id::value; kernels::KernelContext ctx; + ctx.gpu.stream = ws.stream(); if (args_.ref_max) { auto& req_max = kmgr_max_.Setup(0, ctx, in_view); max_out_desc_.resize(1);