mirror of https://github.com/commaai/tinygrad.git
tests from new lazy branch (#2774)
* tests from new lazy branch * fix lin 11 * that was needed * doesn't fail * mark * meant that * llvm passes
This commit is contained in:
parent
a044125c39
commit
c6eb618013
|
@ -390,7 +390,7 @@ jobs:
|
||||||
cache-name: cache-gpuocelot-build
|
cache-name: cache-gpuocelot-build
|
||||||
with:
|
with:
|
||||||
path: ${{ github.workspace }}/gpuocelot/ocelot
|
path: ${{ github.workspace }}/gpuocelot/ocelot
|
||||||
key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild
|
key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild-2
|
||||||
- name: Clone/compile gpuocelot
|
- name: Clone/compile gpuocelot
|
||||||
if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton') && steps.cache-build.outputs.cache-hit != 'true'
|
if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton') && steps.cache-build.outputs.cache-hit != 'true'
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -416,7 +416,8 @@ def train_cifar():
|
||||||
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
|
||||||
cl = time.monotonic()
|
cl = time.monotonic()
|
||||||
if not getenv("DIST"):
|
if not getenv("DIST"):
|
||||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
|
||||||
|
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
|
||||||
else:
|
else:
|
||||||
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
|
||||||
st = cl
|
st = cl
|
||||||
|
|
|
@ -13,7 +13,6 @@ import onnx
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Tuple, List, Optional, Dict
|
from typing import Tuple, List, Optional, Dict
|
||||||
from extra.onnx import get_run_onnx
|
from extra.onnx import get_run_onnx
|
||||||
from tinygrad.graph import log_schedule_item
|
|
||||||
from tinygrad import Tensor, Device
|
from tinygrad import Tensor, Device
|
||||||
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
|
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
|
||||||
from tinygrad.realize import run_schedule, lower_schedule_item
|
from tinygrad.realize import run_schedule, lower_schedule_item
|
||||||
|
@ -111,10 +110,6 @@ if __name__ == "__main__":
|
||||||
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
||||||
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
||||||
|
|
||||||
if GRAPH:
|
|
||||||
for si in schedule_input: log_schedule_item(si)
|
|
||||||
for si in schedule: log_schedule_item(si)
|
|
||||||
|
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
run_schedule(schedule[:])
|
run_schedule(schedule[:])
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
import unittest, time
|
import unittest, time, gc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.nn import optim
|
from tinygrad.nn import optim
|
||||||
|
@ -14,28 +14,32 @@ from examples.hlb_cifar10 import SpeedyResNet
|
||||||
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
|
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
|
||||||
from examples.stable_diffusion import UNetModel
|
from examples.stable_diffusion import UNetModel
|
||||||
|
|
||||||
|
global_mem_used = 0
|
||||||
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
|
||||||
tms = []
|
tms = []
|
||||||
for _ in range(4):
|
for _ in range(4):
|
||||||
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
|
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
|
||||||
GlobalCounters.reset()
|
GlobalCounters.reset()
|
||||||
GlobalCounters.mem_used = 0
|
|
||||||
Device[Device.DEFAULT].synchronize()
|
Device[Device.DEFAULT].synchronize()
|
||||||
st = time.perf_counter_ns()
|
st = time.perf_counter_ns()
|
||||||
train(*early_gen)
|
train(*early_gen)
|
||||||
Device[Device.DEFAULT].synchronize()
|
Device[Device.DEFAULT].synchronize()
|
||||||
tms.append(time.perf_counter_ns() - st)
|
tms.append(time.perf_counter_ns() - st)
|
||||||
|
mem_used = GlobalCounters.mem_used - global_mem_used
|
||||||
|
|
||||||
# TODO: jit should expose this correctly with graph
|
# TODO: jit should expose this correctly with graph
|
||||||
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
|
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
|
||||||
print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
||||||
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
||||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
||||||
if all_jitted:
|
if all_jitted:
|
||||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
|
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
|
||||||
|
|
||||||
class TestRealWorld(unittest.TestCase):
|
class TestRealWorld(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
gc.collect()
|
||||||
|
global global_mem_used
|
||||||
|
global_mem_used = GlobalCounters.mem_used
|
||||||
self.old_type = Tensor.default_type
|
self.old_type = Tensor.default_type
|
||||||
np.random.seed(2002)
|
np.random.seed(2002)
|
||||||
|
|
||||||
|
@ -62,7 +66,7 @@ class TestRealWorld(unittest.TestCase):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def test(t): return model(t, 0).realize()
|
def test(t): return model(t, 0).realize()
|
||||||
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
|
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
|
||||||
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.25 if CI else 13.5, 181 if CI else 685, all_jitted=True)
|
||||||
|
|
||||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
|
||||||
def test_gpt2(self):
|
def test_gpt2(self):
|
||||||
|
@ -73,7 +77,24 @@ class TestRealWorld(unittest.TestCase):
|
||||||
derandomize_model(model)
|
derandomize_model(model)
|
||||||
@TinyJit
|
@TinyJit
|
||||||
def test(t, v): return model(t, v).realize()
|
def test(t, v): return model(t, v).realize()
|
||||||
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 164 if CI else 468, all_jitted=True)
|
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 164 if CI else 468, all_jitted=True)
|
||||||
|
|
||||||
|
def test_train_mnist(self):
|
||||||
|
from examples.beautiful_mnist import Model
|
||||||
|
with Tensor.train():
|
||||||
|
model = Model()
|
||||||
|
optimizer = optim.Adam(get_parameters(model))
|
||||||
|
BS = 32
|
||||||
|
|
||||||
|
@TinyJit
|
||||||
|
def train(X):
|
||||||
|
out = model(X)
|
||||||
|
loss = out.mean()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 127)
|
||||||
|
|
||||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
|
||||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")
|
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -25,7 +25,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||||
if len(sched) != allowed or DEBUG >= 3:
|
if len(sched) != allowed or DEBUG >= 3:
|
||||||
for i, s in enumerate(sched):
|
for i, s in enumerate(sched):
|
||||||
print("op", i)
|
print("kernel", i+1)
|
||||||
print_tree(s.ast)
|
print_tree(s.ast)
|
||||||
assert len(sched) == allowed
|
assert len(sched) == allowed
|
||||||
# test the (non loadops) ops linearize
|
# test the (non loadops) ops linearize
|
||||||
|
@ -260,6 +260,22 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(c, 1)
|
check_schedule(c, 1)
|
||||||
check_schedule(e, 1)
|
check_schedule(e, 1)
|
||||||
|
|
||||||
|
def test_shrink_fuse(self):
|
||||||
|
a = Tensor.empty(8192, 16)
|
||||||
|
b = Tensor.empty(8192, 16)
|
||||||
|
c = a * b
|
||||||
|
d = Tensor.empty(1, 16)
|
||||||
|
e = c[0] * d
|
||||||
|
check_schedule(e, 1)
|
||||||
|
|
||||||
|
def test_expand_nofuse(self):
|
||||||
|
a = Tensor.empty(1, 16)
|
||||||
|
b = Tensor.empty(1, 16)
|
||||||
|
c = a * b
|
||||||
|
d = Tensor.empty(8192, 16)
|
||||||
|
e = c * d
|
||||||
|
check_schedule(e, 2)
|
||||||
|
|
||||||
# this is the failing case in openpilot...it's very simple like this
|
# this is the failing case in openpilot...it's very simple like this
|
||||||
@unittest.skip("failing in old lazy")
|
@unittest.skip("failing in old lazy")
|
||||||
def test_image_conv_fusion(self):
|
def test_image_conv_fusion(self):
|
||||||
|
@ -304,12 +320,18 @@ class TestSchedule(unittest.TestCase):
|
||||||
check_schedule(x, 3)
|
check_schedule(x, 3)
|
||||||
|
|
||||||
def test_resnet_block(self):
|
def test_resnet_block(self):
|
||||||
from extra.models.resnet import BasicBlock
|
|
||||||
Tensor.training = False
|
Tensor.training = False
|
||||||
bb = BasicBlock(64,64)
|
|
||||||
|
in_planes, planes = 64, 64
|
||||||
|
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||||
|
bn1 = nn.BatchNorm2d(planes)
|
||||||
|
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
|
||||||
|
bn2 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
x = Tensor.empty(1, 64, 32, 32)
|
x = Tensor.empty(1, 64, 32, 32)
|
||||||
out = bb(x)
|
out = bn1(conv1(x)).relu()
|
||||||
|
out = bn2(conv2(out))
|
||||||
|
out = (out + x).relu()
|
||||||
check_schedule(out, 4)
|
check_schedule(out, 4)
|
||||||
|
|
||||||
def test_contiguous_while_contiguous(self):
|
def test_contiguous_while_contiguous(self):
|
||||||
|
|
|
@ -115,8 +115,9 @@ class TestSymbolicReshape(unittest.TestCase):
|
||||||
|
|
||||||
def test_reshape_into_symbols_bad_shape(self):
|
def test_reshape_into_symbols_bad_shape(self):
|
||||||
vi = Variable("i", 1, 10).bind(4)
|
vi = Variable("i", 1, 10).bind(4)
|
||||||
with self.assertRaises(ValueError):
|
# TODO: this never actually worked, it relied on lazy
|
||||||
Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
|
#with self.assertRaises(ValueError):
|
||||||
|
# Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
|
Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,7 @@ class ContextVar:
|
||||||
def __lt__(self, x): return self.value < x
|
def __lt__(self, x): return self.value < x
|
||||||
|
|
||||||
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
||||||
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
GRAPH, GRAPHPATH = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
||||||
|
|
||||||
class Timing(contextlib.ContextDecorator):
|
class Timing(contextlib.ContextDecorator):
|
||||||
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
||||||
|
|
Loading…
Reference in New Issue