Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HIP compilation #39

Merged
merged 6 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/exchcxx/enums/kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,19 @@ enum class Kernel {
EPC18_2,
};

inline static bool supports_unpolarized(ExchCXX::Kernel kern) {
switch (kern) {
case ExchCXX::Kernel::EPC17_1:
case ExchCXX::Kernel::EPC17_2:
case ExchCXX::Kernel::EPC18_1:
case ExchCXX::Kernel::EPC18_2:
return false;
default:
return true;
}
}


extern BidirectionalMap<std::string, Kernel> kernel_map;

std::ostream& operator<<( std::ostream& out, Kernel kern );
Expand Down
9 changes: 4 additions & 5 deletions src/builtin_interface.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ namespace detail {
std::unique_ptr<BuiltinKernel>
gen_from_kern( Kernel kern, Spin polar ) {

// Bail if polarized eval is requested and not supported
EXCHCXX_BOOL_CHECK(kernel_map.key(kern) + " Needs to be Spin-Polarized!",
supports_unpolarized(kern) or polar == Spin::Polarized);

if( kern == Kernel::SlaterExchange )
return std::make_unique<BuiltinSlaterExchange>( polar );
else if( kern == Kernel::VWN3 )
Expand Down Expand Up @@ -119,22 +123,17 @@ std::unique_ptr<BuiltinKernel>
return std::make_unique<BuiltinPC07OPT_K>( polar );

else if( kern == Kernel::EPC17_1) {
EXCHCXX_BOOL_CHECK("EPC17_1 Needs to be Spin-Polarized!",polar==Spin::Polarized);
return std::make_unique<BuiltinEPC17_1>( polar );
} else if( kern == Kernel::EPC17_2) {
EXCHCXX_BOOL_CHECK("EPC17_2 Needs to be Spin-Polarized!",polar==Spin::Polarized);
return std::make_unique<BuiltinEPC17_2>( polar );
} else if( kern == Kernel::EPC18_1) {
EXCHCXX_BOOL_CHECK("EPC18_1 Needs to be Spin-Polarized!",polar==Spin::Polarized);
return std::make_unique<BuiltinEPC18_1>( polar );
} else if( kern == Kernel::EPC18_2) {
EXCHCXX_BOOL_CHECK("EPC18_2 Needs to be Spin-Polarized!",polar==Spin::Polarized);
return std::make_unique<BuiltinEPC18_2>( polar );

} else
throw std::runtime_error("Specified kernel does not have a builtin implementation");


}

BuiltinKernelInterface::~BuiltinKernelInterface() noexcept = default;
Expand Down
315 changes: 315 additions & 0 deletions src/hip/builtin.hip
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,209 @@ __global__ GGA_EXC_VXC_INC_GENERATOR( device_eval_exc_vxc_inc_helper_polar_kerne

}

template <typename KernelType>
__global__ MGGA_EXC_GENERATOR( device_eval_exc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;
traits::eval_exc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid], eps[tid] );

}

}


template <typename KernelType>
__global__ MGGA_EXC_GENERATOR( device_eval_exc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : nullptr;
auto* tau_i = tau + 2*tid;

const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;

traits::eval_exc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], eps[tid] );

}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_GENERATOR( device_eval_exc_vxc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;

double dummy;
auto& vlapl_return = traits::needs_laplacian ? vlapl[tid] : dummy;
traits::eval_exc_vxc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid],
eps[tid], vrho[tid], vsigma[tid], vlapl_return, vtau[tid] );

}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_GENERATOR( device_eval_exc_vxc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double dummy_vlapl[2];

if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : lapl;
auto* tau_i = tau + 2*tid;

auto* vrho_i = vrho + 2*tid;
auto* vsigma_i = vsigma + 3*tid;
auto* vlapl_i = traits::needs_laplacian ? vlapl + 2*tid : dummy_vlapl;
auto* vtau_i = vtau + 2*tid;
const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;

traits::eval_exc_vxc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], eps[tid], vrho_i[0], vrho_i[1], vsigma_i[0], vsigma_i[1],
vsigma_i[2], vlapl_i[0], vlapl_i[1], vtau_i[0], vtau_i[1] );

}

}


template <typename KernelType>
__global__ MGGA_EXC_INC_GENERATOR( device_eval_exc_inc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double e;
if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;
traits::eval_exc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid], e );
eps[tid] += scal_fact * e;


}

}

template <typename KernelType>
__global__ MGGA_EXC_INC_GENERATOR( device_eval_exc_inc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : lapl;
auto* tau_i = tau + 2*tid;

const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;

double e;
traits::eval_exc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], e );
eps[tid] += scal_fact * e;


}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_INC_GENERATOR( device_eval_exc_vxc_inc_helper_unpolar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double e, vr, vs, vl, vt;
if( tid < N ) {

const double lapl_use = traits::needs_laplacian ? lapl[tid] : 0.0;

traits::eval_exc_vxc_unpolar( rho[tid], sigma[tid], lapl_use, tau[tid],
e, vr, vs, vl, vt );
eps[tid] += scal_fact * e;
vrho[tid] += scal_fact * vr;
vsigma[tid] += scal_fact * vs;
vtau[tid] += scal_fact * vt;
if(traits::needs_laplacian) vlapl[tid] += scal_fact * vl;

}

}

