mirror of https://github.com/commaai/tinygrad.git
move get_parameters to optim.py
This commit is contained in:
parent
a0c0239ff1
commit
ff11c4316b
|
@ -23,18 +23,7 @@ def fetch(url):
|
|||
os.rename(fp+".tmp", fp)
|
||||
return dat
|
||||
|
||||
# TODO: move this to optim.py?
|
||||
def get_parameters(obj):
|
||||
parameters = []
|
||||
if isinstance(obj, Tensor):
|
||||
parameters.append(obj)
|
||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for x in obj:
|
||||
parameters.extend(get_parameters(x))
|
||||
elif hasattr(obj, '__dict__'):
|
||||
for v in obj.__dict__.values():
|
||||
parameters.extend(get_parameters(v))
|
||||
return parameters
|
||||
from tinygrad.nn.optim import get_parameters
|
||||
|
||||
def my_unpickle(fb0):
|
||||
key_prelookup = {}
|
||||
|
|
|
@ -64,3 +64,15 @@ class Adam(Optimizer):
|
|||
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)
|
||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
||||
self.realize(self.m + self.v)
|
||||
|
||||
def get_parameters(obj):
|
||||
parameters = []
|
||||
if isinstance(obj, Tensor):
|
||||
parameters.append(obj)
|
||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||
for x in obj:
|
||||
parameters.extend(get_parameters(x))
|
||||
elif hasattr(obj, '__dict__'):
|
||||
for v in obj.__dict__.values():
|
||||
parameters.extend(get_parameters(v))
|
||||
return parameters
|
||||
|
|
Loading…
Reference in New Issue