diff --git a/test/test_pickle.py b/test/test_pickle.py index 7955b17a..d1f0a1ac 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -1,6 +1,6 @@ import unittest, pickle import numpy as np -from tinygrad import Tensor, TinyJit +from tinygrad import Tensor, TinyJit, Variable from tinygrad.engine.schedule import create_schedule class TestPickle(unittest.TestCase): @@ -16,6 +16,16 @@ class TestPickle(unittest.TestCase): t2:Tensor = pickle.loads(st) np.testing.assert_equal(t.numpy(), t2.numpy()) + def test_pickle_variable(self): + v = Variable("i", 1, 20).bind(10) + t1 = Tensor.ones(10, v).contiguous() + t2 = Tensor.ones(10, v).contiguous() + ret = (t1+t2).sum(1) + st = pickle.dumps(ret) + del ret + vt2 = pickle.loads(st) + np.testing.assert_equal(vt2.numpy(), 20) + def test_pickle_buffer_view(self): t = Tensor.arange(10, device="CLANG").contiguous().realize() vt = t[3:5].contiguous().realize() diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 6e780868..f18bac1d 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -22,11 +22,9 @@ class Node: @functools.cached_property def key(self) -> str: return self.render(ctx="DEBUG") - @functools.cached_property - def hash(self) -> int: return hash(self.key) def __repr__(self): return self.render(ctx="REPR") def __str__(self): return "<"+self.key+">" - def __hash__(self): return self.hash + def __hash__(self): return hash(self.key) def __bool__(self): return not (self.max == self.min == 0) def __eq__(self, other:object) -> bool: if not isinstance(other, Node): return NotImplemented