fix tests to use render (#7116)

This commit is contained in:
George Hotz 2024-10-17 14:35:22 +08:00 committed by GitHub
parent 9f4ca88218
commit d990a16326
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 22 deletions

View File

@ -59,6 +59,12 @@ class TestHelpers(unittest.TestCase):
self.assertTrue(is_increasing(rng+2))
class TestValidIdxSimplification(unittest.TestCase):
def check(self, val0, sidx, svalid):
val0 = full_graph_rewrite(val0.sink()).src[0]
idx, valid = val0.src[1], val0.src[3]
self.assertEqual(idx.render(simplify=False), sidx)
self.assertEqual(valid.render(simplify=False), svalid)
def test_conv_backward(self):
# DEBUG=4 python3 test/test_ops.py TestOps.test_simple_conv2d
gidx0 = Special("gidx0", 3)
@ -88,23 +94,27 @@ class TestValidIdxSimplification(unittest.TestCase):
# TODO: simplify these
val0 = get_gated_load_uop(alu17&(alu9.lt(7)), alu15+(alu5//10)+(alu9*9))
self.assertEqual(render(val0),
"((((alu2<30)&(alu3<7))&(alu1<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu2//10)*3)+(alu3*63)+(alu0//10)+(alu1*9)]:0.0f)")
self.check(val0,
"(((((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63))+((((gidx0*3)+lidx2)+9)//10))+(((((gidx0*3)+lidx2)+9)%10)*9))",
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+9)%10)<7))")
val1 = get_gated_load_uop(
((alu16&gidx0.lt(1))&alu13.lt(7))&alu7.lt(7),
((((((((((lidx1*10)+gidx0)//3)+3)//10)+alu10)//10)+lidx0)%4)*441)+((((alu6+alu12)//10)%3)*3)+(alu13*63)+(((alu3//10)+2)%3)+(alu7*9)
)
self.assertEqual(render(val1),
"(((((alu2<30)&(gidx0<1))&(((((gidx0+9)//10)+alu0+lidx1+alu1)%10)<7))&((((gidx0*3)+lidx2+7)%10)<7))?data0[(lidx2*9)+(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu2//10)*3)+((alu2%10)*63)+65]:0.0f)") # noqa: E501
self.check(val1,
"(((lidx2*9)+(((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63)))+65)",
"(((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(gidx0<1))&(((((((gidx0+9)//10)+(gidx1*3))+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+7)%10)<7))")
val2 = get_gated_load_uop(alu17&alu1.lt(7), alu15+(gidx0*27)+(lidx2*9))
self.assertEqual(render(val2),
"((((alu0<30)&(alu1<7))&(((gidx0*3)+lidx2)<7))?data0[(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu0//10)*3)+(alu1*63)+(gidx0*27)+(lidx2*9)]:0.0f)") # noqa: E501
self.check(val2,
"(((((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63))+(gidx0*27))+(lidx2*9))",
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((gidx0*3)+lidx2)<7))")
val3 = get_gated_load_uop(alu17&alu8.lt(7), (alu4//10)+alu15+(alu8*9)+1)
self.assertEqual(render(val3),
"((((alu2<30)&(alu3<7))&(alu1<7))?data0[(alu0//10)+(((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+((alu2//10)*3)+(alu3*63)+(alu1*9)+1]:0.0f)")
self.check(val3,
"(((((((gidx0*3)+lidx2)+8)//10)+(((((((gidx1+(ridx0*3))//10)+lidx0)%4)*441)+(((((gidx1*3)+lidx1)+(ridx0*9))//10)*3))+(((((gidx1*3)+lidx1)+(ridx0*9))%10)*63)))+(((((gidx0*3)+lidx2)+8)%10)*9))+1)",
"((((((gidx1*3)+lidx1)+(ridx0*9))<30)&(((((gidx1*3)+lidx1)+(ridx0*9))%10)<7))&(((((gidx0*3)+lidx2)+8)%10)<7))")
def test_cumsum(self):
gidx0 = Special("gidx0", 5)
@ -112,7 +122,9 @@ class TestValidIdxSimplification(unittest.TestCase):
gate = (gidx0*4+lidx0).lt(19).ne(True)
idx = gidx0*4+lidx0-19
load = get_gated_load_uop(gate, idx)
self.assertEqual(render(load), "(((((gidx0*4)+lidx0)<19)!=1)?data0[0]:0.0f)")
self.check(load,
"0",
"((((gidx0*4)+lidx0)<19)!=True)")
def test_simplify_within_valid(self):
ridx0 = Range(0, 4)
@ -124,7 +136,9 @@ class TestValidIdxSimplification(unittest.TestCase):
load = get_gated_load_uop(valid, idx)
# TODO: simplify the valid
# alu0 = ((ridx0*3)+ridx1)
self.assertEqual(render(load), "(((alu0<8)&((((alu0//8)+(ridx2*3)+ridx3)%4)<2))?data0[ridx0+ridx1+ridx2+ridx3]:0.0f)")
self.check(load,
"(((ridx0+ridx1)+ridx2)+ridx3)",
"((((ridx0*3)+ridx1)<8)&(((((((ridx0*3)+ridx1)//8)+(ridx2*3))+ridx3)%4)<2))")
class TestImageSimplification(unittest.TestCase):
def test_idx_gt_c(self):

View File

@ -24,7 +24,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
code_for_op = {**CStyleLanguage.code_for_op, BinaryOps.IDIV: lambda a,b,dtype: f"({a}//{b})"}
rewritten_uop = [uop for uop in uops if uop.op is UOps.STORE][0].src[-1]
fxn = TestRenderer().render("", uops)
return fxn.split("data0[0] = ")[1].split(";")[0], rewritten_uop.vmin, rewritten_uop.vmax
return fxn.split("*(data0+0) = ")[1].split(";")[0], rewritten_uop.vmin, rewritten_uop.vmax
def NumNode(val): return UOp.const(dtypes.int, val)
class Node:

View File

@ -961,16 +961,14 @@ symbolic_flat = symbolic+PatternMatcher([
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
# for debug
syms = { BinaryOps.ADD: "+", BinaryOps.MUL: "*", BinaryOps.IDIV: "//", BinaryOps.MOD: "%",
BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"}
renderer = PatternMatcher([
(UPat(UOps.DEFINE_VAR, name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
(UPat((UOps.DEFINE_VAR, UOps.SPECIAL), name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
(UPat(UOps.RANGE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"ridx{x.arg[0]}")),
(UPat(UOps.CONST, name="x"), lambda x: UOp(UOps.NOOP, arg=str(x.arg))),
(UPat(UOps.BIND, src=UPat(UOps.NOOP), name="x"), lambda x: x.src[0]),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.ADD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}+{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MUL, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}*{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.IDIV, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}//{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.MOD, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}%{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPLT, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}<{x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPNE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}!={x.src[1].arg})")),
(UPat(UOps.ALU, src=UPat(UOps.NOOP), name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}{syms[x.arg]}{x.src[1].arg})")),
])
# *** what was symbolic.py ***

View File

@ -11,7 +11,7 @@ def _render_index(r:CStyleLanguage, buf:UOp, idx:UOp, dtype:DType):
sidx = strip_parens(r[idx]) if idx.arg == BinaryOps.ADD else r[idx]
if dtype.count > 1 and isinstance(buf.dtype, PtrDType):
return f"*(({r.smem_prefix if buf.dtype.local and r.smem_prefix_for_cast else r.buffer_prefix}{r.render_dtype(dtype)}*)({r[buf]}+{sidx}))"
return f"*({r[buf]}+{sidx})" if r.uses_ptr_arithmetic else f"{r[buf]}[{sidx}]"
return f"*({r[buf]}+{sidx})"
base_rewrite = PatternMatcher([
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
@ -81,7 +81,6 @@ class CStyleLanguage(Renderer):
code_for_workitem: Dict[Union[Literal["g"], Literal["l"], Literal["i"]], Callable] = {}
extra_args: List[str] = []
float4: Optional[str] = None
uses_ptr_arithmetic: bool = False
type_map: Dict[DType, str] = {}
infinity: str = "INFINITY"
nan: str = "NAN"
@ -266,7 +265,6 @@ class MetalRenderer(CStyleLanguage):
arg_int_prefix = "constant int&"
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
float4 = "float4"
uses_ptr_arithmetic = True
code_for_workitem = {"g": lambda x: f"gid.{chr(120+int(x))}", "l": lambda x: f"lid.{chr(120+int(x))}"}
# uint3 used for gid/lid - TODO: this should probably be `ushort3 lid [[thread_position_in_threadgroup]]`
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
@ -390,7 +388,6 @@ class AMDRenderer(CStyleLanguage):
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
float4 = "make_float4"
uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
type_map = {dtypes.bfloat16: "hip_bfloat16"}
extra_matcher = PatternMatcher([
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),