refactor into utils

This commit is contained in:
George Hotz 2020-10-18 14:36:29 -07:00
parent 0c3dd12b3b
commit cc9054e3ec
3 changed files with 12 additions and 8 deletions

View File

@ -42,11 +42,12 @@ print(y.grad) # dz/dy
```python ```python
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
import tinygrad.optim as optim import tinygrad.optim as optim
from tinygrad.utils import layer_init_uniform
class TinyBobNet: class TinyBobNet:
def __init__(self): def __init__(self):
self.l1 = Tensor(layer_init(784, 128)) self.l1 = Tensor(layer_init_uniform(784, 128))
self.l2 = Tensor(layer_init(128, 10)) self.l2 = Tensor(layer_init_uniform(128, 10))
def forward(self, x): def forward(self, x):
return x.dot(self.l1).relu().dot(self.l2).logsoftmax() return x.dot(self.l1).relu().dot(self.l2).logsoftmax()

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import numpy as np import numpy as np
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.utils import fetch_mnist from tinygrad.utils import layer_init_uniform, fetch_mnist
import tinygrad.optim as optim import tinygrad.optim as optim
from tqdm import trange from tqdm import trange
@ -13,14 +13,11 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
# train a model # train a model
np.random.seed(1337) np.random.seed(1337)
def layer_init(m, h):
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
return ret.astype(np.float32)
class TinyBobNet: class TinyBobNet:
def __init__(self): def __init__(self):
self.l1 = Tensor(layer_init(784, 128)) self.l1 = Tensor(layer_init_uniform(784, 128))
self.l2 = Tensor(layer_init(128, 10)) self.l2 = Tensor(layer_init_uniform(128, 10))
def forward(self, x): def forward(self, x):
return x.dot(self.l1).relu().dot(self.l2).logsoftmax() return x.dot(self.l1).relu().dot(self.l2).logsoftmax()

View File

@ -1,3 +1,9 @@
import numpy as np
def layer_init_uniform(m, h):
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
return ret.astype(np.float32)
def fetch_mnist(): def fetch_mnist():
def fetch(url): def fetch(url):
import requests, gzip, os, hashlib, numpy import requests, gzip, os, hashlib, numpy