remove expr_idxs [run_process_replay] (#6567)

* remove expr_idxs [run_process_replay]

* goodbye that test
This commit is contained in:
George Hotz 2024-09-17 18:34:51 +08:00 committed by GitHub
parent 9ebbedc37f
commit 67a03e72bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18 additions and 81 deletions

View File

@ -130,8 +130,6 @@ jobs:
run: |
PYTHONPATH="." python test/external/fuzz_shapetracker.py
PYTHONPATH="." python test/external/fuzz_shapetracker_math.py
- name: Test to_movement_ops
run: PYTHONPATH="." python extra/to_movement_ops.py
- name: Use as an external package
run: |
mkdir $HOME/test_external_dir

View File

@ -51,10 +51,6 @@ class TestConvShapetracker(unittest.TestCase):
print(i, i1, i2, si.inputs[0].size, i1==i2)
#self.assertEqual(i1, i2)
for stt in [st, test_st]:
s,va = stt.expr_idxs()
print(s)
print(va)
with self.assertRaises(AssertionError):
assert len(st.views) <= 2

View File

@ -10,18 +10,6 @@ class TestSymbolic(unittest.TestCase):
assert st.shape == (x, 3)
assert st.real_strides() == (3, 1)
def test_expr_idxs(self):
x = Variable("x", 1, 100)
st = ShapeTracker.from_shape((x, 3))
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((x*3)+y)"
assert e2.render() == "1"
st = st.permute((1, 0))
e1, e2 = st.expr_idxs(idxs)
assert e1.render() == "((y*3)+x)"
assert e2.render() == "1"
@unittest.expectedFailure
def test_real_strides_0(self):
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
@ -230,22 +218,6 @@ class TestSymbolicPad(unittest.TestCase):
assert t.shape == (9,)
st = t.lazydata.st
print(st)
# TODO: fix this, required for symbolic arange
with self.assertRaises(RuntimeError):
st.expr_idxs()
class TestSymbolicShapeExpr(unittest.TestCase):
def test_symbolic_expr_idxs(self):
# taken from symbolic shape llama
i = Variable("i", 1, 120)
gidx0 = Variable("gidx0", 0, i)
lidx1 = Variable("lidx1", 0, 7)
idx = (gidx0, lidx1, NumNode(1))
shape = (i+1, 8, 4)
strides = (1, (i*4)+4, i+1)
st = ShapeTracker((View.create(shape, strides), ))
idx, _valid = st.expr_idxs(idx)
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
if __name__ == '__main__':
unittest.main()

View File

@ -1,16 +1,19 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.helpers import prod, DEBUG
from tinygrad.dtype import dtypes
from tinygrad.helpers import prod
from tinygrad.shape.shapetracker import ShapeTracker, View
from tinygrad.shape.symbolic import Variable, NumNode
from tinygrad.ops import UOp, UOps, graph_rewrite
from tinygrad.codegen.uopgraph import constant_folder
from itertools import product
def shapetracker_getitem(st, val):
_locals = {"idx0": val, "valid": 1}
idx, valid = st.reshape((st.size,)).expr_idxs()
exec(f"valid={valid.render()};idx0={idx.render()}", None, _locals)
return _locals["idx0"] if _locals["valid"] else -1
def shapetracker_getitem(st:ShapeTracker, val:int):
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)])
idx, valid = graph_rewrite(idx, constant_folder), graph_rewrite(valid, constant_folder)
assert idx.op is UOps.CONST and valid.op is UOps.CONST
return idx.arg, valid.arg
class CheckingShapeTracker:
def __init__(self, shape):
@ -70,10 +73,8 @@ class CheckingShapeTracker:
def contiguous(self): return self.st.contiguous
def assert_same(self):
x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))]
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))]
y = [self[i] for i in range(prod(self.shape))]
idx, valid = self.st.expr_idxs()
if DEBUG >= 1: print(x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st)
assert self.st.shape == self.shape
assert x == y, f"mismatch shapetracker:{x} real:{y}"
@ -163,7 +164,6 @@ class TestIndexExpressions2d(unittest.TestCase):
def tearDown(self):
for st, offset, shape, idxs_expr in zip(self.sts, self.offset, self.shapes, self.idxs_exprs):
numel = prod(shape)
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0]
self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel)
idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)]
idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)]
@ -171,7 +171,6 @@ class TestIndexExpressions2d(unittest.TestCase):
(st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s]
for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s):
idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None]
assert idxs_expr(idxs) == st.expr_idxs(idxs)[0]
self.check_bounds(idxs_expr(idxs), offset, numel)
def default_idx(self, shape):
@ -786,14 +785,6 @@ class TestShapeTrackerSize(unittest.TestCase):
strides=(0, 128, 0, 4096, 1), offset=0, mask=None, contiguous=False)))
self.assertEqual(st.real_size(), 8389632)
class TestIdxs(unittest.TestCase):
def test_check_idx_range(self):
# generated from: (Tensor.rand(4096,599*64) @ Tensor.rand(599*64,1024)).realize()
# TODO: use int64
st = ShapeTracker(views=(View(shape=(4096, 1024, 599, 1), strides=(613376, 599, 1, 0), offset=0, mask=None, contiguous=True),))
with self.assertRaises(AssertionError):
st.expr_idxs()
class TestConsecutive(unittest.TestCase):
@classmethod
def setUpClass(self):

