mirror of https://github.com/commaai/tinygrad.git
bump tinygrad to 0.5, move reshape logic from mlops
This commit is contained in:
parent
e9e71fbfc4
commit
ea3fa07c2a
2
setup.py
2
setup.py
|
@ -8,7 +8,7 @@ with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f:
|
|||
long_description = f.read()
|
||||
|
||||
setup(name='tinygrad',
|
||||
version='0.4.0',
|
||||
version='0.5.0',
|
||||
description='You like pytorch? You like micrograd? You love tinygrad! heart',
|
||||
author='George Hotz',
|
||||
license='MIT',
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from tinygrad.helpers import prod, argsort
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps
|
||||
from tinygrad.tensor import Function
|
||||
|
||||
|
@ -126,9 +126,7 @@ class Expand(Function):
|
|||
|
||||
class Reshape(Function):
|
||||
def forward(self, x, shape):
|
||||
assert len(shape) > 0 and all(x != 0 for x in shape), f"zeros not allowed in shape {shape}"
|
||||
self.input_shape = x.shape
|
||||
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
|
||||
return x.movement_op(MovementOps.RESHAPE, shape)
|
||||
|
||||
def backward(self, grad_output):
|
||||
|
|
|
@ -188,7 +188,10 @@ class Tensor:
|
|||
|
||||
# ***** movement mlops *****
|
||||
|
||||
def reshape(self, shape, *args) -> Tensor: return mlops.Reshape.apply(self, shape=argfix(shape, *args))
|
||||
def reshape(self, shape, *args) -> Tensor:
|
||||
new_shape = argfix(shape, *args)
|
||||
assert len(new_shape) > 0 and all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}"
|
||||
return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape))
|
||||
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
|
||||
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
|
||||
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))
|
||||
|
|
Loading…
Reference in New Issue