From 67e34b356a55dd043ed523457b48cff9e7a41dba Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 8 Jul 2023 16:58:26 -0700 Subject: [PATCH] good stuff from tensor cores branch (#1199) --- extra/gemm/metal_matmul.py | 11 +++++----- test/test_ops.py | 45 ++++++++++++++++++++++++-------------- test/test_speed_v_torch.py | 6 ++--- tinygrad/helpers.py | 1 + tinygrad/shape/symbolic.py | 6 +++-- 5 files changed, 43 insertions(+), 26 deletions(-) diff --git a/extra/gemm/metal_matmul.py b/extra/gemm/metal_matmul.py index 20ab342b..610bdecf 100644 --- a/extra/gemm/metal_matmul.py +++ b/extra/gemm/metal_matmul.py @@ -6,6 +6,7 @@ from tinygrad.helpers import dtypes, getenv from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram N = getenv("N", 2048) +LID = 2 a = RawMetalBuffer(N*N, dtypes.float32) @@ -21,10 +22,10 @@ prog = MetalProgram("test", f""" #include #include // Available from Metal version 2.3 released with OS X 11.0+ using namespace metal; -kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[thread_position_in_grid]], uint3 xid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint sidx [[simdgroup_index_in_threadgroup]]) {{ - a += gid.y * 32 * {N} + gid.z * 32; - data1 += gid.y * 32 * {N}; - data2 += gid.z * 32; +kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{ + a += gid.x * 32 * {N} + (gid.y * {LID} + lid.y) * 32; + data1 += gid.x * 32 * {N}; + data2 += (gid.y * {LID} + lid.y) * 32; simdgroup_float8x8 acc[4][4]; for (uint i = 0; i < 4; i++) {{ @@ -85,7 +86,7 @@ def timeit(fxn): et = fxn() # NOTE: et doesn't contain the launch overhead return time.perf_counter() - st -tm = min([timeit(lambda: prog([32, N//(8*4), N//(8*4)], [32, 1, 4], a, b, c, wait=True)) for _ in range(20)]) +tm = min([timeit(lambda: prog([N//(8*4), N//(8*4*LID), 1], [32, LID, 1], a, b, c, wait=True)) for _ in range(20)]) na = a.toCPU().reshape(N,N) comp = nb@nc if N <= 32: diff --git a/test/test_ops.py b/test/test_ops.py index 4986b660..03262e66 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -709,7 +709,7 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(), lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5) - @unittest.skipIf(Device.DEFAULT == "METAL", "weird, broken in METAL CI") + @unittest.skipIf(Device.DEFAULT == "METAL" and getenv("CI", "") != "", "broken in METAL CI") def test_output_padded_conv_transpose2d(self): for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]: helper_test_op([(2,4,6,5), (4,4,3,3),(4,)], @@ -878,25 +878,38 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(), lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4) - def test_padded_conv2d(self): - bs = 4 - cin = 3 - H,W = 3,3 - for p in [2, (2,1), (2,2)]: - with self.subTest(padding := p): - helper_test_op([(bs,cin,11,28), (4,cin,H,W)], - lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), - lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) - - def test_padded_conv2d_bs1(self): - bs = 1 - cin = 3 - H,W = 3,3 - padding = 1 + @unittest.skipIf(Device.DEFAULT == "METAL" and getenv("CI", "") != "", "broken in METAL CI") + def test_padded_conv2d_p21(self): + bs,cin,H,W,padding = 4, 3, 3, 3, (2,1) helper_test_op([(bs,cin,11,28), (4,cin,H,W)], lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) + @unittest.skipIf(Device.DEFAULT == "METAL" and getenv("CI", "") != "", "broken in METAL CI") + def test_padded_conv2d_p22(self): + bs,cin,H,W,padding = 4, 3, 3, 3, (2,2) + helper_test_op([(bs,cin,11,28), (4,cin,H,W)], + lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), + lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) + + def test_padded_conv2d_1x1(self): + bs,cin,H,W,padding = 4, 3, 1, 1, 2 + helper_test_op([(bs,cin,11,28), (4,cin,H,W)], + lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), + lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) + + def test_padded_conv2d_bs1(self): + bs,cin,H,W,padding = 1, 3, 3, 3, 1 + helper_test_op([(bs,cin,11,28), (4,cin,H,W)], + lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(), + lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4) + + def test_padding_add(self): + helper_test_op([(64,64), (60,60)], + lambda x,w: x+torch.nn.functional.pad(w, (2,2,2,2)), + lambda x,w: x+w.pad2d((2,2,2,2)), + ) + def test_dilated_conv2d(self): bs = 4 cin = 3 diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 935174ea..3af3a3c6 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -33,7 +33,7 @@ def colorize_float(x): ret = f"{x:7.2f}x" if x < 0.75: return colored(ret, 'green') - elif x > 1.33: + elif x > 1.15: return colored(ret, 'red') else: return colored(ret, 'yellow') @@ -118,10 +118,10 @@ class TestBigSpeed(unittest.TestCase): return super().setUp() def test_add(self): def f(a, b): return a+b - helper_test_generic_square('add', 16384, f, f) + helper_test_generic_square('add', 8192, f, f) def test_exp(self): def f(a, b): return a.exp() - helper_test_generic_square('exp', 16384, f, f, onearg=True) + helper_test_generic_square('exp', 8192, f, f, onearg=True) def test_gemm_2048(self): def f(a, b): return a @ b helper_test_generic_square('gemm', 2048, f, f) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index f742cb33..d7636437 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -103,6 +103,7 @@ class dtypes: # NOTE: these are internal dtypes, should probably check for that _half4: Final[DType] = DType(0, 2*4, "half4", None, 4) + _float2: Final[DType] = DType(4, 4*2, "float2", None, 2) _float4: Final[DType] = DType(4, 4*4, "float4", None, 4) # HACK: staticmethods are not callable in 3.8 so we have to compare the class diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index d5ce2dff..0c2006b8 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -12,10 +12,12 @@ class Node: b: int min: int max: int - def render(self, ops=None, ctx=None) -> str: + def render(self, ops=None, ctx=None, strip_parens=False) -> str: if ops is None: ops = render_python assert self.__class__ in (Variable, NumNode) or self.min != self.max - return ops[type(self)](self, ops, ctx) + ret = ops[type(self)](self, ops, ctx) + if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1] + return ret def vars(self): return [] @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG")