Skip to content

Commit

Permalink
Fix E721: do not compare types, for exact checks use is / is not
Browse files Browse the repository at this point in the history
  • Loading branch information
maresb authored and ricardoV94 committed Jun 6, 2024
1 parent 4b6a444 commit 1935809
Show file tree
Hide file tree
Showing 27 changed files with 41 additions and 41 deletions.
4 changes: 2 additions & 2 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,7 @@ def _lessbroken_deepcopy(a):
else:
rval = copy.deepcopy(a)

assert type(rval) == type(a), (type(rval), type(a))
assert type(rval) is type(a), (type(rval), type(a))

if isinstance(rval, np.ndarray):
assert rval.dtype == a.dtype
Expand Down Expand Up @@ -1154,7 +1154,7 @@ def __str__(self):
return str(self.__dict__)

def __eq__(self, other):
rval = type(self) == type(other)
rval = type(self) is type(other)
if rval:
# nodes are not compared because this comparison is
# supposed to be true for corresponding events that happen
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(self, fn, itypes, otypes, infer_shape):
self.infer_shape = self._infer_shape

def __eq__(self, other):
return type(self) == type(other) and self.__fn == other.__fn
return type(self) is type(other) and self.__fn == other.__fn

def __hash__(self):
return hash(type(self)) ^ hash(self.__fn)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def __eq__(self, other):
return True

