mirror of https://github.com/commaai/tinygrad.git
good stuff from tensor cores branch (#1199)
This commit is contained in:
parent
7151382364
commit
67e34b356a
|
@ -6,6 +6,7 @@ from tinygrad.helpers import dtypes, getenv
|
||||||
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
|
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
|
||||||
|
|
||||||
N = getenv("N", 2048)
|
N = getenv("N", 2048)
|
||||||
|
LID = 2
|
||||||
|
|
||||||
a = RawMetalBuffer(N*N, dtypes.float32)
|
a = RawMetalBuffer(N*N, dtypes.float32)
|
||||||
|
|
||||||
|
@ -21,10 +22,10 @@ prog = MetalProgram("test", f"""
|
||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[thread_position_in_grid]], uint3 xid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint sidx [[simdgroup_index_in_threadgroup]]) {{
|
kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
|
||||||
a += gid.y * 32 * {N} + gid.z * 32;
|
a += gid.x * 32 * {N} + (gid.y * {LID} + lid.y) * 32;
|
||||||
data1 += gid.y * 32 * {N};
|
data1 += gid.x * 32 * {N};
|
||||||
data2 += gid.z * 32;
|
data2 += (gid.y * {LID} + lid.y) * 32;
|
||||||
|
|
||||||
simdgroup_float8x8 acc[4][4];
|
simdgroup_float8x8 acc[4][4];
|
||||||
for (uint i = 0; i < 4; i++) {{
|
for (uint i = 0; i < 4; i++) {{
|
||||||
|
@ -85,7 +86,7 @@ def timeit(fxn):
|
||||||
et = fxn()
|
et = fxn()
|
||||||
# NOTE: et doesn't contain the launch overhead
|
# NOTE: et doesn't contain the launch overhead
|
||||||
return time.perf_counter() - st
|
return time.perf_counter() - st
|
||||||
tm = min([timeit(lambda: prog([32, N//(8*4), N//(8*4)], [32, 1, 4], a, b, c, wait=True)) for _ in range(20)])
|
tm = min([timeit(lambda: prog([N//(8*4), N//(8*4*LID), 1], [32, LID, 1], a, b, c, wait=True)) for _ in range(20)])
|
||||||
na = a.toCPU().reshape(N,N)
|
na = a.toCPU().reshape(N,N)
|
||||||
comp = nb@nc
|
comp = nb@nc
|
||||||
if N <= 32:
|
if N <= 32:
|
||||||
|
|
|
@ -709,7 +709,7 @@ class TestOps(unittest.TestCase):
|
||||||
lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(),
|
lambda x,w: torch.nn.functional.conv_transpose2d(x,w, stride=stride).relu(),
|
||||||
lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5)
|
lambda x,w: Tensor.conv_transpose2d(x,w,stride=stride).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||||
|
|
||||||
@unittest.skipIf(Device.DEFAULT == "METAL", "weird, broken in METAL CI")
|
@unittest.skipIf(Device.DEFAULT == "METAL" and getenv("CI", "") != "", "broken in METAL CI")
|
||||||
def test_output_padded_conv_transpose2d(self):
|
def test_output_padded_conv_transpose2d(self):
|
||||||
for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]:
|
for output_padding, stride in [((1,1), (2,3)), ((2,1), (3,2))]:
|
||||||
helper_test_op([(2,4,6,5), (4,4,3,3),(4,)],
|
helper_test_op([(2,4,6,5), (4,4,3,3),(4,)],
|
||||||
|
@ -878,25 +878,38 @@ class TestOps(unittest.TestCase):
|
||||||
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
|
lambda x,w: torch.nn.functional.conv2d(torch.nn.functional.pad(x, p),w).relu(),
|
||||||
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
|
lambda x,w: Tensor.conv2d(x,w,padding=p).relu(), atol=1e-4)
|
||||||
|
|
||||||
def test_padded_conv2d(self):
|
@unittest.skipIf(Device.DEFAULT == "METAL" and getenv("CI", "") != "", "broken in METAL CI")
|
||||||
bs = 4
|
def test_padded_conv2d_p21(self):
|
||||||
cin = 3
|
bs,cin,H,W,padding = 4, 3, 3, 3, (2,1)
|
||||||
H,W = 3,3
|
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||||
for p in [2, (2,1), (2,2)]:
|
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
|
||||||
with self.subTest(padding := p):
|
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
|
||||||
|
|
||||||
|
@unittest.skipIf(Device.DEFAULT == "METAL" and getenv("CI", "") != "", "broken in METAL CI")
|
||||||
|
def test_padded_conv2d_p22(self):
|
||||||
|
bs,cin,H,W,padding = 4, 3, 3, 3, (2,2)
|
||||||
|
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||||
|
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
|
||||||
|
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
|
||||||
|
|
||||||
|
def test_padded_conv2d_1x1(self):
|
||||||
|
bs,cin,H,W,padding = 4, 3, 1, 1, 2
|
||||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||||
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
|
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
|
||||||
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
|
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
|
||||||
|
|
||||||
def test_padded_conv2d_bs1(self):
|
def test_padded_conv2d_bs1(self):
|
||||||
bs = 1
|
bs,cin,H,W,padding = 1, 3, 3, 3, 1
|
||||||
cin = 3
|
|
||||||
H,W = 3,3
|
|
||||||
padding = 1
|
|
||||||
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
helper_test_op([(bs,cin,11,28), (4,cin,H,W)],
|
||||||
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
|
lambda x,w: torch.nn.functional.conv2d(x,w,padding=padding).relu(),
|
||||||
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
|
lambda x,w: Tensor.conv2d(x,w,padding=padding).relu(), atol=1e-4)
|
||||||
|
|
||||||
|
def test_padding_add(self):
|
||||||
|
helper_test_op([(64,64), (60,60)],
|
||||||
|
lambda x,w: x+torch.nn.functional.pad(w, (2,2,2,2)),
|
||||||
|
lambda x,w: x+w.pad2d((2,2,2,2)),
|
||||||
|
)
|
||||||
|
|
||||||
def test_dilated_conv2d(self):
|
def test_dilated_conv2d(self):
|
||||||
bs = 4
|
bs = 4
|
||||||
cin = 3
|
cin = 3
|
||||||
|
|
|
@ -33,7 +33,7 @@ def colorize_float(x):
|
||||||
ret = f"{x:7.2f}x"
|
ret = f"{x:7.2f}x"
|
||||||
if x < 0.75:
|
if x < 0.75:
|
||||||
return colored(ret, 'green')
|
return colored(ret, 'green')
|
||||||
elif x > 1.33:
|
elif x > 1.15:
|
||||||
return colored(ret, 'red')
|
return colored(ret, 'red')
|
||||||
else:
|
else:
|
||||||
return colored(ret, 'yellow')
|
return colored(ret, 'yellow')
|
||||||
|
@ -118,10 +118,10 @@ class TestBigSpeed(unittest.TestCase):
|
||||||
return super().setUp()
|
return super().setUp()
|
||||||
def test_add(self):
|
def test_add(self):
|
||||||
def f(a, b): return a+b
|
def f(a, b): return a+b
|
||||||
helper_test_generic_square('add', 16384, f, f)
|
helper_test_generic_square('add', 8192, f, f)
|
||||||
def test_exp(self):
|
def test_exp(self):
|
||||||
def f(a, b): return a.exp()
|
def f(a, b): return a.exp()
|
||||||
helper_test_generic_square('exp', 16384, f, f, onearg=True)
|
helper_test_generic_square('exp', 8192, f, f, onearg=True)
|
||||||
def test_gemm_2048(self):
|
def test_gemm_2048(self):
|
||||||
def f(a, b): return a @ b
|
def f(a, b): return a @ b
|
||||||
helper_test_generic_square('gemm', 2048, f, f)
|
helper_test_generic_square('gemm', 2048, f, f)
|
||||||
|
|
|
@ -103,6 +103,7 @@ class dtypes:
|
||||||
|
|
||||||
# NOTE: these are internal dtypes, should probably check for that
|
# NOTE: these are internal dtypes, should probably check for that
|
||||||
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
||||||
|
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
||||||
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
||||||
|
|
||||||
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
||||||
|
|
|
@ -12,10 +12,12 @@ class Node:
|
||||||
b: int
|
b: int
|
||||||
min: int
|
min: int
|
||||||
max: int
|
max: int
|
||||||
def render(self, ops=None, ctx=None) -> str:
|
def render(self, ops=None, ctx=None, strip_parens=False) -> str:
|
||||||
if ops is None: ops = render_python
|
if ops is None: ops = render_python
|
||||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||||
return ops[type(self)](self, ops, ctx)
|
ret = ops[type(self)](self, ops, ctx)
|
||||||
|
if strip_parens and ret[0] == '(' and ret[-1] == ')': ret = ret[1:-1]
|
||||||
|
return ret
|
||||||
def vars(self): return []
|
def vars(self): return []
|
||||||
@functools.cached_property
|
@functools.cached_property
|
||||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||||
|
|
Loading…
Reference in New Issue