mirror of https://github.com/commaai/tinygrad.git
fix tests to use render (#7116)
This commit is contained in:
parent
9f4ca88218
commit
d990a16326
|
@ -59,6 +59,12 @@ class TestHelpers(unittest.TestCase):
|
|||
self.assertTrue(is_increasing(rng+2))
|
||||
|
||||
class TestValidIdxSimplification(unittest.TestCase):
|
||||
def check(self, val0, sidx, svalid):
|
||||
val0 = full_graph_rewrite(val0.sink()).src[0]
|
||||
idx, valid = val0.src[1], val0.src[3]
|
||||
self.assertEqual(idx.render(simplify=False), sidx)
|
||||
self.assertEqual(valid.render(simplify=False), svalid)
|
||||
|
||||
def test_conv_backward(self):
|
||||
# DEBUG=4 python3 test/test_ops.py TestOps.test_simple_conv2d
|
||||
gidx0 = Special("gidx0", 3)
|
||||
|
@ -88,23 +94,27 @@ class TestValidIdxSimplification(unittest.TestCase):
|
|||
|
||||
# TODO: simplify these
|
||||
val0 = get_gated_load_uop(alu17&(alu9.lt(7)), alu15+(alu5//10)+(alu9*9))
|
||||
self.assertEqual(render(val0),
|
||||
"((((alu2<30)&(alu3<7))&(alu1<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu2//10)*3)+(alu3*63)+(alu0//10)+(alu1*9)]:0.0f)")
|
||||
self.check(val0,
|
||||
"(((((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63))+((((gidx0*3)+lidx2)+9)//10))+(((((gidx0*3)+lidx2)+9)%10)*9))",
|
||||
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+9)%10)<7))")
|
||||
|
||||
val1 = get_gated_load_uop(
|
||||
((alu16&gidx0.lt(1))&alu13.lt(7))&alu7.lt(7),
|
||||
((((((((((lidx1*10)+gidx0)//3)+3)//10)+alu10)//10)+lidx0)%4)*441)+((((alu6+alu12)//10)%3)*3)+(alu13*63)+(((alu3//10)+2)%3)+(alu7*9)
|
||||
)
|
||||
self.assertEqual(render(val1),
|
||||
"(((((alu2<30)&(gidx0<1))&(((((gidx0+9)//10)+alu0+lidx1+alu1)%10)<7))&((((gidx0*3)+lidx2+7)%10)<7))?data0[(lidx2*9)+(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu2//10)*3)+((alu2%10)*63)+65]:0.0f)") # noqa: E501
|
||||
self.check(val1,
|
||||
"(((lidx2*9)+(((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63)))+65)",
|
||||
"(((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(gidx0<1))&(((((((gidx0+9)//10)+(gidx1*3))+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+7)%10)<7))")
|
||||
|
||||
val2 = get_gated_load_uop(alu17&alu1.lt(7), alu15+(gidx0*27)+(lidx2*9))
|
||||
self.assertEqual(render(val2),
|
||||
"((((alu0<30)&(alu1<7))&(((gidx0*3)+lidx2)<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu0//10)*3)+(alu1*63)+(gidx0*27)+(lidx2*9)]:0.0f)") # noqa: E501
|
||||
self.check(val2,
|
||||
"(((((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63))+(gidx0*27))+(lidx2*9))",
|
||||
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((gidx0*3)+lidx2)<7))")
|
||||
|
||||
val3 = get_gated_load_uop(alu17&alu8.lt(7), (alu4//10)+alu15+(alu8*9)+1)
|
||||
self.assertEqual(render(val3),
|
||||
"((((alu2<30)&(alu3<7))&(alu1<7))?data0[(alu0//10)+(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu2//10)*3)+(alu3*63)+(alu1*9)+1]:0.0f)")
|
||||
self.check(val3,
|
||||
"(((((((gidx0*3)+lidx2)+8)//10)+(((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63)))+(((((gidx0*3)+lidx2)+8)%10)*9))+1)",
|
||||
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+8)%10)<7))")
|
||||
|
||||
def test_cumsum(self):
|
||||
gidx0 = Special("gidx0", 5)
|
||||
|
@ -112,7 +122,9 @@ class TestValidIdxSimplification(unittest.TestCase):
|
|||
gate = (gidx0*4+lidx0).lt(19).ne(True)
|
||||
idx = gidx0*4+lidx0-19
|
||||
load = get_gated_load_uop(gate, idx)
|
||||
self.assertEqual(render(load), "(((((gidx0*4)+lidx0)<19)!=1)?data0[0]:0.0f)")
|
||||
self.check(load,
|
||||
"0",
|
||||
"((((gidx0*4)+lidx0)<19)!=True)")
|
||||
|
||||
def test_simplify_within_valid(self):
|
||||
ridx0 = Range(0, 4)
|
||||
|
@ -124,7 +136,9 @@ class TestValidIdxSimplification(unittest.TestCase):
|
|||
load = get_gated_load_uop(valid, idx)
|
||||
# TODO: simplify the valid
|
||||
# alu0 = ((ridx0*3)+ridx1)
|
||||
self.assertEqual(render(load), "(((alu0<8)&((((alu0//8)+(ridx2*3)+ridx3)%4)<2))?data0[ridx0+ridx1+ridx2+ridx3]:0.0f)")
|
||||
self.check(load,
|
||||
"(((ridx0+ridx1)+ridx2)+ridx3)",
|
||||
"((((ridx0*3)+ridx1)<8)&(((((((ridx0*3)+ridx1)//8)+(ridx2*3))+ridx3)%4)<2))")
|
||||
|
||||
class TestImageSimplification(unittest.TestCase):
|
||||
def test_idx_gt_c(self):
|
||||
|
|
|
@ -24,7 +24,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
|
|||
code_for_op = {**CStyleLanguage.code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
|
||||
rewritten_uop = [uop for uop in uops if uop.op is UOps.STORE][0].src[-1]
|
||||
fxn = TestRenderer().render("", uops)
|
||||
return fxn.split("data0[0] = ")[1].split(";")[0], rewritten_uop.vmin, rewritten_uop.vmax
|
||||
return fxn.split("*(data0+0) = ")[1].split(";")[0], rewritten_uop.vmin, rewritten_uop.vmax
|
||||
|
||||
def NumNode(val): return UOp.const(dtypes.int, val)
|
||||
class Node:
|
||||
|
|
|
@ -961,16 +961,14 @@ symbolic_flat = symbolic+PatternMatcher([
|
|||
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
syms = { BinaryOps.ADD: "+", BinaryOps.MUL: "*", BinaryOps.IDIV: "//", BinaryOps.MOD: "%",
|
||||
BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
|
||||
renderer = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_VAR, name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
|
||||
(UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
|
||||
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
|
||||
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
|
||||
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.ADD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}+{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MUL, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.IDIV, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}//{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MOD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}%{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPLT, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}<{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPNE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}!={x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
|
||||
])
|
||||
|
||||
# *** what was symbolic.py ***
|
||||
|
|
|
@ -11,7 +11,7 @@ def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
|
|||
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
|
||||
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
|
||||
return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
|
||||
return f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
|
||||
return f"*({r[buf]}+{sidx})"
|
||||
|
||||
base_rewrite = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
|
||||
|
@ -81,7 +81,6 @@ class CStyleLanguage(Renderer):
|
|||
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
|
||||
extra_args: List[str] = []
|
||||
float4: Optional[str] = None
|
||||
uses_ptr_arithmetic: bool = False
|
||||
type_map: Dict[DType, str] = {}
|
||||
infinity: str = "INFINITY"
|
||||
nan: str = "NAN"
|
||||
|
@ -266,7 +265,6 @@ class MetalRenderer(CStyleLanguage):
|
|||
arg_int_prefix = "constant int&"
|
||||
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
||||
float4 = "float4"
|
||||
uses_ptr_arithmetic = True
|
||||
code_for_workitem = {"g": lambda x: f"gid.{chr(120+int(x))}", "l": lambda x: f"lid.{chr(120+int(x))}"}
|
||||
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
|
||||
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
||||
|
@ -390,7 +388,6 @@ class AMDRenderer(CStyleLanguage):
|
|||
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
||||
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
||||
float4 = "make_float4"
|
||||
uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
extra_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
|
|
Loading…
Reference in New Issue