return (
type(self) == type(other)
type(self) is type(other)
and self.id == other.id
and self.type == other.type
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/null_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def values_eq(self, a, b, force_same_dtype=True):
raise ValueError("NullType has no values to compare")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
4 changes: 2 additions & 2 deletions pytensor/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def __new__(cls, constraint, token=None, prefix=""):
return obj

def __eq__(self, other):
if type(self) == type(other):
return self.token == other.token and self.constraint == other.constraint
if type(self) is type(other):
return self.token is other.token and self.constraint == other.constraint
return NotImplemented

def __hash__(self):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def __hash__(self):
if "__eq__" not in dct:

def __eq__(self, other):
return type(self) == type(other) and tuple(
return type(self) is type(other) and tuple(
getattr(self, a) for a in props
) == tuple(getattr(other, a) for a in props)

Expand Down
2 changes: 1 addition & 1 deletion pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, n_outs, as_view=False, name=None):
self.name = name

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False
if self.as_view != other.as_view:
return False
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/c/params_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __hash__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.__params_type__ == other.__params_type__
and all(
# NB: Params object should have been already filtered.
Expand Down Expand Up @@ -435,7 +435,7 @@ def __repr__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.fields == other.fields
and self.types == other.types
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/c/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def __hash__(self):

def __eq__(self, other):
return (
type(self) == type(other)
type(self) is type(other)
and self.ctype == other.ctype
and len(self) == len(other)
and len(self.aliases) == len(other.aliases)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/raise_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

class ExceptionType(Generic):
def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -51,7 +51,7 @@ def __str__(self):
return f"CheckAndRaise{{{self.exc_type}({self.msg})}}"

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False

if self.msg == other.msg and self.exc_type == other.exc_type:
Expand Down
6 changes: 3 additions & 3 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ def __call__(self, *types):
return [rval]

def __eq__(self, other):
return type(self) == type(other) and self.tbl == other.tbl
return type(self) is type(other) and self.tbl == other.tbl

def __hash__(self):
return hash(type(self)) # ignore hash of table
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def L_op(self, inputs, outputs, output_gradients):
return self.grad(inputs, output_gradients)

def __eq__(self, other):
test = type(self) == type(other) and getattr(
test = type(self) is type(other) and getattr(
self, "output_types_preference", None
) == getattr(other, "output_types_preference", None)
return test
Expand Down Expand Up @@ -4133,7 +4133,7 @@ def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
type(self) is not type(other)
or self.nin != other.nin
or self.nout != other.nout
):
Expand Down
10 changes: 5 additions & 5 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -679,7 +679,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -732,7 +732,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down Expand Up @@ -1083,7 +1083,7 @@ def c_code(self, node, name, inp, out, sub):
raise NotImplementedError("only floatingpoint is implemented")

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,7 @@ def is_cpu_vector(s):
return apply_node

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False

if self.info != other.info:
Expand Down
2 changes: 1 addition & 1 deletion pytensor/sparse/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def __eq__(self, other):
return (
a == x
and (b.dtype == y.dtype)
and (type(b) == type(y))
and (type(b) is type(y))
and (b.shape == y.shape)
and (abs(b - y).sum() < 1e-6 * b.nnz)
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/random/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _eq(sa, sb):
return _eq(sa, sb)

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1742,7 +1742,7 @@ def local_reduce_broadcastable(fgraph, node):
ii += 1
new_reduced = reduced.dimshuffle(*pattern)
if new_axis:
if type(node.op) == CAReduce:
if type(node.op) is CAReduce:
# This case handles `CAReduce` instances
# (e.g. generated by `scalar_elemwise`), and not the
# scalar `Op`-specific subclasses
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def values_eq_approx(
return values_eq_approx(a, b, allow_remove_inf, allow_remove_nan, rtol, atol)

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return NotImplemented

return other.dtype == self.dtype and other.shape == self.shape
Expand Down Expand Up @@ -624,7 +624,7 @@ def c_code_cache_version(self):

class DenseTypeMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorType or isinstance(o, DenseTypeMeta):
if type(o) is TensorType or isinstance(o, DenseTypeMeta):
return True
return False

Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/type_other.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __str__(self):
return "slice"

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ class TensorConstantSignature(tuple):
"""

def __eq__(self, other):
if type(self) != type(other):
if type(self) is not type(other):
return False
try:
(t0, d0), (t1, d1) = self, other
Expand Down Expand Up @@ -1105,7 +1105,7 @@ def __deepcopy__(self, memo):

class DenseVariableMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorVariable or isinstance(o, DenseVariableMeta):
if type(o) is TensorVariable or isinstance(o, DenseVariableMeta):
return True
return False

Expand All @@ -1120,7 +1120,7 @@ class DenseTensorVariable(TensorType, metaclass=DenseVariableMeta):

class DenseConstantMeta(MetaType):
def __instancecheck__(self, o):
if type(o) == TensorConstant or isinstance(o, DenseConstantMeta):
if type(o) is TensorConstant or isinstance(o, DenseConstantMeta):
return True
return False

Expand Down
2 changes: 1 addition & 1 deletion pytensor/typed_list/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __eq__(self, other):
Two lists are equal if they contain the same type.
"""
return type(self) == type(other) and self.ttype == other.ttype
return type(self) is type(other) and self.ttype == other.ttype

def __hash__(self):
return hash((type(self), self.ttype))
Expand Down
2 changes: 1 addition & 1 deletion tests/graph/rewriting/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def perform(self, node, inputs, outputs):

class CustomOpNoProps(CustomOpNoPropsNoEq):
def __eq__(self, other):
return type(self) == type(other) and self.a == other.a
return type(self) is type(other) and self.a == other.a

def __hash__(self):
return hash((type(self), self.a))
Expand Down
4 changes: 2 additions & 2 deletions tests/graph/test_fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def test_pickle(self):
s = pickle.dumps(func)
new_func = pickle.loads(s)

assert all(type(a) == type(b) for a, b in zip(func.inputs, new_func.inputs))
assert all(type(a) == type(b) for a, b in zip(func.outputs, new_func.outputs))
assert all(type(a) is type(b) for a, b in zip(func.inputs, new_func.inputs))
assert all(type(a) is type(b) for a, b in zip(func.outputs, new_func.outputs))
assert all(
type(a.op) is type(b.op) # noqa: E721
for a, b in zip(func.apply_nodes, new_func.apply_nodes)
Expand Down
2 changes: 1 addition & 1 deletion tests/graph/test_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, thingy):
self.thingy = thingy

def __eq__(self, other):
return type(other) == type(self) and other.thingy == self.thingy
return type(other) is type(self) and other.thingy == self.thingy

def __str__(self):
return str(self.thingy)
Expand Down
2 changes: 1 addition & 1 deletion tests/link/c/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def c_code_cache_version(self):
return (1,)

def __eq__(self, other):
return type(self) == type(other)
return type(self) is type(other)

def __hash__(self):
return hash(type(self))
Expand Down
2 changes: 1 addition & 1 deletion tests/sparse/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def __init__(self, structured):
self.structured = structured

def __eq__(self, other):
return (type(self) == type(other)) and self.structured == other.structured
return (type(self) is type(other)) and self.structured == other.structured

def __hash__(self):
return hash(type(self)) ^ hash(self.structured)
Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3163,7 +3163,7 @@ def test_stack():
sx, sy = dscalar(), dscalar()

rval = inplace_func([sx, sy], stack([sx, sy]))(-4.0, -2.0)
assert type(rval) == np.ndarray
assert type(rval) is np.ndarray
assert [-4, -2] == list(rval)


Expand Down
2 changes: 1 addition & 1 deletion tests/tensor/test_subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def test_ok_list(self):
assert np.allclose(val, good), (val, good)

# Test reuse of output memory
if type(AdvancedSubtensor1) == AdvancedSubtensor1:
if type(AdvancedSubtensor1) is AdvancedSubtensor1:
op = AdvancedSubtensor1()
# When idx is a TensorConstant.
if hasattr(idx, "data"):
Expand Down

0 comments on commit 1935809

Please sign in to comment.