mirror of https://github.com/commaai/tinygrad.git
move the op casting logic from mlops to tensor try 2 (#2887)
* unary works * where works * add sub mul * xor div * CMPLT * sparse_categorical_crossentropy * image const * sparse_categorical_crossentropy
This commit is contained in:
parent
7da2325dc7
commit
8a04107d30
|
@ -107,7 +107,7 @@ def Atan(y: Tensor):
|
|||
|
||||
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
|
||||
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
|
||||
return x.triu(k) if upper else x.tril(k)
|
||||
return x.triu(k).cast(dtypes.int64) if upper else x.tril(k).cast(dtypes.int64)
|
||||
|
||||
def Squeeze(data: Tensor, axes):
|
||||
if isinstance(axes, Tensor): axes = safe_numpy(axes)
|
||||
|
@ -122,7 +122,7 @@ def Unsqueeze(data: Tensor, axes):
|
|||
new_shape[i] = next(ptr)
|
||||
return data.reshape(new_shape)
|
||||
|
||||
def Binarizer(input, threshold=0.0): return input > threshold
|
||||
def Binarizer(input, threshold=0.0): return (input > threshold).cast(dtypes.float32)
|
||||
|
||||
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
|
||||
axis = axis + x.ndim if axis < 0 else axis
|
||||
|
|
|
@ -59,7 +59,9 @@ class LazyBuffer:
|
|||
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ())
|
||||
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
return LazyBuffer.loadop(LoadOps.CONST, (), self.dtype, self.device, val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
# NOTE: we force the image dtype const to be a float32
|
||||
const_dtype = self.dtype if not isinstance(self.dtype, ImageDType) else dtypes.float32
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), const_dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
|
||||
def contiguous(self):
|
||||
if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const():
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import math
|
||||
from typing import Tuple, Optional, cast
|
||||
from tinygrad.helpers import argsort, DType, least_upper_float, dtypes, least_upper_dtype
|
||||
from tinygrad.helpers import argsort, DType, dtypes
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
|
@ -35,7 +35,7 @@ class Neg(Function):
|
|||
class Sin(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.cast(least_upper_float(x.dtype)).e(UnaryOps.SIN)
|
||||
return x.e(UnaryOps.SIN)
|
||||
|
||||
def backward(self, grad:LazyBuffer) -> LazyBuffer:
|
||||
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
|
||||
|
@ -47,19 +47,19 @@ class Relu(Function):
|
|||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
|
||||
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Log(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.cast(ftype:= least_upper_float(x.dtype)).e(UnaryOps.LOG2).e(BinaryOps.MUL, x.cast(ftype).const(math.log(2)))
|
||||
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return grad_output.e(BinaryOps.DIV, self.x)
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.cast(ftype:=least_upper_float(x.dtype)).e(BinaryOps.MUL, x.cast(ftype).const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
@ -67,7 +67,7 @@ class Exp(Function):
|
|||
|
||||
class Sqrt(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.cast(least_upper_float(x.dtype)).e(UnaryOps.SQRT)
|
||||
self.ret = x.e(UnaryOps.SQRT)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
@ -78,7 +78,6 @@ class Sqrt(Function):
|
|||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
x = x.cast(least_upper_float(x.dtype))
|
||||
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
|
||||
return self.ret
|
||||
|
||||
|
@ -89,60 +88,58 @@ class Sigmoid(Function):
|
|||
|
||||
class Less(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.CMPLT, y.cast(output_dtype))
|
||||
# in webgpu bool cannot be used as a storage buffer type
|
||||
return x.e(BinaryOps.CMPLT, y).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool)
|
||||
|
||||
class Xor(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.XOR, y.cast(output_dtype))
|
||||
return x.e(BinaryOps.XOR, y)
|
||||
|
||||
class Add(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.ADD, y.cast(output_dtype))
|
||||
return x.e(BinaryOps.ADD, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \
|
||||
grad_output.cast(self.y_dtype) if self.needs_input_grad[1] else None
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output if self.needs_input_grad[1] else None
|
||||
|
||||
class Sub(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.SUB, y.cast(output_dtype))
|
||||
self.x_dtype, self.y_dtype = x.dtype, y.dtype
|
||||
return x.e(BinaryOps.SUB, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \
|
||||
grad_output.cast(self.y_dtype).e(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
||||
return grad_output if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
|
||||
|
||||
class Mul(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.MUL, y.cast(output_dtype))
|
||||
self.x, self.y = x, y
|
||||
return x.e(BinaryOps.MUL, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return self.y.cast(self.x.dtype).e(BinaryOps.MUL, grad_output.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \
|
||||
self.x.cast(self.y.dtype).e(BinaryOps.MUL, grad_output.cast(self.y.dtype)) if self.needs_input_grad[1] else None
|
||||
return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
|
||||
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype)
|
||||
return x.cast(output_dtype).e(BinaryOps.DIV, y.cast(output_dtype))
|
||||
self.x, self.y = x, y
|
||||
return x.e(BinaryOps.DIV, y)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return grad_output.cast(self.x.dtype).e(BinaryOps.DIV, self.y.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \
|
||||
grad_output.cast(self.y.dtype).e(UnaryOps.NEG).e(BinaryOps.MUL, self.x.cast(self.y.dtype)).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
|
||||
return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
|
||||
grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
class Where(Function):
|
||||
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
|
||||
self.x, self.y_dtype, self.z_dtype, output_type = x.cast(dtypes.bool), y.dtype, z.dtype, least_upper_dtype(y.dtype, z.dtype)
|
||||
return self.x.e(TernaryOps.WHERE, y.cast(output_type), z.cast(output_type))
|
||||
self.x = x
|
||||
return self.x.e(TernaryOps.WHERE, y, z)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
|
||||
return None, \
|
||||
self.x.e(TernaryOps.WHERE, grad_output.cast(self.y_dtype), grad_output.cast(self.y_dtype).const(0)) if self.needs_input_grad[1] else None, \
|
||||
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
|
||||
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
|
|
@ -7,7 +7,7 @@ from functools import partialmethod, reduce
|
|||
from itertools import accumulate
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.helpers import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
from tinygrad.ops import LoadOps
|
||||
|
@ -249,7 +249,7 @@ class Tensor:
|
|||
|
||||
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
|
||||
# this is "implicit gradient creation"
|
||||
self.grad = Tensor(1, device=self.device, requires_grad=False)
|
||||
self.grad = Tensor(1.0, device=self.device, requires_grad=False)
|
||||
|
||||
for t0 in reversed(self.deepwalk()):
|
||||
assert (t0.grad is not None)
|
||||
|
@ -670,14 +670,14 @@ class Tensor:
|
|||
def neg(self): return mlops.Neg.apply(self)
|
||||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def log2(self): return mlops.Log.apply(self)/math.log(2)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def log2(self): return self.log()/math.log(2)
|
||||
def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def exp2(self): return mlops.Exp.apply(self*math.log(2))
|
||||
def relu(self): return mlops.Relu.apply(self)
|
||||
def sigmoid(self): return mlops.Sigmoid.apply(self)
|
||||
def sin(self): return mlops.Sin.apply(self)
|
||||
def sqrt(self): return mlops.Sqrt.apply(self)
|
||||
def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
|
||||
def rsqrt(self): return self.reciprocal().sqrt()
|
||||
def cos(self): return ((math.pi/2)-self).sin()
|
||||
def tan(self): return self.sin() / self.cos()
|
||||
|
@ -729,6 +729,9 @@ class Tensor:
|
|||
x = x.cast(y_dtype)
|
||||
y = Tensor(y, self.device, y_dtype, requires_grad=False)
|
||||
|
||||
output_dtype = least_upper_dtype(x.dtype, y.dtype)
|
||||
x, y = x.cast(output_dtype), y.cast(output_dtype)
|
||||
|
||||
if reverse: x, y = y, x
|
||||
|
||||
# left pad shape with 1s
|
||||
|
@ -784,7 +787,7 @@ class Tensor:
|
|||
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
|
||||
x_,y = self._broadcasted(input_)
|
||||
x,z = x_._broadcasted(other)
|
||||
return mlops.Where.apply(x, *y._broadcasted(z))
|
||||
return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
|
||||
|
||||
# ***** op wrappers (wasted lines to make the typechecker happy) *****
|
||||
|
||||
|
@ -860,9 +863,9 @@ class Tensor:
|
|||
|
||||
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
|
||||
# NOTE: self is a logits input
|
||||
loss_mask = Y != ignore_index
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
loss_mask = (Y != ignore_index).cast(dtypes.float)
|
||||
y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501
|
||||
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
|
||||
return self.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||
|
||||
# ***** cast ops *****
|
||||
|
|
Loading…
Reference in New Issue