mirror of https://github.com/commaai/tinygrad.git
fit nits in compare
This commit is contained in:
parent
52204a7b88
commit
8902764167
10
README.md
10
README.md
|
@ -129,13 +129,13 @@ hlops are syntactic sugar around mlops. They support most things torch does.
|
||||||
|
|
||||||
### mlops
|
### mlops
|
||||||
|
|
||||||
mlops are mid level ops, there's 16 of them. They understand derivatives. They are very simple.
|
mlops are mid level ops. They understand derivatives. They are very simple.
|
||||||
|
|
||||||
```
|
```
|
||||||
Log, Exp # unary ops
|
Log, Exp # unary ops
|
||||||
Sum, Max # reduce ops (with axis argument)
|
Sum, Max # reduce ops (with axis argument)
|
||||||
Maximum, Add, Sub, Mul, Pow, Div # binary ops (no broadcasting, use expand)
|
Maximum, Add, Sub, Mul, Pow, Div, CompareLess, CompareEqual # binary ops (no broadcasting, use expand)
|
||||||
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
|
Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops
|
||||||
```
|
```
|
||||||
|
|
||||||
You no longer need to write mlops for a new accelerator
|
You no longer need to write mlops for a new accelerator
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union
|
from typing import Callable, List, Tuple, Any, Dict, cast, Union
|
||||||
import itertools
|
import itertools
|
||||||
from tinygrad.helpers import DEBUG, colored
|
from tinygrad.helpers import DEBUG, colored
|
||||||
|
|
||||||
from tinygrad.lazy import Device
|
from tinygrad.lazy import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer
|
from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer
|
||||||
|
|
|
@ -51,10 +51,10 @@ class Max(Function):
|
||||||
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
|
||||||
|
|
||||||
# ************* binary ops *************
|
# ************* binary ops *************
|
||||||
|
|
||||||
class CompareLess(Function):
|
class CompareLess(Function):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
self.ret = x.binary_op(BinaryOps.CMPLT, y)
|
return x.binary_op(BinaryOps.CMPLT, y)
|
||||||
return self.ret
|
|
||||||
|
|
||||||
class CompareEqual(Function):
|
class CompareEqual(Function):
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
|
|
|
@ -24,7 +24,7 @@ def einsum_mulacc(einsum, get_strides, expand):
|
||||||
|
|
||||||
numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||||
UnaryOps.NOOP: lambda x: np.ascontiguousarray(x), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x),
|
UnaryOps.NOOP: lambda x: np.ascontiguousarray(x), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x),
|
||||||
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), BinaryOps.CMPLT: lambda x, y: (x<y).astype(np.float32),
|
BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), BinaryOps.CMPLT: lambda x,y: (x<y).astype(np.float32),
|
||||||
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
|
MovementOps.FLIP: lambda x, axis: np.flip(x, axis), MovementOps.PERMUTE: lambda x, order: x.transpose(order),
|
||||||
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
|
MovementOps.PAD: lambda x, padding: np.pad(x, padding), MovementOps.EXPAND: lambda x, new_shape: np.broadcast_to(x, new_shape),
|
||||||
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to)
|
FusedOps.MULACC: einsum_mulacc(lambda s,a,b: np.einsum(s, a.copy(), b.copy()), lambda x: x.strides, np.broadcast_to)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.runtime.ops_cpu import base_fxn_for_op, einsum_mulacc
|
||||||
|
|
||||||
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
torch_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{
|
||||||
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
|
UnaryOps.NOOP: lambda x: x.contiguous(), UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(),
|
||||||
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), BinaryOps.CMPLT: lambda x, y:(x<y).float(),
|
BinaryOps.MAX: torch.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).float(), BinaryOps.CMPLT: lambda x,y: (x<y).float(),
|
||||||
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
MovementOps.PAD: lambda x, padding: torch.nn.functional.pad(x, [item for sublist in padding[::-1] for item in sublist]),
|
||||||
FusedOps.MULACC: einsum_mulacc(torch.einsum, lambda x: x.stride(), lambda x,s: x.expand(s))
|
FusedOps.MULACC: einsum_mulacc(torch.einsum, lambda x: x.stride(), lambda x,s: x.expand(s))
|
||||||
}}
|
}}
|
||||||
|
|
Loading…
Reference in New Issue