push geps through wmma (#6559)

* push geps through wmma

* update tests
This commit is contained in:
George Hotz 2024-09-17 14:38:40 +08:00 committed by GitHub
parent 5a30a32af8
commit 0ab06d5840
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 53 additions and 26 deletions

View File

@ -66,10 +66,10 @@ jobs:
TC=3 PYTHONPATH=. DEBUG=3 AMX=1 EMULATE_AMX=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
- name: Test device flop counts
run: |
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
PYTHONPATH=. DEBUG=2 EMULATE_INTEL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
PYTHONPATH=. DEBUG=2 EMULATE_INTEL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
PYTHONPATH=. DEBUG=2 AMX=1 EMULATE_AMX=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py

View File

@ -60,6 +60,27 @@ class TestMemoryCount(unittest.TestCase):
_, mem = get_stats(a.assign(a+a))
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
# NOTE: this still isn't testing unroll using the acc
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
class TestUOpsStatsMatmulHalf(unittest.TestCase):
def test_simple_matmul_half(self, N=16):
GlobalCounters.reset()
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
c = a.matmul(b)
c.realize()
expected_ops = N ** 3 * 2
self.assertEqual(expected_ops, GlobalCounters.global_ops)
def test_bigger_matmul_half(self): self.test_simple_matmul_half(64)
def test_batched_matmul_half(self, N=16):
GlobalCounters.reset()
a, b = Tensor.empty(4, N, N, dtype=dtypes.half), Tensor.empty(1, N, N, dtype=dtypes.half)
c = a.matmul(b)
c.realize()
expected_ops = 4 * N ** 3 * 2
self.assertEqual(expected_ops, GlobalCounters.global_ops)
class TestUOpsStats(unittest.TestCase):
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
def test_simple_add(self):
@ -95,16 +116,6 @@ class TestUOpsStats(unittest.TestCase):
# NOTE: it's hard to assert on the memory here, all depends on caching
assert required_mem <= mem
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
def test_simple_matmul_half(self):
GlobalCounters.reset()
N = 16
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
c = a.matmul(b)
c.realize()
expected_ops = N ** 3 * 2
self.assertEqual(expected_ops, GlobalCounters.global_ops)
#MULACC should have the same stats as MUL + ADD
def test_mulacc(self):
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())

View File

@ -219,6 +219,31 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const_like(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
# TODO: there's a lot shared with no_vectorized_wmma here
def gep_through_wmma(gep:UOp, wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
wmma_idxs = gep.arg[::out_sz]
for i in range(out_sz):
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
src_args = []
ssz = prod(x[1] for x in sz)
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
tsrcs.append(s.gep(tuple(src_args)))
return UOp(UOps.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
def no_vectorized_wmma(wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
if wmma.dtype.count == out_sz: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
ssz = prod(x[1] for x in sz)
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
# this is symbolic 2.0
constant_folder = PatternMatcher([
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
@ -247,6 +272,8 @@ constant_folder = PatternMatcher([
# push all GEPs through ALUs (fix arange stuff)
(UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'),
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
# push some GEPs through WMMAs
(UPat(UOps.GEP, src=(UPat(UOps.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
# tensor core with a 0 input is acc
(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
@ -542,17 +569,6 @@ def no_vectorized_acc(acc:UOp):
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
return UOp(UOps.VECTORIZE, acc.dtype, alus)
def no_vectorized_wmma(wmma:UOp):
out_sz = prod(x[1] for x in wmma.arg[6][-1])
if wmma.dtype.count == out_sz: return None
tsrcs = []
for s,sz in zip(wmma.src, wmma.arg[6]):
ssz = prod(x[1] for x in sz)
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
@functools.lru_cache(None)
def find_gate(x:UOp) -> Optional[UOp]:

View File

@ -74,7 +74,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
UOps.SWIZZLE: "#7ACD93", UOps.SHAPETRACKER: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
UOps.WMMA: "#efefc0", UOps.SWIZZLE: "#7ACD93", UOps.SHAPETRACKER: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
graph_uops_cnt = 0
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
def graph_uops(uops:List[UOp]):