move the op casting logic from mlops to tensor try 2 (#2887)

* unary works

* where works

* add sub mul

* xor div

* CMPLT

* sparse_categorical_crossentropy

* image const

* sparse_categorical_crossentropy
This commit is contained in:
chenyu 2023-12-20 23:50:37 -05:00 committed by GitHub
parent 7da2325dc7
commit 8a04107d30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 45 deletions

View File

@ -107,7 +107,7 @@ def Atan(y: Tensor):
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k)
return x.triu(k).cast(dtypes.int64) if upper else x.tril(k).cast(dtypes.int64)
def Squeeze(data: Tensor, axes):
if isinstance(axes, Tensor): axes = safe_numpy(axes)
@ -122,7 +122,7 @@ def Unsqueeze(data: Tensor, axes):
new_shape[i] = next(ptr)
return data.reshape(new_shape)
def Binarizer(input, threshold=0.0): return input > threshold
def Binarizer(input, threshold=0.0): return (input > threshold).cast(dtypes.float32)
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
axis = axis + x.ndim if axis < 0 else axis

View File

@ -59,7 +59,9 @@ class LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ())
def const(self, val:Union[float, int]) -> LazyBuffer:
return LazyBuffer.loadop(LoadOps.CONST, (), self.dtype, self.device, val).reshape((1,)*len(self.shape)).expand(self.shape)
# NOTE: we force the image dtype const to be a float32
const_dtype = self.dtype if not isinstance(self.dtype, ImageDType) else dtypes.float32
return LazyBuffer.loadop(LoadOps.CONST, tuple(), const_dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
def contiguous(self):
if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const():

View File

@ -1,6 +1,6 @@
import math
from typing import Tuple, Optional, cast
from tinygrad.helpers import argsort, DType, least_upper_float, dtypes, least_upper_dtype
from tinygrad.helpers import argsort, DType, dtypes
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
@ -35,7 +35,7 @@ class Neg(Function):
class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.cast(least_upper_float(x.dtype)).e(UnaryOps.SIN)
return x.e(UnaryOps.SIN)
def backward(self, grad:LazyBuffer) -> LazyBuffer:
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
@ -47,19 +47,19 @@ class Relu(Function):
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.cast(ftype:= least_upper_float(x.dtype)).e(UnaryOps.LOG2).e(BinaryOps.MUL, x.cast(ftype).const(math.log(2)))
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(BinaryOps.DIV, self.x)
class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.cast(ftype:=least_upper_float(x.dtype)).e(BinaryOps.MUL, x.cast(ftype).const(1/math.log(2))).e(UnaryOps.EXP2)
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@ -67,7 +67,7 @@ class Exp(Function):
class Sqrt(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.cast(least_upper_float(x.dtype)).e(UnaryOps.SQRT)
self.ret = x.e(UnaryOps.SQRT)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@ -78,7 +78,6 @@ class Sqrt(Function):
# TODO: have the backend automatically find this
class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
x = x.cast(least_upper_float(x.dtype))
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
return self.ret
@ -89,60 +88,58 @@ class Sigmoid(Function):
class Less(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
return x.cast(output_dtype).e(BinaryOps.CMPLT, y.cast(output_dtype))
# in webgpu bool cannot be used as a storage buffer type
return x.e(BinaryOps.CMPLT, y).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool)
class Xor(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
output_dtype = least_upper_dtype(x.dtype, y.dtype)
return x.cast(output_dtype).e(BinaryOps.XOR, y.cast(output_dtype))
return x.e(BinaryOps.XOR, y)
class Add(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype)
return x.cast(output_dtype).e(BinaryOps.ADD, y.cast(output_dtype))
return x.e(BinaryOps.ADD, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \
grad_output.cast(self.y_dtype) if self.needs_input_grad[1] else None
return grad_output if self.needs_input_grad[0] else None, \
grad_output if self.needs_input_grad[1] else None
class Sub(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype)
return x.cast(output_dtype).e(BinaryOps.SUB, y.cast(output_dtype))
self.x_dtype, self.y_dtype = x.dtype, y.dtype
return x.e(BinaryOps.SUB, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \
grad_output.cast(self.y_dtype).e(UnaryOps.NEG) if self.needs_input_grad[1] else None
return grad_output if self.needs_input_grad[0] else None, \
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
class Mul(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype)
return x.cast(output_dtype).e(BinaryOps.MUL, y.cast(output_dtype))
self.x, self.y = x, y
return x.e(BinaryOps.MUL, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return self.y.cast(self.x.dtype).e(BinaryOps.MUL, grad_output.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \
self.x.cast(self.y.dtype).e(BinaryOps.MUL, grad_output.cast(self.y.dtype)) if self.needs_input_grad[1] else None
return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Div(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype)
return x.cast(output_dtype).e(BinaryOps.DIV, y.cast(output_dtype))
self.x, self.y = x, y
return x.e(BinaryOps.DIV, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.cast(self.x.dtype).e(BinaryOps.DIV, self.y.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \
grad_output.cast(self.y.dtype).e(UnaryOps.NEG).e(BinaryOps.MUL, self.x.cast(self.y.dtype)).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
# ************* ternary ops *************
class Where(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
self.x, self.y_dtype, self.z_dtype, output_type = x.cast(dtypes.bool), y.dtype, z.dtype, least_upper_dtype(y.dtype, z.dtype)
return self.x.e(TernaryOps.WHERE, y.cast(output_type), z.cast(output_type))
self.x = x
return self.x.e(TernaryOps.WHERE, y, z)
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
return None, \
self.x.e(TernaryOps.WHERE, grad_output.cast(self.y_dtype), grad_output.cast(self.y_dtype).const(0)) if self.needs_input_grad[1] else None, \
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
# ************* reduce ops *************

View File

@ -7,7 +7,7 @@ from functools import partialmethod, reduce
from itertools import accumulate
import numpy as np
from tinygrad.helpers import DType, dtypes, ImageDType
from tinygrad.helpers import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.ops import LoadOps
@ -249,7 +249,7 @@ class Tensor:
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation"
self.grad = Tensor(1, device=self.device, requires_grad=False)
self.grad = Tensor(1.0, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()):
assert (t0.grad is not None)
@ -670,14 +670,14 @@ class Tensor:
def neg(self): return mlops.Neg.apply(self)
def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self)
def log2(self): return mlops.Log.apply(self)/math.log(2)
def exp(self): return mlops.Exp.apply(self)
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
def log2(self): return self.log()/math.log(2)
def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self): return mlops.Exp.apply(self*math.log(2))
def relu(self): return mlops.Relu.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self)
def sin(self): return mlops.Sin.apply(self)
def sqrt(self): return mlops.Sqrt.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self): return self.reciprocal().sqrt()
def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos()
@ -729,6 +729,9 @@ class Tensor:
x = x.cast(y_dtype)
y = Tensor(y, self.device, y_dtype, requires_grad=False)
output_dtype = least_upper_dtype(x.dtype, y.dtype)
x, y = x.cast(output_dtype), y.cast(output_dtype)
if reverse: x, y = y, x
# left pad shape with 1s
@ -784,7 +787,7 @@ class Tensor:
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
x_,y = self._broadcasted(input_)
x,z = x_._broadcasted(other)
return mlops.Where.apply(x, *y._broadcasted(z))
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
# ***** op wrappers (wasted lines to make the typechecker happy) *****
@ -860,9 +863,9 @@ class Tensor:
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
# NOTE: self is a logits input
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
loss_mask = (Y != ignore_index).cast(dtypes.float)
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum()
# ***** cast ops *****