mirror of https://github.com/commaai/tinygrad.git
60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
import unittest
|
|
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
|
import numpy as np
|
|
|
|
class TestSetitem(unittest.TestCase):
|
|
def test_simple_setitem(self):
|
|
t = Tensor.zeros(6, 6).contiguous().realize()
|
|
t[2:4, 3:5] = Tensor.ones(2, 2)
|
|
n = np.zeros((6, 6))
|
|
n[2:4, 3:5] = np.ones((2, 2))
|
|
np.testing.assert_allclose(t.numpy(), n)
|
|
|
|
t = Tensor.zeros(6, 6).contiguous().realize()
|
|
t[2:4, 3:5] = 1.0
|
|
n = np.zeros((6, 6))
|
|
n[2:4, 3:5] = 1.0
|
|
np.testing.assert_allclose(t.numpy(), n)
|
|
|
|
def test_setitem_into_unrealized(self):
|
|
t = Tensor.arange(4).reshape(2, 2)
|
|
t[1] = 5
|
|
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
|
|
|
|
def test_setitem_into_noncontiguous(self):
|
|
t = Tensor.ones(4)
|
|
assert not t.lazydata.st.contiguous
|
|
with self.assertRaises(AssertionError): t[1] = 5
|
|
|
|
def test_simple_jit_setitem(self):
|
|
@TinyJit
|
|
def f(t:Tensor, a:Tensor):
|
|
t[2:4, 3:5] = a
|
|
|
|
for i in range(1, 6):
|
|
t = Tensor.zeros(6, 6).contiguous().realize()
|
|
a = Tensor.full((2, 2), fill_value=i, dtype=dtypes.float).contiguous()
|
|
f(t, a)
|
|
|
|
n = np.zeros((6, 6))
|
|
n[2:4, 3:5] = np.full((2, 2), i)
|
|
np.testing.assert_allclose(t.numpy(), n)
|
|
|
|
def test_jit_setitem_variable_offset(self):
|
|
@TinyJit
|
|
def f(t:Tensor, a:Tensor, v:Variable):
|
|
t.shrink(((v,v+1), None)).assign(a).realize()
|
|
|
|
t = Tensor.zeros(6, 6).contiguous().realize()
|
|
n = np.zeros((6, 6))
|
|
|
|
for i in range(6):
|
|
v = Variable("v", 0, 6).bind(i)
|
|
a = Tensor.full((1, 6), fill_value=i+1, dtype=dtypes.float).contiguous()
|
|
n[i, :] = i+1
|
|
f(t, a, v)
|
|
np.testing.assert_allclose(t.numpy(), n)
|
|
np.testing.assert_allclose(t.numpy(), [[1,1,1,1,1,1],[2,2,2,2,2,2],[3,3,3,3,3,3],[4,4,4,4,4,4],[5,5,5,5,5,5],[6,6,6,6,6,6]])
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |