mirror of https://github.com/commaai/tinygrad.git
move get_parameters to optim.py
This commit is contained in:
parent
a0c0239ff1
commit
ff11c4316b
|
@ -23,18 +23,7 @@ def fetch(url):
|
||||||
os.rename(fp+".tmp", fp)
|
os.rename(fp+".tmp", fp)
|
||||||
return dat
|
return dat
|
||||||
|
|
||||||
# TODO: move this to optim.py?
|
from tinygrad.nn.optim import get_parameters
|
||||||
def get_parameters(obj):
|
|
||||||
parameters = []
|
|
||||||
if isinstance(obj, Tensor):
|
|
||||||
parameters.append(obj)
|
|
||||||
elif isinstance(obj, list) or isinstance(obj, tuple):
|
|
||||||
for x in obj:
|
|
||||||
parameters.extend(get_parameters(x))
|
|
||||||
elif hasattr(obj, '__dict__'):
|
|
||||||
for v in obj.__dict__.values():
|
|
||||||
parameters.extend(get_parameters(v))
|
|
||||||
return parameters
|
|
||||||
|
|
||||||
def my_unpickle(fb0):
|
def my_unpickle(fb0):
|
||||||
key_prelookup = {}
|
key_prelookup = {}
|
||||||
|
|
|
@ -64,3 +64,15 @@ class Adam(Optimizer):
|
||||||
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)
|
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)
|
||||||
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
|
||||||
self.realize(self.m + self.v)
|
self.realize(self.m + self.v)
|
||||||
|
|
||||||
|
def get_parameters(obj):
|
||||||
|
parameters = []
|
||||||
|
if isinstance(obj, Tensor):
|
||||||
|
parameters.append(obj)
|
||||||
|
elif isinstance(obj, list) or isinstance(obj, tuple):
|
||||||
|
for x in obj:
|
||||||
|
parameters.extend(get_parameters(x))
|
||||||
|
elif hasattr(obj, '__dict__'):
|
||||||
|
for v in obj.__dict__.values():
|
||||||
|
parameters.extend(get_parameters(v))
|
||||||
|
return parameters
|
||||||
|
|
Loading…
Reference in New Issue