mirror of https://github.com/commaai/tinygrad.git
make the example simpler
This commit is contained in:
parent
1f0514e5df
commit
43591a1e71
|
@ -15,11 +15,10 @@ The Tensor class is a wrapper around a numpy array, except it does Tensor things
|
|||
### Example
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
x = Tensor(np.eye(3))
|
||||
y = Tensor(np.array([[2.0,0,-2.0]]))
|
||||
x = Tensor.eye(3)
|
||||
y = Tensor([[2.0,0,-2.0]])
|
||||
z = y.dot(x).sum()
|
||||
z.backward()
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -8,7 +8,7 @@ with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f:
|
|||
long_description = f.read()
|
||||
|
||||
setup(name='tinygrad',
|
||||
version='0.1.0',
|
||||
version='0.2.0',
|
||||
description='You like pytorch? You like micrograd? You love tinygrad! heart',
|
||||
author='George Hotz',
|
||||
license='MIT',
|
||||
|
|
|
@ -8,6 +8,8 @@ import numpy as np
|
|||
class Tensor:
|
||||
def __init__(self, data):
|
||||
#print(type(data), data)
|
||||
if type(data) == list:
|
||||
data = np.array(data, dtype=np.float32)
|
||||
if type(data) != np.ndarray:
|
||||
print("error constructing tensor with %r" % data)
|
||||
assert(False)
|
||||
|
@ -35,6 +37,10 @@ class Tensor:
|
|||
def randn(*shape):
|
||||
return Tensor(np.random.randn(*shape).astype(np.float32))
|
||||
|
||||
@staticmethod
|
||||
def eye(dim):
|
||||
return Tensor(np.eye(dim).astype(np.float32))
|
||||
|
||||
def backward(self, allow_fill=True):
|
||||
#print("running backward on", self)
|
||||
if self._ctx is None:
|
||||
|
|
Loading…
Reference in New Issue