tinygrad/test/test_pickle.py

61 lines
1.8 KiB
Python

import unittest, pickle
import numpy as np
from tinygrad import Tensor, TinyJit
from tinygrad.engine.schedule import create_schedule
class TestPickle(unittest.TestCase):
def test_pickle_realized_tensor(self):
t = Tensor.rand(10, 10).realize()
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_unrealized_tensor(self):
t = Tensor.ones(10, 10)
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_buffer_view(self):
t = Tensor.arange(10, device="CLANG").contiguous().realize()
vt = t[3:5].contiguous().realize()
assert hasattr(vt.lazydata.buffer, 'base')
ref_value = vt.tolist()
st = pickle.dumps(vt)
del t, vt
vt2 = pickle.loads(st)
assert hasattr(vt2.lazydata.buffer, 'base')
assert ref_value == vt2.tolist()
def test_pickle_numpy(self):
t = Tensor(np.array([1,2,3,4.]))
st = pickle.dumps(t)
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_jit(self):
@TinyJit
def add(a, b): return a+b+1
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
del add.fxn # pickling the JIT requires the function to be deleted
st = pickle.dumps(add)
del add
add_fxn = pickle.loads(st)
x = Tensor.ones(10, 10).contiguous().realize()
y = Tensor.ones(10, 10).contiguous().realize()
print("post jit")
out = add_fxn(x, y)
np.testing.assert_equal(out.numpy(), 3)
def test_pickle_schedule(self):
a = Tensor([1,2])
out = a + 2
sched = create_schedule([out.lazydata])
pk = pickle.dumps(sched)
sched_pk = pickle.loads(pk)
assert sched_pk[-1].ast == sched[-1].ast
if __name__ == '__main__':
unittest.main()