wmma: clean up to make WMMA arg order consistent (#2014)

also add cache defeat to extra/gemm/simple_matmul.py
This commit is contained in:
Francis Lam 2023-10-07 17:45:40 -07:00 committed by GitHub
parent cea4cbfc7a
commit dece9958f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 9 deletions

View File

@ -7,6 +7,8 @@ N = getenv("N", 4096)
CNT = getenv("CNT", 10)
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
for i in range(CNT):
if i > 0 and getenv("RAND", 0) != 0:
a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize()
c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize()
comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)
nc = c.numpy()

View File

@ -280,18 +280,18 @@ class Linearizer(OptimizedKernel):
i = 0
for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]):
for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]):
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL")
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), ("METAL",))
i += 2
else:
k = len(locals_to_store[1][2]) // 2
for i in range(0, len(acc), 2):
for y0,y1,x0,x1 in zip(locals_to_store[1][2][:k], locals_to_store[1][2][k:], locals_to_store[0][2][k*i:], locals_to_store[0][2][k*i+k:]):
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL")
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), ("METAL",))
elif self.opts.device == "HIP":
i = 0
for y in range(0, len(locals_to_store[1][2]), 0x10):
for x in range(0, len(locals_to_store[0][2]), 0x10):
self.uop(UOps.WMMA, None, tuple(acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]), "HIP")
self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]+acc[i:i+8]), ("HIP",))
i += 8
else:
if locals_to_store:

View File

@ -132,7 +132,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
depth -= 1
kk("}")
elif uop == UOps.WMMA:
if args == "METAL":
if args[0] == "METAL":
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
kk("{ simdgroup_float8x8 a,b,c;")
kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};")
@ -140,13 +140,13 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};")
kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}")
elif args == "HIP":
elif args[0] == "HIP":
kk("{")
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[8:8+16]])} }};")
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[8+16:8+32]])} }};")
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[:8]])} }};")
kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};")
kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};")
kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};")
kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);")
for i in range(8): kk(f"{r[vin[i]]} = c_frag[{i}];")
for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];")
kk("}")
else:
raise NotImplementedError(f"WMMA not implemented for {args}")