fit nits in compare

This commit is contained in:
George Hotz 2023-03-02 08:15:26 -08:00
parent 52204a7b88
commit 8902764167
5 changed files with 10 additions and 9 deletions

View File

@ -129,13 +129,13 @@ hlops are syntactic sugar around mlops. They support most things torch does.
### mlops ### mlops
mlops are mid level ops, there's 16 of them. They understand derivatives. They are very simple. mlops are mid level ops. They understand derivatives. They are very simple.
``` ```
Log, Exp # unary ops Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument) Sum, Max # reduce ops (with axis argument)
Maximum, Add, Sub, Mul, Pow, Div # binary ops (no broadcasting, use expand) Maximum, Add, Sub, Mul, Pow, Div, CompareLess, CompareEqual # binary ops (no broadcasting, use expand)
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
``` ```
You no longer need to write mlops for a new accelerator You no longer need to write mlops for a new accelerator

View File

@ -1,6 +1,7 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union from typing import Callable, List, Tuple, Any, Dict, cast, Union
import itertools import itertools
from tinygrad.helpers import DEBUG, colored from tinygrad.helpers import DEBUG, colored
from tinygrad.lazy import Device from tinygrad.lazy import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer

View File

@ -51,10 +51,10 @@ class Max(Function):
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded) return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
# ************* binary ops ************* # ************* binary ops *************
class CompareLess(Function): class CompareLess(Function):
def forward(self, x, y): def forward(self, x, y):
self.ret = x.binary_op(BinaryOps.CMPLT, y) return x.binary_op(BinaryOps.CMPLT, y)
return self.ret
class CompareEqual(Function): class CompareEqual(Function):
def forward(self, x, y): def forward(self, x, y):

View File

@ -24,7 +24,7 @@ def einsum_mulacc(einsum, get_strides, expand):
numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{ numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: np.ascontiguousarray(x), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), UnaryOps.NOOP: lambda x: np.ascontiguousarray(x), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), BinaryOps.CMPLT: lambda x, y: (x<y).astype(np.float32), BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.float32),
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order), MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape), MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to) FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to)

View File

@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{ torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), BinaryOps.CMPLT: lambda x, y:(x<y).float(), BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), BinaryOps.CMPLT: lambda x,y: (x<y).float(),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]), MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
FusedOps.MULACC: einsum_mulacc(torch.einsum, lambda x: x.stride(), lambda x,s: x.expand(s)) FusedOps.MULACC: einsum_mulacc(torch.einsum, lambda x: x.stride(), lambda x,s: x.expand(s))
}} }}