From 8902764167f346aec68fad8a51660d1b99ac3727 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Thu, 2 Mar 2023 08:15:26 -0800 Subject: [PATCH] fit nits in compare --- README.md | 10 +++++----- tinygrad/jit.py | 1 + tinygrad/mlops.py | 4 ++-- tinygrad/runtime/ops_cpu.py | 2 +- tinygrad/runtime/ops_torch.py | 2 +- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index fd3ebe53..62e787c3 100644 --- a/README.md +++ b/README.md @@ -129,13 +129,13 @@ hlops are syntactic sugar around mlops. They support most things torch does. ### mlops -mlops are mid level ops, there's 16 of them. They understand derivatives. They are very simple. +mlops are mid level ops. They understand derivatives. They are very simple. ``` -Log, Exp # unary ops -Sum, Max # reduce ops (with axis argument) -Maximum, Add, Sub, Mul, Pow, Div # binary ops (no broadcasting, use expand) -Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops +Log, Exp # unary ops +Sum, Max # reduce ops (with axis argument) +Maximum, Add, Sub, Mul, Pow, Div, CompareLess, CompareEqual # binary ops (no broadcasting, use expand) +Expand, Reshape, Permute, Pad, Shrink, Flip # movement ops ``` You no longer need to write mlops for a new accelerator diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 87ab04d9..885e49e7 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -1,6 +1,7 @@ from typing import Callable, List, Tuple, Any, Dict, cast, Union import itertools from tinygrad.helpers import DEBUG, colored + from tinygrad.lazy import Device from tinygrad.tensor import Tensor from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 959a3fe6..75d118a2 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -51,10 +51,10 @@ class Max(Function): return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded) # ************* binary ops ************* + class CompareLess(Function): def forward(self, x, y): - self.ret = x.binary_op(BinaryOps.CMPLT, y) - return self.ret + return x.binary_op(BinaryOps.CMPLT, y) class CompareEqual(Function): def forward(self, x, y): diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 7c7f2a49..f638e92b 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -24,7 +24,7 @@ def einsum_mulacc(einsum, get_strides, expand): numpy_fxn_for_op : Dict[Op, Callable] = {**base_fxn_for_op, **{ UnaryOps.NOOP: lambda x: np.ascontiguousarray(x), UnaryOps.EXP: lambda x: np.exp(x), UnaryOps.LOG: lambda x: np.log(x), - BinaryOps.MAX: np.maximum, BinaryOps.CMPEQ: lambda x,y: (x==y).astype(np.float32), BinaryOps.CMPLT: lambda x, y: (x