From ff11c4316bd859fb9881bea167ca6162446bd28e Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 25 Sep 2022 13:16:58 -0400 Subject: [PATCH] move get_parameters to optim.py --- extra/utils.py | 13 +------------ tinygrad/nn/optim.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/extra/utils.py b/extra/utils.py index dcd73a62..23f1e903 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -23,18 +23,7 @@ def fetch(url): os.rename(fp+".tmp", fp) return dat -# TODO: move this to optim.py? -def get_parameters(obj): - parameters = [] - if isinstance(obj, Tensor): - parameters.append(obj) - elif isinstance(obj, list) or isinstance(obj, tuple): - for x in obj: - parameters.extend(get_parameters(x)) - elif hasattr(obj, '__dict__'): - for v in obj.__dict__.values(): - parameters.extend(get_parameters(v)) - return parameters +from tinygrad.nn.optim import get_parameters def my_unpickle(fb0): key_prelookup = {} diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 1e006c4e..db513a41 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -64,3 +64,15 @@ class Adam(Optimizer): self.v[i] = self.b2 * self.v[i] + (1.0 - self.b2) * (t.grad * t.grad) t.assign(t.detach() - a * self.m[i].div(self.v[i].sqrt() + self.eps)) self.realize(self.m + self.v) + +def get_parameters(obj): + parameters = [] + if isinstance(obj, Tensor): + parameters.append(obj) + elif isinstance(obj, list) or isinstance(obj, tuple): + for x in obj: + parameters.extend(get_parameters(x)) + elif hasattr(obj, '__dict__'): + for v in obj.__dict__.values(): + parameters.extend(get_parameters(v)) + return parameters