mirror of https://github.com/commaai/tinygrad.git
test pickle variable (#5150)
* test pickle variable * fix process replay
This commit is contained in:
parent
8fcc41582f
commit
c98ca23cb9
|
@ -1,6 +1,6 @@
|
|||
import unittest, pickle
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit
|
||||
from tinygrad import Tensor, TinyJit, Variable
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
class TestPickle(unittest.TestCase):
|
||||
|
@ -16,6 +16,16 @@ class TestPickle(unittest.TestCase):
|
|||
t2:Tensor = pickle.loads(st)
|
||||
np.testing.assert_equal(t.numpy(), t2.numpy())
|
||||
|
||||
def test_pickle_variable(self):
|
||||
v = Variable("i", 1, 20).bind(10)
|
||||
t1 = Tensor.ones(10, v).contiguous()
|
||||
t2 = Tensor.ones(10, v).contiguous()
|
||||
ret = (t1+t2).sum(1)
|
||||
st = pickle.dumps(ret)
|
||||
del ret
|
||||
vt2 = pickle.loads(st)
|
||||
np.testing.assert_equal(vt2.numpy(), 20)
|
||||
|
||||
def test_pickle_buffer_view(self):
|
||||
t = Tensor.arange(10, device="CLANG").contiguous().realize()
|
||||
vt = t[3:5].contiguous().realize()
|
||||
|
|
|
@ -22,11 +22,9 @@ class Node:
|
|||
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
def hash(self) -> int: return hash(self.key)
|
||||
def __repr__(self): return self.render(ctx="REPR")
|
||||
def __str__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return self.hash
|
||||
def __hash__(self): return hash(self.key)
|
||||
def __bool__(self): return not (self.max == self.min == 0)
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
|
|
Loading…
Reference in New Issue