Skip to content

Commit

Permalink
patching nvfuser conv cudnn test numerics mismatch (#2048)
Browse files Browse the repository at this point in the history
Tests failed on upstream, not yet in our devel branch. Disabling TF32 in the test, which creates numerical issue when validating outputs.
  • Loading branch information
jjsjann123 authored Oct 11, 2022
1 parent 65af1a4 commit 7117a7e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions torch/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2976,6 +2976,7 @@ TEST_F(NVFuserTest, FusionConv2D_CUDA) {
TEST_F(NVFuserTest, FusionConv2DNoPadding_CUDA) {
Fusion fusion;
FusionGuard fg(&fusion);
ContextCudnnTF32Disabled disabling_tf32_cudnn;

// Input: [C, H, W]
auto inp = makeSymbolicTensor(3);
Expand Down
16 changes: 16 additions & 0 deletions torch/csrc/jit/codegen/cuda/test/test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <torch/csrc/jit/codegen/cuda/lower_magic_zero.h>
#include <torch/csrc/jit/codegen/cuda/transform_replay.h>

#include <ATen/Context.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <torch/torch.h>
Expand Down Expand Up @@ -340,6 +341,21 @@ struct TransformPropagatorWithCheck : public TransformPropagator {

} // namespace

class ContextCudnnTF32Disabled {
public:
ContextCudnnTF32Disabled() {
flag_ = at::globalContext().allowTF32CuDNN();
at::globalContext().setAllowTF32CuDNN(false);
}

~ContextCudnnTF32Disabled() {
at::globalContext().setAllowTF32CuDNN(flag_);
}

private:
bool flag_;
};

// Fixture class must be uniquely identified, i.e., can't be in an
// anonymous namespace
class NVFuserTest : public ::testing::Test {
Expand Down

0 comments on commit 7117a7e

Please sign in to comment.