bump tinygrad to 0.5, move reshape logic from mlops

This commit is contained in:
George Hotz 2023-02-28 18:07:03 -08:00
parent e9e71fbfc4
commit ea3fa07c2a
3 changed files with 6 additions and 5 deletions

View File

@ -8,7 +8,7 @@ with open(os.path.join(directory, 'README.md'), encoding='utf-8') as f:
long_description = f.read()
setup(name='tinygrad',
version='0.4.0',
version='0.5.0',
description='You like pytorch? You like micrograd? You love tinygrad! heart',
author='George Hotz',
license='MIT',

View File

@ -1,4 +1,4 @@
from tinygrad.helpers import prod, argsort
from tinygrad.helpers import argsort
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps
from tinygrad.tensor import Function
@ -126,9 +126,7 @@ class Expand(Function):
class Reshape(Function):
def forward(self, x, shape):
assert len(shape) > 0 and all(x != 0 for x in shape), f"zeros not allowed in shape {shape}"
self.input_shape = x.shape
shape = tuple(-prod(x.shape) // prod(shape) if s == -1 else s for s in shape)
return x.movement_op(MovementOps.RESHAPE, shape)
def backward(self, grad_output):

View File

@ -188,7 +188,10 @@ class Tensor:
# ***** movement mlops *****
def reshape(self, shape, *args) -> Tensor: return mlops.Reshape.apply(self, shape=argfix(shape, *args))
def reshape(self, shape, *args) -> Tensor:
new_shape = argfix(shape, *args)
assert len(new_shape) > 0 and all(x != 0 for x in new_shape), f"zeros not allowed in shape {new_shape}"
return mlops.Reshape.apply(self, shape=tuple(-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape))
def expand(self, shape, *args) -> Tensor: return mlops.Expand.apply(self, shape=tuple(x if x != -1 else s for s,x in zip(self.shape, argfix(shape, *args))))
def permute(self, order, *args) -> Tensor: return mlops.Permute.apply(self, order=argfix(order, *args))
def flip(self, axis, *args) -> Tensor: return mlops.Flip.apply(self, axis=argfix(axis, *args))