from tensor cores + lb touchup (#1127)

This commit is contained in:
George Hotz 2023-07-04 15:45:20 -07:00 committed by GitHub
parent 2f968f8547
commit 793a670187
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 29 additions and 8 deletions

View File

@ -4,7 +4,7 @@ import math
import numpy as np
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, IMAGE
from tinygrad.helpers import getenv, IMAGE, DEBUG
from tinygrad.lazy import Device
FORWARD_ONLY = getenv("FORWARD_ONLY", 0)
@ -36,6 +36,10 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
except Exception:
raise Exception(f"{s} failed shape {x.shape}")
if DEBUG >= 4:
np.set_printoptions(linewidth=200, suppress=True)
print(ret.numpy())
print(out.detach().numpy())
compare("forward pass", ret.numpy(), out.detach().numpy(), atol=atol, rtol=rtol)
torch_fbp, tinygrad_fbp = np.nan, np.nan
@ -328,6 +332,10 @@ class TestOps(unittest.TestCase):
@unittest.skipIf(IMAGE>0, "no batched matmul on images")
def test_matmul_batched_vector(self):
helper_test_op([(4,3), (1,3,3,5)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4)
def test_small_gemm(self):
helper_test_op([(8,8), (8,8)], lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3)
def test_small_gemm_eye(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, atol=1e-3, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
def test_gemm(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3)
def test_big_gemm(self):

View File

@ -67,6 +67,12 @@ class CheckingShapeTracker:
assert self.st.shape == self.shape
assert x == y, f"mismatch shapetracker:{x} real:{y}"
class TestRealIssues(unittest.TestCase):
def test_reshape_doesnt_multiview(self):
self.st = ShapeTracker((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), views=[View((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None)])
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
assert len(self.st.views) == 1
class TestRealDoesntSimplify(unittest.TestCase):
def tearDown(self):
st = self.st.real_strides()

View File

@ -17,8 +17,10 @@ class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = aut
class LocalBuffer(NamedTuple):
name: str
size: int
dtype: DType = dtypes.float32
realized: None = None
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
class Token(NamedTuple):
name: str
@ -218,9 +220,9 @@ class Linearizer:
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
self.bufs.append(LocalBuffer("temp"))
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
# print
@ -401,10 +403,11 @@ class Linearizer:
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
def printbufs(self, prefix=""):
for i in range(len(self.sts)):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s}", self.sts[i].views)
print(' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors())))
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
print(self.colored_shape())
# ******************** base simplifiers ********************

View File

@ -164,10 +164,10 @@ class ShapeTracker:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable):
ret[idxs.index(this_dim)] = 1
render_idx, render_valid = idx.render(), valid.render()
for i in range(len(self.shape)):
if f'idx{i}' in render_valid and not ignore_valid: ret[i] = None
elif f'idx{i}' not in render_idx: ret[i] = 0
idx_vars, valid_vars = idx.vars(), valid.vars()
for i,tidx in enumerate(idxs):
if tidx in valid_vars and not ignore_valid: ret[i] = None
elif tidx not in idx_vars: ret[i] = 0
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]

View File

@ -16,6 +16,7 @@ class Node:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
def vars(self): return []
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
@ -124,6 +125,7 @@ class Variable(Node):
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
def vars(self): return [self]
class NumNode(Node):
def __init__(self, num:int):
@ -138,6 +140,7 @@ class OpNode(Node):
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars()
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
@ -174,6 +177,7 @@ class ModNode(OpNode):
class RedNode(Node):
def __init__(self, nodes:List[Node]): self.nodes = nodes
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
class SumNode(RedNode):
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum