fit nits in compare

This commit is contained in:
George Hotz 2023-03-02 08:15:26 -08:00
parent 52204a7b88
commit 8902764167
5 changed files with 10 additions and 9 deletions

View File

@ -129,13 +129,13 @@ hlops are syntactic sugar around mlops. They support most things torch does.
### mlops
mlops are mid level ops, there's 16 of them. They understand derivatives. They are very simple.
mlops are mid level ops. They understand derivatives. They are very simple.
```
Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Maximum, Add, Sub, Mul, Pow, Div # binary ops (no broadcasting, use expand)
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
Log, Exp # unary ops
Sum, Max # reduce ops (with axis argument)
Maximum, Add, Sub, Mul, Pow, Div, CompareLess, CompareEqual # binary ops (no broadcasting, use expand)
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
```
You no longer need to write mlops for a new accelerator

View File

@ -1,6 +1,7 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union
import itertools
from tinygrad.helpers import DEBUG, colored
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer

View File

@ -51,10 +51,10 @@ class Max(Function):
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
# ************* binary ops *************
class CompareLess(Function):
def forward(self, x, y):
self.ret = x.binary_op(BinaryOps.CMPLT, y)
return self.ret
return x.binary_op(BinaryOps.CMPLT, y)
class CompareEqual(Function):
def forward(self, x, y):

View File

@ -24,7 +24,7 @@ def einsum_mulacc(einsum, get_strides, expand):
numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: np.ascontiguousarray(x), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x),
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), BinaryOps.CMPLT: lambda x, y: (x<y).astype(np.float32),
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.float32),
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to)

View File

@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), BinaryOps.CMPLT: lambda x, y:(x<y).float(),
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), BinaryOps.CMPLT: lambda x,y: (x<y).float(),
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
FusedOps.MULACC: einsum_mulacc(torch.einsum, lambda x: x.stride(), lambda x,s: x.expand(s))
}}