2023-02-12 02:04:03 +08:00
|
|
|
#!/usr/bin/env python
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from tinygrad.tensor import Tensor, Device
|
2023-07-19 02:40:37 +08:00
|
|
|
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
2023-07-24 04:00:56 +08:00
|
|
|
import pytest
|
|
|
|
|
|
|
|
pytestmark = pytest.mark.webgpu
|
2023-02-12 02:04:03 +08:00
|
|
|
|
2023-07-19 02:40:37 +08:00
|
|
|
# NOTE: METAL fails, might be platform and optimization options dependent.
|
|
|
|
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
|
2023-02-12 02:04:03 +08:00
|
|
|
class TestJit(unittest.TestCase):
|
|
|
|
def test_simple_jit(self):
|
|
|
|
@TinyJit
|
2023-02-12 23:43:17 +08:00
|
|
|
def add(a, b): return (a+b).realize()
|
2023-03-13 13:33:25 +08:00
|
|
|
for _ in range(5):
|
2023-02-12 02:04:03 +08:00
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
c = add(a, b)
|
2023-08-09 13:22:17 +08:00
|
|
|
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
2023-07-20 01:45:43 +08:00
|
|
|
assert len(add.jit_cache) == 1
|
|
|
|
|
|
|
|
def test_jit_multiple_outputs(self):
|
|
|
|
@TinyJit
|
|
|
|
def f(a, b): return (a+b).realize(), (a-b).realize(), (a*b).realize()
|
|
|
|
for _ in range(5):
|
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
c, d, e = f(a, b)
|
2023-08-09 13:22:17 +08:00
|
|
|
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
|
|
|
np.testing.assert_allclose(d.numpy(), a.numpy()-b.numpy(), atol=1e-4, rtol=1e-5)
|
|
|
|
np.testing.assert_allclose(e.numpy(), a.numpy()*b.numpy(), atol=1e-4, rtol=1e-5)
|
2023-07-20 01:45:43 +08:00
|
|
|
assert len(f.jit_cache) == 3
|
|
|
|
|
|
|
|
def test_nothing_jitted(self):
|
|
|
|
@TinyJit
|
|
|
|
def add(a, b): return a+b
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
for _ in range(5):
|
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
c = add(a, b)
|
2023-02-12 02:04:03 +08:00
|
|
|
|
2023-05-06 12:56:32 +08:00
|
|
|
def test_jit_shape_mismatch(self):
|
|
|
|
@TinyJit
|
|
|
|
def add(a, b): return (a+b).realize()
|
2023-07-20 01:45:43 +08:00
|
|
|
for _ in range(5):
|
2023-05-06 12:56:32 +08:00
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
c = add(a, b)
|
|
|
|
bad = Tensor.randn(20, 20)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
add(a, bad)
|
|
|
|
|
2023-08-23 00:08:21 +08:00
|
|
|
def test_jit_shape_views_mismatch(self):
|
|
|
|
@TinyJit
|
|
|
|
def add(a): return (a+1).realize()
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
for i in range(1,5):
|
|
|
|
# a has an offset that the kernel doesn't know about
|
|
|
|
a = Tensor.randn(10, 10).realize()[:, i:i+2]
|
|
|
|
add(a)
|
|
|
|
|
2023-05-06 12:56:32 +08:00
|
|
|
def test_jit_duplicate_fail(self):
|
|
|
|
# the jit doesn't support duplicate arguments
|
|
|
|
@TinyJit
|
|
|
|
def add(a, b): return (a+b).realize()
|
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
|
|
add(a, a)
|
|
|
|
|
2023-02-12 02:04:03 +08:00
|
|
|
def test_kwargs_jit(self):
|
|
|
|
@TinyJit
|
2023-02-12 23:43:17 +08:00
|
|
|
def add_kwargs(first, second): return (first+second).realize()
|
2023-03-13 13:33:25 +08:00
|
|
|
for _ in range(5):
|
2023-02-12 02:04:03 +08:00
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
c = add_kwargs(first=a, second=b)
|
2023-08-09 13:22:17 +08:00
|
|
|
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
2023-07-20 01:45:43 +08:00
|
|
|
assert len(add_kwargs.jit_cache) == 1
|
2023-02-12 02:04:03 +08:00
|
|
|
|
|
|
|
def test_array_jit(self):
|
|
|
|
@TinyJit
|
2023-02-12 23:43:17 +08:00
|
|
|
def add_array(a, arr): return (a+arr[0]).realize()
|
2023-03-13 13:33:25 +08:00
|
|
|
for i in range(5):
|
2023-02-12 02:04:03 +08:00
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
a.realize(), b.realize()
|
2023-02-12 23:43:17 +08:00
|
|
|
c = add_array(a, [b])
|
2023-03-13 13:33:25 +08:00
|
|
|
if i >= 2:
|
2023-02-12 02:04:03 +08:00
|
|
|
# should fail once jitted since jit can't handle arrays
|
2023-08-09 13:22:17 +08:00
|
|
|
np.testing.assert_allclose(np.any(np.not_equal(c.numpy(),a.numpy()+b.numpy())), True, atol=1e-4, rtol=1e-5)
|
2023-02-12 02:04:03 +08:00
|
|
|
else:
|
2023-08-09 13:22:17 +08:00
|
|
|
np.testing.assert_allclose(c.numpy(), a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
2023-07-20 01:45:43 +08:00
|
|
|
assert len(add_array.jit_cache) == 1
|
2023-02-12 02:04:03 +08:00
|
|
|
|
2023-03-13 05:15:04 +08:00
|
|
|
def test_method_jit(self):
|
|
|
|
class Fun:
|
|
|
|
def __init__(self):
|
|
|
|
self.a = Tensor.randn(10, 10)
|
|
|
|
@TinyJit
|
|
|
|
def __call__(self, b:Tensor) -> Tensor:
|
|
|
|
return (self.a+b).realize()
|
|
|
|
fun = Fun()
|
|
|
|
for _ in range(5):
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
c = fun(b)
|
2023-08-09 13:22:17 +08:00
|
|
|
np.testing.assert_allclose(c.numpy(), fun.a.numpy()+b.numpy(), atol=1e-4, rtol=1e-5)
|
2023-07-20 01:45:43 +08:00
|
|
|
assert len(fun.__call__.func.__self__.jit_cache) == 1
|
|
|
|
|
|
|
|
def test_jit_size1_input(self):
|
|
|
|
@TinyJit
|
|
|
|
def f(a, b): return (a+b).realize()
|
|
|
|
a = Tensor([1, 2, 3])
|
|
|
|
for i in range(5):
|
2023-08-22 00:53:29 +08:00
|
|
|
np.testing.assert_allclose(f(a, Tensor([i])).numpy(), (a+i).numpy(), atol=1e-4, rtol=1e-5)
|
2023-07-20 01:45:43 +08:00
|
|
|
assert len(f.jit_cache) == 1
|
|
|
|
|
|
|
|
def test_jit_output_non_tensor_fail(self):
|
|
|
|
@TinyJit
|
|
|
|
def f(a, b, i): return (a+b).realize(), i
|
|
|
|
output1, output2 = [], []
|
|
|
|
expect1, expect2 = [], []
|
|
|
|
for i in range(5):
|
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
o1, o2 = f(a, b, i)
|
|
|
|
output1.append(o1.numpy().copy())
|
|
|
|
output2.append(o2)
|
|
|
|
expect1.append(a.numpy().copy()+b.numpy().copy())
|
|
|
|
expect2.append(i)
|
2023-08-09 13:22:17 +08:00
|
|
|
np.testing.assert_allclose(output1, expect1, atol=1e-4, rtol=1e-5)
|
2023-07-20 01:45:43 +08:00
|
|
|
# the jit only works with Tensor outputs
|
|
|
|
assert output2 != expect2
|
|
|
|
assert len(f.jit_cache) == 1
|
2023-03-13 05:15:04 +08:00
|
|
|
|
2023-10-04 16:48:10 +08:00
|
|
|
@unittest.skip("random isn't working in JIT")
|
2023-10-04 03:55:25 +08:00
|
|
|
def test_jit_random_regen(self):
|
|
|
|
def f(a, b):
|
|
|
|
rn = Tensor.randn(*a.shape)
|
|
|
|
return ((a+b)*rn).realize()
|
|
|
|
a = Tensor.randn(10, 10)
|
|
|
|
b = Tensor.randn(10, 10)
|
|
|
|
|
|
|
|
Tensor._seed = 1234
|
|
|
|
jf = TinyJit(f)
|
|
|
|
res = set()
|
|
|
|
for _ in range(5):
|
|
|
|
o1 = jf(a, b)
|
|
|
|
res.add(o1.numpy()[0][0])
|
|
|
|
assert len(res) == 5, "All values should be different, rand works in jit."
|
|
|
|
|
|
|
|
Tensor._seed = 1234
|
|
|
|
jf2 = TinyJit(f)
|
|
|
|
res2 = set()
|
|
|
|
for _ in range(5):
|
|
|
|
o1 = jf2(a, b)
|
|
|
|
res2.add(o1.numpy()[0][0])
|
|
|
|
assert len(res2) == 5, "All values should be different, rand works in jit."
|
|
|
|
assert res == res2, "Jit rand is not reproducible with the same seed"
|
|
|
|
|
|
|
|
Tensor._seed = 3421
|
|
|
|
jf3 = TinyJit(f)
|
|
|
|
res3 = set()
|
|
|
|
for _ in range(5):
|
|
|
|
o1 = jf3(a, b)
|
|
|
|
res3.add(o1.numpy()[0][0])
|
|
|
|
assert len(res3) == 5, "All values should be different, rand works in jit."
|
|
|
|
assert res3 != res2, "Jit rand is diff with diff seeds"
|
|
|
|
|
2023-02-12 02:04:03 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|