tests from new lazy branch (#2774)

* tests from new lazy branch

* fix lin 11

* that was needed

* doesn't fail

* mark

* meant that

* llvm passes
This commit is contained in:
George Hotz 2023-12-14 23:06:39 -08:00 committed by GitHub
parent a044125c39
commit c6eb618013
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 64 additions and 20 deletions

View File

@ -390,7 +390,7 @@ jobs:
cache-name: cache-gpuocelot-build cache-name: cache-gpuocelot-build
with: with:
path: ${{ github.workspace }}/gpuocelot/ocelot path: ${{ github.workspace }}/gpuocelot/ocelot
key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild key: ubuntu22.04-gpuocelot-18401f4245b27ca4b3af433196583cc81ef84480-rebuild-2
- name: Clone/compile gpuocelot - name: Clone/compile gpuocelot
if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton') && steps.cache-build.outputs.cache-hit != 'true' if: (matrix.backend == 'cuda' || matrix.backend == 'ptx' || matrix.backend == 'triton') && steps.cache-build.outputs.cache-hit != 'true'
run: | run: |

View File

@ -416,7 +416,8 @@ def train_cifar():
model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']])) model_ema.update(model, Tensor([projected_ema_decay_val*(i/STEPS)**hyp['ema']['decay_pow']]))
cl = time.monotonic() cl = time.monotonic()
if not getenv("DIST"): if not getenv("DIST"):
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") # 53 221.74 ms run, 2.22 ms python, 219.52 ms CL, 803.39 loss, 0.000807 LR, 4.66 GB used, 3042.49 GFLOPS, 674.65 GOPS
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS, {GlobalCounters.global_ops*1e-9:9.2f} GOPS")
else: else:
print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS")
st = cl st = cl

View File

@ -13,7 +13,6 @@ import onnx
from tqdm import tqdm from tqdm import tqdm
from typing import Tuple, List, Optional, Dict from typing import Tuple, List, Optional, Dict
from extra.onnx import get_run_onnx from extra.onnx import get_run_onnx
from tinygrad.graph import log_schedule_item
from tinygrad import Tensor, Device from tinygrad import Tensor, Device
from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH, DEBUG
from tinygrad.realize import run_schedule, lower_schedule_item from tinygrad.realize import run_schedule, lower_schedule_item
@ -111,10 +110,6 @@ if __name__ == "__main__":
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule) image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****") print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
if GRAPH:
for si in schedule_input: log_schedule_item(si)
for si in schedule: log_schedule_item(si)
GlobalCounters.reset() GlobalCounters.reset()
run_schedule(schedule[:]) run_schedule(schedule[:])

View File

@ -1,4 +1,4 @@
import unittest, time import unittest, time, gc
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.nn import optim from tinygrad.nn import optim
@ -14,28 +14,32 @@ from examples.hlb_cifar10 import SpeedyResNet
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
from examples.stable_diffusion import UNetModel from examples.stable_diffusion import UNetModel
global_mem_used = 0
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False): def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False):
tms = [] tms = []
for _ in range(4): for _ in range(4):
early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()] early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()]
GlobalCounters.reset() GlobalCounters.reset()
GlobalCounters.mem_used = 0
Device[Device.DEFAULT].synchronize() Device[Device.DEFAULT].synchronize()
st = time.perf_counter_ns() st = time.perf_counter_ns()
train(*early_gen) train(*early_gen)
Device[Device.DEFAULT].synchronize() Device[Device.DEFAULT].synchronize()
tms.append(time.perf_counter_ns() - st) tms.append(time.perf_counter_ns() - st)
mem_used = GlobalCounters.mem_used - global_mem_used
# TODO: jit should expose this correctly with graph # TODO: jit should expose this correctly with graph
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms") print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB" assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels" assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
if all_jitted: if all_jitted:
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501 assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used == 1 and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
class TestRealWorld(unittest.TestCase): class TestRealWorld(unittest.TestCase):
def setUp(self): def setUp(self):
gc.collect()
global global_mem_used
global_mem_used = GlobalCounters.mem_used
self.old_type = Tensor.default_type self.old_type = Tensor.default_type
np.random.seed(2002) np.random.seed(2002)
@ -62,7 +66,7 @@ class TestRealWorld(unittest.TestCase):
@TinyJit @TinyJit
def test(t): return model(t, 0).realize() def test(t): return model(t, 0).realize()
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector # TODO: test first token vs rest properly, also memory test is broken with CacheCollector
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True) helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.25 if CI else 13.5, 181 if CI else 685, all_jitted=True)
@unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16") @unittest.skipIf(Device.DEFAULT in ["LLVM", "GPU"] and CI, "too long on CI LLVM, GPU requires cl_khr_fp16")
def test_gpt2(self): def test_gpt2(self):
@ -73,7 +77,24 @@ class TestRealWorld(unittest.TestCase):
derandomize_model(model) derandomize_model(model)
@TinyJit @TinyJit
def test(t, v): return model(t, v).realize() def test(t, v): return model(t, v).realize()
helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.21 if CI else 0.9, 164 if CI else 468, all_jitted=True) helper_test("test_gpt2", lambda: (Tensor([[1,]]),Variable("pos", 1, 100).bind(1)), test, 0.23 if CI else 0.9, 164 if CI else 468, all_jitted=True)
def test_train_mnist(self):
from examples.beautiful_mnist import Model
with Tensor.train():
model = Model()
optimizer = optim.Adam(get_parameters(model))
BS = 32
@TinyJit
def train(X):
out = model(X)
loss = out.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 127)
@unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault") @unittest.skipIf(Device.DEFAULT == "LLVM", "LLVM segmentation fault")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG") @unittest.skipIf(Device.DEFAULT in ["LLVM", "CLANG"] and CI, "too long on CI LLVM and CLANG")

