tests from new lazy branch (#2774)

* tests from new lazy branch

* fix lin 11

* that was needed

* doesn't fail

* mark

* meant that

* llvm passes
This commit is contained in:
George Hotz 2023-12-14 23:06:39 -08:00 committed by GitHub
parent a044125c39
commit c6eb618013
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 20 deletions

View File

@ -390,7 +390,7 @@ jobs:
cache-name: cache-gpuocelot-build
with:
path: ${{ github.workspace }}/gpuocelot/ocelot
key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild
key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild-2
- name: Clone/compile gpuocelot
if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton') && steps.cache-build.outputs.cache-hit != 'true'
run: |

View File

@ -416,7 +416,8 @@ def train_cifar():
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
cl = time.monotonic()
if not getenv("DIST"):
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
# 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
st = cl

View File

@ -13,7 +13,6 @@ import onnx
from tqdm import tqdm
from typing import Tuple, List, Optional, Dict
from extra.onnx import get_run_onnx
from tinygrad.graph import log_schedule_item
from tinygrad import Tensor, Device
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
from tinygrad.realize import run_schedule, lower_schedule_item
@ -111,10 +110,6 @@ if __name__ == "__main__":
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
if GRAPH:
for si in schedule_input: log_schedule_item(si)
for si in schedule: log_schedule_item(si)
GlobalCounters.reset()
run_schedule(schedule[:])

View File

@ -1,4 +1,4 @@
import unittest, time
import unittest, time, gc
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
@ -14,28 +14,32 @@ from examples.hlb_cifar10 import SpeedyResNet
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
from examples.stable_diffusion import UNetModel
global_mem_used = 0
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
tms = []
for _ in range(4):
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
GlobalCounters.reset()
GlobalCounters.mem_used = 0
Device[Device.DEFAULT].synchronize()
st = time.perf_counter_ns()
train(*early_gen)
Device[Device.DEFAULT].synchronize()
tms.append(time.perf_counter_ns() - st)
mem_used = GlobalCounters.mem_used - global_mem_used
# TODO: jit should expose this correctly with graph
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
if all_jitted:
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
class TestRealWorld(unittest.TestCase):
def setUp(self):
gc.collect()
global global_mem_used
global_mem_used = GlobalCounters.mem_used
self.old_type = Tensor.default_type
np.random.seed(2002)
@ -62,7 +66,7 @@ class TestRealWorld(unittest.TestCase):
@TinyJit
def test(t): return model(t, 0).realize()
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True)
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.25 if CI else 13.5, 181 if CI else 685, all_jitted=True)
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self):
@ -73,7 +77,24 @@ class TestRealWorld(unittest.TestCase):
derandomize_model(model)
@TinyJit
def test(t, v): return model(t, v).realize()
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 164 if CI else 468, all_jitted=True)
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 164 if CI else 468, all_jitted=True)
def test_train_mnist(self):
from examples.beautiful_mnist import Model
with Tensor.train():
model = Model()
optimizer = optim.Adam(get_parameters(model))
BS = 32
@TinyJit
def train(X):
out = model(X)
loss = out.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 127)
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")

File diff suppressed because one or more lines are too long

View File

@ -25,7 +25,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
print("op", i)
print("kernel", i+1)
print_tree(s.ast)
assert len(sched) == allowed
# test the (non loadops) ops linearize
@ -260,6 +260,22 @@ class TestSchedule(unittest.TestCase):
check_schedule(c, 1)
check_schedule(e, 1)
def test_shrink_fuse(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(8192, 16)
c = a * b
d = Tensor.empty(1, 16)
e = c[0] * d
check_schedule(e, 1)
def test_expand_nofuse(self):
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)
c = a * b
d = Tensor.empty(8192, 16)
e = c * d
check_schedule(e, 2)
# this is the failing case in openpilot...it's very simple like this
@unittest.skip("failing in old lazy")
def test_image_conv_fusion(self):
@ -304,12 +320,18 @@ class TestSchedule(unittest.TestCase):
check_schedule(x, 3)
def test_resnet_block(self):
from extra.models.resnet import BasicBlock
Tensor.training = False
bb = BasicBlock(64,64)
in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes)
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
bn2 = nn.BatchNorm2d(planes)
x = Tensor.empty(1, 64, 32, 32)
out = bb(x)
out = bn1(conv1(x)).relu()
out = bn2(conv2(out))
out = (out + x).relu()
check_schedule(out, 4)
def test_contiguous_while_contiguous(self):

View File

@ -115,8 +115,9 @@ class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10).bind(4)
with self.assertRaises(ValueError):
Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
# TODO: this never actually worked, it relied on lazy
#with self.assertRaises(ValueError):
# Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
with self.assertRaises(AssertionError):
Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node

View File

@ -81,7 +81,7 @@ class ContextVar:
def __lt__(self, x): return self.value < x
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
GRAPH, GRAPHPATH = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
class Timing(contextlib.ContextDecorator):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled