mirror of https://github.com/commaai/tinygrad.git
parent
ca77d6cd72
commit
17830e25da
|
@ -1,43 +0,0 @@
|
|||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.state import get_parameters
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.lazy import Device
|
||||
|
||||
class TestCifar(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
|
||||
def test_train_step(self):
|
||||
# TODO: with default device
|
||||
#old_default = Device.DEFAULT
|
||||
#Device.DEFAULT = "FAKE"
|
||||
#Device['fake'].codegen = Device[old_default].codegen
|
||||
|
||||
# TODO: with train
|
||||
old_training = Tensor.training
|
||||
Tensor.training = True
|
||||
|
||||
model = SpeedyResNet()
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||
|
||||
@TinyJit
|
||||
def train(X):
|
||||
out = model(X)
|
||||
loss = out.mean()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
for _ in range(3): train(Tensor.randn(32, 3, 32, 32))
|
||||
|
||||
print(f"used {GlobalCounters.mem_used/1e9} GB and {len(train.jit_cache)} kernels")
|
||||
assert GlobalCounters.mem_used/1e9 < 0.55, "CIFAR training used more than 0.55 GB"
|
||||
assert len(train.jit_cache) <= 236, "CIFAR training used more than 236 kernels"
|
||||
|
||||
# reset device
|
||||
Tensor.training = old_training
|
||||
#Device.DEFAULT = old_default
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -0,0 +1,97 @@
|
|||
import unittest, time
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.helpers import CI, dtypes
|
||||
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
from examples.llama import Transformer, args_7B
|
||||
from examples.stable_diffusion import UNetModel
|
||||
|
||||
def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed):
|
||||
tms = []
|
||||
for _ in range(4):
|
||||
GlobalCounters.reset()
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
st = time.perf_counter_ns()
|
||||
train(*gen())
|
||||
Device[Device.DEFAULT].synchronize()
|
||||
tms.append(time.perf_counter_ns() - st)
|
||||
|
||||
kernels_used = len(train.jit_cache) if hasattr(train, "jit_cache") else None
|
||||
print(f"{nm}: used {GlobalCounters.mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
||||
assert GlobalCounters.mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB"
|
||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
||||
|
||||
# for speed
|
||||
def derandomize(x):
|
||||
if isinstance(x, LazyOp):
|
||||
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
|
||||
x.src = [derandomize(s) for s in x.src]
|
||||
else:
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = derandomize(p.lazydata)
|
||||
p.realize()
|
||||
|
||||
class TestRealWorld(unittest.TestCase):
|
||||
@unittest.skipUnless(not CI, "too big for CI")
|
||||
def test_stable_diffusion(self):
|
||||
model = UNetModel()
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t, t2): return model(t, 801, t2).realize()
|
||||
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 14.04, 912)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
|
||||
def test_llama(self):
|
||||
old_type = Tensor.default_type
|
||||
Tensor.default_type = dtypes.float16
|
||||
|
||||
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
|
||||
model = Transformer(**(args_tiny if CI else args_7B))
|
||||
derandomize_model(model)
|
||||
@TinyJit
|
||||
def test(t): return model(t, 0).realize()
|
||||
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 126 if CI else 486)
|
||||
|
||||
Tensor.default_type = old_type
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
|
||||
def test_train_cifar(self):
|
||||
# TODO: with default device
|
||||
#old_default = Device.DEFAULT
|
||||
#Device.DEFAULT = "FAKE"
|
||||
#Device['fake'].codegen = Device[old_default].codegen
|
||||
|
||||
# TODO: with train
|
||||
old_training = Tensor.training
|
||||
Tensor.training = True
|
||||
|
||||
model = SpeedyResNet()
|
||||
optimizer = optim.SGD(get_parameters(model), lr=0.01, momentum=0.8, nesterov=True, weight_decay=0.15)
|
||||
|
||||
BS = 32 if CI else 512
|
||||
|
||||
@TinyJit
|
||||
def train(X):
|
||||
out = model(X)
|
||||
loss = out.mean()
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_cifar", lambda: (Tensor.randn(BS, 3, 32, 32),), train, (0.55/32)*BS, 236)
|
||||
|
||||
# reset device
|
||||
Tensor.training = old_training
|
||||
#Device.DEFAULT = old_default
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -125,6 +125,7 @@ class TestBigSpeed(unittest.TestCase):
|
|||
helper_test_generic_square('gemm', 4096, f, f)
|
||||
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
|
||||
def test_large_conv_3x3(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
|
||||
def test_large_conv_5x5(self): helper_test_conv(bs=16, in_chans=128, out_chans=128, kernel_size=5, img_size_y=130, img_size_x=130)
|
||||
|
||||
@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests")
|
||||
class TestSpeed(unittest.TestCase):
|
||||
|
|
|
@ -9,6 +9,7 @@ from math import prod # noqa: F401 # pylint:disable=unused-import
|
|||
ShapeType = Tuple[int, ...]
|
||||
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
||||
OSX = platform.system() == "Darwin"
|
||||
CI = os.getenv("CI", "") != ""
|
||||
|
||||
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
||||
def argfix(*x):
|
||||
|
|
|
@ -362,7 +362,7 @@ def _realize_empty(buffer: LazyBuffer) -> None:
|
|||
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=buffer.dtype.np), **buffer._device_extra_args()) # type: ignore
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) # type: ignore
|
||||
|
||||
def _realize_const(buffer: LazyBuffer) -> None:
|
||||
if hasattr(Device[buffer.device].codegen, 'supports_constant_folding'):
|
||||
|
|
|
@ -31,7 +31,7 @@ def einsum_mulacc(einsum, get_strides, expand):
|
|||
return mulacc
|
||||
|
||||
numpy_fxn_for_op: Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np), UnaryOps.SIN: np.sin,
|
||||
UnaryOps.NOOP: lambda x: np.require(x, requirements='C'), UnaryOps.EXP2: np.exp2, UnaryOps.LOG2: np.log2, UnaryOps.CAST: lambda x,y: x.astype(y.np, copy=False), UnaryOps.SIN: np.sin,
|
||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.promote_types(x.dtype,y.dtype)), BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
|
||||
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)), BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
|
||||
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)), UnaryOps.SQRT: np.sqrt,
|
||||
|
|
Loading…
Reference in New Issue