mirror of https://github.com/commaai/tinygrad.git
from tensor cores + lb touchup (#1127)
This commit is contained in:
parent
2f968f8547
commit
793a670187
|
@ -4,7 +4,7 @@ import math
|
|||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv, IMAGE
|
||||
from tinygrad.helpers import getenv, IMAGE, DEBUG
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
|
||||
|
@ -36,6 +36,10 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
|||
except Exception:
|
||||
raise Exception(f"{s} failed shape {x.shape}")
|
||||
|
||||
if DEBUG >= 4:
|
||||
np.set_printoptions(linewidth=200, suppress=True)
|
||||
print(ret.numpy())
|
||||
print(out.detach().numpy())
|
||||
compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol)
|
||||
|
||||
torch_fbp, tinygrad_fbp = np.nan, np.nan
|
||||
|
@ -328,6 +332,10 @@ class TestOps(unittest.TestCase):
|
|||
@unittest.skipIf(IMAGE>0, "no batched matmul on images")
|
||||
def test_matmul_batched_vector(self):
|
||||
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
|
||||
def test_small_gemm(self):
|
||||
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3)
|
||||
def test_small_gemm_eye(self):
|
||||
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
|
||||
def test_gemm(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
|
||||
def test_big_gemm(self):
|
||||
|
|
|
@ -67,6 +67,12 @@ class CheckingShapeTracker:
|
|||
assert self.st.shape == self.shape
|
||||
assert x == y, f"mismatch shapetracker:{x} real:{y}"
|
||||
|
||||
class TestRealIssues(unittest.TestCase):
|
||||
def test_reshape_doesnt_multiview(self):
|
||||
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
|
||||
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
class TestRealDoesntSimplify(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
|
|
|
@ -17,8 +17,10 @@ class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = aut
|
|||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
size: int
|
||||
dtype: DType = dtypes.float32
|
||||
realized: None = None
|
||||
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
||||
|
||||
class Token(NamedTuple):
|
||||
name: str
|
||||
|
@ -218,9 +220,9 @@ class Linearizer:
|
|||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
self.bufs.append(LocalBuffer("temp"))
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
|
||||
|
||||
# print
|
||||
|
@ -401,10 +403,11 @@ class Linearizer:
|
|||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
|
||||
def printbufs(self, prefix=""):
|
||||
for i in range(len(self.sts)):
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s}", self.sts[i].views)
|
||||
print(' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors())))
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
|
||||
print(self.colored_shape())
|
||||
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
|
|
|
@ -164,10 +164,10 @@ class ShapeTracker:
|
|||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable):
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
render_idx, render_valid = idx.render(), valid.render()
|
||||
for i in range(len(self.shape)):
|
||||
if f'idx{i}' in render_valid and not ignore_valid: ret[i] = None
|
||||
elif f'idx{i}' not in render_idx: ret[i] = 0
|
||||
idx_vars, valid_vars = idx.vars(), valid.vars()
|
||||
for i,tidx in enumerate(idxs):
|
||||
if tidx in valid_vars and not ignore_valid: ret[i] = None
|
||||
elif tidx not in idx_vars: ret[i] = 0
|
||||
return tuple(ret)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ class Node:
|
|||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def vars(self): return []
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
|
@ -124,6 +125,7 @@ class Variable(Node):
|
|||
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
def vars(self): return [self]
|
||||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
|
@ -138,6 +140,7 @@ class OpNode(Node):
|
|||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars()
|
||||
@abstractmethod
|
||||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
|
||||
|
@ -174,6 +177,7 @@ class ModNode(OpNode):
|
|||
|
||||
class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]): self.nodes = nodes
|
||||
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
|
||||
|
||||
class SumNode(RedNode):
|
||||
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
|
|
Loading…
Reference in New Issue