mirror of https://github.com/commaai/tinygrad.git
use __getnewargs__ to fix unpickling Variable (#3441)
it's recommended to use __getnewargs__ to update the args of classes that use __new__ when unpickling. It's preferred because it does not change the __new__ behavior.
This commit is contained in:
parent
5647148937
commit
2da734920e
|
@ -3,14 +3,9 @@ import unittest, pickle
|
|||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer
|
||||
|
||||
class TestSymbolicPickle(unittest.TestCase):
|
||||
def test_pickle_variable(self):
|
||||
dat = Variable("a", 3, 8)
|
||||
datp = pickle.loads(pickle.dumps(dat))
|
||||
self.assertEqual(str(datp), "<a[3-8]>")
|
||||
def test_pickle_variable_times_2(self):
|
||||
dat = Variable("a", 3, 8)*2
|
||||
datp = pickle.loads(pickle.dumps(dat))
|
||||
self.assertEqual(str(datp), "<(a[3-8]*2)>")
|
||||
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
|
||||
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
|
||||
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
|
|
|
@ -118,12 +118,13 @@ class Node:
|
|||
|
||||
class Variable(Node):
|
||||
def __new__(cls, *args):
|
||||
if len(args) == 0: return super().__new__(cls) # fix pickle
|
||||
expr, nmin, nmax = args
|
||||
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
|
||||
if nmin == nmax: return NumNode(nmin)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
|
||||
|
||||
def __init__(self, expr:str, nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
self._val: Optional[int] = None
|
||||
|
|
Loading…
Reference in New Issue