mirror of https://github.com/commaai/tinygrad.git
default threefry (#6116)
This commit is contained in:
parent
992cde05d7
commit
c100f3d406
|
@ -238,9 +238,6 @@ jobs:
|
|||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test MLPerf datasets
|
||||
run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Test THREEFRY
|
||||
run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py --durations=20
|
||||
- if: ${{ matrix.task == 'onnx' }}
|
||||
name: Run handcode_opt
|
||||
run: PYTHONPATH=. MODEL=resnet GPU=1 DEBUG=1 BS=4 HALF=0 python3 examples/handcode_opt.py
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 1.3 MiB After Width: | Height: | Size: 1.5 MiB |
|
@ -218,7 +218,7 @@ class StableDiffusion:
|
|||
if __name__ == "__main__":
|
||||
default_prompt = "a horse sized cat eating a bagel"
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
|
||||
parser.add_argument('--steps', type=int, default=6, help="Number of steps in diffusion")
|
||||
parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to render")
|
||||
parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename")
|
||||
parser.add_argument('--noshow', action='store_true', help="Don't show the image")
|
||||
|
@ -287,8 +287,8 @@ if __name__ == "__main__":
|
|||
if not args.noshow: im.show()
|
||||
|
||||
# validation!
|
||||
if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5:
|
||||
if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5:
|
||||
ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png")))
|
||||
distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item()
|
||||
assert distance < 3e-4, colored(f"validation failed with {distance=}", "red")
|
||||
assert distance < 45e-5, colored(f"validation failed with {distance=}", "red")
|
||||
print(colored(f"output validated with {distance=}", "green"))
|
||||
|
|
Binary file not shown.
Before Width: | Height: | Size: 479 KiB After Width: | Height: | Size: 454 KiB |
Binary file not shown.
|
@ -142,8 +142,8 @@ class TestIndexing(unittest.TestCase):
|
|||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize()
|
||||
GlobalCounters.reset()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
x = X_train[samples].numpy()
|
||||
y = Y_train[samples].numpy()
|
||||
assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}"
|
||||
|
|
|
@ -22,7 +22,7 @@ class TestGC(unittest.TestCase):
|
|||
Tensor.manual_seed(0)
|
||||
a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True)
|
||||
b = Tensor.rand(4, 4, requires_grad=True)
|
||||
assert (tensors_allocated() == 3)
|
||||
assert (tensors_allocated() == 4)
|
||||
(a*b).mean().backward()
|
||||
assert (tensors_allocated() == 5)
|
||||
del b
|
||||
|
|
|
@ -555,10 +555,15 @@ class TestMultiTensor(unittest.TestCase):
|
|||
# don't allow assigns that change axes
|
||||
t_none.assign(t_zero)
|
||||
|
||||
def test_rand_on_multiple_devices(self):
|
||||
def test_rand_with_multiple_devices(self):
|
||||
with self.assertRaises(ValueError):
|
||||
Tensor.rand(256, device=devices_2)
|
||||
|
||||
def test_rand_on_multiple_devices(self):
|
||||
d0_rand = Tensor.rand(256, device=d0).realize()
|
||||
d1_rand = Tensor.rand(256, device=d1).realize()
|
||||
assert not np.allclose(d0_rand.numpy(), d1_rand.numpy())
|
||||
|
||||
def test_rand_like_on_shard(self):
|
||||
t = Tensor.empty((16, 16)).shard(devices_2)
|
||||
t2 = Tensor.rand_like(t)
|
||||
|
@ -591,11 +596,11 @@ class TestMultiTensor(unittest.TestCase):
|
|||
|
||||
def test_dropout_on_shard_axis(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).shard(devices_2, axis=0)
|
||||
X = Tensor.ones(512).shard(devices_2, axis=0)
|
||||
output = X.dropout(0.5).numpy()
|
||||
unique, counts = np.unique(output, return_counts=True)
|
||||
assert set(unique) == {0, 2}, unique
|
||||
assert 100 < counts[0] < 156, counts[0]
|
||||
assert 228 < counts[0] < 284, counts[0]
|
||||
|
||||
def test_dropout_on_uneven_shard_axis(self):
|
||||
with Tensor.train():
|
||||
|
|
|
@ -452,6 +452,7 @@ class TestNN(unittest.TestCase):
|
|||
|
||||
def test_embedding_one_kernel(self):
|
||||
layer = Embedding(20, 30)
|
||||
layer.weight = Tensor.zeros_like(layer.weight).contiguous()
|
||||
a = Tensor([[1, 5, 9, 11],
|
||||
[12, 19, 8, 1]])
|
||||
result = layer(a)
|
||||
|
|
|
@ -4,7 +4,7 @@ from functools import partial
|
|||
import numpy as np
|
||||
import torch
|
||||
from tinygrad import nn, dtypes, Tensor, Device, TinyJit
|
||||
from tinygrad.helpers import THREEFRY, getenv, CI
|
||||
from tinygrad.helpers import getenv, CI
|
||||
from test.helpers import is_dtype_supported
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
|
||||
|
@ -75,7 +75,7 @@ class TestRandomness(unittest.TestCase):
|
|||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
@unittest.skipIf(not THREEFRY.value, "not using threefry")
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry")
|
||||
def test_threefly_against_reference(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
|
@ -96,28 +96,28 @@ class TestRandomness(unittest.TestCase):
|
|||
|
||||
np.testing.assert_allclose(jr, r)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "reference is on GPU")
|
||||
@unittest.skipIf(not THREEFRY.value, "not using threefry")
|
||||
def test_threefly_against_reference_full(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
# reference generated using
|
||||
"""
|
||||
key0 = 1337
|
||||
key1 = 0
|
||||
key1 = int.from_bytes(hashlib.sha256(int(0).to_bytes(4)).digest(), "big") & 0xffffffff
|
||||
values = jax.extend.random.threefry_2x32((np.uint32(key1), np.uint32(key0)), np.arange(20, dtype=np.uint32))
|
||||
values = (values >> (32 - 23)) | np.array(1, dtype=np.float32).view(np.uint32)
|
||||
values = values.view(np.float32) - 1
|
||||
print(f"[{', '.join(f'{v}' for v in values)}]")
|
||||
"""
|
||||
jr = np.array([0.7882130146026611, 0.0680311918258667, 0.6758031845092773, 0.2525523900985718, 0.5712389945983887,
|
||||
0.8758237361907959, 0.13559412956237793, 0.9069793224334717, 0.8781528472900391, 0.7737162113189697,
|
||||
0.050452232360839844, 0.1645597219467163, 0.06776463985443115, 0.09560704231262207, 0.2754603624343872,
|
||||
0.10108339786529541, 0.3488548994064331, 0.7904064655303955, 0.2519160509109497, 0.7925788164138794], dtype=np.float32)
|
||||
jr = np.array([0.9073467254638672, 0.8235964775085449, 0.6872662305831909, 0.9920015335083008, 0.4941047430038452,
|
||||
0.3108327388763428, 0.09639489650726318, 0.004686474800109863, 0.8435229063034058, 0.824237585067749,
|
||||
0.5873836278915405, 0.4232727289199829, 0.2530076503753662, 0.40300023555755615, 0.03966474533081055,
|
||||
0.27904558181762695, 0.9150195121765137, 0.48057758808135986, 0.23821306228637695, 0.7676635980606079], dtype=np.float32)
|
||||
|
||||
r = Tensor.rand(20).numpy()
|
||||
|
||||
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI")
|
||||
def test_threefly_tensors_cnt(self):
|
||||
Tensor.manual_seed(1337)
|
||||
|
||||
|
@ -141,10 +141,9 @@ class TestRandomness(unittest.TestCase):
|
|||
N = 128
|
||||
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
|
||||
assert x.dtype == dtypes.bfloat16
|
||||
if THREEFRY.value:
|
||||
nx = x.numpy()
|
||||
assert nx[nx == 1].size == 0
|
||||
assert nx[nx == 0].size > 0
|
||||
nx = x.numpy()
|
||||
assert nx[nx == 1].size == 0
|
||||
assert nx[nx == 0].size > 0
|
||||
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
|
||||
|
||||
def test_rand_like(self):
|
||||
|
|
|
@ -47,13 +47,16 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
|
|||
l.linearize()
|
||||
return sched
|
||||
|
||||
def _realize_weights(m):
|
||||
for p in nn.state.get_parameters(m): p.realize()
|
||||
|
||||
def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
old_default_float, dtypes.default_float = dtypes.default_float, dtype
|
||||
dtypes.default_float = dtype
|
||||
Tensor.manual_seed(0)
|
||||
BS, CIN = 2, 3
|
||||
img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True)
|
||||
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True)
|
||||
img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True).realize()
|
||||
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
|
||||
ret = Tensor.conv2d(img, w).relu().mean().backward()
|
||||
dtypes.default_float = old_default_float
|
||||
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
|
||||
|
@ -256,12 +259,13 @@ class TestSchedule(unittest.TestCase):
|
|||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]:
|
||||
for optim, cnt in [(nn.optim.Adam, 17), (nn.optim.SGD, 15)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
_realize_weights([c1, bn])
|
||||
opt = optim(nn.state.get_parameters([c1, bn]))
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
|
@ -919,57 +923,63 @@ class TestSchedule(unittest.TestCase):
|
|||
with Tensor.train():
|
||||
x = Tensor.empty(4, 64, 768)
|
||||
layer = nn.Linear(768, 768*4)
|
||||
_realize_weights(layer)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4)
|
||||
layer(x).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
check_schedule(opt.schedule_step(), 9)
|
||||
|
||||
def test_adam_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
_realize_weights(c1)
|
||||
opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 11)
|
||||
check_schedule(opt.schedule_step(), 9)
|
||||
|
||||
def test_adam_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
_realize_weights([c1, c2])
|
||||
opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 13)
|
||||
check_schedule(opt.schedule_step(), 12)
|
||||
|
||||
def test_sgd_conv_fuse(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
_realize_weights(c1)
|
||||
opt = nn.optim.SGD(nn.state.get_parameters(c1))
|
||||
opt.zero_grad()
|
||||
c1(img).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 7)
|
||||
check_schedule(opt.schedule_step(), 5)
|
||||
|
||||
def test_sgd_2convs_fuse(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
_realize_weights([c1, c2])
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]))
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 7)
|
||||
check_schedule(opt.schedule_step(), 6)
|
||||
|
||||
def test_fold_2convs_sgd_nesterov_momentum_wd(self):
|
||||
with Tensor.train():
|
||||
img = Tensor.empty(2,3,4,4)
|
||||
c1 = nn.Conv2d(3,16,3,bias=False)
|
||||
c2 = nn.Conv2d(16,32,3,bias=False)
|
||||
_realize_weights([c1, c2])
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1)
|
||||
opt.zero_grad()
|
||||
c2(c1(img).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 9)
|
||||
check_schedule(opt.schedule_step(), 8)
|
||||
|
||||
def test_sgd_4convs_fuse(self):
|
||||
with Tensor.train():
|
||||
|
@ -978,10 +988,11 @@ class TestSchedule(unittest.TestCase):
|
|||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
_realize_weights([c1, c2, c3, c4])
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
check_schedule(opt.schedule_step(), 22)
|
||||
check_schedule(opt.schedule_step(), 18)
|
||||
|
||||
def test_sgd_4convs_fuse_conv_bw(self):
|
||||
with Tensor.train():
|
||||
|
@ -990,10 +1001,11 @@ class TestSchedule(unittest.TestCase):
|
|||
c2 = nn.Conv2d(4,8,3,bias=False)
|
||||
c3 = nn.Conv2d(8,16,3,bias=False)
|
||||
c4 = nn.Conv2d(16,32,3,bias=False)
|
||||
_realize_weights([c1, c2, c3, c4])
|
||||
opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4]))
|
||||
opt.zero_grad()
|
||||
c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward()
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 19)
|
||||
with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_prefer_half_buffer(self):
|
||||
|
@ -1184,7 +1196,7 @@ class TestSchedule(unittest.TestCase):
|
|||
b = Tensor.rand(3, 4, 5).realize()
|
||||
out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
run_schedule(check_schedule(out, 1))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum())
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_pad_reduce_safe(self):
|
||||
|
@ -1202,7 +1214,7 @@ class TestSchedule(unittest.TestCase):
|
|||
a = Tensor.rand(3, 4, 5).realize()
|
||||
out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), rtol=1e-6)
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_pad_reduce_unsafe(self):
|
||||
|
@ -1213,7 +1225,7 @@ class TestSchedule(unittest.TestCase):
|
|||
# run_schedule(check_schedule(out, 1))
|
||||
run_schedule(check_schedule(out, 4))
|
||||
np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \
|
||||
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-6)
|
||||
b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-6)
|
||||
|
||||
def test_shrink_pad_safe(self):
|
||||
a = Tensor.ones((3, )).contiguous().realize()
|
||||
|
@ -1301,18 +1313,18 @@ class TestSchedule(unittest.TestCase):
|
|||
out = x.argmax(1)
|
||||
run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape
|
||||
|
||||
def test_conv2d(self): _test_conv2d(8)
|
||||
def test_conv2d_fused(self): _test_conv2d(7, FUSE_CONV_BW=1)
|
||||
def test_conv2d_fused_ast_rewrite(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1)
|
||||
def test_conv2d(self): _test_conv2d(7)
|
||||
def test_conv2d_fused(self): _test_conv2d(6, FUSE_CONV_BW=1)
|
||||
def test_conv2d_fused_ast_rewrite(self): _test_conv2d(6, FUSE_CONV_BW=1, AST_REWRITE=1)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_conv2d_half(self): _test_conv2d(8, dtype=dtypes.half)
|
||||
def test_conv2d_half(self): _test_conv2d(7, dtype=dtypes.half)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_fused_half(self): _test_conv2d(7, dtype=dtypes.half)
|
||||
def test_conv2d_fused_half(self): _test_conv2d(5, dtype=dtypes.half)
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1, dtype=dtypes.half)
|
||||
def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(6, FUSE_CONV_BW=1, AST_REWRITE=1, dtype=dtypes.half)
|
||||
|
||||
def test_buf_cnt_at_limit(self): _test_buf_cnt(5, buf_max=5, allowed=1)
|
||||
@unittest.expectedFailure
|
||||
|
@ -1395,18 +1407,18 @@ class TestIndexing(unittest.TestCase):
|
|||
|
||||
def test_arange_transposed(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randint(4, 1)
|
||||
x = Tensor.randint(4, 1).realize()
|
||||
a = (Tensor.arange(4,)*x).T
|
||||
self.check_schedule(a, 2)
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T)
|
||||
|
||||
def test_arange_transposed_descendants(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randint(4, 1)
|
||||
x = Tensor.randint(4, 1).realize()
|
||||
a = (Tensor.arange(4,)*x).T
|
||||
b = Tensor.randint(4, 4).realize()
|
||||
out = a+b
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 1)
|
||||
np.testing.assert_equal(out.numpy(), (np.arange(4)*x.numpy()).T+b.numpy())
|
||||
|
||||
def test_arange_index(self):
|
||||
|
@ -1415,7 +1427,7 @@ class TestIndexing(unittest.TestCase):
|
|||
a = Tensor.arange(10)
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum())
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
@ -1423,7 +1435,7 @@ class TestIndexing(unittest.TestCase):
|
|||
a = Tensor.arange(10).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum())
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_child(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
@ -1431,7 +1443,7 @@ class TestIndexing(unittest.TestCase):
|
|||
a = Tensor.arange(10)+1
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous_child(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
@ -1439,7 +1451,7 @@ class TestIndexing(unittest.TestCase):
|
|||
a = (Tensor.arange(10)+1).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum())
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_childless_base(self):
|
||||
a = Tensor.arange(4)
|
||||
|
@ -1453,7 +1465,7 @@ class TestIndexing(unittest.TestCase):
|
|||
|
||||
def test_arange_group_childless_base(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randint(4)
|
||||
x = Tensor.randint(4).realize()
|
||||
a = Tensor.arange(4)+x
|
||||
self.check_schedule(a, 1)
|
||||
np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy())
|
||||
|
@ -1527,8 +1539,8 @@ class TestIndexing(unittest.TestCase):
|
|||
from tinygrad.nn.datasets import mnist
|
||||
import torch
|
||||
_, Y_train, _, _ = mnist()
|
||||
samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1]))
|
||||
yt = Tensor.randn(BS, 10)
|
||||
samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize()
|
||||
yt = Tensor.randn(BS, 10).realize()
|
||||
with Context(SPLIT_REDUCEOP=0):
|
||||
loss = yt.sparse_categorical_crossentropy(Y_train[samples])
|
||||
self.check_schedule(loss, 6)
|
||||
|
|
|
@ -32,7 +32,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
|||
assert all(r.size > 0 for r in rawbufs)
|
||||
|
||||
def test_bufs_from_lin_alt(self):
|
||||
a = Tensor.randn(4, 4)
|
||||
a = Tensor.randn(4, 4).realize()
|
||||
b = a+a[0]
|
||||
si = [si for si in b.schedule() if si.ast.op is UOps.SINK][0]
|
||||
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
|
||||
|
|
|
@ -105,7 +105,7 @@ class ContextVar:
|
|||
def __lt__(self, x): return self.value < x
|
||||
|
||||
DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1)
|
||||
WINO, THREEFRY, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||
WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1)
|
||||
GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1)
|
||||
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
|
||||
USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1)
|
||||
|
|
|
@ -8,7 +8,7 @@ import numpy as np
|
|||
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, truncate
|
||||
|
@ -438,26 +438,20 @@ class Tensor:
|
|||
if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}")
|
||||
if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}")
|
||||
if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}")
|
||||
device, had_counter = Device.canonicalize(device), False
|
||||
_device = device = Device.canonicalize(device)
|
||||
|
||||
# when using MOCKGPU and NV generate rand on CLANG
|
||||
if THREEFRY and getenv("MOCKGPU") and device.startswith("NV"): _device, device = device, "CLANG"
|
||||
else: _device = None
|
||||
if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG"
|
||||
|
||||
# generate per device seeds and rng counter if we haven't seen this device yet
|
||||
if device not in Tensor._device_seeds:
|
||||
Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(device.encode()).digest(), "big") & 0xffffffff
|
||||
Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff
|
||||
Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False)
|
||||
had_counter = False
|
||||
else: had_counter = True
|
||||
|
||||
if not THREEFRY:
|
||||
# for bfloat16, numpy rand passes buffer in float
|
||||
if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16:
|
||||
return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16)
|
||||
return Tensor._metaop(MetaOps.CUSTOM, shape, arg=custom_random, device=device, dtype=dtype, **kwargs)
|
||||
|
||||
# if shape has 0, return zero tensor
|
||||
if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs)
|
||||
if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs)
|
||||
|
||||
# increment rng counter for devices
|
||||
if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num)
|
||||
|
@ -3462,14 +3456,6 @@ if IMAGE:
|
|||
setattr(Tensor, "conv2d", Tensor.image_conv2d)
|
||||
setattr(Tensor, "dot", Tensor.image_dot)
|
||||
|
||||
# TODO: eventually remove this
|
||||
def custom_random(out:Buffer):
|
||||
Tensor._seed += 1
|
||||
rng = np.random.default_rng(Tensor._seed)
|
||||
if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False)
|
||||
else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False)
|
||||
out.copyin(rng_np_buffer.data)
|
||||
|
||||
def _metadata_wrapper(fn):
|
||||
def _wrapper(*args, **kwargs):
|
||||
if _METADATA.get() is not None: return fn(*args, **kwargs)
|
||||
|
|
|
@ -42,8 +42,8 @@ class TestViz(unittest.TestCase):
|
|||
|
||||
def test_ctx_groups(self):
|
||||
contexts.clear()
|
||||
schedule1 = Tensor.randn(4, 1).contiguous().schedule()
|
||||
schedule2 = Tensor.randn(4, 4).contiguous().schedule()
|
||||
schedule1 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule()
|
||||
list(lower_schedule(schedule1))
|
||||
list(lower_schedule(schedule2))
|
||||
ret = load_kernels(contexts)
|
||||
|
@ -118,8 +118,8 @@ class TestViz(unittest.TestCase):
|
|||
|
||||
def test_dedup_ast(self):
|
||||
contexts.clear()
|
||||
a = Tensor.randn(4, 4)+2
|
||||
b = Tensor.randn(4, 4)+2
|
||||
a = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
b = Tensor.empty(4, 4).contiguous().realize()+2
|
||||
Tensor.schedule(a, b)
|
||||
kernels = load_kernels(contexts)
|
||||
self.assertEqual(len(kernels), 1)
|
||||
|
|
Loading…
Reference in New Issue