mirror of https://github.com/commaai/tinygrad.git
oops, stay in float32
This commit is contained in:
parent
5e7e359706
commit
1dde4ce609
|
@ -108,7 +108,8 @@ class EfficientNet:
|
|||
mv = eval(mk.replace(".weight", ""))
|
||||
except AttributeError:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
mv.data[:] = v.numpy() if k != '_fc.weight' else v.numpy().T
|
||||
vnp = v.numpy().astype(np.float32)
|
||||
mv.data[:] = vnp if k != '_fc.weight' else vnp.T
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
|
@ -121,7 +122,7 @@ if __name__ == "__main__":
|
|||
img = img.resize((224, 224))
|
||||
img = np.moveaxis(np.array(img), [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32).reshape(1,3,224,224)
|
||||
print(img.shape)
|
||||
print(img.shape, img.dtype)
|
||||
|
||||
# run the net
|
||||
out = model.forward(Tensor(img))
|
||||
|
|
|
@ -15,7 +15,7 @@ class Tensor:
|
|||
assert(False)
|
||||
if data.dtype != np.float32:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
pass
|
||||
print("warning, %r isn't float32" % (data.shape,))
|
||||
|
||||
self.data = data
|
||||
self.grad = None
|
||||
|
@ -75,11 +75,11 @@ class Tensor:
|
|||
return self.sum().mul(div)
|
||||
|
||||
def sqrt(self):
|
||||
root = Tensor(np.zeros(self.shape)+0.5)
|
||||
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)+0.5)
|
||||
return self.pow(root)
|
||||
|
||||
def div(self, y):
|
||||
root = Tensor(np.zeros(self.shape)-1)
|
||||
root = Tensor(np.zeros(self.shape, dtype=self.data.dtype)-1)
|
||||
return self.mul(y.pow(root))
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
|
|
Loading…
Reference in New Issue