move the op casting logic from mlops to tensor try 2 (#2887)

* unary works

* where works

* add sub mul

* xor div

* CMPLT

* sparse_categorical_crossentropy

* image const

* sparse_categorical_crossentropy
This commit is contained in:
chenyu 2023-12-20 23:50:37 -05:00 committed by GitHub
parent 7da2325dc7
commit 8a04107d30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 47 additions and 45 deletions

View File

@ -107,7 +107,7 @@ def Atan(y: Tensor):
def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1): def Trilu(x: Tensor, k: Union[Tensor, int]=0, upper=1):
k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0 k = int(k.numpy().item()) if isinstance(k, Tensor) else 0 # onnx passes k as a tensor int64 with one element, default is 0
return x.triu(k) if upper else x.tril(k) return x.triu(k).cast(dtypes.int64) if upper else x.tril(k).cast(dtypes.int64)
def Squeeze(data: Tensor, axes): def Squeeze(data: Tensor, axes):
if isinstance(axes, Tensor): axes = safe_numpy(axes) if isinstance(axes, Tensor): axes = safe_numpy(axes)
@ -122,7 +122,7 @@ def Unsqueeze(data: Tensor, axes):
new_shape[i] = next(ptr) new_shape[i] = next(ptr)
return data.reshape(new_shape) return data.reshape(new_shape)
def Binarizer(input, threshold=0.0): return input > threshold def Binarizer(input, threshold=0.0): return (input > threshold).cast(dtypes.float32)
def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0): def ArgMax(x: Tensor, axis=0, keepdims=1, select_last_index=0):
axis = axis + x.ndim if axis < 0 else axis axis = axis + x.ndim if axis < 0 else axis

View File

