Skip to content

Commit

Permalink
Add support for matrix multiplication. Fixes #66.
Browse files Browse the repository at this point in the history
  • Loading branch information
ionelmc committed Jan 4, 2023
1 parent 30e8c5a commit 8f6f9d3
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 2 deletions.
34 changes: 34 additions & 0 deletions src/lazy_object_proxy/cext.c
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,16 @@ static PyObject *Proxy_multiply(PyObject *o1, PyObject *o2)

/* ------------------------------------------------------------------------- */

static PyObject *Proxy_matrix_multiply(PyObject *o1, PyObject *o2)
{
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o1);
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o2);

return PyNumber_MatrixMultiply(o1, o2);
}

/* ------------------------------------------------------------------------- */

static PyObject *Proxy_remainder(PyObject *o1, PyObject *o2)
{
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o1);
Expand Down Expand Up @@ -458,6 +468,28 @@ static PyObject *Proxy_inplace_multiply(

/* ------------------------------------------------------------------------- */

static PyObject *Proxy_inplace_matrix_multiply(
ProxyObject *self, PyObject *other)
{
PyObject *object = NULL;

Proxy__ENSURE_WRAPPED_OR_RETURN_NULL(self);
Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(other);

object = PyNumber_InPlaceMatrixMultiply(self->wrapped, other);

if (!object)
return NULL;

Py_DECREF(self->wrapped);
self->wrapped = object;

Py_INCREF(self);
return (PyObject *)self;
}

/* ------------------------------------------------------------------------- */

static PyObject *Proxy_inplace_remainder(
ProxyObject *self, PyObject *other)
{
Expand Down Expand Up @@ -1239,6 +1271,8 @@ static PyNumberMethods Proxy_as_number = {
(binaryfunc)Proxy_inplace_floor_divide, /*nb_inplace_floor_divide*/
(binaryfunc)Proxy_inplace_true_divide, /*nb_inplace_true_divide*/
(unaryfunc)Proxy_index, /*nb_index*/
(binaryfunc)Proxy_matrix_multiply, /*nb_matrix_multiply*/
(binaryfunc)Proxy_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/
};

static PySequenceMethods Proxy_as_sequence = {
Expand Down
5 changes: 5 additions & 0 deletions src/lazy_object_proxy/simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __delattr__(self, name):
__add__ = make_proxy_method(operator.add)
__sub__ = make_proxy_method(operator.sub)
__mul__ = make_proxy_method(operator.mul)
__matmul__ = make_proxy_method(operator.matmul)
__truediv__ = make_proxy_method(operator.truediv)
__floordiv__ = make_proxy_method(operator.floordiv)
__mod__ = make_proxy_method(operator.mod)
Expand All @@ -161,6 +162,9 @@ def __rsub__(self, other):
def __rmul__(self, other):
return other * self.__wrapped__

def __rmatmul__(self, other):
return other @ self.__wrapped__

def __rdiv__(self, other):
return operator.div(other, self.__wrapped__)

Expand Down Expand Up @@ -197,6 +201,7 @@ def __ror__(self, other):
__iadd__ = make_proxy_method(operator.iadd)
__isub__ = make_proxy_method(operator.isub)
__imul__ = make_proxy_method(operator.imul)
__imatmul__ = make_proxy_method(operator.imatmul)
__itruediv__ = make_proxy_method(operator.itruediv)
__ifloordiv__ = make_proxy_method(operator.ifloordiv)
__imod__ = make_proxy_method(operator.imod)
Expand Down
10 changes: 8 additions & 2 deletions src/lazy_object_proxy/slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def __sub__(self, other):
def __mul__(self, other):
return self.__wrapped__ * other

def __matmul__(self, other):
return self.__wrapped__ @ other

def __truediv__(self, other):
return operator.truediv(self.__wrapped__, other)

Expand Down Expand Up @@ -265,6 +268,9 @@ def __rsub__(self, other):
def __rmul__(self, other):
return other * self.__wrapped__

def __rmatmul__(self, other):
return other @ self.__wrapped__

def __rdiv__(self, other):
return operator.div(other, self.__wrapped__)

Expand Down Expand Up @@ -310,8 +316,8 @@ def __imul__(self, other):
self.__wrapped__ *= other
return self

def __idiv__(self, other):
self.__wrapped__ = operator.idiv(self.__wrapped__, other)
def __imatmul__(self, other):
self.__wrapped__ @= other
return self

def __itruediv__(self, other):
Expand Down
41 changes: 41 additions & 0 deletions tests/test_lazy_object_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,26 @@ def test_mul(lop):
assert two * 3 == 2 * 3


def test_matmul(lop):
import numpy

one = numpy.array((1, 2, 3))
two = numpy.array((2, 3, 4))
assert one @ two == 20

one = lop.Proxy(lambda: numpy.array((1, 2, 3)))
two = lop.Proxy(lambda: numpy.array((2, 3, 4)))
assert one @ two == 20

one = lop.Proxy(lambda: numpy.array((1, 2, 3)))
two = numpy.array((2, 3, 4))
assert one @ two == 20

one = numpy.array((1, 2, 3))
two = lop.Proxy(lambda: numpy.array((2, 3, 4)))
assert one @ two == 20


def test_div(lop):
# On Python 2 this will pick up div and on Python
# 3 it will pick up truediv.
Expand Down Expand Up @@ -1067,6 +1087,27 @@ def test_imul(lop):
assert type(value) == lop.Proxy


def test_imatmul(lop):
class InplaceMatmul:
value = None

def __imatmul__(self, other):
self.value = other
return self

value = InplaceMatmul()
assert value.value is None
value @= 123
assert value.value == 123

value = lop.Proxy(InplaceMatmul)
value @= 234
assert value.value == 234

if lop.kind != 'simple':
assert type(value) == lop.Proxy


def test_idiv(lop):
# On Python 2 this will pick up div and on Python
# 3 it will pick up truediv.
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ deps =
pytest
pytest-benchmark
Django
numpy
objproxies==0.9.4
hunter
cover: pytest-cov
Expand Down

0 comments on commit 8f6f9d3

Please sign in to comment.