mirror of https://github.com/commaai/tinygrad.git
early assert buffer count limit [run_process_replay] (#6746)
* better error message for buffer count limit [run_process_replay] * 3.9 needs that * assert ScheduleItem * new _test_buf_cnt
This commit is contained in:
parent
4ebc9589a6
commit
b629a7998d
|
@ -1,6 +1,5 @@
|
|||
import csv, pathlib, time, numpy as np
|
||||
from os import getenv
|
||||
from tinygrad.device import CompileError
|
||||
import torch
|
||||
torch.set_num_threads(1)
|
||||
import onnx
|
||||
|
@ -73,10 +72,10 @@ def benchmark_model(m, devices, validate_outs=False):
|
|||
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
|
||||
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
|
||||
del inputs, tinygrad_model, tinygrad_jitted_model
|
||||
except CompileError as e:
|
||||
# METAL fails with buffer count limit
|
||||
if m == "dm" and device == "METAL": return
|
||||
raise e
|
||||
except RuntimeError as e:
|
||||
# TODO: we don't run the dm model on METAL for now
|
||||
if Device.DEFAULT == "METAL": assert "buffer count limit" in str(e)
|
||||
else: raise e
|
||||
|
||||
# convert model to torch
|
||||
try:
|
||||
|
|
|
@ -10,7 +10,6 @@ from typing import List, Optional, Union, cast
|
|||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import DType, PtrDType
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
|
@ -72,18 +71,6 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
|||
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||
|
||||
def _test_buf_cnt(cnt:int, buf_max:int, allowed:int):
|
||||
backup_renderer = Device[Device.DEFAULT].renderer
|
||||
r = CStyleLanguage()
|
||||
r.buf_max = buf_max
|
||||
alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)])
|
||||
s = alu.schedule()
|
||||
assert len(s) == allowed
|
||||
Device[Device.DEFAULT].renderer = backup_renderer
|
||||
run_schedule(s)
|
||||
expected = functools.reduce(lambda x,y: x+y, [np.ones((1, 1)) for _ in range(cnt-1)])
|
||||
np.testing.assert_equal(alu.numpy(), expected)
|
||||
|
||||
class TestSchedule(unittest.TestCase):
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
|
@ -1326,11 +1313,20 @@ class TestSchedule(unittest.TestCase):
|
|||
@unittest.expectedFailure
|
||||
def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(6, FUSE_CONV_BW=1, AST_REWRITE=1, dtype=dtypes.half)
|
||||
|
||||
def test_buf_cnt_at_limit(self): _test_buf_cnt(5, buf_max=5, allowed=1)
|
||||
def _test_buf_cnt(self, cnt:int, allowed:int):
|
||||
if (m:=Device[Device.DEFAULT].renderer.buf_max) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
|
||||
alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)])
|
||||
s = alu.schedule()
|
||||
assert len(s) == allowed
|
||||
run_schedule(s)
|
||||
expected = functools.reduce(lambda x,y: x+y, [np.ones((1, 1)) for _ in range(cnt-1)])
|
||||
np.testing.assert_equal(alu.numpy(), expected)
|
||||
|
||||
def test_buf_cnt_at_limit(self): self._test_buf_cnt(31, allowed=1)
|
||||
@unittest.expectedFailure
|
||||
def test_buf_cnt_over_limit(self): _test_buf_cnt(7, buf_max=5, allowed=2)
|
||||
def test_buf_cnt_over_limit(self): self._test_buf_cnt(32, allowed=2)
|
||||
@unittest.expectedFailure
|
||||
def test_buf_cnt_over_limit_alt(self): _test_buf_cnt(11, buf_max=5, allowed=3)
|
||||
def test_buf_cnt_over_limit_alt(self): self._test_buf_cnt(63, allowed=3)
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
def check_schedule(self, xt:Union[Tensor,List[Tensor]], cnt:int):
|
||||
|
|
|
@ -10,7 +10,7 @@ from tinygrad.shape.symbolic import Variable, sint
|
|||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
|
||||
# creation can recurse a lot
|
||||
|
@ -412,7 +412,9 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem]
|
|||
kernel_number += 1
|
||||
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
|
||||
for out in lsi.outputs: del out.srcs # can only schedule once
|
||||
schedule.append(ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata))
|
||||
schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.bufs if x.size != 0), lsi.metadata))
|
||||
if (m:=Device[(device:=si.outputs[0].device)].renderer.buf_max) and len(si.bufs) >= m:
|
||||
raise RuntimeError(f"{si} exceeded the buffer count limit for {device}: {len(si.bufs)} >= {m}")
|
||||
for x in graph[lsi]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
|
Loading…
Reference in New Issue