mirror of https://github.com/commaai/tinygrad.git
wmma: clean up to make WMMA arg order consistent (#2014)
also add cache defeat to extra/gemm/simple_matmul.py
This commit is contained in:
parent
cea4cbfc7a
commit
dece9958f8
|
@ -7,6 +7,8 @@ N = getenv("N", 4096)
|
|||
CNT = getenv("CNT", 10)
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
for i in range(CNT):
|
||||
if i > 0 and getenv("RAND", 0) != 0:
|
||||
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
|
||||
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
|
||||
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
|
||||
nc = c.numpy()
|
||||
|
|
|
@ -280,18 +280,18 @@ class Linearizer(OptimizedKernel):
|
|||
i = 0
|
||||
for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]):
|
||||
for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]):
|
||||
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL")
|
||||
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), ("METAL",))
|
||||
i += 2
|
||||
else:
|
||||
k = len(locals_to_store[1][2]) // 2
|
||||
for i in range(0, len(acc), 2):
|
||||
for y0,y1,x0,x1 in zip(locals_to_store[1][2][:k], locals_to_store[1][2][k:], locals_to_store[0][2][k*i:], locals_to_store[0][2][k*i+k:]):
|
||||
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL")
|
||||
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), ("METAL",))
|
||||
elif self.opts.device == "HIP":
|
||||
i = 0
|
||||
for y in range(0, len(locals_to_store[1][2]), 0x10):
|
||||
for x in range(0, len(locals_to_store[0][2]), 0x10):
|
||||
self.uop(UOps.WMMA, None, tuple(acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]), "HIP")
|
||||
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]+acc[i:i+8]), ("HIP",))
|
||||
i += 8
|
||||
else:
|
||||
if locals_to_store:
|
||||
|
|
|
@ -132,7 +132,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
|||
depth -= 1
|
||||
kk("}")
|
||||
elif uop == UOps.WMMA:
|
||||
if args == "METAL":
|
||||
if args[0] == "METAL":
|
||||
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
|
||||
kk("{ simdgroup_float8x8 a,b,c;")
|
||||
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
|
||||
|
@ -140,13 +140,13 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
|||
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
|
||||
kk("simdgroup_multiply_accumulate(c, a, b, c);")
|
||||
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}")
|
||||
elif args == "HIP":
|
||||
elif args[0] == "HIP":
|
||||
kk("{")
|
||||
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[8:8+16]])} }};")
|
||||
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[8+16:8+32]])} }};")
|
||||
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[:8]])} }};")
|
||||
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};")
|
||||
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
|
||||
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
|
||||
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
|
||||
for i in range(8): kk(f"{r[vin[i]]} = c_frag[{i}];")
|
||||
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
|
||||
kk("}")
|
||||
else:
|
||||
raise NotImplementedError(f"WMMA not implemented for {args}")
|
||||
|
|
Loading…
Reference in New Issue