minor changes from prerender (#1734)

This commit is contained in:
George Hotz 2023-09-01 10:04:47 -07:00 committed by GitHub
parent f964b9e5ee
commit 458eb89463
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 9 additions and 2 deletions

View File

@ -269,12 +269,13 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: float("inf")/x, lambda x: float("inf")/x)
helper_test_op([(45,65)], lambda x: (-float("inf"))/x, lambda x: (-float("inf"))/x)
helper_test_op([(45,65)], lambda x: float("nan")/x, lambda x: float("nan")/x)
def test_pow_full(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
def test_pow(self):
# TODO: why is a=0 for these tests?
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1151

View File

@ -83,6 +83,10 @@ class ImageDType(DType):
super().__init__()
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
class PtrDType(DType):
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
def __repr__(self): return f"ptr.{super().__repr__()}"
class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)

View File

@ -170,7 +170,7 @@ class ShapeTracker:
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def _expr_idx(self, idx, valid):
def _expr_idx(self, idx, valid) -> Tuple[Node, Node]:
for v in reversed(self.views[0:-1]):
valid = v.expr_node_mask(idx, valid)
idx = v.expr_node(idx)

View File

@ -159,6 +159,7 @@ class Variable(Node):
class NumNode(Node):
def __init__(self, num:int):
assert isinstance(num, int), f"{num} is not an int"
self.b:int = num
self.min, self.max = num, num
def __int__(self): return self.b

View File

@ -587,6 +587,7 @@ class Tensor:
if x.__class__ is not Tensor and not reverse:
# simple pow identities
if x < 0: return (1.0/self).pow(-x)
if x == 3.0: return self*self*self
if x == 2.0: return self*self
if x == 1.0: return self
if x == 0.5: return self.sqrt()