move get_parameters to optim.py

This commit is contained in:
George Hotz 2022-09-25 13:16:58 -04:00
parent a0c0239ff1
commit ff11c4316b
2 changed files with 13 additions and 12 deletions

View File

@ -23,18 +23,7 @@ def fetch(url):
os.rename(fp+".tmp", fp) os.rename(fp+".tmp", fp)
return dat return dat
# TODO: move this to optim.py? from tinygrad.nn.optim import get_parameters
def get_parameters(obj):
parameters = []
if isinstance(obj, Tensor):
parameters.append(obj)
elif isinstance(obj, list) or isinstance(obj, tuple):
for x in obj:
parameters.extend(get_parameters(x))
elif hasattr(obj, '__dict__'):
for v in obj.__dict__.values():
parameters.extend(get_parameters(v))
return parameters
def my_unpickle(fb0): def my_unpickle(fb0):
key_prelookup = {} key_prelookup = {}

View File

@ -64,3 +64,15 @@ class Adam(Optimizer):
self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad) self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad)
t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps))
self.realize(self.m + self.v) self.realize(self.m + self.v)
def get_parameters(obj):
parameters = []
if isinstance(obj, Tensor):
parameters.append(obj)
elif isinstance(obj, list) or isinstance(obj, tuple):
for x in obj:
parameters.extend(get_parameters(x))
elif hasattr(obj, '__dict__'):
for v in obj.__dict__.values():
parameters.extend(get_parameters(v))
return parameters