File diff suppressed because one or more lines are too long

View File

@ -25,7 +25,7 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3: if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched): for i, s in enumerate(sched):
print("op", i) print("kernel", i+1)
print_tree(s.ast) print_tree(s.ast)
assert len(sched) == allowed assert len(sched) == allowed
# test the (non loadops) ops linearize # test the (non loadops) ops linearize
@ -260,6 +260,22 @@ class TestSchedule(unittest.TestCase):
check_schedule(c, 1) check_schedule(c, 1)
check_schedule(e, 1) check_schedule(e, 1)
def test_shrink_fuse(self):
a = Tensor.empty(8192, 16)
b = Tensor.empty(8192, 16)
c = a * b
d = Tensor.empty(1, 16)
e = c[0] * d
check_schedule(e, 1)
def test_expand_nofuse(self):
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)
c = a * b
d = Tensor.empty(8192, 16)
e = c * d
check_schedule(e, 2)
# this is the failing case in openpilot...it's very simple like this # this is the failing case in openpilot...it's very simple like this
@unittest.skip("failing in old lazy") @unittest.skip("failing in old lazy")
def test_image_conv_fusion(self): def test_image_conv_fusion(self):
@ -304,12 +320,18 @@ class TestSchedule(unittest.TestCase):
check_schedule(x, 3) check_schedule(x, 3)
def test_resnet_block(self): def test_resnet_block(self):
from extra.models.resnet import BasicBlock
Tensor.training = False Tensor.training = False
bb = BasicBlock(64,64)
in_planes, planes = 64, 64
conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
bn1 = nn.BatchNorm2d(planes)
conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, stride=1, bias=False)
bn2 = nn.BatchNorm2d(planes)
x = Tensor.empty(1, 64, 32, 32) x = Tensor.empty(1, 64, 32, 32)
out = bb(x) out = bn1(conv1(x)).relu()
out = bn2(conv2(out))
out = (out + x).relu()
check_schedule(out, 4) check_schedule(out, 4)
def test_contiguous_while_contiguous(self): def test_contiguous_while_contiguous(self):

View File

@ -115,8 +115,9 @@ class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_bad_shape(self): def test_reshape_into_symbols_bad_shape(self):
vi = Variable("i", 1, 10).bind(4) vi = Variable("i", 1, 10).bind(4)
with self.assertRaises(ValueError): # TODO: this never actually worked, it relied on lazy
Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape #with self.assertRaises(ValueError):
# Tensor.rand(4, 6).reshape(vi, 6).reshape(1, 77) # reshape to a different size new shape through symbolic shape
with self.assertRaises(AssertionError): with self.assertRaises(AssertionError):
Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node Tensor.rand(3, 4).reshape(3, (vi+1)) # reshape into non-Variable Node

View File

@ -81,7 +81,7 @@ class ContextVar:
def __lt__(self, x): return self.value < x def __lt__(self, x): return self.value < x
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0) DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net") GRAPH, GRAPHPATH = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
class Timing(contextlib.ContextDecorator): class Timing(contextlib.ContextDecorator):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled