mirror of https://github.com/commaai/tinygrad.git
fast mnist indexing (#5921)
* fast mnist indexing * more tests * remove those tests, new indexing rule
This commit is contained in:
parent
e81c18f494
commit
5d17f54e3c
|
@ -82,12 +82,12 @@ class TestIndexing(unittest.TestCase):
|
|||
#assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
|
||||
def test_index_fused(self):
|
||||
def test_index_fused(self, noopt=1):
|
||||
dataset = Tensor.rand(16384, 256).realize()
|
||||
idxs = Tensor([0,3,5,6]).realize()
|
||||
real_index = dataset.numpy()[idxs.numpy()]
|
||||
print("*** indexing ***")
|
||||
with Context(NOOPT=1, FUSE_ARANGE=1):
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1):
|
||||
GlobalCounters.reset()
|
||||
X = dataset[idxs]
|
||||
assert X.shape == (4,256)
|
||||
|
@ -96,6 +96,23 @@ class TestIndexing(unittest.TestCase):
|
|||
run_schedule(sched)
|
||||
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
|
||||
np.testing.assert_allclose(real_index, X.numpy())
|
||||
@unittest.skip("not ready")
|
||||
def test_index_fused_opt(self): self.test_index_fused(0)
|
||||
|
||||
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
|
||||
def test_index_mnist(self, noopt=1):
|
||||
from tinygrad.nn.datasets import mnist
|
||||
X_train, Y_train, _, _ = mnist()
|
||||
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
|
||||
GlobalCounters.reset()
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
x = X_train[samples].numpy()
|
||||
y = Y_train[samples].numpy()
|
||||
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
|
||||
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
|
||||
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
|
||||
@unittest.skip("not ready")
|
||||
def test_index_mnist_opt(self): self.test_index_mnist(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -9,7 +9,7 @@ from tinygrad import nn, dtypes
|
|||
from tinygrad.device import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, ReduceOps, UnaryOps, verify_lazyop
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, flatten, getenv
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, flatten, getenv, SPLIT_REDUCEOP
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
@ -496,7 +496,7 @@ class TestSchedule(unittest.TestCase):
|
|||
check_schedule(out, 2)
|
||||
|
||||
# multireduce spec
|
||||
@unittest.skipUnless(getenv("SPLIT_REDUCEOP", 1), "Testing split reducop requires SPLIT_REDUCEOP")
|
||||
@unittest.skipUnless(SPLIT_REDUCEOP, "Testing split reducop requires SPLIT_REDUCEOP")
|
||||
def test_preserve_multistage_reduce(self):
|
||||
big_enough = getenv("REDUCEOP_SPLIT_THRESHOLD", 32768)
|
||||
x = Tensor.randn(big_enough).realize()
|
||||
|
|
|
@ -193,11 +193,16 @@ constant_folder = PatternMatcher([
|
|||
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).cast()*
|
||||
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('idx').ne(NOp(UOps.RANGE, name="rng")).__neg__().cast()*
|
||||
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp(UOps.RANGE, name="rng")), name="ld"),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True),
|
||||
lambda **kwargs: index_collapse(add=UOp.const(dtypes.int, 0), mul=UOp.const(dtypes.int, 1), **kwargs)),
|
||||
(NOp(UOps.REDUCE, src=(NOp.var('idx').eq(NOp(UOps.RANGE, name="rng")).where(
|
||||
NOp(UOps.LOAD, src=(NOp.var("buf"), NOp.var('add')+NOp.var('mul')*NOp(UOps.RANGE, name="rng")), name="ld"), NOp.const(None, 0.0)),),
|
||||
arg=BinaryOps.ADD, name="reduce", allow_any_len=True), index_collapse),
|
||||
# other arange folders
|
||||
(NOp.cvar("c1") - (NOp.var("x") + NOp.cvar("c2")), lambda c1, c2, x: (c1-c2)-x), # c1 - (x + c2) -> (c1-c2) - x
|
||||
(-(NOp.var("x") * NOp.cvar("c1")), lambda x, c1: x*-c1),
|
||||
# max folding
|
||||
(NOp.max(NOp.var('x'), NOp.var('y')), lambda x,y: x if x.vmin.arg >= y.vmax.arg else y if x.vmax.arg <= y.vmin.arg else None),
|
||||
# const rules
|
||||
|
|
|
@ -108,6 +108,7 @@ GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPAT
|
|||
MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json"))
|
||||
USE_TC, TC_OPT, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("TRANSCENDENTAL", 1)
|
||||
FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_BW", 0)
|
||||
SPLIT_REDUCEOP = ContextVar("SPLIT_REDUCEOP", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from __future__ import annotations
|
||||
from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, DTypeLike, ConstType, to_dtype
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, reduce_st
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
@ -184,7 +184,7 @@ class LazyBuffer:
|
|||
return self.const(self.base.arg * {ReduceOps.SUM: prod(self.shape[i] for i in axis), ReduceOps.MAX: 1}[op], new_shape)
|
||||
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not getenv("SPLIT_REDUCEOP", 1) or not all_int(self.shape) or (0 in self.shape) or \
|
||||
if not SPLIT_REDUCEOP or not all_int(self.shape) or (0 in self.shape) or \
|
||||
prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
||||
return self._reduce_op(op, axis)
|
||||
|
||||
|
|
Loading…
Reference in New Issue