diff --git a/src/lazy_object_proxy/cext.c b/src/lazy_object_proxy/cext.c index 54f69e1..17f1098 100644 --- a/src/lazy_object_proxy/cext.c +++ b/src/lazy_object_proxy/cext.c @@ -248,6 +248,16 @@ static PyObject *Proxy_multiply(PyObject *o1, PyObject *o2) /* ------------------------------------------------------------------------- */ +static PyObject *Proxy_matrix_multiply(PyObject *o1, PyObject *o2) +{ + Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o1); + Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o2); + + return PyNumber_MatrixMultiply(o1, o2); +} + +/* ------------------------------------------------------------------------- */ + static PyObject *Proxy_remainder(PyObject *o1, PyObject *o2) { Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(o1); @@ -458,6 +468,28 @@ static PyObject *Proxy_inplace_multiply( /* ------------------------------------------------------------------------- */ +static PyObject *Proxy_inplace_matrix_multiply( + ProxyObject *self, PyObject *other) +{ + PyObject *object = NULL; + + Proxy__ENSURE_WRAPPED_OR_RETURN_NULL(self); + Proxy__WRAPPED_REPLACE_OR_RETURN_NULL(other); + + object = PyNumber_InPlaceMatrixMultiply(self->wrapped, other); + + if (!object) + return NULL; + + Py_DECREF(self->wrapped); + self->wrapped = object; + + Py_INCREF(self); + return (PyObject *)self; +} + +/* ------------------------------------------------------------------------- */ + static PyObject *Proxy_inplace_remainder( ProxyObject *self, PyObject *other) { @@ -1239,6 +1271,8 @@ static PyNumberMethods Proxy_as_number = { (binaryfunc)Proxy_inplace_floor_divide, /*nb_inplace_floor_divide*/ (binaryfunc)Proxy_inplace_true_divide, /*nb_inplace_true_divide*/ (unaryfunc)Proxy_index, /*nb_index*/ + (binaryfunc)Proxy_matrix_multiply, /*nb_matrix_multiply*/ + (binaryfunc)Proxy_inplace_matrix_multiply, /*nb_inplace_matrix_multiply*/ }; static PySequenceMethods Proxy_as_sequence = { diff --git a/src/lazy_object_proxy/simple.py b/src/lazy_object_proxy/simple.py index cfd175d..283894b 100644 --- a/src/lazy_object_proxy/simple.py +++ b/src/lazy_object_proxy/simple.py @@ -141,6 +141,7 @@ def __delattr__(self, name): __add__ = make_proxy_method(operator.add) __sub__ = make_proxy_method(operator.sub) __mul__ = make_proxy_method(operator.mul) + __matmul__ = make_proxy_method(operator.matmul) __truediv__ = make_proxy_method(operator.truediv) __floordiv__ = make_proxy_method(operator.floordiv) __mod__ = make_proxy_method(operator.mod) @@ -161,6 +162,9 @@ def __rsub__(self, other): def __rmul__(self, other): return other * self.__wrapped__ + def __rmatmul__(self, other): + return other @ self.__wrapped__ + def __rdiv__(self, other): return operator.div(other, self.__wrapped__) @@ -197,6 +201,7 @@ def __ror__(self, other): __iadd__ = make_proxy_method(operator.iadd) __isub__ = make_proxy_method(operator.isub) __imul__ = make_proxy_method(operator.imul) + __imatmul__ = make_proxy_method(operator.imatmul) __itruediv__ = make_proxy_method(operator.itruediv) __ifloordiv__ = make_proxy_method(operator.ifloordiv) __imod__ = make_proxy_method(operator.imod) diff --git a/src/lazy_object_proxy/slots.py b/src/lazy_object_proxy/slots.py index 1e5c841..4b62859 100644 --- a/src/lazy_object_proxy/slots.py +++ b/src/lazy_object_proxy/slots.py @@ -226,6 +226,9 @@ def __sub__(self, other): def __mul__(self, other): return self.__wrapped__ * other + def __matmul__(self, other): + return self.__wrapped__ @ other + def __truediv__(self, other): return operator.truediv(self.__wrapped__, other) @@ -265,6 +268,9 @@ def __rsub__(self, other): def __rmul__(self, other): return other * self.__wrapped__ + def __rmatmul__(self, other): + return other @ self.__wrapped__ + def __rdiv__(self, other): return operator.div(other, self.__wrapped__) @@ -310,8 +316,8 @@ def __imul__(self, other): self.__wrapped__ *= other return self - def __idiv__(self, other): - self.__wrapped__ = operator.idiv(self.__wrapped__, other) + def __imatmul__(self, other): + self.__wrapped__ @= other return self def __itruediv__(self, other): diff --git a/tests/test_lazy_object_proxy.py b/tests/test_lazy_object_proxy.py index 1257726..0e6dad3 100644 --- a/tests/test_lazy_object_proxy.py +++ b/tests/test_lazy_object_proxy.py @@ -912,6 +912,26 @@ def test_mul(lop): assert two * 3 == 2 * 3 +def test_matmul(lop): + import numpy + + one = numpy.array((1, 2, 3)) + two = numpy.array((2, 3, 4)) + assert one @ two == 20 + + one = lop.Proxy(lambda: numpy.array((1, 2, 3))) + two = lop.Proxy(lambda: numpy.array((2, 3, 4))) + assert one @ two == 20 + + one = lop.Proxy(lambda: numpy.array((1, 2, 3))) + two = numpy.array((2, 3, 4)) + assert one @ two == 20 + + one = numpy.array((1, 2, 3)) + two = lop.Proxy(lambda: numpy.array((2, 3, 4))) + assert one @ two == 20 + + def test_div(lop): # On Python 2 this will pick up div and on Python # 3 it will pick up truediv. @@ -1067,6 +1087,27 @@ def test_imul(lop): assert type(value) == lop.Proxy +def test_imatmul(lop): + class InplaceMatmul: + value = None + + def __imatmul__(self, other): + self.value = other + return self + + value = InplaceMatmul() + assert value.value is None + value @= 123 + assert value.value == 123 + + value = lop.Proxy(InplaceMatmul) + value @= 234 + assert value.value == 234 + + if lop.kind != 'simple': + assert type(value) == lop.Proxy + + def test_idiv(lop): # On Python 2 this will pick up div and on Python # 3 it will pick up truediv. diff --git a/tox.ini b/tox.ini index 1d82d44..7ff7b5b 100644 --- a/tox.ini +++ b/tox.ini @@ -42,6 +42,7 @@ deps = pytest pytest-benchmark Django + numpy objproxies==0.9.4 hunter cover: pytest-cov