mirror of https://github.com/commaai/tinygrad.git
parent
5a30a32af8
commit
0ab06d5840
|
@ -66,10 +66,10 @@ jobs:
|
|||
TC=3 PYTHONPATH=. DEBUG=3 AMX=1 EMULATE_AMX=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm
|
||||
- name: Test device flop counts
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_INTEL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul_half
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_METAL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_INTEL=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStatsMatmulHalf
|
||||
PYTHONPATH=. DEBUG=2 AMX=1 EMULATE_AMX=1 PYTHON=1 python3 ./test/test_uops_stats.py TestUOpsStats.test_simple_matmul
|
||||
- name: Test dtype with Python emulator
|
||||
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py
|
||||
|
|
|
@ -60,6 +60,27 @@ class TestMemoryCount(unittest.TestCase):
|
|||
_, mem = get_stats(a.assign(a+a))
|
||||
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
||||
|
||||
# NOTE: this still isn't testing unroll using the acc
|
||||
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
|
||||
class TestUOpsStatsMatmulHalf(unittest.TestCase):
|
||||
def test_simple_matmul_half(self, N=16):
|
||||
GlobalCounters.reset()
|
||||
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
|
||||
c = a.matmul(b)
|
||||
c.realize()
|
||||
expected_ops = N ** 3 * 2
|
||||
self.assertEqual(expected_ops, GlobalCounters.global_ops)
|
||||
|
||||
def test_bigger_matmul_half(self): self.test_simple_matmul_half(64)
|
||||
|
||||
def test_batched_matmul_half(self, N=16):
|
||||
GlobalCounters.reset()
|
||||
a, b = Tensor.empty(4, N, N, dtype=dtypes.half), Tensor.empty(1, N, N, dtype=dtypes.half)
|
||||
c = a.matmul(b)
|
||||
c.realize()
|
||||
expected_ops = 4 * N ** 3 * 2
|
||||
self.assertEqual(expected_ops, GlobalCounters.global_ops)
|
||||
|
||||
class TestUOpsStats(unittest.TestCase):
|
||||
@unittest.skipIf(getenv("PTX"), "wrong in PTX")
|
||||
def test_simple_add(self):
|
||||
|
@ -95,16 +116,6 @@ class TestUOpsStats(unittest.TestCase):
|
|||
# NOTE: it's hard to assert on the memory here, all depends on caching
|
||||
assert required_mem <= mem
|
||||
|
||||
@unittest.skipUnless(getenv("PYTHON"), "only run test on emulated tensor cores")
|
||||
def test_simple_matmul_half(self):
|
||||
GlobalCounters.reset()
|
||||
N = 16
|
||||
a, b = Tensor.empty(N, N, dtype=dtypes.half), Tensor.empty(N, N, dtype=dtypes.half)
|
||||
c = a.matmul(b)
|
||||
c.realize()
|
||||
expected_ops = N ** 3 * 2
|
||||
self.assertEqual(expected_ops, GlobalCounters.global_ops)
|
||||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
|
||||
|
|
|
@ -219,6 +219,31 @@ def index_collapse(idx,rng,buf,add,mul,ld,reduce):
|
|||
return UOp(reduce.op, reduce.dtype, (UOp(ld.op, ld.dtype, (buf, add+mul*idx, ld.const_like(0), idx.ge(rng.src[0]) & idx.lt(rng.src[1]))),)+
|
||||
tuple(x for x in reduce.src[1:] if x is not rng), reduce.arg)
|
||||
|
||||
# TODO: there's a lot shared with no_vectorized_wmma here
|
||||
def gep_through_wmma(gep:UOp, wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
wmma_idxs = gep.arg[::out_sz]
|
||||
for i in range(out_sz):
|
||||
if tuple(x-i for x in gep.arg[i::out_sz]) != wmma_idxs: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
src_args = []
|
||||
ssz = prod(x[1] for x in sz)
|
||||
for w in wmma_idxs: src_args += list(range((w//out_sz)*ssz, (w//out_sz)*ssz + ssz))
|
||||
tsrcs.append(s.gep(tuple(src_args)))
|
||||
return UOp(UOps.WMMA, gep.dtype, tuple(tsrcs), wmma.arg)
|
||||
|
||||
def no_vectorized_wmma(wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
if wmma.dtype.count == out_sz: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
ssz = prod(x[1] for x in sz)
|
||||
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
||||
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
# this is symbolic 2.0
|
||||
constant_folder = PatternMatcher([
|
||||
# bool ADD is OR, MUL is AND. prevents other rules to rewrite bool ADD/MUL incorrectly
|
||||
|
@ -247,6 +272,8 @@ constant_folder = PatternMatcher([
|
|||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat(UOps.GEP, src=(UPat((UOps.ALU, UOps.CAST, UOps.BITCAST), name='alu'),), name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg)),
|
||||
# push some GEPs through WMMAs
|
||||
(UPat(UOps.GEP, src=(UPat(UOps.WMMA, name="wmma"),), name="gep"), gep_through_wmma),
|
||||
# tensor core with a 0 input is acc
|
||||
(UPat(UOps.WMMA, src=(UPat.const(None, 0.0), UPat.var(), UPat.var("acc"))), lambda acc: acc),
|
||||
(UPat(UOps.WMMA, src=(UPat.var(), UPat.const(None, 0.0), UPat.var("acc"))), lambda acc: acc),
|
||||
|
@ -542,17 +569,6 @@ def no_vectorized_acc(acc:UOp):
|
|||
tuple(UOp(UOps.GEP, s.dtype.scalar(), (s,), (i,)) if j == 0 else s for j,s in enumerate(acc.src)), acc.arg+(i,)) for i in range(acc.dtype.count))
|
||||
return UOp(UOps.VECTORIZE, acc.dtype, alus)
|
||||
|
||||
def no_vectorized_wmma(wmma:UOp):
|
||||
out_sz = prod(x[1] for x in wmma.arg[6][-1])
|
||||
if wmma.dtype.count == out_sz: return None
|
||||
tsrcs = []
|
||||
for s,sz in zip(wmma.src, wmma.arg[6]):
|
||||
ssz = prod(x[1] for x in sz)
|
||||
tsrcs.append([s.gep(tuple(range(grp, grp+ssz))) for grp in range(0, s.dtype.count, ssz)])
|
||||
wmmas = [UOp(UOps.WMMA, wmma.dtype.scalar().vec(out_sz), tsrc, wmma.arg) for tsrc in zip(*tsrcs)]
|
||||
wmma_ex = flatten([[e.gep(i) for i in range(out_sz)] for e in wmmas])
|
||||
return UOp(UOps.VECTORIZE, wmma.dtype, tuple(wmma_ex))
|
||||
|
||||
def delete_redundant_gates(root:UOp) -> Optional[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def find_gate(x:UOp) -> Optional[UOp]:
|
||||
|
|
|
@ -74,7 +74,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
|
|||
uops_colors = {UOps.ALU: "#ffffc0", UOps.LOAD: "#ffc0c0", UOps.STORE: "#c0ffc0", UOps.CONST: "#e0e0e0", UOps.VCONST: "#e0e0e0",
|
||||
UOps.DEFINE_GLOBAL: "#ffe0b0", UOps.DEFINE_LOCAL: "#ffe0d0", UOps.DEFINE_ACC: "#f0ffe0", UOps.REDUCE: "#C4A484",
|
||||
UOps.RANGE: "#c8a0e0", UOps.ASSIGN: "#e0ffc0", UOps.BARRIER: "#ff8080", UOps.IF: "#c8b0c0", UOps.SPECIAL: "#c0c0ff",
|
||||
UOps.SWIZZLE: "#7ACD93", UOps.SHAPETRACKER: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
|
||||
UOps.WMMA: "#efefc0", UOps.SWIZZLE: "#7ACD93", UOps.SHAPETRACKER: "#C8F9D4", UOps.REDUCE_AXIS: "#f58488"}
|
||||
graph_uops_cnt = 0
|
||||
def word_wrap(x, wrap=80): return x if len(x) <= wrap else (x[0:wrap] + "\n" + word_wrap(x[wrap:], wrap))
|
||||
def graph_uops(uops:List[UOp]):
|
||||
|
|
Loading…
Reference in New Issue