diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index a3ac4389..825089a9 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -7,6 +7,8 @@ N = getenv("N", 4096) CNT = getenv("CNT", 10) a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() for i in range(CNT): + if i > 0 and getenv("RAND", 0) != 0: + a, b = Tensor.rand(N, N, dtype=dtype_in).realize(), Tensor.rand(N, N, dtype=dtype_in).realize() c = (a.reshape(N, 1, N) * b.permute(1,0).reshape(1, N, N)).float().sum(axis=2).realize() if getenv("ACCUM_FP32") else (a @ b).realize() comp = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) nc = c.numpy() diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index bbe2d4b4..98d4cc5f 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -280,18 +280,18 @@ class Linearizer(OptimizedKernel): i = 0 for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]): for x0,x1 in zip(locals_to_store[0][2][::2], locals_to_store[0][2][1::2]): - self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL") + self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), ("METAL",)) i += 2 else: k = len(locals_to_store[1][2]) // 2 for i in range(0, len(acc), 2): for y0,y1,x0,x1 in zip(locals_to_store[1][2][:k], locals_to_store[1][2][k:], locals_to_store[0][2][k*i:], locals_to_store[0][2][k*i+k:]): - self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL") + self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), ("METAL",)) elif self.opts.device == "HIP": i = 0 for y in range(0, len(locals_to_store[1][2]), 0x10): for x in range(0, len(locals_to_store[0][2]), 0x10): - self.uop(UOps.WMMA, None, tuple(acc[i:i+8]+locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]), "HIP") + self.uop(UOps.WMMA, None, tuple(locals_to_store[0][2][x:x+0x10]+locals_to_store[1][2][y:y+0x10]+acc[i:i+8]), ("HIP",)) i += 8 else: if locals_to_store: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 9df05684..bf988796 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -132,7 +132,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st depth -= 1 kk("}") elif uop == UOps.WMMA: - if args == "METAL": + if args[0] == "METAL": # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) kk("{ simdgroup_float8x8 a,b,c;") kk(f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};") @@ -140,13 +140,13 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st kk(f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};") kk("simdgroup_multiply_accumulate(c, a, b, c);") kk(f"{r[vin[4]]} = c.thread_elements()[0]; {r[vin[5]]} = c.thread_elements()[1]; }}") - elif args == "HIP": + elif args[0] == "HIP": kk("{") - kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[8:8+16]])} }};") - kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[8+16:8+32]])} }};") - kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[:8]])} }};") + kk(f"half16 a_frag = {{ {','.join(['(half)'+r[x] for x in vin[0:16]])} }};") + kk(f"half16 b_frag = {{ {','.join(['(half)'+r[x] for x in vin[16:32]])} }};") + kk(f"float8 c_frag = {{ {','.join([r[x] for x in vin[32:]])} }};") kk("c_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, c_frag);") - for i in range(8): kk(f"{r[vin[i]]} = c_frag[{i}];") + for i in range(8): kk(f"{r[vin[32+i]]} = c_frag[{i}];") kk("}") else: raise NotImplementedError(f"WMMA not implemented for {args}")