2024-09-30 21:54:46 +08:00
|
|
|
import unittest, pickle, types
|
2024-03-18 09:29:14 +08:00
|
|
|
import numpy as np
|
2024-09-05 13:36:36 +08:00
|
|
|
from test.helpers import assert_equiv_uops
|
2024-09-30 21:54:46 +08:00
|
|
|
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
2024-04-10 04:47:25 +08:00
|
|
|
from tinygrad.engine.schedule import create_schedule
|
2024-09-30 21:54:46 +08:00
|
|
|
from tinygrad.ops import PatternMatcher, UPat, UOp
|
2024-03-18 09:29:14 +08:00
|
|
|
|
2024-09-05 13:36:36 +08:00
|
|
|
class TestPickle(unittest.TestCase):
|
2024-09-30 21:54:46 +08:00
|
|
|
def test_pickle_code_object(self):
|
|
|
|
y = lambda x: x*2 # noqa: E731
|
|
|
|
code_str = pickle.dumps(y.__code__)
|
|
|
|
fxn = types.FunctionType(pickle.loads(code_str), globals())
|
|
|
|
self.assertEqual(fxn(2), 4)
|
|
|
|
|
|
|
|
def test_pickle_pattern_matcher(self):
|
|
|
|
pm = PatternMatcher([(UPat.cvar('x'), lambda x: x*2)])
|
|
|
|
sink = UOp.const(dtypes.int, 2)
|
|
|
|
tt = pm.rewrite(sink)
|
|
|
|
pm_str = pickle.dumps(pm)
|
|
|
|
pm2 = pickle.loads(pm_str)
|
|
|
|
self.assertEqual(pm2.rewrite(sink).key, tt.key)
|
|
|
|
|
2024-10-01 13:58:42 +08:00
|
|
|
def test_pickle_main_pattern_matcher(self):
|
|
|
|
from tinygrad.codegen.uopgraph import sym
|
|
|
|
pickle.dumps(sym)
|
|
|
|
|
2024-03-18 09:29:14 +08:00
|
|
|
def test_pickle_realized_tensor(self):
|
|
|
|
t = Tensor.rand(10, 10).realize()
|
|
|
|
st = pickle.dumps(t)
|
|
|
|
t2:Tensor = pickle.loads(st)
|
|
|
|
np.testing.assert_equal(t.numpy(), t2.numpy())
|
|
|
|
|
|
|
|
def test_pickle_unrealized_tensor(self):
|
|
|
|
t = Tensor.ones(10, 10)
|
|
|
|
st = pickle.dumps(t)
|
|
|
|
t2:Tensor = pickle.loads(st)
|
|
|
|
np.testing.assert_equal(t.numpy(), t2.numpy())
|
|
|
|
|
2024-06-26 10:49:21 +08:00
|
|
|
def test_pickle_variable(self):
|
|
|
|
v = Variable("i", 1, 20).bind(10)
|
|
|
|
t1 = Tensor.ones(10, v).contiguous()
|
|
|
|
t2 = Tensor.ones(10, v).contiguous()
|
|
|
|
ret = (t1+t2).sum(1)
|
|
|
|
st = pickle.dumps(ret)
|
|
|
|
del ret
|
|
|
|
vt2 = pickle.loads(st)
|
|
|
|
np.testing.assert_equal(vt2.numpy(), 20)
|
|
|
|
|
2024-05-04 09:05:57 +08:00
|
|
|
def test_pickle_buffer_view(self):
|
|
|
|
t = Tensor.arange(10, device="CLANG").contiguous().realize()
|
|
|
|
vt = t[3:5].contiguous().realize()
|
|
|
|
assert hasattr(vt.lazydata.buffer, 'base')
|
|
|
|
ref_value = vt.tolist()
|
|
|
|
st = pickle.dumps(vt)
|
|
|
|
del t, vt
|
|
|
|
vt2 = pickle.loads(st)
|
|
|
|
assert hasattr(vt2.lazydata.buffer, 'base')
|
|
|
|
assert ref_value == vt2.tolist()
|
|
|
|
|
2024-04-10 04:19:30 +08:00
|
|
|
def test_pickle_numpy(self):
|
|
|
|
t = Tensor(np.array([1,2,3,4.]))
|
|
|
|
st = pickle.dumps(t)
|
|
|
|
t2:Tensor = pickle.loads(st)
|
|
|
|
np.testing.assert_equal(t.numpy(), t2.numpy())
|
|
|
|
|
2024-03-18 09:29:14 +08:00
|
|
|
def test_pickle_jit(self):
|
|
|
|
@TinyJit
|
2024-08-01 11:48:17 +08:00
|
|
|
def add(a, b): return a.sum()+b+1
|
2024-03-18 09:29:14 +08:00
|
|
|
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
|
|
|
|
st = pickle.dumps(add)
|
2024-05-06 01:14:03 +08:00
|
|
|
del add
|
2024-03-18 09:29:14 +08:00
|
|
|
|
2024-05-06 01:14:03 +08:00
|
|
|
add_fxn = pickle.loads(st)
|
2024-03-18 09:29:14 +08:00
|
|
|
x = Tensor.ones(10, 10).contiguous().realize()
|
|
|
|
y = Tensor.ones(10, 10).contiguous().realize()
|
|
|
|
print("post jit")
|
|
|
|
out = add_fxn(x, y)
|
2024-08-01 11:48:17 +08:00
|
|
|
np.testing.assert_equal(out.numpy(), 102)
|
2024-03-18 09:29:14 +08:00
|
|
|
|
2024-04-10 04:47:25 +08:00
|
|
|
def test_pickle_schedule(self):
|
|
|
|
a = Tensor([1,2])
|
|
|
|
out = a + 2
|
|
|
|
sched = create_schedule([out.lazydata])
|
|
|
|
pk = pickle.dumps(sched)
|
|
|
|
sched_pk = pickle.loads(pk)
|
2024-09-05 13:36:36 +08:00
|
|
|
assert_equiv_uops(sched_pk[-1].ast, sched[-1].ast)
|
2024-04-10 04:47:25 +08:00
|
|
|
|
2024-10-11 12:24:06 +08:00
|
|
|
def test_pickle_renderer(self):
|
|
|
|
from tinygrad.device import Device
|
|
|
|
pk = pickle.dumps(Device.default.renderer)
|
|
|
|
pickle.loads(pk)
|
|
|
|
|
2024-08-02 03:39:59 +08:00
|
|
|
class TestPickleJIT(unittest.TestCase):
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
|
|
|
@TinyJit
|
|
|
|
def add(a, b): return a.sum()+b+1
|
|
|
|
for _ in range(3): add(Tensor.rand(1000, 1000), Tensor.rand(1000, 1000))
|
|
|
|
cls.st = pickle.dumps(add)
|
|
|
|
del add
|
|
|
|
|
|
|
|
def test_inspect(self):
|
|
|
|
import io
|
|
|
|
class FakeClass:
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
print(self.module, self.name)
|
|
|
|
class InspectUnpickler(pickle.Unpickler):
|
|
|
|
def find_class(self, module, name): return type("SpecializedFakeClass", (FakeClass,), {"name": name, "module": module})
|
|
|
|
InspectUnpickler(io.BytesIO(self.st)).load()
|
|
|
|
|
|
|
|
@unittest.skip("we are still saving intermediate buffers")
|
|
|
|
def test_size(self):
|
|
|
|
# confirm no intermediate buffers are saved
|
|
|
|
self.assertLess(len(self.st), 1_000_000)
|
|
|
|
|
2024-03-18 09:29:14 +08:00
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|