Skip to content

Commit

Permalink
narrow return type of bool, int, float on UOp [pr] (#7246)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz authored Oct 24, 2024
1 parent 9f370cc commit 451c043
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, TypeVar, DefaultDict
from typing import Any, List, Optional, Set, Union, Tuple, Dict, Callable, cast, TYPE_CHECKING, Type, TypeVar, DefaultDict
import sys, time, functools, itertools, math, operator, hashlib, os, types, pickle
from enum import auto, IntEnum, Enum
from dataclasses import dataclass, field
Expand Down Expand Up @@ -234,12 +234,11 @@ def simplify(self):
with Context(TRACK_MATCH_STATS=0):
return graph_rewrite(self, symbolic)
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is UOps.CONST else ret
def _eval(self, dtype, expected_type) -> ConstType:
def _eval(self, dtype, expected_type:Type[T]) -> T:
assert self.dtype in dtype, f"eval with wrong dtype {self}"
simple_self = self.simplify()
vmin, vmax = simple_self._min_max
vmin, vmax = (simple_self:=self.simplify())._min_max
if vmin != vmax: raise ValueError(f"eval failed to be a single number, range is {vmin} to {vmax} in {simple_self.render()}")
assert type(vmin) is expected_type, f"vmin is wrong dtype {vmin} != {expected_type}"
assert isinstance(vmin, expected_type), f"vmin is wrong dtype {type(vmin)} != {expected_type}"
return vmin
def __bool__(self): return self._eval((dtypes.bool,), bool)
def __int__(self): return self._eval(dtypes.ints, int)
Expand Down

0 comments on commit 451c043

Please sign in to comment.