fix a bug in flops counting + touchups [pr] (#7126)

This commit is contained in:
George Hotz 2024-10-17 21:02:11 +08:00 committed by GitHub
parent a2eefa6f97
commit be9a433a60
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 5 deletions

View File

@ -433,7 +433,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
for u in uops:
if u.op is UOps.LOAD:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
elif u.op is UOps.STORE:
dont_count = dont_count.union(u.src[1].sparents)
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
@ -755,8 +755,9 @@ spec = PatternMatcher([
def type_verify(uops:List[UOp]):
for i,u in enumerate(uops):
chk = cast(bool, spec.rewrite(u))
assert chk is True, f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
if cast(bool, spec.rewrite(u)) is not True:
print_uops(uops)
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
# *** most of symbolic lives here now ***

View File

@ -7,7 +7,7 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
from tinygrad.renderer import Renderer, TensorCore
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType) -> str:
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
return f"(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
@ -134,7 +134,9 @@ class CStyleLanguage(Renderer):
continue
# mark buffers that we store to writable
if u.op is UOps.STORE and u.src[0].op is UOps.DEFINE_GLOBAL: bufs[u.src[0]] = (bufs[u.src[0]][0], (bufs[u.src[0]][1][0], True))
if u.op is UOps.STORE:
for up in u.src[0].sparents:
if up.op is UOps.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
# naming
prefix = None