fixup test_pattern_matcher (#5712)

This commit is contained in:
chenyu 2024-07-25 13:48:52 -04:00 committed by GitHub
parent 9ceb3a3d1f
commit 05e02ddfb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 20 additions and 12 deletions

View File

@ -2,7 +2,7 @@ import unittest, itertools
from test.helpers import TestUOps from test.helpers import TestUOps
from tinygrad.dtype import dtypes from tinygrad.dtype import dtypes
from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401 from tinygrad.ops import BinaryOps, TernaryOps, ReduceOps, UnaryOps # noqa: F401
from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat, _match from tinygrad.codegen.uops import UOps, UOp, PatternMatcher, UPat
from tinygrad.codegen.uopgraph import UOpGraph, constant_folder from tinygrad.codegen.uopgraph import UOpGraph, constant_folder
class TestPatternMatcher(TestUOps): class TestPatternMatcher(TestUOps):
@ -47,18 +47,25 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c4), None) self.assertEqual(matcher.rewrite(c4), None)
self.assertEqual(matcher.rewrite(c5), None) self.assertEqual(matcher.rewrite(c5), None)
@unittest.skip("this is not supported any more") def test_filter_arg(self):
def test_arg_set(self): matcher = PatternMatcher([
matcher = PatternMatcher([(UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, {-1, 1}), UPat(UOps.CONST, 2)), name="x"), lambda x: x)]) (UPat(UOps.ALU, BinaryOps.MUL, (UPat(UOps.CONST, name="c"), UPat(UOps.CONST, 2)), name="x"),
lambda x,c: x if c.arg in {1, -1} else None)
])
y1 = UOp(UOps.CONST, dtypes.int, arg=1) y1 = UOp(UOps.CONST, dtypes.int, arg=1)
y2 = UOp(UOps.CONST, dtypes.int, arg=2) y2 = UOp(UOps.CONST, dtypes.int, arg=2)
y3 = UOp(UOps.CONST, dtypes.int, arg=-1) y3 = UOp(UOps.CONST, dtypes.int, arg=-1)
c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL) c1 = UOp(UOps.ALU, dtypes.int, (y1, y2), BinaryOps.MUL)
c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL) c2 = UOp(UOps.ALU, dtypes.int, (y2, y2), BinaryOps.MUL)
c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL) c3 = UOp(UOps.ALU, dtypes.int, (y3, y2), BinaryOps.MUL)
# c4 = UOp(UOps.ALU, dtypes.int, (y2, y1), BinaryOps.MUL)
# c5 = UOp(UOps.ALU, dtypes.int, (y2, y3), BinaryOps.MUL)
self.assertEqual(matcher.rewrite(c1), c1) self.assertEqual(matcher.rewrite(c1), c1)
self.assertEqual(matcher.rewrite(c2), None) self.assertEqual(matcher.rewrite(c2), None)
self.assertEqual(matcher.rewrite(c3), c3) self.assertEqual(matcher.rewrite(c3), c3)
# TODO: match these
# self.assertEqual(matcher.rewrite(c4), c4)
# self.assertEqual(matcher.rewrite(c5), c5)
def test_dup_name(self): def test_dup_name(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)]) matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST, name="y"), UPat(UOps.CONST, name="y"))), lambda x, y: x)])
@ -77,7 +84,7 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c2), None) self.assertEqual(matcher.rewrite(c2), None)
def test_dtype_set(self): def test_dtype_set(self):
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=set([dtypes.float32, dtypes.float64])), lambda x: x)]) matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype={dtypes.float32, dtypes.float64}), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0) c2 = UOp(UOps.CONST, dtypes.float64, arg=1.0)
c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0) c3 = UOp(UOps.CONST, dtypes.float16, arg=1.0)
@ -87,7 +94,7 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c3), None) self.assertEqual(matcher.rewrite(c3), None)
self.assertEqual(matcher.rewrite(c4), None) self.assertEqual(matcher.rewrite(c4), None)
def test_vin_one(self): def test_src_one(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)]) matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.CONST))), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
@ -101,7 +108,7 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c4), c4) self.assertEqual(matcher.rewrite(c4), c4)
self.assertEqual(matcher.rewrite(c5), None) self.assertEqual(matcher.rewrite(c5), None)
def test_vin_permutations(self): def test_src_permutations(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)]) matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=[UPat(UOps.CONST), UPat(UOps.ALU)]), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
@ -114,7 +121,7 @@ class TestPatternMatcher(TestUOps):
self.assertEqual(matcher.rewrite(c5), c5) self.assertEqual(matcher.rewrite(c5), c5)
self.assertEqual(matcher.rewrite(c6), None) self.assertEqual(matcher.rewrite(c6), None)
def test_vin_repeat(self): def test_src_repeat(self):
matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)]) matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=UPat(UOps.CONST)), lambda x: x)])
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
@ -140,10 +147,11 @@ class TestPatternMatcher(TestUOps):
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0) c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
u1 = (c1 + c2) + c1 u1 = (c1 + c2) + c1
u2 = (c2 + c1) + c1 u2 = (c2 + c1) + c1
pat = UPat(UOps.ALU, src = (UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b'))) matcher = PatternMatcher([
# TODO: why is this calling a private function? (UPat(UOps.ALU, src=[UPat(UOps.ALU, src=[UPat(name='a'), UPat(name='b')]), UPat(name='b')]), lambda a,b: b)
assert _match(u1, pat, {}) ])
assert _match(u2, pat, {}) self.assertIsNotNone(matcher.rewrite(u1))
self.assertIsNotNone(matcher.rewrite(u2))
@unittest.skip("no longer supported") @unittest.skip("no longer supported")
def test_rewrite_graph_folds(self): def test_rewrite_graph_folds(self):