mirror of https://github.com/commaai/tinygrad.git
minor changes from prerender (#1734)
This commit is contained in:
parent
f964b9e5ee
commit
458eb89463
|
@ -269,12 +269,13 @@ class TestOps(unittest.TestCase):
|
|||
helper_test_op([(45,65)], lambda x: float("inf")/x, lambda x: float("inf")/x)
|
||||
helper_test_op([(45,65)], lambda x: (-float("inf"))/x, lambda x: (-float("inf"))/x)
|
||||
helper_test_op([(45,65)], lambda x: float("nan")/x, lambda x: float("nan")/x)
|
||||
def test_pow_full(self):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
|
||||
def test_pow(self):
|
||||
# TODO: why is a=0 for these tests?
|
||||
helper_test_op([(45,65)], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**3, lambda x: Tensor.pow(x,3), a=0)
|
||||
helper_test_op([(45,65)], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x**y, Tensor.pow, a=0)
|
||||
helper_test_op([()], lambda x: x**2, lambda x: Tensor.pow(x,2), a=0)
|
||||
helper_test_op([()], lambda x: x**-2, lambda x: Tensor.pow(x,-2), a=0)
|
||||
# Regression tests for https://github.com/tinygrad/tinygrad/issues/1151
|
||||
|
|
|
@ -83,6 +83,10 @@ class ImageDType(DType):
|
|||
super().__init__()
|
||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||
|
||||
class PtrDType(DType):
|
||||
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
|
||||
def __repr__(self): return f"ptr.{super().__repr__()}"
|
||||
|
||||
class dtypes:
|
||||
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
||||
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
||||
|
|
|
@ -170,7 +170,7 @@ class ShapeTracker:
|
|||
return tuple(ret)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def _expr_idx(self, idx, valid):
|
||||
def _expr_idx(self, idx, valid) -> Tuple[Node, Node]:
|
||||
for v in reversed(self.views[0:-1]):
|
||||
valid = v.expr_node_mask(idx, valid)
|
||||
idx = v.expr_node(idx)
|
||||
|
|
|
@ -159,6 +159,7 @@ class Variable(Node):
|
|||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
assert isinstance(num, int), f"{num} is not an int"
|
||||
self.b:int = num
|
||||
self.min, self.max = num, num
|
||||
def __int__(self): return self.b
|
||||
|
|
|
@ -587,6 +587,7 @@ class Tensor:
|
|||
if x.__class__ is not Tensor and not reverse:
|
||||
# simple pow identities
|
||||
if x < 0: return (1.0/self).pow(-x)
|
||||
if x == 3.0: return self*self*self
|
||||
if x == 2.0: return self*self
|
||||
if x == 1.0: return self
|
||||
if x == 0.5: return self.sqrt()
|
||||
|
|
Loading…
Reference in New Issue