mirror of https://github.com/commaai/tinygrad.git
fix a bug in flops counting + touchups [pr] (#7126)
This commit is contained in:
parent
a2eefa6f97
commit
be9a433a60
|
@ -433,7 +433,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
|||
for u in uops:
|
||||
if u.op is UOps.LOAD:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[2].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
||||
elif u.op is UOps.STORE:
|
||||
dont_count = dont_count.union(u.src[1].sparents)
|
||||
if len(u.src) > 3: dont_count = dont_count.union(u.src[3].sparents)
|
||||
|
@ -755,8 +755,9 @@ spec = PatternMatcher([
|
|||
|
||||
def type_verify(uops:List[UOp]):
|
||||
for i,u in enumerate(uops):
|
||||
chk = cast(bool, spec.rewrite(u))
|
||||
assert chk is True, f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
if cast(bool, spec.rewrite(u)) is not True:
|
||||
print_uops(uops)
|
||||
raise RuntimeError(f"UOp verification failed at {i} on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}")
|
||||
|
||||
# *** most of symbolic lives here now ***
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
|||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
|
||||
def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType) -> str:
|
||||
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
|
||||
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
return f"(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
|
||||
|
@ -134,7 +134,9 @@ class CStyleLanguage(Renderer):
|
|||
continue
|
||||
|
||||
# mark buffers that we store to writable
|
||||
if u.op is UOps.STORE and u.src[0].op is UOps.DEFINE_GLOBAL: bufs[u.src[0]] = (bufs[u.src[0]][0], (bufs[u.src[0]][1][0], True))
|
||||
if u.op is UOps.STORE:
|
||||
for up in u.src[0].sparents:
|
||||
if up.op is UOps.DEFINE_GLOBAL: bufs[up] = (bufs[up][0], (bufs[up][1][0], True))
|
||||
|
||||
# naming
|
||||
prefix = None
|
||||
|
|
Loading…
Reference in New Issue