fast mnist indexing (#5921)

* fast mnist indexing

* more tests

* remove those tests, new indexing rule
This commit is contained in:
George Hotz 2024-08-05 13:55:15 -07:00 committed by GitHub
parent e81c18f494
commit 5d17f54e3c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 29 additions and 6 deletions

View File

@ -82,12 +82,12 @@ class TestIndexing(unittest.TestCase):
#assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(real_index, X.numpy())
def test_index_fused(self):
def test_index_fused(self, noopt=1):
dataset = Tensor.rand(16384, 256).realize()
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=1, FUSE_ARANGE=1):
with Context(NOOPT=noopt, FUSE_ARANGE=1):
GlobalCounters.reset()
X = dataset[idxs]
assert X.shape == (4,256)
@ -96,6 +96,23 @@ class TestIndexing(unittest.TestCase):
run_schedule(sched)
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
np.testing.assert_allclose(real_index, X.numpy())
@unittest.skip("not ready")
def test_index_fused_opt(self): self.test_index_fused(0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_index_mnist(self, noopt=1):
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
GlobalCounters.reset()
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
x = X_train[samples].numpy()
y = Y_train[samples].numpy()
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
@unittest.skip("not ready")
def test_index_mnist_opt(self): self.test_index_mnist(0)
if __name__ == "__main__":
unittest.main()

View File

@ -9,7 +9,7 @@ from tinygrad import nn, dtypes
from tinygrad.device import Device
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps, verify_lazyop
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, flatten, getenv
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, flatten, getenv, SPLIT_REDUCEOP
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
@ -496,7 +496,7 @@ class TestSchedule(unittest.TestCase):
check_schedule(out, 2)
# multireduce spec
@unittest.skipUnless(getenv("SPLIT_REDUCEOP", 1), "Testing split reducop requires SPLIT_REDUCEOP")
@unittest.skipUnless(SPLIT_REDUCEOP, "Testing split reducop requires SPLIT_REDUCEOP")
def test_preserve_multistage_reduce(self):
big_enough = getenv("REDUCEOP_SPLIT_THRESHOLD", 32768)
x = Tensor.randn(big_enough).realize()

View File

@ -193,11 +193,16 @@ constant_folder = PatternMatcher([
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
(NOp(UOps.REDUCE, src=(NOp.var('idx').ne(NOp(UOps.RANGE, name="rng")).__neg__().cast()*
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.RANGE, name="rng")), name="ld"),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).where(
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"), NOp.const(None, 0.0)),),
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
# other arange folders
(NOp.cvar("c1") - (NOp.var("x") + NOp.cvar("c2")), lambda c1, c2, x: (c1-c2)-x), # c1 - (x + c2) -> (c1-c2) - x
(-(NOp.var("x") * NOp.cvar("c1")), lambda x, c1: x*-c1),
# max folding
(NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin.arg >= y.vmax.arg else y if x.vmax.arg <= y.vmin.arg else None),
# const rules

View File

@ -108,6 +108,7 @@ GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPAT
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
USE_TC, TC_OPT, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("TRANSCENDENTAL", 1)
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
SPLIT_REDUCEOP = ContextVar("SPLIT_REDUCEOP", 1)
@dataclass(frozen=True)
class Metadata:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from typing import Union, Optional, Any, Tuple, List, get_args
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
@ -184,7 +184,7 @@ class LazyBuffer:
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, axis)