mirror of https://github.com/commaai/tinygrad.git
parent
88ff1edcf0
commit
877c78b4ce
|
@ -104,7 +104,7 @@ if __name__ == "__main__":
|
|||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
|
||||
run_schedule(schedule_independent, disable_logging=True)
|
||||
run_schedule(schedule_independent)
|
||||
run_schedule(schedule_input)
|
||||
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
|
||||
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
||||
|
|
|
@ -12,7 +12,7 @@ from test.helpers import derandomize_model
|
|||
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
|
||||
from examples.stable_diffusion import UNetModel
|
||||
from examples.stable_diffusion import UNetModel, ResBlock
|
||||
|
||||
global_mem_used = 0
|
||||
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||
|
@ -55,6 +55,16 @@ class TestRealWorld(unittest.TestCase):
|
|||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 953)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["CPU", "TORCH"], "tons of ram with interpreted")
|
||||
def test_mini_stable_diffusion(self):
|
||||
model = [ResBlock(16, 24, 16) for _ in range(4)]
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2):
|
||||
for l in model: t = l(t, t2)
|
||||
return t.realize()
|
||||
helper_test("test_mini_sd", lambda: (Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24)), test, 0.01, 43)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp1")
|
||||
def test_llama(self):
|
||||
|
@ -66,7 +76,7 @@ class TestRealWorld(unittest.TestCase):
|
|||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.25 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.27 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||
def test_gpt2(self):
|
||||
|
|
|
@ -221,8 +221,9 @@ class TestJit(unittest.TestCase):
|
|||
np.testing.assert_equal([1], cache.bad_cache.numpy())
|
||||
|
||||
for i in range(5):
|
||||
cache.good_jitted(zero)
|
||||
cache.bad_jitted(zero)
|
||||
x = Tensor([i]) # NOTE: if this doesn't change, it just hits the lazybuffer cache
|
||||
cache.good_jitted(x)
|
||||
cache.bad_jitted(x)
|
||||
|
||||
# verify the jitted calls read 1 from the cache
|
||||
np.testing.assert_equal([1], cache.good_jitted(zero).numpy())
|
||||
|
|
|
@ -372,5 +372,52 @@ class TestSchedule(unittest.TestCase):
|
|||
out = x.sum(axis=2).T+y
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_two_elus_sum(self):
|
||||
x = Tensor.empty(32, 32)
|
||||
y = Tensor.empty(32, 32)
|
||||
out = x.sum(1).relu().elu() + y.sum(1).relu().elu()
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_multistage_reduce(self):
|
||||
x = Tensor.empty(32, 32, 32)
|
||||
out = x.sum(2).relu().sum(1)
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_multistage_reduce_fork(self):
|
||||
x = Tensor.empty(32, 32, 32)
|
||||
x = x.sum(2)
|
||||
out2 = x + 1
|
||||
out = x.relu().sum(1) + out2[0]
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_example_matmul(self):
|
||||
x = Tensor.eye(64, requires_grad=True)
|
||||
y = Tensor.eye(64, requires_grad=True)
|
||||
z = y.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_contiguous_add(self):
|
||||
x = Tensor.empty(32)
|
||||
y = Tensor.empty(32)
|
||||
z = Tensor.empty(32)
|
||||
out = (x+y).contiguous()+z
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_double_sum_ref(self):
|
||||
x = Tensor.empty(32, 32, 32)
|
||||
x = x.sum(2)
|
||||
out = x + x[:, 4]
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_reduce_shrink(self):
|
||||
x = Tensor.empty(32, 32)
|
||||
y = Tensor.empty(16)
|
||||
x = x.sum(1)
|
||||
x = x[:16]
|
||||
out = x + y
|
||||
check_schedule(out, 2) # TODO: this should be 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -78,12 +78,25 @@ class CheckingShapeTracker:
|
|||
assert self.st.shape == self.shape
|
||||
assert x == y, f"mismatch shapetracker:{x} real:{y}"
|
||||
|
||||
@unittest.skip("don't create shapetrackers with views")
|
||||
class TestRealIssues(unittest.TestCase):
|
||||
def test_reshape_doesnt_multiview(self):
|
||||
self.st = ShapeTracker((View.create((256, 256, 2, 2, 2, 2, 2, 256, 8, 2), (0, 8, 0, 4, 0, 0, 2, 16384, 2048, 1), 0, None),))
|
||||
self.st.reshape((128, 2, 256, 2, 2, 2, 2, 2, 256, 8, 2))
|
||||
assert len(self.st.views) == 1
|
||||
|
||||
def test_reshape_stable_diffusion(self):
|
||||
# regression test for https://github.com/tinygrad/tinygrad/pull/2616
|
||||
st = ShapeTracker((View((2, 1920, 32, 32), (1310720, 1024, 32, 1), 0, ((0, 2), (0, 1280), (0, 32), (0, 32)), False),))
|
||||
st = st.reshape((2, 32, 240, 256))
|
||||
assert len(st.views) == 2
|
||||
|
||||
def test_reshape_trailing_invalid_ones(self):
|
||||
st = ShapeTracker((View(shape=(1, 1, 5), strides=(0, 0, 1), offset=-5, mask=((1, 1), (0, 1), (0, 5)), contiguous=False),))
|
||||
st = st.reshape((5,))
|
||||
assert len(st.views) == 1
|
||||
assert st.views[0].mask == ((0,0),)
|
||||
|
||||
class TestRealDoesntSimplify(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
st = self.st.real_strides()
|
||||
|
@ -136,7 +149,7 @@ class TestIndexExpressions2d(unittest.TestCase):
|
|||
def setUp(self):
|
||||
shapes = [(30, 5), (15, 10), (15, 1), (5, 10), (5, 1)] # Make sure dim0 is a multiple of 5, one of the tests divides this dimension by 5
|
||||
offsets = [0, 1, 15, 28, 10000]
|
||||
self.sts = [ShapeTracker((View.create(base_shape, offset=offset),)) for base_shape in shapes for offset in offsets]
|
||||
self.sts = [ShapeTracker.from_shape((prod(base_shape)+offset,)).shrink(((offset, offset+prod(base_shape)),)).reshape(base_shape) for base_shape in shapes for offset in offsets]
|
||||
self.offset = [NumNode(offset) for base_shape in shapes for offset in offsets]
|
||||
self.shapes = [shape for shape in shapes for offset in offsets]
|
||||
self.node_exprs = []
|
||||
|
@ -478,18 +491,6 @@ class TestComplexShapeTracker(unittest.TestCase):
|
|||
print(self.st.views)
|
||||
assert self.st.contiguous
|
||||
|
||||
def test_reshape_stable_diffusion(self):
|
||||
# regression test for https://github.com/tinygrad/tinygrad/pull/2616
|
||||
st = ShapeTracker((View((2, 1920, 32, 32), (1310720, 1024, 32, 1), 0, ((0, 2), (0, 1280), (0, 32), (0, 32)), False),))
|
||||
st = st.reshape((2, 32, 240, 256))
|
||||
assert len(st.views) == 2
|
||||
|
||||
def test_reshape_trailing_invalid_ones(self):
|
||||
st = ShapeTracker((View(shape=(1, 1, 5), strides=(0, 0, 1), offset=-5, mask=((1, 1), (0, 1), (0, 5)), contiguous=False),))
|
||||
st = st.reshape((5,))
|
||||
assert len(st.views) == 1
|
||||
assert st.views[0].mask == ((0,0),)
|
||||
|
||||
class TestSingleShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.st = CheckingShapeTracker((7,4))
|
||||
|
|
Loading…
Reference in New Issue