mirror of https://github.com/commaai/tinygrad.git
don't alloc for InterpretedASTRunner (#2999)
This commit is contained in:
parent
bca0b95ee3
commit
9699c8c90b
|
@ -124,7 +124,7 @@ class TestSafetensors(unittest.TestCase):
|
|||
def test_save_all_dtypes(self):
|
||||
for dtype in dtypes.fields().values():
|
||||
if dtype in [dtypes.bfloat16]: continue # not supported in numpy
|
||||
path = temp("ones.safetensors")
|
||||
path = temp(f"ones.{dtype}.safetensors")
|
||||
ones = Tensor.rand((10,10), dtype=dtype)
|
||||
safe_save(get_state_dict(ones), path)
|
||||
assert ones == list(safe_load(path).values())[0]
|
||||
|
|
|
@ -179,7 +179,7 @@ class InterpretedASTRunner(JITRunner):
|
|||
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float:
|
||||
st = time.perf_counter()
|
||||
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals)
|
||||
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs[1:]], var_vals)
|
||||
et = time.perf_counter() - st
|
||||
update_stats(f"<interpreted {rawbufs[0].size}>", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, device=rawbufs[0].device)
|
||||
return et
|
||||
|
@ -219,7 +219,7 @@ def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> Interpret
|
|||
|
||||
if ast.op in BufferOps:
|
||||
if ast.op == ast.op == BufferOps.CONST: tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
|
||||
else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx}], ({gstr(ast.arg.dtype)}, True))"
|
||||
else: tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx-1}], ({gstr(ast.arg.dtype)}, True))"
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
|
||||
else:
|
||||
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List, Dict, Optional, cast
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
|
||||
from tinygrad.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.helpers import prod, colored, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
@ -39,7 +39,8 @@ def run_schedule(schedule:List[ScheduleItem]):
|
|||
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
|
||||
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype)
|
||||
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype,
|
||||
"PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None)
|
||||
del si.out.srcs
|
||||
|
||||
# run the function (put it in JIT)
|
||||
|
|
Loading…
Reference in New Issue