Skip to content

Commit

Permalink
Make all of TH and THC C++. (pytorch#6913)
Browse files Browse the repository at this point in the history
Changelist:

- Move *.c to *.cpp
- Change includes of ".c" to ".cpp"
- A bunch of cmake configuration modifying CMAKE_C_FLAGS changed
to CMAKE_CXX_FLAGS or add_compile_options, because if you do CMAKE_C_FLAGS it only applies when you compile C code
- Explicitly cast void* to T* in a number of places
- Delete extern "C" { ... } blocks; instead, properly apply TH_API to everything that should have it (TH_API handles extern "C")
- Stop using stdatomic.h, instead, use <atomic>. This resulted in a bunch of placement-new/delete to be "totally properly correct"
- Refactor of THLongStorageView to not have static constructor methods (since it no longer has a copy/move constructor)
- Documentation about how the TH C interface (and extern C business) works
- Note that THD master_worker mode is dead
- C++ headers in TH libraries are given .hpp suffix, to make it less likely that you'll confuse them with the C-compatible headers (now suffixed .h)
- New function THCStream_stream and THCStream_device to project out fields of THCStream instead of accessing fields directly
- New function THStorage_(retainIfLive), which is equivalent to a retain but only if the refcount is greater than zero.
- In general, I tried to avoid using hpp headers outside of ATen/TH. However, there were a few places where I gave up and depended on the headers for my own sanity. See Note [TH abstraction violation] for all the sites where this occurred. All other sites were refactored to use functions
- Some extra Werror fixes (char* versus const char*)
  • Loading branch information
ezyang authored Apr 28, 2018
1 parent 4667983 commit 4caea64
Show file tree
Hide file tree
Showing 219 changed files with 916 additions and 849 deletions.
36 changes: 21 additions & 15 deletions aten/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ ENDIF()
IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
MESSAGE(STATUS "Found CUDA with FP16 support, compiling with torch.CudaHalfTensor")
LIST(APPEND CUDA_NVCC_FLAGS "-DCUDA_HAS_FP16=1 -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__")
SET(CMAKE_C_FLAGS "-DCUDA_HAS_FP16=1 ${CMAKE_C_FLAGS}")
add_compile_options(-DCUDA_HAS_FP16=1)
ELSE(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
MESSAGE(STATUS "Could not find CUDA with FP16 support, compiling without torch.CudaHalfTensor")
ENDIF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
Expand Down Expand Up @@ -162,7 +162,7 @@ IF (APPLE AND CMAKE_COMPILER_IS_GNUCC)
IF (APPLE_OPENMP_SUCKS AND GCC_VERSION VERSION_LESS 4.6.2)
MESSAGE(STATUS "Warning: Disabling OpenMP (unstable with this version of GCC)")
MESSAGE(STATUS " Install GCC >= 4.6.2 or change your OS to enable OpenMP")
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wno-unknown-pragmas")
add_compile_options(-Wno-unknown-pragmas)
SET(WITH_OPENMP OFF CACHE BOOL "OpenMP support if available?" FORCE)
ENDIF ()
ENDIF ()
Expand Down Expand Up @@ -212,18 +212,18 @@ ENDIF()
FIND_PACKAGE(ARM)
IF (ASIMD_FOUND)
MESSAGE(STATUS "asimd/Neon found with compiler flag : -D__NEON__")
SET(CMAKE_C_FLAGS "-D__NEON__ ${CMAKE_C_FLAGS}")
add_compile_options(-D__NEON__)
ELSEIF (NEON_FOUND)
MESSAGE(STATUS "Neon found with compiler flag : -mfpu=neon -D__NEON__")
SET(CMAKE_C_FLAGS "-mfpu=neon -D__NEON__ ${CMAKE_C_FLAGS}")
add_compile_options(-mfpu=neon -D__NEON__)
ENDIF (ASIMD_FOUND)
IF (CORTEXA8_FOUND)
MESSAGE(STATUS "Cortex-A8 Found with compiler flag : -mcpu=cortex-a8")
SET(CMAKE_C_FLAGS "-mcpu=cortex-a8 -fprefetch-loop-arrays ${CMAKE_C_FLAGS}")
add_compile_options(-mcpu=cortex-a8 -fprefetch-loop-arrays)
ENDIF (CORTEXA8_FOUND)
IF (CORTEXA9_FOUND)
MESSAGE(STATUS "Cortex-A9 Found with compiler flag : -mcpu=cortex-a9")
SET(CMAKE_C_FLAGS "-mcpu=cortex-a9 ${CMAKE_C_FLAGS}")
add_compile_options(-mcpu=cortex-a9)
ENDIF (CORTEXA9_FOUND)

IF(UNIX)
Expand Down Expand Up @@ -264,7 +264,7 @@ IF(HAVE_CPUID_H)
}" HAVE_GCC_GET_CPUID)
ENDIF()
IF(HAVE_GCC_GET_CPUID)
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DHAVE_GCC_GET_CPUID")
add_compile_options(-DHAVE_GCC_GET_CPUID)
ENDIF(HAVE_GCC_GET_CPUID)

CHECK_C_SOURCE_COMPILES("#include <stdint.h>
Expand All @@ -282,34 +282,40 @@ CHECK_C_SOURCE_COMPILES("#include <stdint.h>
}" NO_GCC_EBX_FPIC_BUG)

IF(NOT NO_GCC_EBX_FPIC_BUG)
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_GCC_GET_CPUID")
add_compile_options(-DUSE_GCC_GET_CPUID)
ENDIF(NOT NO_GCC_EBX_FPIC_BUG)

FIND_PACKAGE(SSE) # checks SSE, AVX and AVX2
IF(C_SSE2_FOUND)
MESSAGE(STATUS "SSE2 Found")
SET(CMAKE_C_FLAGS "${C_SSE2_FLAGS} -DUSE_SSE2 ${CMAKE_C_FLAGS}")
# TODO: Work out correct way to do this. Note that C_SSE2_FLAGS is often
# empty, in which case it expands to " " flag which is bad
SET(CMAKE_C_FLAGS "${C_SSE2_FLAGS} ${CMAKE_C_FLAGS}")
SET(CMAKE_CXX_FLAGS "${C_SSE2_FLAGS} ${CMAKE_CXX_FLAGS}")
add_compile_options(-DUSE_SSE2)
ENDIF(C_SSE2_FOUND)
IF(C_SSE4_1_FOUND AND C_SSE4_2_FOUND)
SET(CMAKE_C_FLAGS "${C_SSE4_1_FLAGS} -DUSE_SSE4_1 ${C_SSE4_2_FLAGS} -DUSE_SSE4_2 ${CMAKE_C_FLAGS}")
SET(CMAKE_C_FLAGS "${C_SSE4_1_FLAGS} ${C_SSE4_2_FLAGS} ${CMAKE_C_FLAGS}")
SET(CMAKE_CXX_FLAGS "${C_SSE4_1_FLAGS} ${C_SSE4_2_FLAGS} ${CMAKE_CXX_FLAGS}")
add_compile_options(-DUSE_SSE4_1 -DUSE_SSE4_2)
ENDIF()
IF(C_SSE3_FOUND)
MESSAGE(STATUS "SSE3 Found")
SET(CMAKE_C_FLAGS "${C_SSE3_FLAGS} -DUSE_SSE3 ${CMAKE_C_FLAGS}")
SET(CMAKE_CXX_FLAGS "${C_SSE3_FLAGS} -DUSE_SSE3 ${CMAKE_CXX_FLAGS}")
SET(CMAKE_C_FLAGS "${C_SSE3_FLAGS} ${CMAKE_C_FLAGS}")
SET(CMAKE_CXX_FLAGS "${C_SSE3_FLAGS} ${CMAKE_CXX_FLAGS}")
add_compile_options(-DUSE_SSE3)
ENDIF(C_SSE3_FOUND)

# we don't set -mavx and -mavx2 flags globally, but only for specific files
# however, we want to enable the AVX codepaths, so we still need to
# add USE_AVX and USE_AVX2 macro defines
IF(C_AVX_FOUND)
MESSAGE(STATUS "AVX Found")
SET(CMAKE_C_FLAGS "-DUSE_AVX ${CMAKE_C_FLAGS}")
add_compile_options(-DUSE_AVX)
ENDIF(C_AVX_FOUND)
IF(C_AVX2_FOUND)
MESSAGE(STATUS "AVX2 Found")
SET(CMAKE_C_FLAGS "-DUSE_AVX2 ${CMAKE_C_FLAGS}")
SET(CMAKE_CXX_FLAGS "-DUSE_AVX2 ${CMAKE_CXX_FLAGS}")
add_compile_options(-DUSE_AVX2)
ENDIF(C_AVX2_FOUND)

CHECK_C_SOURCE_RUNS("
Expand Down
21 changes: 10 additions & 11 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,31 @@ ENDIF()
# so we need to set these commands here rather than in src/TH
IF(C_SSE4_1_FOUND AND C_SSE4_2_FOUND)
IF(MSVC)
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.c PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast")
ELSE(MSVC)
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_sse.cpp PROPERTIES COMPILE_FLAGS "-O3 -ffast-math")
ENDIF(MSVC)
ENDIF(C_SSE4_1_FOUND AND C_SSE4_2_FOUND)
IF(C_AVX_FOUND)
IF(MSVC)
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast ${C_AVX_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.c PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX ${C_AVX_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/fp:fast ${CXX_AVX_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX ${CXX_AVX_FLAGS}")
ELSE(MSVC)
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.c PROPERTIES COMPILE_FLAGS "-O3 -ffast-math ${C_AVX_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.c PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/generic/simd/convolve5x5_avx.cpp PROPERTIES COMPILE_FLAGS "-O3 -ffast-math ${CXX_AVX_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX.cpp PROPERTIES COMPILE_FLAGS "-O3 ${CXX_AVX_FLAGS}")
ENDIF(MSVC)
ENDIF(C_AVX_FOUND)

IF(C_AVX2_FOUND)
IF(MSVC)
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX2 ${C_AVX2_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX2 ${CXX_AVX2_FLAGS}")
ELSE(MSVC)
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX2_FLAGS}")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "-O3 ${CXX_AVX2_FLAGS}")
ENDIF(MSVC)
ENDIF(C_AVX2_FOUND)

IF(NOT MSVC AND NOT "${CMAKE_C_COMPILER_ID}" MATCHES "Clang")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/THAtomic.c PROPERTIES COMPILE_FLAGS "-fno-openmp")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/THAllocator.c PROPERTIES COMPILE_FLAGS "-fno-openmp")
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/THAllocator.cpp PROPERTIES COMPILE_FLAGS "-fno-openmp")
ENDIF()

FILE(GLOB cpu_kernel_cpp_in RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "native/cpu/*.cpp")
Expand Down Expand Up @@ -332,7 +331,7 @@ ENDIF(NOT MSVC)
IF(NOT C_HAS_THREAD)
MESSAGE(STATUS "Warning: __thread is not supported, generating thread-unsafe code")
ELSE(NOT C_HAS_THREAD)
SET(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DTH_HAVE_THREAD")
add_compile_options(-DTH_HAVE_THREAD)
ENDIF(NOT C_HAS_THREAD)

if(MKLDNN_FOUND)
Expand Down
59 changes: 45 additions & 14 deletions aten/src/ATen/THLongStorageView.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "TH/TH.h"
#include "TH/THStorage.hpp"

namespace at {

Expand All @@ -10,34 +11,63 @@ static inline bool is_noelem_tensor_size(ArrayRef<int64_t> size) {
return size.size() == 1 && size[0] == 0;
}

enum class THLongStorageViewKind {
SIZE,
// noelem_to_empty is to differentiate strides of empty tensors vs scalars. In ATen, both may have strides [1],
// but in TH an empty tensor should have stride [], while a scalar should have stride [1].
STRIDE_EMPTY_TENSOR, // noelem_to_empty = true
STRIDE_SCALAR, // noelem_to_empty = false
LENGTH,
};

// make a fake storage out of a size, pointer pair...
// used as an argument where THSize and THStride are passed into TH
class THLongStorageView {
public:
operator THLongStorage*() {
if (storage.size == 0 && zero_dim_to_null) {
return nullptr;
}
return &storage;
}

/*
// This is done as an enum, and not as these static constructors, as there
// is no move/copy constructor for THLongStorageView
static THLongStorageView makeFromSize(ArrayRef<int64_t> ref) {
return THLongStorageView(ref, true, false, false);
}
// noelem_to_empty is to differentiate strides of empty tensors vs scalars. In ATen, both may have strides [1],
// but in TH an empty tensor should have stride [], while a scalar should have stride [1].
static THLongStorageView makeFromStride(ArrayRef<int64_t> ref, bool noelem_to_empty) {
return THLongStorageView(ref, false, true, noelem_to_empty);
}
static THLongStorageView makeFromLength(ArrayRef<int64_t> ref) {
return THLongStorageView(ref, false, false, false);
}
operator THLongStorage*() {
if (storage.size == 0 && zero_dim_to_null) {
return nullptr;
}
return &storage;
}
private:
// zero_dim_to_one converts an empty ArrayRef into [1]
// zero_dim_to_null converts an empty ArrayRef into a null THLongStorage
// noelem_to_empty makes an ArrayRef of [0] into an empty THLongStorage
THLongStorageView(ArrayRef<int64_t> ref, bool zero_dim_to_one, bool zero_dim_to_null, bool noelem_to_empty)
: zero_dim_to_null(zero_dim_to_null)
*/

THLongStorageView(ArrayRef<int64_t> ref, THLongStorageViewKind kind)
: zero_dim_to_null(false)
{
// zero_dim_to_one converts an empty ArrayRef into [1]
// zero_dim_to_null converts an empty ArrayRef into a null THLongStorage
// noelem_to_empty makes an ArrayRef of [0] into an empty THLongStorage
bool zero_dim_to_one = false;
bool noelem_to_empty = false;
switch (kind) {
case THLongStorageViewKind::SIZE:
zero_dim_to_one = true;
break;
case THLongStorageViewKind::STRIDE_EMPTY_TENSOR:
zero_dim_to_null = true;
noelem_to_empty = true;
break;
case THLongStorageViewKind::STRIDE_SCALAR:
zero_dim_to_null = true;
break;
case THLongStorageViewKind::LENGTH:
break;
}
if(zero_dim_to_one && ref.size() == 0) {
// make storage of size 0 actually a 1-length storage with 1 element
// so that our 0-dim tensors get allocated as 1-dim inside TH
Expand All @@ -57,6 +87,7 @@ class THLongStorageView {
storage.allocator = nullptr;
storage.allocatorContext = nullptr;
}
private:
int64_t one;
THLongStorage storage;
bool zero_dim_to_null;
Expand Down
36 changes: 26 additions & 10 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,23 @@ def __init__(self, reason):
'THGenerator*':
CodeTemplate(
'check_generator<${Backend}Generator>(${arg_name}, &context->defaultGenerator(backend()))'),
'THSize*': CodeTemplate('THLongStorageView::makeFromSize(${arg_name})'),
'THStride*': CodeTemplate('THLongStorageView::makeFromStride(${arg_name}, ${noelem_to_empty})'),
# This is a cast done via direct-construction
'THSize*': CodeTemplate('THLongStorageView ${result_name}(${arg_name}, THLongStorageViewKind::SIZE);'),
# This is a cast done via direct-construction
'THStride*':
CodeTemplate(
'THLongStorageView ${result_name}(${arg_name}, '
'${noelem_to_empty} ? '
'THLongStorageViewKind::STRIDE_EMPTY_TENSOR : THLongStorageViewKind::STRIDE_SCALAR);'),
'real': CodeTemplate('${arg_name}.to${ScalarName}()'),
'accreal': CodeTemplate('${arg_name}.to${AccScalarName}()'),
'TensorList': CodeTemplate('tensor_list_checked_cast<${Tensor}, Tensor, '
'${THTensor}>(${arg_name},"${arg_name}",${arg_pos})'),
'IntList': CodeTemplate('check_intlist<${size}>(${arg_name}, "${arg_name}", ${arg_pos}${,default_init})')
}

DIRECT_CONSTRUCTION_CHECKED_CAST = {'THSize*', 'THStride*'}

CHECKED_USE = {
'THTensor*': '{}_->tensor',
'THSTensor*': '{}_->tensor',
Expand Down Expand Up @@ -271,7 +279,7 @@ def __init__(self, reason):
CONSTANT_REPLACEMENTS = [
('AS_REAL', '${AS_REAL}'),
('__storage_size.get\\(\\)',
'THLongStorageView::makeFromLength(static_cast<int64_t>(storage.size()))'),
'THLongStorageView(static_cast<int64_t>(storage.size()), THLongStorageViewKind::LENGTH)'),
('__last_dim', 'self.ndimension()-1'),
]

Expand Down Expand Up @@ -1235,13 +1243,21 @@ def emit_body(env, option):
default_init.append(arg['default_init'])

noelem_to_empty = 'is_noelem_tensor_size(size)' if 'size' in seen_names else 'false'
check_cast = CHECKED_CAST[arg['type']].substitute(
env, arg_name=arg['name'], arg_pos=count,
null_okay=null_okay, default_init=default_init,
size=arg.get('size'),
noelem_to_empty=noelem_to_empty)
body.append("auto {}_ = {};".format(
arg['name'], check_cast))
if arg['type'] in DIRECT_CONSTRUCTION_CHECKED_CAST:
body.append(CHECKED_CAST[arg['type']].substitute(
env, arg_name=arg['name'], arg_pos=count,
null_okay=null_okay, default_init=default_init,
size=arg.get('size'),
noelem_to_empty=noelem_to_empty,
result_name=arg['name'] + '_'))
else:
check_cast = CHECKED_CAST[arg['type']].substitute(
env, arg_name=arg['name'], arg_pos=count,
null_okay=null_okay, default_init=default_init,
size=arg.get('size'),
noelem_to_empty=noelem_to_empty)
body.append("auto {}_ = {};".format(
arg['name'], check_cast))
if drop_argument(arg, option) or replace_with_null(arg):
body.append(
"(void) {}_; //silence unused warning".format(arg['name']))
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,12 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
if backend == 'CUDA':
env['th_headers'] = [
'#include <THC/THC.h>',
'#include <THC/THCTensor.hpp>',
'#include <THCUNN/THCUNN.h>',
'#undef THNN_',
'#undef THCIndexTensor_',
'#include <THCS/THCS.h>',
'#include <THCS/THCSTensor.hpp>',
'#undef THCIndexTensor_',
]
env['extra_cuda_headers'] = ['#include <ATen/cuda/CUDAHalf.cuh>']
Expand All @@ -263,9 +265,11 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
else:
env['th_headers'] = [
'#include <TH/TH.h>',
'#include <TH/THTensor.hpp>',
'#include <THNN/THNN.h>',
'#undef THNN_',
'#include <THS/THS.h>',
'#include <THS/THSTensor.hpp>',
]
env['extra_cuda_headers'] = []
env['THType'] = scalar_name
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <functional>

#include "TH/THRandom.h"
#include "TH/THGenerator.h"
#include "TH/THGenerator.hpp"
#include "TH/THMath.h"

namespace {
Expand Down
6 changes: 2 additions & 4 deletions aten/src/ATen/native/cuda/Distributions.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@

#include "ATen/native/Distributions.h"

#include <TH/THAtomic.h>

#include <THC/THCGeneral.h>
#include <THC/THCTensorRandom.h>
#include <THC/THCGenerator.h>
#include <THC/THCGenerator.hpp>
#include <THC/THCApply.cuh>
#include <THC/THCNumerics.cuh>

Expand All @@ -32,7 +30,7 @@ THCGenerator* THCRandom_getGenerator(THCState* state);
namespace {
std::pair<uint64_t, uint64_t> next_philox_seed(at::Generator* gen, uint64_t increment) {
auto gen_ = THCRandom_getGenerator(at::globalContext().thc_state);
uint64_t offset = THAtomicAddLong(&gen_->state.philox_seed_offset, increment);
uint64_t offset = gen_->state.philox_seed_offset.fetch_add(increment);
return std::make_pair(gen_->state.initial_seed, offset);
}

Expand Down
Loading

0 comments on commit 4caea64

Please sign in to comment.