Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Kokkos::complex #293

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pykokkos/core/visitors/pykokkos_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,17 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
)
elif name in ["PerTeam", "PerThread", "fence"]:
name = "Kokkos::" + name
elif name in {"complex32", "complex64"}:
name = "Kokkos::complex"
if "32" in name:
name += "<float>"
else:
name += "<double>"

function = cppast.DeclRefExpr(name)
args: List[cppast.Expr] = [self.visit(a) for a in node.args]

if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::PerTeam", "Kokkos::PerThread", "Kokkos::fence"]:
if visitors_util.is_math_function(name) or name in ["printf", "abs", "Kokkos::PerTeam", "Kokkos::PerThread", "Kokkos::fence", "Kokkos::complex<float>", "Kokkos::complex<double>"]:
return cppast.CallExpr(function, args)

if function in self.kokkos_functions:
Expand Down
7 changes: 5 additions & 2 deletions pykokkos/core/visitors/visitors_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def pretty_print(node):
"double": "double",
"bool": "bool",
"TeamMember": f"Kokkos::TeamPolicy<{Keywords.DefaultExecSpace.value}>::member_type",
"cpp_auto": "auto"
"cpp_auto": "auto",
"complex32": "Kokkos::complex<float>",
"complex64": "Kokkos::complex<double>"
}

# Maps from the DataType enum to cppast
Expand Down Expand Up @@ -307,7 +309,8 @@ def parse_view_template_params(

if parameter in ("int", "double", "float",
"int8_t", "int16_t", "int32_t", "int64_t",
"uint8_t", "uint16_t", "uint32_t", "uint64_t"):
"uint8_t", "uint16_t", "uint32_t", "uint64_t",
"Kokkos::complex<float>", "Kokkos::complex<double>"):
datatype: str = parameter + "*" * rank
params["dtype"] = datatype

Expand Down
1 change: 1 addition & 0 deletions pykokkos/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
uint16, uint32, uint64,
float, double, real,
float32, float64, bool,
complex32, complex64
)
from .decorators import (
callback, classtype, Decorator, function, functor, main,
Expand Down
102 changes: 102 additions & 0 deletions pykokkos/interface/data_types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from enum import Enum

from pykokkos.bindings import kokkos
import pykokkos.kokkos_manager as km

import numpy as np


Expand All @@ -26,6 +28,8 @@ class DataType(Enum):
float64 = kokkos.double
real = None
bool = np.bool_
complex32 = kokkos.complex_float32_dtype
complex64 = kokkos.complex_float64_dtype


class DataTypeClass:
Expand Down Expand Up @@ -91,3 +95,101 @@ class float64(DataTypeClass):
class bool(DataTypeClass):
value = kokkos.uint8
np_equiv = np.bool_


class complex(DataTypeClass):
def __add__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot add '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex + other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex + other.kokkos_complex)

def __iadd__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot add '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex += other.kokkos_complex
return self

def __sub__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot subtract '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex - other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex - other.kokkos_complex)

def __isub__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot subtract '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex -= other.kokkos_complex
return self

def __mul__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot multiply '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex * other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex * other.kokkos_complex)

def __imul__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot multiply '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex *= other.kokkos_complex
return self

def __truediv__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot divide '{}' and '{}'".format(type(other), type(self)))

if isinstance(self, complex32):
return complex32(self.kokkos_complex / other.kokkos_complex)
elif isinstance(self, complex64):
return complex64(self.kokkos_complex / other.kokkos_complex)

def __itruediv__(self, other):
if not isinstance(other, type(self)):
raise TypeError("cannot divide '{}' and '{}'".format(type(other), type(self)))

self.kokkos_complex /= other.kokkos_complex
return self

def __repr__(self):
return f"({self.kokkos_complex.real_const()}, {self.kokkos_complex.imag_const()})"

@property
def real(self):
return self.kokkos_complex.real_const()

@property
def imag(self):
return self.kokkos_complex.imag_const()

class complex32(complex):
value = kokkos.complex_float32_dtype
np_equiv = np.complex64 # 32 bits from real + 32 from imaginary

def __init__(self, real: float, imaginary: float = 0.0):
if isinstance(real, kokkos.complex_float32):
self.kokkos_complex = real
else:
self.kokkos_complex = km.get_kokkos_module(is_cpu=True).complex_float32(real, imaginary)


class complex64(complex):
value = kokkos.complex_float64_dtype
np_equiv = np.complex128 # 64 bits from real + 64 from imaginary

def __init__(self, real: float, imaginary: float = 0.0):
if isinstance(real, kokkos.complex_float64):
self.kokkos_complex = real
else:
self.kokkos_complex = km.get_kokkos_module(is_cpu=True).complex_float64(real, imaginary)
31 changes: 27 additions & 4 deletions pykokkos/interface/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
uint8,
uint16, uint32, uint64,
double, float32, float64,
complex32, complex64
)
from .data_types import float as pk_float
from .layout import get_default_layout, Layout
Expand Down Expand Up @@ -98,13 +99,16 @@ def extent(self, dimension: int) -> int:

