don't alloc for InterpretedASTRunner (#2999)

This commit is contained in:
George Hotz 2024-01-03 17:05:53 -08:00 committed by GitHub
parent bca0b95ee3
commit 9699c8c90b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 5 deletions

View File

@ -124,7 +124,7 @@ class TestSafetensors(unittest.TestCase):
def test_save_all_dtypes(self):
for dtype in dtypes.fields().values():
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
path = temp("ones.safetensors")
path = temp(f"ones.{dtype}.safetensors")
ones = Tensor.rand((10,10), dtype=dtype)
safe_save(get_state_dict(ones), path)
assert ones == list(safe_load(path).values())[0]

View File

@ -179,7 +179,7 @@ class InterpretedASTRunner(JITRunner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
st = time.perf_counter()
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals)
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs[1:]], var_vals)
et = time.perf_counter() - st
update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device)
return et
@ -219,7 +219,7 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
if ast.op in BufferOps:
if ast.op == ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx}], ({gstr(ast.arg.dtype)}, True))"
else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx-1}], ({gstr(ast.arg.dtype)}, True))"
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"

View File

@ -1,6 +1,6 @@
from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import prod, colored, getenv
from tinygrad.shape.symbolic import Variable
@ -39,7 +39,8 @@ def run_schedule(schedule:List[ScheduleItem]):
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype,
"PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None)
del si.out.srcs
# run the function (put it in JIT)