mirror of https://github.com/commaai/tinygrad.git
refactor into utils
This commit is contained in:
parent
0c3dd12b3b
commit
cc9054e3ec
|
@ -42,11 +42,12 @@ print(y.grad) # dz/dy
|
|||
```python
|
||||
from tinygrad.tensor import Tensor
|
||||
import tinygrad.optim as optim
|
||||
from tinygrad.utils import layer_init_uniform
|
||||
|
||||
class TinyBobNet:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor(layer_init(784, 128))
|
||||
self.l2 = Tensor(layer_init(128, 10))
|
||||
self.l1 = Tensor(layer_init_uniform(784, 128))
|
||||
self.l2 = Tensor(layer_init_uniform(128, 10))
|
||||
|
||||
def forward(self, x):
|
||||
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import fetch_mnist
|
||||
from tinygrad.utils import layer_init_uniform, fetch_mnist
|
||||
import tinygrad.optim as optim
|
||||
|
||||
from tqdm import trange
|
||||
|
@ -13,14 +13,11 @@ X_train, Y_train, X_test, Y_test = fetch_mnist()
|
|||
# train a model
|
||||
|
||||
np.random.seed(1337)
|
||||
def layer_init(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
class TinyBobNet:
|
||||
def __init__(self):
|
||||
self.l1 = Tensor(layer_init(784, 128))
|
||||
self.l2 = Tensor(layer_init(128, 10))
|
||||
self.l1 = Tensor(layer_init_uniform(784, 128))
|
||||
self.l2 = Tensor(layer_init_uniform(128, 10))
|
||||
|
||||
def forward(self, x):
|
||||
return x.dot(self.l1).relu().dot(self.l2).logsoftmax()
|
||||
|
|
|
@ -1,3 +1,9 @@
|
|||
import numpy as np
|
||||
|
||||
def layer_init_uniform(m, h):
|
||||
ret = np.random.uniform(-1., 1., size=(m,h))/np.sqrt(m*h)
|
||||
return ret.astype(np.float32)
|
||||
|
||||
def fetch_mnist():
|
||||
def fetch(url):
|
||||
import requests, gzip, os, hashlib, numpy
|
||||
|
|
Loading…
Reference in New Issue