return self.shape[dimension]

def fill(self, value: Union[int, float]) -> None:
def fill(self, value: Union[int, float, complex, complex32, complex64]) -> None:
"""
Sets all elements to a scalar value

:param value: the scalar value
"""

if isinstance(value, (complex, complex32, complex64)):
value = np.complex64(value.real, value.imag) if self.dtype is complex32 else np.complex128(value.real, value.imag)

if self.trait is Trait.Unmanaged:
self.xp_array.fill(value)
else:
Expand All @@ -126,8 +130,16 @@ def __getitem__(self, key: Union[int, TeamMember, slice, Tuple]) -> Union[int, f

if isinstance(key, int) or isinstance(key, TeamMember):
if self.trait is Trait.Unmanaged:
return self.xp_array[key]
return self.data[key]
return_val = self.xp_array[key]
else:
return_val = self.data[key]

if self.dtype is complex32:
return_val = complex32(return_val.real, return_val.imag)
elif self.dtype is complex64:
return_val = complex64(return_val.real, return_val.imag)

return return_val

length: int = 1 if isinstance(key, slice) else len(key)
if length != self.rank():
Expand All @@ -137,7 +149,7 @@ def __getitem__(self, key: Union[int, TeamMember, slice, Tuple]) -> Union[int, f

return subview

def __setitem__(self, key: Union[int, TeamMember], value: Union[int, float]) -> None:
def __setitem__(self, key: Union[int, TeamMember], value: Union[int, float, complex, complex32, complex64]) -> None:
"""
Overloads the indexing operator setting an item in the View.

Expand All @@ -148,6 +160,9 @@ def __setitem__(self, key: Union[int, TeamMember], value: Union[int, float]) ->
if "PK_FUSION" in os.environ:
runtime_singleton.runtime.flush_data(self)

if isinstance(value, (complex, complex32, complex64)):
value = np.complex64(value.real, value.imag) if self.dtype is complex32 else np.complex128(value.real, value.imag)

if self.trait is Trait.Unmanaged:
self.xp_array[key] = value
else:
Expand Down Expand Up @@ -674,6 +689,10 @@ def from_numpy(array: np.ndarray, space: Optional[MemorySpace] = None, layout: O
dtype = float64
elif np_dtype is np.bool_:
dtype = uint8
elif np_dtype is np.complex64:
dtype = complex32
elif np_dtype is np.complex128:
dtype = complex64
else:
raise RuntimeError(f"ERROR: unsupported numpy datatype {np_dtype}")

Expand Down Expand Up @@ -736,6 +755,10 @@ def from_array(array) -> ViewType:
ctype = ctypes.c_double
elif np_dtype is np.bool_:
ctype = ctypes.c_uint8
elif np_dtype is np.complex64:
dtype = complex32
elif np_dtype is np.complex128:
dtype = complex64
else:
raise RuntimeError(f"ERROR: unsupported numpy datatype {np_dtype}")

Expand Down
Loading