mirror of https://github.com/commaai/tinygrad.git
remove expr_idxs [run_process_replay] (#6567)
* remove expr_idxs [run_process_replay] * goodbye that test
This commit is contained in:
parent
9ebbedc37f
commit
67a03e72bb
|
@ -130,8 +130,6 @@ jobs:
|
|||
run: |
|
||||
PYTHONPATH="." python test/external/fuzz_shapetracker.py
|
||||
PYTHONPATH="." python test/external/fuzz_shapetracker_math.py
|
||||
- name: Test to_movement_ops
|
||||
run: PYTHONPATH="." python extra/to_movement_ops.py
|
||||
- name: Use as an external package
|
||||
run: |
|
||||
mkdir $HOME/test_external_dir
|
||||
|
|
|
@ -51,10 +51,6 @@ class TestConvShapetracker(unittest.TestCase):
|
|||
print(i, i1, i2, si.inputs[0].size, i1==i2)
|
||||
#self.assertEqual(i1, i2)
|
||||
|
||||
for stt in [st, test_st]:
|
||||
s,va = stt.expr_idxs()
|
||||
print(s)
|
||||
print(va)
|
||||
with self.assertRaises(AssertionError):
|
||||
assert len(st.views) <= 2
|
||||
|
||||
|
|
|
@ -10,18 +10,6 @@ class TestSymbolic(unittest.TestCase):
|
|||
assert st.shape == (x, 3)
|
||||
assert st.real_strides() == (3, 1)
|
||||
|
||||
def test_expr_idxs(self):
|
||||
x = Variable("x", 1, 100)
|
||||
st = ShapeTracker.from_shape((x, 3))
|
||||
idxs = [Variable("x", 0, 100), Variable("y", 0, 100)]
|
||||
e1, e2 = st.expr_idxs(idxs)
|
||||
assert e1.render() == "((x*3)+y)"
|
||||
assert e2.render() == "1"
|
||||
st = st.permute((1, 0))
|
||||
e1, e2 = st.expr_idxs(idxs)
|
||||
assert e1.render() == "((y*3)+x)"
|
||||
assert e2.render() == "1"
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_real_strides_0(self):
|
||||
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
|
||||
|
@ -230,22 +218,6 @@ class TestSymbolicPad(unittest.TestCase):
|
|||
assert t.shape == (9,)
|
||||
st = t.lazydata.st
|
||||
print(st)
|
||||
# TODO: fix this, required for symbolic arange
|
||||
with self.assertRaises(RuntimeError):
|
||||
st.expr_idxs()
|
||||
|
||||
class TestSymbolicShapeExpr(unittest.TestCase):
|
||||
def test_symbolic_expr_idxs(self):
|
||||
# taken from symbolic shape llama
|
||||
i = Variable("i", 1, 120)
|
||||
gidx0 = Variable("gidx0", 0, i)
|
||||
lidx1 = Variable("lidx1", 0, 7)
|
||||
idx = (gidx0, lidx1, NumNode(1))
|
||||
shape = (i+1, 8, 4)
|
||||
strides = (1, (i*4)+4, i+1)
|
||||
st = ShapeTracker((View.create(shape, strides), ))
|
||||
idx, _valid = st.expr_idxs(idx)
|
||||
assert idx.render() == "((lidx1*((i*4)+4))+1+gidx0+i)"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -1,16 +1,19 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import constant_folder
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st, val):
|
||||
_locals = {"idx0": val, "valid": 1}
|
||||
idx, valid = st.reshape((st.size,)).expr_idxs()
|
||||
exec(f"valid={valid.render()};idx0={idx.render()}", None, _locals)
|
||||
return _locals["idx0"] if _locals["valid"] else -1
|
||||
def shapetracker_getitem(st:ShapeTracker, val:int):
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)])
|
||||
idx, valid = graph_rewrite(idx, constant_folder), graph_rewrite(valid, constant_folder)
|
||||
assert idx.op is UOps.CONST and valid.op is UOps.CONST
|
||||
return idx.arg, valid.arg
|
||||
|
||||
class CheckingShapeTracker:
|
||||
def __init__(self, shape):
|
||||
|
@ -70,10 +73,8 @@ class CheckingShapeTracker:
|
|||
def contiguous(self): return self.st.contiguous
|
||||
|
||||
def assert_same(self):
|
||||
x = [shapetracker_getitem(self.st, i) for i in range(prod(self.st.shape))]
|
||||
x = [(v[0] if (v:=shapetracker_getitem(self.st, i))[1] else -1) for i in range(prod(self.st.shape))]
|
||||
y = [self[i] for i in range(prod(self.shape))]
|
||||
idx, valid = self.st.expr_idxs()
|
||||
if DEBUG >= 1: print(x, y, self.st.shape, self.shape, idx.render(), valid.render(), self.st)
|
||||
assert self.st.shape == self.shape
|
||||
assert x == y, f"mismatch shapetracker:{x} real:{y}"
|
||||
|
||||
|
@ -163,7 +164,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
|||
def tearDown(self):
|
||||
for st, offset, shape, idxs_expr in zip(self.sts, self.offset, self.shapes, self.idxs_exprs):
|
||||
numel = prod(shape)
|
||||
assert idxs_expr(self.default_idxs(st.shape)) == st.expr_idxs(None)[0]
|
||||
self.check_bounds(idxs_expr(self.default_idxs(st.shape)), offset, numel)
|
||||
idx0s = [(0,0), (0, min(1, st.shape[0]-1)), (0, st.shape[0]-1), (min(3, st.shape[0]-1), min(6, st.shape[0]-1)), (st.shape[0]-1, st.shape[0]-1)]
|
||||
idx1s = [(0,0), (0, min(1, st.shape[1]-1)), (0, st.shape[1]-1), (min(3, st.shape[1]-1), min(6, st.shape[1]-1)), (st.shape[1]-1, st.shape[1]-1)]
|
||||
|
@ -171,7 +171,6 @@ class TestIndexExpressions2d(unittest.TestCase):
|
|||
(st.shape[2]-1, st.shape[2]-1)] if len(st.shape) == 3 else [None for _ in idx0s]
|
||||
for idx0, idx1, idx2 in product(idx0s, idx1s, idx2s):
|
||||
idxs = [Variable(f"idx{i}", idx[0], idx[1]) for i, idx in enumerate((idx0, idx1, idx2)) if idx is not None]
|
||||
assert idxs_expr(idxs) == st.expr_idxs(idxs)[0]
|
||||
self.check_bounds(idxs_expr(idxs), offset, numel)
|
||||
|
||||
def default_idx(self, shape):
|
||||
|
@ -786,14 +785,6 @@ class TestShapeTrackerSize(unittest.TestCase):
|
|||
strides=(0, 128, 0, 4096, 1), offset=0, mask=None, contiguous=False)))
|
||||
self.assertEqual(st.real_size(), 8389632)
|
||||
|
||||
class TestIdxs(unittest.TestCase):
|
||||
def test_check_idx_range(self):
|
||||
# generated from: (Tensor.rand(4096,599*64) @ Tensor.rand(599*64,1024)).realize()
|
||||
# TODO: use int64
|
||||
st = ShapeTracker(views=(View(shape=(4096, 1024, 599, 1), strides=(613376, 599, 1, 0), offset=0, mask=None, contiguous=True),))
|
||||
with self.assertRaises(AssertionError):
|
||||
st.expr_idxs()
|
||||
|
||||
class TestConsecutive(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
|
|
|
@ -3,7 +3,8 @@ from typing import List
|
|||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from test.unit.test_shapetracker import shapetracker_getitem
|
||||
|
||||
class MultiShapeTracker:
|
||||
def __init__(self, sts:List[ShapeTracker]): self.sts = sts
|
||||
|
@ -19,14 +20,9 @@ class MultiShapeTracker:
|
|||
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
|
||||
if st1.shape != st2.shape: return False
|
||||
if st1 == st2: return True
|
||||
idx = Variable("idx", 0, prod(st1.shape)-1)
|
||||
st1_idx, st1_valid = st1.reshape((st1.size,)).expr_idxs([idx])
|
||||
st2_idx, st2_valid = st2.reshape((st2.size,)).expr_idxs([idx])
|
||||
for i in range(idx.min, idx.max + 1):
|
||||
st1_off = sym_infer(st1_idx, {idx: i})
|
||||
st2_off = sym_infer(st2_idx, {idx: i})
|
||||
st1_v = sym_infer(st1_valid, {idx: i})
|
||||
st2_v = sym_infer(st2_valid, {idx: i})
|
||||
for i in range(0, prod(st1.shape)):
|
||||
st1_off, st1_v = shapetracker_getitem(st1, i)
|
||||
st2_off, st2_v = shapetracker_getitem(st2, i)
|
||||
if st1_v != st2_v or (st1_off != st2_off and st1_v):
|
||||
print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}")
|
||||
print(st1)
|
||||
|
|
|
@ -2,13 +2,12 @@
|
|||
from __future__ import annotations
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, Iterable, Any
|
||||
from typing import Tuple, List, Optional, Dict, Set, Any
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps
|
||||
from tinygrad.ops import graph_rewrite
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import constant_folder, _get_chain
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
|
@ -117,21 +116,6 @@ class ShapeTracker:
|
|||
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
def expr_idxs(self, idxs:Optional[Iterable[Node]]=None) -> Tuple[Node, Node]:
|
||||
idxs = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(self.shape)] if idxs is None else list(idxs)
|
||||
idx, valid = self.views[-1].expr(idxs)
|
||||
for view in reversed(self.views[0:-1]):
|
||||
if valid.max == 0: return NumNode(-1), valid
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for d in reversed(view.shape):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = view.expr(idxs[::-1], valid)
|
||||
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
|
||||
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
|
||||
return idx, valid
|
||||
|
||||
def axis_is_masked(self, axis:int) -> bool:
|
||||
_, valid = self.to_indexed_uops()
|
||||
return axis in [x.arg for x in graph_rewrite(valid, constant_folder).sparents if x.op is UOps.RANGE]
|
||||
|
|
Loading…
Reference in New Issue