mirror of https://github.com/commaai/tinygrad.git
Support constant expand to symbolic shape (#1411)
This commit is contained in:
parent
6572ca6835
commit
34f348643b
|
@ -85,11 +85,17 @@ class TestSymbolicReshape(unittest.TestCase):
|
|||
with self.assertRaises(AssertionError):
|
||||
t = Tensor.rand(3, 4).reshape(Variable("too_big", 100, 200), 4)
|
||||
|
||||
class TestSymbolicReshape(unittest.TestCase):
|
||||
class TestSymbolicExpand(unittest.TestCase):
|
||||
def test_expand_into_symbols(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor([[1], [2], [3]]).expand((3, vi))
|
||||
assert a.shape == (3, vi)
|
||||
vj = Variable("j", 1, 10)
|
||||
a = a.reshape(3, vi, 1).expand((3, vi, vj))
|
||||
assert a.shape == (3, vi, vj)
|
||||
assert a.shape == (3, vi, vj)
|
||||
|
||||
def test_plus_expands_constant(self):
|
||||
vi = Variable("i", 1, 10)
|
||||
a = Tensor.rand(3, 4).reshape(3, vi)
|
||||
a = a + 1
|
||||
assert a.shape == (3, vi)
|
|
@ -270,6 +270,12 @@ class TestSymbolicVars(unittest.TestCase):
|
|||
assert sym_vars(a+b) == [a, b]
|
||||
assert sym_vars(a*3) == [a]
|
||||
|
||||
class TestSymbolicMinMax(unittest.TestCase):
|
||||
def test_min_max_known(self):
|
||||
a = Variable("a", 1, 8)
|
||||
assert max(1, a) == max(a, 1) == a
|
||||
assert min(1, a) == min(a, 1) == 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ class Node:
|
|||
def hash(self) -> int: return hash(self.key)
|
||||
def __repr__(self): return "<"+self.key+">"
|
||||
def __hash__(self): return self.hash
|
||||
def __bool__(self): return not (self.max == self.min == 0)
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
|
@ -119,7 +120,7 @@ class Node:
|
|||
def ands(nodes:List[Node]) -> Node:
|
||||
if not nodes: return NumNode(1)
|
||||
if len(nodes) == 1: return nodes[0]
|
||||
if any(x.min == x.max == 0 for x in nodes): return NumNode(0)
|
||||
if any(not x for x in nodes): return NumNode(0)
|
||||
|
||||
# filter 1s
|
||||
nodes = [x for x in nodes if x.min != x.max]
|
||||
|
|
Loading…
Reference in New Issue