@ -59,7 +59,9 @@ class LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ()) return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else ())
def const(self, val:Union[float, int]) -> LazyBuffer: def const(self, val:Union[float, int]) -> LazyBuffer:
return LazyBuffer.loadop(LoadOps.CONST, (), self.dtype, self.device, val).reshape((1,)*len(self.shape)).expand(self.shape) # NOTE: we force the image dtype const to be a float32
const_dtype = self.dtype if not isinstance(self.dtype, ImageDType) else dtypes.float32
return LazyBuffer.loadop(LoadOps.CONST, tuple(), const_dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
def contiguous(self): def contiguous(self):
if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const(): if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const():

View File

@ -1,6 +1,6 @@
import math import math
from typing import Tuple, Optional, cast from typing import Tuple, Optional, cast
from tinygrad.helpers import argsort, DType, least_upper_float, dtypes, least_upper_dtype from tinygrad.helpers import argsort, DType, dtypes
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
@ -35,7 +35,7 @@ class Neg(Function):
class Sin(Function): class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x self.x = x
return x.cast(least_upper_float(x.dtype)).e(UnaryOps.SIN) return x.e(UnaryOps.SIN)
def backward(self, grad:LazyBuffer) -> LazyBuffer: def backward(self, grad:LazyBuffer) -> LazyBuffer:
return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad) return self.x.const(math.pi / 2).e(BinaryOps.SUB, self.x).e(UnaryOps.SIN).e(BinaryOps.MUL, grad)
@ -47,19 +47,19 @@ class Relu(Function):
return self.ret return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output) return self.ret.const(0).e(BinaryOps.CMPLT, self.ret).cast(grad_output.dtype).e(BinaryOps.MUL, grad_output)
class Log(Function): class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x self.x = x
return x.cast(ftype:= least_upper_float(x.dtype)).e(UnaryOps.LOG2).e(BinaryOps.MUL, x.cast(ftype).const(math.log(2))) return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.e(BinaryOps.DIV, self.x) return grad_output.e(BinaryOps.DIV, self.x)
class Exp(Function): class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.cast(ftype:=least_upper_float(x.dtype)).e(BinaryOps.MUL, x.cast(ftype).const(1/math.log(2))).e(UnaryOps.EXP2) self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
return self.ret return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@ -67,7 +67,7 @@ class Exp(Function):
class Sqrt(Function): class Sqrt(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.cast(least_upper_float(x.dtype)).e(UnaryOps.SQRT) self.ret = x.e(UnaryOps.SQRT)
return self.ret return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
@ -78,7 +78,6 @@ class Sqrt(Function):
# TODO: have the backend automatically find this # TODO: have the backend automatically find this
class Sigmoid(Function): class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer) -> LazyBuffer:
x = x.cast(least_upper_float(x.dtype))
self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2))) self.ret = x.const(1).e(BinaryOps.DIV, x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)))
return self.ret return self.ret
@ -89,60 +88,58 @@ class Sigmoid(Function):
class Less(Function): class Less(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
output_dtype = least_upper_dtype(x.dtype, y.dtype) # in webgpu bool cannot be used as a storage buffer type
return x.cast(output_dtype).e(BinaryOps.CMPLT, y.cast(output_dtype)) return x.e(BinaryOps.CMPLT, y).cast(dtypes.float if self.device == "WEBGPU" else dtypes.bool)
class Xor(Function): class Xor(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
output_dtype = least_upper_dtype(x.dtype, y.dtype) return x.e(BinaryOps.XOR, y)
return x.cast(output_dtype).e(BinaryOps.XOR, y.cast(output_dtype))
class Add(Function): class Add(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype) return x.e(BinaryOps.ADD, y)
return x.cast(output_dtype).e(BinaryOps.ADD, y.cast(output_dtype))
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \ return grad_output if self.needs_input_grad[0] else None, \
grad_output.cast(self.y_dtype) if self.needs_input_grad[1] else None grad_output if self.needs_input_grad[1] else None
class Sub(Function): class Sub(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x_dtype, self.y_dtype, output_dtype = x.dtype, y.dtype, least_upper_dtype(x.dtype, y.dtype) self.x_dtype, self.y_dtype = x.dtype, y.dtype
return x.cast(output_dtype).e(BinaryOps.SUB, y.cast(output_dtype)) return x.e(BinaryOps.SUB, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.cast(self.x_dtype) if self.needs_input_grad[0] else None, \ return grad_output if self.needs_input_grad[0] else None, \
grad_output.cast(self.y_dtype).e(UnaryOps.NEG) if self.needs_input_grad[1] else None grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None
class Mul(Function): class Mul(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype) self.x, self.y = x, y
return x.cast(output_dtype).e(BinaryOps.MUL, y.cast(output_dtype)) return x.e(BinaryOps.MUL, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return self.y.cast(self.x.dtype).e(BinaryOps.MUL, grad_output.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \ return self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.cast(self.y.dtype).e(BinaryOps.MUL, grad_output.cast(self.y.dtype)) if self.needs_input_grad[1] else None self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Div(Function): class Div(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y, output_dtype = x, y, least_upper_dtype(x.dtype, y.dtype) self.x, self.y = x, y
return x.cast(output_dtype).e(BinaryOps.DIV, y.cast(output_dtype)) return x.e(BinaryOps.DIV, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.cast(self.x.dtype).e(BinaryOps.DIV, self.y.cast(self.x.dtype)) if self.needs_input_grad[0] else None, \ return grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
grad_output.cast(self.y.dtype).e(UnaryOps.NEG).e(BinaryOps.MUL, self.x.cast(self.y.dtype)).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501 grad_output.e(UnaryOps.NEG).e(BinaryOps.MUL, self.x).e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None # noqa: E501
# ************* ternary ops ************* # ************* ternary ops *************
class Where(Function): class Where(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer: def forward(self, x:LazyBuffer, y:LazyBuffer, z:LazyBuffer) -> LazyBuffer:
self.x, self.y_dtype, self.z_dtype, output_type = x.cast(dtypes.bool), y.dtype, z.dtype, least_upper_dtype(y.dtype, z.dtype) self.x = x
return self.x.e(TernaryOps.WHERE, y.cast(output_type), z.cast(output_type)) return self.x.e(TernaryOps.WHERE, y, z)
def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]: def backward(self, grad_output:LazyBuffer) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
return None, \ return None, \
self.x.e(TernaryOps.WHERE, grad_output.cast(self.y_dtype), grad_output.cast(self.y_dtype).const(0)) if self.needs_input_grad[1] else None, \ self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0)) if self.needs_input_grad[1] else None, \
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output) if self.needs_input_grad[2] else None
# ************* reduce ops ************* # ************* reduce ops *************

View File

@ -7,7 +7,7 @@ from functools import partialmethod, reduce
from itertools import accumulate from itertools import accumulate
import numpy as np import numpy as np
from tinygrad.helpers import DType, dtypes, ImageDType from tinygrad.helpers import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
from tinygrad.lazy import LazyBuffer, create_schedule from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps
@ -249,7 +249,7 @@ class Tensor:
# fill in the first grad with one. don't use Tensor.ones because we don't need contiguous # fill in the first grad with one. don't use Tensor.ones because we don't need contiguous
# this is "implicit gradient creation" # this is "implicit gradient creation"
self.grad = Tensor(1, device=self.device, requires_grad=False) self.grad = Tensor(1.0, device=self.device, requires_grad=False)
for t0 in reversed(self.deepwalk()): for t0 in reversed(self.deepwalk()):
assert (t0.grad is not None) assert (t0.grad is not None)
@ -670,14 +670,14 @@ class Tensor:
def neg(self): return mlops.Neg.apply(self) def neg(self): return mlops.Neg.apply(self)
def contiguous(self): return mlops.Contiguous.apply(self) def contiguous(self): return mlops.Contiguous.apply(self)
def contiguous_backward(self): return mlops.ContiguousBackward.apply(self) def contiguous_backward(self): return mlops.ContiguousBackward.apply(self)
def log(self): return mlops.Log.apply(self) def log(self): return mlops.Log.apply(self.cast(least_upper_float(self.dtype)))
def log2(self): return mlops.Log.apply(self)/math.log(2) def log2(self): return self.log()/math.log(2)
def exp(self): return mlops.Exp.apply(self) def exp(self): return mlops.Exp.apply(self.cast(least_upper_float(self.dtype)))
def exp2(self): return mlops.Exp.apply(self*math.log(2)) def exp2(self): return mlops.Exp.apply(self*math.log(2))
def relu(self): return mlops.Relu.apply(self) def relu(self): return mlops.Relu.apply(self)
def sigmoid(self): return mlops.Sigmoid.apply(self) def sigmoid(self): return mlops.Sigmoid.apply(self.cast(least_upper_float(self.dtype)))
def sin(self): return mlops.Sin.apply(self) def sin(self): return mlops.Sin.apply(self.cast(least_upper_float(self.dtype)))
def sqrt(self): return mlops.Sqrt.apply(self) def sqrt(self): return mlops.Sqrt.apply(self.cast(least_upper_float(self.dtype)))
def rsqrt(self): return self.reciprocal().sqrt() def rsqrt(self): return self.reciprocal().sqrt()
def cos(self): return ((math.pi/2)-self).sin() def cos(self): return ((math.pi/2)-self).sin()
def tan(self): return self.sin() / self.cos() def tan(self): return self.sin() / self.cos()
@ -729,6 +729,9 @@ class Tensor:
x = x.cast(y_dtype) x = x.cast(y_dtype)
y = Tensor(y, self.device, y_dtype, requires_grad=False) y = Tensor(y, self.device, y_dtype, requires_grad=False)
output_dtype = least_upper_dtype(x.dtype, y.dtype)
x, y = x.cast(output_dtype), y.cast(output_dtype)
if reverse: x, y = y, x if reverse: x, y = y, x
# left pad shape with 1s # left pad shape with 1s
@ -784,7 +787,7 @@ class Tensor:
def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]): def where(self:Tensor, input_:Union[Tensor, float], other:Union[Tensor, float]):
x_,y = self._broadcasted(input_) x_,y = self._broadcasted(input_)
x,z = x_._broadcasted(other) x,z = x_._broadcasted(other)
return mlops.Where.apply(x, *y._broadcasted(z)) return mlops.Where.apply(x.cast(dtypes.bool), *y._broadcasted(z))
# ***** op wrappers (wasted lines to make the typechecker happy) ***** # ***** op wrappers (wasted lines to make the typechecker happy) *****
@ -860,9 +863,9 @@ class Tensor:
def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor: def sparse_categorical_crossentropy(self, Y:Tensor, ignore_index=-1) -> Tensor:
# NOTE: self is a logits input # NOTE: self is a logits input
loss_mask = Y != ignore_index loss_mask = (Y != ignore_index).cast(dtypes.float)
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501 y_counter = Tensor.arange(self.shape[-1], requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1]) # noqa: E501
y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1]) y = ((y_counter == Y.flatten().reshape(-1, 1)).where(-1, 0) * loss_mask.reshape(-1, 1)).reshape(*Y.shape, self.shape[-1])
return self.log_softmax().mul(y).sum() / loss_mask.sum() return self.log_softmax().mul(y).sum() / loss_mask.sum()
# ***** cast ops ***** # ***** cast ops *****