mirror of https://github.com/commaai/tinygrad.git
test_image_valid.py -> test_simplify_valid_idx.py (#6724)
restructure the tests, will use the same file for non-image tests
This commit is contained in:
parent
e0d8685c99
commit
14524eeddc
|
@ -1,16 +1,20 @@
|
|||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
from tinygrad.codegen.uopgraph import linearize_uop, full_graph_rewrite, is_increasing
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps
|
||||
|
||||
def render(image_shape, valid:UOp, idx:UOp) -> str:
|
||||
uops = linearize_uop(full_graph_rewrite(UOp(UOps.LOAD, dtypes.float.vec(4), (
|
||||
def get_load_image_uop(image_shape:Tuple[int, ...], valid:UOp, idx:UOp):
|
||||
return UOp(UOps.LOAD, dtypes.float.vec(4), (
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef(image_shape), arg=0),
|
||||
idx,
|
||||
UOp(UOps.VECTORIZE, dtypes.float.vec(4), src=(UOp.const(dtypes.float, 0),)*4),
|
||||
valid
|
||||
)).sink()))
|
||||
))
|
||||
|
||||
def render(uop:UOp) -> str:
|
||||
uops = linearize_uop(full_graph_rewrite(uop.sink()))
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
class TestRenderer(OpenCLRenderer):
|
||||
code_for_op = {**OpenCLRenderer().code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
|
@ -45,46 +49,59 @@ class TestHelpers(unittest.TestCase):
|
|||
self.assertTrue(is_increasing(rng))
|
||||
self.assertTrue(is_increasing(rng+2))
|
||||
|
||||
class TestValidSimplification(unittest.TestCase):
|
||||
class TestValidIdxSimplification(unittest.TestCase):
|
||||
def test_conv_backward(self):
|
||||
pass
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def test_idx_gt_c(self):
|
||||
# (idx1 < c+1).ne(True) ? (..., idx1-1+c) : 0 can drop the valid
|
||||
# (idx1 < c+1).ne(True) -> idx > c
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(1).ne(True), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))")
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(1).ne(True), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-2))),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-2))))")
|
||||
shape = (10, 10, 4)
|
||||
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1)))
|
||||
self.assertEqual(render(load), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))")
|
||||
load = get_load_image_uop(shape, (gidx1).lt(1).ne(True), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-2)))
|
||||
self.assertEqual(render(load), "read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-2))))")
|
||||
|
||||
# should match any one of the AND clause and drop the matched statement from valid
|
||||
valid = (gidx0).lt(1).ne(True) & (gidx1).lt(1).ne(True)
|
||||
self.assertEqual(render((10, 10, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0+1, gidx1-1))),
|
||||
load = get_load_image_uop(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0+1, gidx1-1)))
|
||||
self.assertEqual(render(load),
|
||||
"(((gidx0<1)!=1)?read_imagef(data0, smp, (int2)((gidx0+1),(gidx1+(-1)))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
valid = (gidx1).lt(1).ne(True) & (gidx1).lt(1).ne(True)
|
||||
self.assertEqual(render((10, 10, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1))),
|
||||
load = get_load_image_uop(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1-1)))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+(-1))))")
|
||||
|
||||
def test_idx_lt_bound(self):
|
||||
# (idx1 < image_bound) ? (..., idx1) : 0 can drop the valid
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
|
||||
load = get_load_image_uop((10, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1)))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
|
||||
# same thing, valid has a div
|
||||
self.assertEqual(render((10, 10, 4), (gidx1//2).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
|
||||
load = get_load_image_uop((10, 10, 4), (gidx1//2).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1)))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,gidx1))")
|
||||
# 10x20 image, not out of bound
|
||||
self.assertEqual(render((20, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1))),
|
||||
load = get_load_image_uop((20, 10, 4), (gidx1).lt(10), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1)))
|
||||
self.assertEqual(render(load),
|
||||
"((gidx1<10)?read_imagef(data0, smp, (int2)(gidx0,gidx1)):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
def test_generic_idx_lt_bound(self):
|
||||
# (idx1 < image_bound - c) ? (..., idx1 + c) : 0 can drop the valid
|
||||
gidx0 = Special("gidx0", 32)
|
||||
gidx1 = Special("gidx1", 32)
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(8), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+2))),
|
||||
shape = (10, 10, 4)
|
||||
load = get_load_image_uop(shape, (gidx1).lt(8), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+2)))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+2)))")
|
||||
self.assertEqual(render((10, 10, 4), (gidx1).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+5))),
|
||||
load = get_load_image_uop(shape, (gidx1).lt(5), UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0, gidx1+5)))
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(gidx0,(gidx1+5)))")
|
||||
|
||||
def test_valid_empty_set(self):
|
||||
|
@ -93,11 +110,13 @@ class TestValidSimplification(unittest.TestCase):
|
|||
shape = (32, 32, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (gidx0%2, gidx1+2))
|
||||
# not empty
|
||||
self.assertEqual(render(shape, (gidx0).lt(8), idx),
|
||||
load = get_load_image_uop(shape, (gidx0).lt(8), idx)
|
||||
self.assertEqual(render(load),
|
||||
"((gidx0<8)?read_imagef(data0, smp, (int2)((gidx0%2),(gidx1+2))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
# empty
|
||||
self.assertRaises(IndexError, lambda: render(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx))
|
||||
load = get_load_image_uop(shape, (gidx0).lt(8) & (gidx0).lt(8).ne(True), idx)
|
||||
self.assertRaises(IndexError, lambda: render(load))
|
||||
|
||||
def test_openpilot_conv1(self):
|
||||
# first conv in openpilot
|
||||
|
@ -118,7 +137,8 @@ class TestValidSimplification(unittest.TestCase):
|
|||
shape = (128, 1536, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu4+1530)%1536, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
|
||||
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(((idx1*48)+(ridx2*6)+ridx0+(-6)),((idx2*2)+ridx1+(-1))))")
|
||||
|
||||
def test_openpilot_conv2(self):
|
||||
|
@ -138,8 +158,9 @@ class TestValidSimplification(unittest.TestCase):
|
|||
valid = (((idx2*2)+ridx1).lt(1).ne(True))&(((idx1*8)+ridx2).lt(1).ne(True))
|
||||
shape = (128, 768, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((alu3+765)%768, alu1+((idx1+((ridx2+7)//8)+31)//32)+(-2)))
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(((idx1*24)+(ridx2*3)+ridx0+(-3)),((idx2*2)+ridx1+(-1))))")
|
||||
|
||||
def test_openpilot_conv3(self):
|
||||
|
@ -158,39 +179,43 @@ class TestValidSimplification(unittest.TestCase):
|
|||
shape = (8, 1024, 4)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu6+832)%1024),(alu2+((idx1+((ridx1+5)//8)+1)//2)+(-4))))
|
||||
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
# TODO: simplify idx
|
||||
# alu0 = ((idx2*2)+ridx0)
|
||||
self.assertEqual(render(shape, valid, idx),
|
||||
self.assertEqual(render(load),
|
||||
"(((alu0<11)&((((idx1*8)+ridx1)<3)!=1))?read_imagef(data0, smp, (int2)((((idx1*512)+(ridx1*64)+idx0+832)%1024),(alu0+((idx1+((ridx1+5)//8)+1)//2)+(-4)))):(float4)(0.0f,0.0f,0.0f,0.0f))") # noqa: E501
|
||||
|
||||
def test_simplify1(self):
|
||||
# idx has the form (A % m, A // m + k) and valid has (c0 < A) and (A < c1)
|
||||
gidx = Special("gidx", 512)
|
||||
valid = gidx.lt(488) & (gidx).lt(480).ne(True)
|
||||
idx = ((gidx*3+18)%26, (gidx*3+18)//26-56)
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((gidx*3+18)%26, (gidx*3+18)//26-56))
|
||||
load = get_load_image_uop((1, 26, 4), valid, idx)
|
||||
# alu0 is ((gidx*3)+18)
|
||||
self.assertEqual(render((1, 26, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)(((gidx*3)+(-1438)),0))")
|
||||
|
||||
def test_simplify2(self):
|
||||
# from GPU=1 DEBUG=4 FORWARD_ONLY=1 IMAGE=2 python3 test/test_ops.py TestOps.test_simple_padding_conv2d
|
||||
lidx = Special("lidx", 4)
|
||||
valid = lidx.lt(3) & lidx.lt(1).ne(True)
|
||||
idx = ((lidx+1)%2, (lidx+1)//2-1)
|
||||
self.assertEqual(render((1, 2, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((lidx+1)%2, (lidx+1)//2-1))
|
||||
load = get_load_image_uop((1, 2, 4), valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)((lidx+(-1)),0))")
|
||||
|
||||
def test_simplify3(self):
|
||||
# from openpilot
|
||||
idx0 = Special("idx0", 265)
|
||||
valid = idx0.lt(201).ne(True)
|
||||
idx = ((idx0+55)%64, (idx0+55)//64-4)
|
||||
self.assertEqual(render((1, 64, 4), valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((idx0+55)%64, (idx0+55)//64-4))
|
||||
load = get_load_image_uop((1, 64, 4), valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"read_imagef(data0, smp, (int2)((idx0+(-201)),0))")
|
||||
|
||||
def test_simplify4(self):
|
||||
idx0 = Special("idx0", 512)
|
||||
data1_shape = (4, 64, 4)
|
||||
shape = (4, 64, 4)
|
||||
alu2 = ((idx0*4+1)%32)
|
||||
alu3 = ((idx0*4+2)%32)
|
||||
alu4 = ((idx0*4+3)%32)
|
||||
|
@ -199,17 +224,24 @@ class TestValidSimplification(unittest.TestCase):
|
|||
alu9 = idx0.lt(256)
|
||||
|
||||
# TODO: can this be simplified further?
|
||||
load = get_load_image_uop(shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu2*8))%64),(alu2//8))))
|
||||
# alu0 = (((idx0*4)+1)%32)
|
||||
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu2*8))%64),(alu2//8)))),
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((idx0//32)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
load = get_load_image_uop(shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu3*8))%64),(alu3//8))))
|
||||
# alu0 = (((idx0*4)+2)%32)
|
||||
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu3*8))%64),(alu3//8)))),
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((idx0//32)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
load = get_load_image_uop(shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu4*8))%64),(alu4//8))))
|
||||
# alu0 = (((idx0*4)+3)%32)
|
||||
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu4*8))%64),(alu4//8)))),
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((idx0//32)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
load = get_load_image_uop(shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8))))
|
||||
# alu0 = ((idx0*4)%32)
|
||||
self.assertEqual(render(data1_shape, alu9, UOp(UOps.VECTORIZE, dtypes.int.vec(2), (((alu8+(alu5*8))%64),(alu5//8)))),
|
||||
self.assertEqual(render(load),
|
||||
"((idx0<256)?read_imagef(data0, smp, (int2)((((idx0//32)+(alu0*8))%64),(alu0//8))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
def test_simplify5(self):
|
||||
|
@ -221,10 +253,11 @@ class TestValidSimplification(unittest.TestCase):
|
|||
alu1 = (idx1*256)+alu0
|
||||
alu2 = idx1//3
|
||||
alu3 = ((alu1+1)%768)
|
||||
idx = ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10))
|
||||
idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), ((idx0+((((alu3//640)+alu2)%8)*16)+128),((alu3//64)%10)))
|
||||
valid = alu3.lt(640)
|
||||
|
||||
self.assertEqual(render(shape, valid, UOp(UOps.VECTORIZE, dtypes.int.vec(2), idx)),
|
||||
load = get_load_image_uop(shape, valid, idx)
|
||||
self.assertEqual(render(load),
|
||||
"((alu0<640)?read_imagef(data0, smp, (int2)((idx0+((idx1//3)*16)+128),(alu0//64))):(float4)(0.0f,0.0f,0.0f,0.0f))")
|
||||
|
||||
if __name__ == '__main__':
|
Loading…
Reference in New Issue