Support constant expand to symbolic shape (#1411)

This commit is contained in:
chenyu 2023-08-02 21:21:22 -07:00 committed by GitHub
parent 6572ca6835
commit 34f348643b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 3 deletions

View File

@ -85,11 +85,17 @@ class TestSymbolicReshape(unittest.TestCase):
with self.assertRaises(AssertionError):
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
class TestSymbolicReshape(unittest.TestCase):
class TestSymbolicExpand(unittest.TestCase):
def test_expand_into_symbols(self):
vi = Variable("i", 1, 10)
a = Tensor([[1], [2], [3]]).expand((3, vi))
assert a.shape == (3, vi)
vj = Variable("j", 1, 10)
a = a.reshape(3, vi, 1).expand((3, vi, vj))
assert a.shape == (3, vi, vj)
assert a.shape == (3, vi, vj)
def test_plus_expands_constant(self):
vi = Variable("i", 1, 10)
a = Tensor.rand(3, 4).reshape(3, vi)
a = a + 1
assert a.shape == (3, vi)

View File

@ -270,6 +270,12 @@ class TestSymbolicVars(unittest.TestCase):
assert sym_vars(a+b) == [a, b]
assert sym_vars(a*3) == [a]
class TestSymbolicMinMax(unittest.TestCase):
def test_min_max_known(self):
a = Variable("a", 1, 8)
assert max(1, a) == max(a, 1) == a
assert min(1, a) == min(a, 1) == 1
if __name__ == '__main__':
unittest.main()

View File

@ -28,6 +28,7 @@ class Node:
def hash(self) -> int: return hash(self.key)
def __repr__(self): return "<"+self.key+">"
def __hash__(self): return self.hash
def __bool__(self): return not (self.max == self.min == 0)
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
@ -119,7 +120,7 @@ class Node:
def ands(nodes:List[Node]) -> Node:
if not nodes: return NumNode(1)
if len(nodes) == 1: return nodes[0]
if any(x.min == x.max == 0 for x in nodes): return NumNode(0)
if any(not x for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]