2024-03-19 08:01:03 +08:00
|
|
|
import unittest
|
2024-07-12 07:41:51 +08:00
|
|
|
import numpy as np
|
|
|
|
from tinygrad import Tensor, GlobalCounters, dtypes
|
2024-07-12 10:00:57 +08:00
|
|
|
from tinygrad.helpers import Context, getenv
|
2024-07-12 07:41:51 +08:00
|
|
|
from tinygrad.engine.realize import run_schedule
|
2024-03-19 08:01:03 +08:00
|
|
|
|
|
|
|
class TestArange(unittest.TestCase):
|
|
|
|
def _get_flops(self, N):
|
|
|
|
GlobalCounters.reset()
|
2024-05-30 11:12:35 +08:00
|
|
|
with Context(NOOPT=1):
|
|
|
|
Tensor.arange(N).realize()
|
2024-03-19 08:01:03 +08:00
|
|
|
return GlobalCounters.global_ops
|
|
|
|
|
|
|
|
def test_complexity(self):
|
2024-07-12 07:41:51 +08:00
|
|
|
# add 1 to avoid divide by 0. arange is 0 flops now!
|
|
|
|
f1 = self._get_flops(256) + 1
|
|
|
|
f2 = self._get_flops(2560) + 1
|
2024-03-19 08:01:03 +08:00
|
|
|
print(f"{f1=}, {f2=}")
|
2024-05-03 10:34:30 +08:00
|
|
|
assert f2 / f1 < 15, f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
|
|
|
|
2024-07-12 07:41:51 +08:00
|
|
|
class TestIndexing(unittest.TestCase):
|
|
|
|
def test_arange_2_reduce(self):
|
|
|
|
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
|
|
|
needle[1337] = 1
|
|
|
|
needle.realize()
|
2024-08-03 18:13:39 +08:00
|
|
|
with Context(NOOPT=1, FUSE_ARANGE=1):
|
2024-07-12 07:41:51 +08:00
|
|
|
GlobalCounters.reset()
|
|
|
|
# TODO: it should work without these reshapes
|
|
|
|
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
|
|
|
|
sched = out.schedule()
|
|
|
|
assert len(sched) == 1
|
|
|
|
run_schedule(sched)
|
|
|
|
assert out.item() == 1337, f"expected 1337, got {out.item()}"
|
|
|
|
|
2024-07-12 10:00:57 +08:00
|
|
|
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
2024-07-12 07:41:51 +08:00
|
|
|
def test_manual_index(self):
|
|
|
|
dataset = Tensor.rand(16384, 256).realize()
|
|
|
|
idxs = Tensor([0,3,5,6]).realize()
|
|
|
|
real_index = dataset.numpy()[idxs.numpy()]
|
|
|
|
print("*** indexing ***")
|
2024-08-03 18:13:39 +08:00
|
|
|
with Context(NOOPT=1, FUSE_ARANGE=1):
|
2024-07-12 07:41:51 +08:00
|
|
|
GlobalCounters.reset()
|
|
|
|
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumsum(axis=-1, _first_zero=True).reshape(4, 256, 16384, 1)
|
|
|
|
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
|
2024-07-12 10:00:57 +08:00
|
|
|
reshape_dataset = dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)
|
|
|
|
full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
|
|
|
|
X = full.sum(axis=(2,3))
|
2024-07-12 07:41:51 +08:00
|
|
|
sched = X.schedule()
|
|
|
|
assert len(sched) == 1
|
|
|
|
run_schedule(sched)
|
2024-07-12 10:00:57 +08:00
|
|
|
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
|
2024-07-12 07:41:51 +08:00
|
|
|
np.testing.assert_allclose(real_index, X.numpy())
|
|
|
|
|
|
|
|
def test_index(self):
|
|
|
|
dataset = Tensor.rand(16384, 256).realize()
|
|
|
|
idxs = Tensor([0,3,5,6]).realize()
|
|
|
|
real_index = dataset.numpy()[idxs.numpy()]
|
|
|
|
print("*** indexing ***")
|
|
|
|
with Context(NOOPT=1):
|
|
|
|
GlobalCounters.reset()
|
|
|
|
X = dataset[idxs]
|
|
|
|
assert X.shape == (4,256)
|
|
|
|
sched = X.schedule()
|
2024-07-12 10:00:57 +08:00
|
|
|
# TODO: enable these asserts when the scheduler can handle this
|
2024-07-12 15:23:16 +08:00
|
|
|
#assert len(sched) == 1, f"{len(sched)} != 1"
|
2024-07-12 07:41:51 +08:00
|
|
|
run_schedule(sched)
|
2024-07-12 10:00:57 +08:00
|
|
|
#assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
|
|
|
|
np.testing.assert_allclose(real_index, X.numpy())
|
|
|
|
|
|
|
|
def test_index_fused(self):
|
|
|
|
dataset = Tensor.rand(16384, 256).realize()
|
|
|
|
idxs = Tensor([0,3,5,6]).realize()
|
|
|
|
real_index = dataset.numpy()[idxs.numpy()]
|
|
|
|
print("*** indexing ***")
|
2024-08-03 18:13:39 +08:00
|
|
|
with Context(NOOPT=1, FUSE_ARANGE=1):
|
2024-07-12 10:00:57 +08:00
|
|
|
GlobalCounters.reset()
|
|
|
|
X = dataset[idxs]
|
|
|
|
assert X.shape == (4,256)
|
|
|
|
sched = X.schedule()
|
2024-08-03 18:13:39 +08:00
|
|
|
assert len(sched) == 2
|
2024-07-12 10:00:57 +08:00
|
|
|
run_schedule(sched)
|
2024-07-19 01:02:29 +08:00
|
|
|
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
|
2024-07-12 07:41:51 +08:00
|
|
|
np.testing.assert_allclose(real_index, X.numpy())
|
|
|
|
|
2024-05-03 10:34:30 +08:00
|
|
|
if __name__ == "__main__":
|
2024-07-12 15:23:16 +08:00
|
|
|
unittest.main()
|