use __getnewargs__ to fix unpickling Variable (#3441)

it's recommended to use __getnewargs__ to update the args of classes that use __new__ when unpickling.
It's preferred because it does not change the __new__ behavior.
This commit is contained in:
chenyu 2024-02-18 10:28:37 -05:00 committed by GitHub
parent 5647148937
commit 2da734920e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 9 deletions

View File

@ -3,14 +3,9 @@ import unittest, pickle
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer
class TestSymbolicPickle(unittest.TestCase):
def test_pickle_variable(self):
dat = Variable("a", 3, 8)
datp = pickle.loads(pickle.dumps(dat))
self.assertEqual(str(datp), "<a[3-8]>")
def test_pickle_variable_times_2(self):
dat = Variable("a", 3, 8)*2
datp = pickle.loads(pickle.dumps(dat))
self.assertEqual(str(datp), "<(a[3-8]*2)>")
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
def test_pickle_variable_times_2(self): self._test_pickle_unpickle(Variable("a", 3, 8)*2)
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):

View File

@ -118,12 +118,13 @@ class Node:
class Variable(Node):
def __new__(cls, *args):
if len(args) == 0: return super().__new__(cls) # fix pickle
expr, nmin, nmax = args
assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __getnewargs__(self): return (self.expr, self.min, self.max) # args passed to __new__ when unpickling
def __init__(self, expr:str, nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
self._val: Optional[int] = None