oops, stay in float32

This commit is contained in:
George Hotz 2020-10-29 08:24:12 -07:00
parent 5e7e359706
commit 1dde4ce609
2 changed files with 6 additions and 5 deletions

View File

@ -108,7 +108,8 @@ class EfficientNet:
mv = eval(mk.replace(".weight", ""))
except AttributeError:
mv = eval(mk.replace(".bias", "_bias"))
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
vnp = v.numpy().astype(np.float32)
mv.data[:] = vnp if k != '_fc.weight' else vnp.T
if __name__ == "__main__":
# instantiate my net
@ -121,7 +122,7 @@ if __name__ == "__main__":
img = img.resize((224, 224))
img = np.moveaxis(np.array(img), [2,0,1], [0,1,2])
img = img.astype(np.float32).reshape(1,3,224,224)
print(img.shape)
print(img.shape, img.dtype)
# run the net
out = model.forward(Tensor(img))

View File

@ -15,7 +15,7 @@ class Tensor:
assert(False)
if data.dtype != np.float32:
# warning? float64 is actually needed for numerical jacobian
pass
print("warning, %r isn't float32" % (data.shape,))
self.data = data
self.grad = None
@ -75,11 +75,11 @@ class Tensor:
return self.sum().mul(div)
def sqrt(self):
root = Tensor(np.zeros(self.shape)+0.5)
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)+0.5)
return self.pow(root)
def div(self, y):
root = Tensor(np.zeros(self.shape)-1)
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)-1)
return self.mul(y.pow(root))
# An instantiation of the Function is the Context