Skip to content

Commit

Permalink
CtxMixin (#1548)
Browse files Browse the repository at this point in the history
- DRY-fy #1545 (cd30189)
- Apply the fix to two more classes with the same issue (`Domain`, `Group`)
  • Loading branch information
gsakkis authored Jan 11, 2023
1 parent cd30189 commit af02b8e
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 203 deletions.
27 changes: 7 additions & 20 deletions tiledb/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from typing import Any, Sequence, TYPE_CHECKING, Union

import tiledb.cc as lt
from .ctx import default_ctx
from .ctx import CtxMixin
from .filter import FilterList, Filter
from .util import array_type_ncells, numpy_dtype, tiledb_type_is_datetime

if TYPE_CHECKING:
from .libtiledb import Ctx


class Attr(lt.Attribute):
class Attr(CtxMixin, lt.Attribute):
"""
Represents a TileDB attribute.
"""
Expand Down Expand Up @@ -46,15 +46,13 @@ def __init__(
:raises: :py:exc:`tiledb.TileDBError`
"""
self._ctx = ctx or default_ctx()
_dtype = None

if _capsule is not None:
return super().__init__(self._ctx, _capsule)
return super().__init__(ctx, _capsule)

if _lt_obj is not None:
return super().__init__(_lt_obj)
return super().__init__(ctx, _lt_obj=_lt_obj)

_dtype = None
if isinstance(dtype, str) and dtype == "ascii":
tiledb_dtype = lt.DataType.STRING_ASCII
_ncell = lt.TILEDB_VAR_NUM()
Expand Down Expand Up @@ -96,14 +94,7 @@ def __init__(

var = var or False

try:
super().__init__(self._ctx, name, tiledb_dtype)
except lt.TileDBError as e:
# we set this here because if the super().__init__() constructor above
# fails, we want to check if self._ctx is None in __repr__ so we can
# perform a safe repr
self._ctx = None
raise lt.TileDBError(e) from None
super().__init__(ctx, name, tiledb_dtype)

if _ncell:
self._ncell = _ncell
Expand Down Expand Up @@ -180,7 +171,7 @@ def filters(self) -> FilterList:
:raises: :py:exc:`tiledb.TileDBError`
"""
return FilterList(_lt_obj=self._filters)
return FilterList(ctx=self._ctx, _lt_obj=self._filters)

def _get_fill(self, value, dtype) -> Any:
if dtype in (lt.DataType.CHAR, lt.DataType.BLOB):
Expand Down Expand Up @@ -242,10 +233,6 @@ def isascii(self) -> bool:
return self._tiledb_dtype == lt.DataType.STRING_ASCII

def __repr__(self):
# use safe repr if pybind11 constructor failed
if self._ctx is None:
return object.__repr__(self)

filters_str = ""
if self.filters:
filters_str = ", filters=FilterList(["
Expand Down
20 changes: 20 additions & 0 deletions tiledb/ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,26 @@ def get_stats(self, print_out: bool = True, json: bool = False):
return output


class CtxMixin:
"""
Base mixin class for pure Python classes that extend PyBind11 TileDB classes.
To use this class, a subclass must:
- Inherit from it first (i.e. `class Foo(CtxMixin, Bar)`, not `class Foo(Bar, CtxMixin)`
- Call super().__init__ by passing `ctx` (tiledb.Ctx or None) as first parameter and
- either zero or more pure Python positional parameters
- or a single `_lt_obj` PyBind11 parameter
"""

def __init__(self, ctx, *args, _lt_obj=None):
if not ctx:
ctx = default_ctx()
super().__init__(_lt_obj or ctx, *args)
# we set this here because if the super().__init__() constructor above fails,
# we don't want to set self._ctx
self._ctx = ctx


def check_ipykernel_warn_once():
"""
This function checks if we have imported ipykernel version < 6 in the
Expand Down
25 changes: 5 additions & 20 deletions tiledb/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

import tiledb.cc as lt
from .ctx import default_ctx
from .ctx import CtxMixin
from .filter import FilterList, Filter
from .util import (
dtype_to_tiledb,
Expand Down Expand Up @@ -77,7 +77,7 @@ def _tiledb_cast_domain(
return (np_dtype(domain[0]), np_dtype(domain[1]))


class Dim(lt.Dimension):
class Dim(CtxMixin, lt.Dimension):
"""
Represents a TileDB dimension.
"""
Expand Down Expand Up @@ -110,10 +110,8 @@ def __init__(
:param tiledb.Ctx ctx: A TileDB Context
"""
self._ctx = ctx or default_ctx()

if _lt_obj is not None:
return super().__init__(_lt_obj)
return super().__init__(ctx, _lt_obj=_lt_obj)

if var is not None:
if var and np.dtype(dtype) not in (np.str_, np.bytes_):
Expand Down Expand Up @@ -169,16 +167,7 @@ def __init__(
if tile_size_array.size != 1:
raise ValueError("tile extent must be a scalar")

try:
super().__init__(
self._ctx, name, dim_datatype, domain_array, tile_size_array
)
except lt.TileDBError as e:
# we set this here because if the super().__init__() constructor above
# fails, we want to check if self._ctx is None in __repr__ so we can
# perform a safe repr
self._ctx = None
raise lt.TileDBError(e) from None
super().__init__(ctx, name, dim_datatype, domain_array, tile_size_array)

if filters is not None:
if isinstance(filters, FilterList):
Expand All @@ -187,10 +176,6 @@ def __init__(
self._filters = FilterList(filters)

def __repr__(self) -> str:
# use safe repr if pybind11 constructor failed
if self._ctx is None:
return object.__repr__(self)

filters_str = ""
if self.filters:
filters_str = ", filters=FilterList(["
Expand Down Expand Up @@ -303,7 +288,7 @@ def filters(self) -> FilterList:
:raises: :py:exc:`tiledb.TileDBError`
"""
return FilterList(_lt_obj=self._filters)
return FilterList(ctx=self._ctx, _lt_obj=self._filters)

@property
def shape(self) -> Tuple["np.generic", "np.generic"]:
Expand Down
14 changes: 6 additions & 8 deletions tiledb/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import TYPE_CHECKING

import tiledb.cc as lt
from .ctx import default_ctx
from .ctx import CtxMixin
from .dimension import Dim
from .util import numpy_dtype

Expand All @@ -12,7 +12,7 @@
from .libtiledb import Ctx


class Domain(lt.Domain):
class Domain(CtxMixin, lt.Domain):
"""
Represents a TileDB domain.
"""
Expand All @@ -32,15 +32,13 @@ def __init__(
:param tiledb.Ctx ctx: A TileDB Context
"""
self._ctx = ctx or default_ctx()

if _capsule is not None:
return super().__init__(self._ctx, _capsule)
return super().__init__(ctx, _capsule)

if _lt_obj is not None:
return super().__init__(_lt_obj)
return super().__init__(ctx, _lt_obj=_lt_obj)

super().__init__(self._ctx)
super().__init__(ctx)

# support passing a list of dims without splatting
if len(dims) == 1 and isinstance(dims[0], list):
Expand Down Expand Up @@ -79,7 +77,7 @@ def clone_dim_with_name(dim, name):
self._add_dim(d)

def __repr__(self):
dims = ",\n ".join([repr(self.dim(i)) for i in range(self.ndim)])
dims = ",\n ".join(repr(self.dim(i)) for i in range(self.ndim))
return "Domain({0!s})".format(dims)

def _repr_html_(self) -> str:
Expand Down
Loading

0 comments on commit af02b8e

Please sign in to comment.