View File

@ -3,7 +3,8 @@ from typing import List
from tinygrad.helpers import prod
from tinygrad.shape.view import View
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.shape.symbolic import Variable
from test.unit.test_shapetracker import shapetracker_getitem
class MultiShapeTracker:
def __init__(self, sts:List[ShapeTracker]): self.sts = sts
@ -19,14 +20,9 @@ class MultiShapeTracker:
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
if st1.shape != st2.shape: return False
if st1 == st2: return True
idx = Variable("idx", 0, prod(st1.shape)-1)
st1_idx, st1_valid = st1.reshape((st1.size,)).expr_idxs([idx])
st2_idx, st2_valid = st2.reshape((st2.size,)).expr_idxs([idx])
for i in range(idx.min, idx.max + 1):
st1_off = sym_infer(st1_idx, {idx: i})
st2_off = sym_infer(st2_idx, {idx: i})
st1_v = sym_infer(st1_valid, {idx: i})
st2_v = sym_infer(st2_valid, {idx: i})
for i in range(0, prod(st1.shape)):
st1_off, st1_v = shapetracker_getitem(st1, i)
st2_off, st2_v = shapetracker_getitem(st2, i)
if st1_v != st2_v or (st1_off != st2_off and st1_v):
print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}")
print(st1)

View File

@ -2,13 +2,12 @@
from __future__ import annotations
import functools
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, Iterable, Any
from typing import Tuple, List, Optional, Dict, Set, Any
from tinygrad.helpers import merge_dicts, getenv
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
from tinygrad.shape.symbolic import Variable, MulNode, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
from tinygrad.shape.view import View, strides_for_shape
from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, BinaryOps
from tinygrad.ops import graph_rewrite
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite
from tinygrad.codegen.uopgraph import constant_folder, _get_chain
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
@ -117,21 +116,6 @@ class ShapeTracker:
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
idx, valid = self.views[-1].expr(idxs)
for view in reversed(self.views[0:-1]):
if valid.max == 0: return NumNode(-1), valid
view = view.minify()
acc, idxs = 1, []
for d in reversed(view.shape):
idxs.append((idx//acc)%d)
acc *= d
idx, valid = view.expr(idxs[::-1], valid)
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
return idx, valid
def axis_is_masked(self, axis:int) -> bool:
_, valid = self.to_indexed_uops()
return axis in [x.arg for x in graph_rewrite(valid, constant_folder).sparents if x.op is UOps.RANGE]