test pickle variable (#5150)

* test pickle variable

* fix process replay
This commit is contained in:
George Hotz 2024-06-25 19:49:21 -07:00 committed by GitHub
parent 8fcc41582f
commit c98ca23cb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 4 deletions

View File

@ -1,6 +1,6 @@
import unittest, pickle
import numpy as np
from tinygrad import Tensor, TinyJit
from tinygrad import Tensor, TinyJit, Variable
from tinygrad.engine.schedule import create_schedule
class TestPickle(unittest.TestCase):
@ -16,6 +16,16 @@ class TestPickle(unittest.TestCase):
t2:Tensor = pickle.loads(st)
np.testing.assert_equal(t.numpy(), t2.numpy())
def test_pickle_variable(self):
v = Variable("i", 1, 20).bind(10)
t1 = Tensor.ones(10, v).contiguous()
t2 = Tensor.ones(10, v).contiguous()
ret = (t1+t2).sum(1)
st = pickle.dumps(ret)
del ret
vt2 = pickle.loads(st)
np.testing.assert_equal(vt2.numpy(), 20)
def test_pickle_buffer_view(self):
t = Tensor.arange(10, device="CLANG").contiguous().realize()
vt = t[3:5].contiguous().realize()

View File

@ -22,11 +22,9 @@ class Node:
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
def hash(self) -> int: return hash(self.key)
def __repr__(self): return self.render(ctx="REPR")
def __str__(self): return "<"+self.key+">"
def __hash__(self): return self.hash
def __hash__(self): return hash(self.key)
def __bool__(self): return not (self.max == self.min == 0)
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented