From 1dde4ce6094be7a5451b659b99b08c0e8a6f3ec9 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 29 Oct 2020 08:24:12 -0700 Subject: [PATCH] oops, stay in float32 --- examples/efficientnet.py | 5 +++-- tinygrad/tensor.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 95915321..cd347b56 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -108,7 +108,8 @@ class EfficientNet: mv = eval(mk.replace(".weight", "")) except AttributeError: mv = eval(mk.replace(".bias", "_bias")) - mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T + vnp = v.numpy().astype(np.float32) + mv.data[:] = vnp if k != '_fc.weight' else vnp.T if __name__ == "__main__": # instantiate my net @@ -121,7 +122,7 @@ if __name__ == "__main__": img = img.resize((224, 224)) img = np.moveaxis(np.array(img), [2,0,1], [0,1,2]) img = img.astype(np.float32).reshape(1,3,224,224) - print(img.shape) + print(img.shape, img.dtype) # run the net out = model.forward(Tensor(img)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cca93694..a6d83b1a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -15,7 +15,7 @@ class Tensor: assert(False) if data.dtype != np.float32: # warning? float64 is actually needed for numerical jacobian - pass + print("warning, %r isn't float32" % (data.shape,)) self.data = data self.grad = None @@ -75,11 +75,11 @@ class Tensor: return self.sum().mul(div) def sqrt(self): - root = Tensor(np.zeros(self.shape)+0.5) + root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)+0.5) return self.pow(root) def div(self, y): - root = Tensor(np.zeros(self.shape)-1) + root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)-1) return self.mul(y.pow(root)) # An instantiation of the Function is the Context