template <typename KernelType>
__global__ MGGA_EXC_VXC_INC_GENERATOR( device_eval_exc_vxc_inc_helper_polar_kernel ) {

using traits = kernel_traits<KernelType>;
int tid = threadIdx.x + blockIdx.x * blockDim.x;

double dummy_vlapl[2];
if( tid < N ) {

auto* rho_i = rho + 2*tid;
auto* sigma_i = sigma + 3*tid;
auto* lapl_i = traits::needs_laplacian ? (lapl + 2*tid) : lapl;
auto* tau_i = tau + 2*tid;

auto* vrho_i = vrho + 2*tid;
auto* vsigma_i = vsigma + 3*tid;
auto* vlapl_i = traits::needs_laplacian ? vlapl + 2*tid : dummy_vlapl;
auto* vtau_i = vtau + 2*tid;

const double lapl_a_use = traits::needs_laplacian ? lapl_i[0] : 0.0;
const double lapl_b_use = traits::needs_laplacian ? lapl_i[1] : 0.0;


double e, vra, vrb, vsaa,vsab,vsbb, vla, vlb, vta, vtb;
traits::eval_exc_vxc_polar( rho_i[0], rho_i[1], sigma_i[0],
sigma_i[1], sigma_i[2], lapl_a_use, lapl_b_use, tau_i[0],
tau_i[1], e, vra, vrb, vsaa, vsab, vsbb, vla, vlb, vta, vtb );

eps[tid] += scal_fact * e;
vrho_i[0] += scal_fact * vra;
vrho_i[1] += scal_fact * vrb;
vsigma_i[0] += scal_fact * vsaa;
vsigma_i[1] += scal_fact * vsab;
vsigma_i[2] += scal_fact * vsbb;
vtau_i[0] += scal_fact * vta;
vtau_i[1] += scal_fact * vtb;
if(traits::needs_laplacian) {
vlapl_i[0] += scal_fact * vla;
vlapl_i[1] += scal_fact * vlb;
}

}

}

template <typename KernelType>
LDA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar ) {
Expand Down Expand Up @@ -582,6 +785,99 @@ GGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar ) {

}

template <typename KernelType>
MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}

template <typename KernelType>
MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}


template <typename KernelType>
MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_inc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );
device_eval_exc_inc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps
);

}

template <typename KernelType>
MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_unpolar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_inc_helper_unpolar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}

template <typename KernelType>
MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar ) {

dim3 threads(32);
dim3 blocks( util::div_ceil( N, threads.x) );

device_eval_exc_vxc_inc_helper_polar_kernel<KernelType><<<blocks,threads,0,stream>>>(
scal_fact, N, rho, sigma, lapl, tau, eps, vrho, vsigma, vlapl, vtau
);

}

#define LDA_GENERATE_DEVICE_HELPERS(KERN) \
template LDA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar<KERN> ); \
template LDA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_unpolar<KERN> ); \
Expand All @@ -602,6 +898,16 @@ GGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar ) {
template GGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_polar<KERN> ); \
template GGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar<KERN> );

#define MGGA_GENERATE_DEVICE_HELPERS(KERN) \
template MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_unpolar<KERN> ); \
template MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_unpolar<KERN> ); \
template MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_unpolar<KERN> ); \
template MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_unpolar<KERN> );\
template MGGA_EXC_GENERATOR_DEVICE( device_eval_exc_helper_polar<KERN> ); \
template MGGA_EXC_VXC_GENERATOR_DEVICE( device_eval_exc_vxc_helper_polar<KERN> ); \
template MGGA_EXC_INC_GENERATOR_DEVICE( device_eval_exc_inc_helper_polar<KERN> ); \
template MGGA_EXC_VXC_INC_GENERATOR_DEVICE( device_eval_exc_vxc_inc_helper_polar<KERN> );

LDA_GENERATE_DEVICE_HELPERS( BuiltinSlaterExchange );
LDA_GENERATE_DEVICE_HELPERS( BuiltinVWN3 );
LDA_GENERATE_DEVICE_HELPERS( BuiltinVWN_RPA );
Expand All @@ -624,6 +930,15 @@ MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCAN_X );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCAN_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCAN_X );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCAN_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinFT98_X );

MGGA_GENERATE_DEVICE_HELPERS( BuiltinPC07_K );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinPC07OPT_K );

MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCANL_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinSCANL_X );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCANL_C );
MGGA_GENERATE_DEVICE_HELPERS( BuiltinR2SCANL_X );

LDA_GENERATE_DEVICE_HELPERS( BuiltinEPC17_1 )
LDA_GENERATE_DEVICE_HELPERS( BuiltinEPC17_2 )
Expand Down
Loading
Loading