O(1) VALIDHACKS (#2072)

* first refactoring

* O(1) validhacks

* O(1) validhacks

* Some cleaning

* mypy

* flake8

* Trim trim

* flake8

* clean

* less chaotic

* less chaotic

* flake8

* Symbolic, SumNode include mulnode for gcd

* fix tests

* smal optim

* revert

* clean

* clean

* flake8

* small fix

* Add symbolic test
This commit is contained in:
Umut Zengin 2023-10-15 21:26:41 +03:00 committed by GitHub
parent 30933d5bd0
commit 776605f2fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 83 deletions

View File

@ -26,16 +26,15 @@ class TestSymbolic(unittest.TestCase):
def test_ge_divides(self):
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
self.helper_test_variable(expr, 0, 1, "((idx*4)<512)")
self.helper_test_variable(expr//4, 0, 1, "(idx<128)")
self.helper_test_variable(expr, 0, 1, "(idx<128)")
def test_ge_divides_and(self):
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
self.helper_test_variable(expr//4, 0, 1, "((idx1<128) and (idx2<128))")
self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and (idx1<128))")
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))")
def test_lt_factors(self):
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
@ -140,6 +139,10 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_const_big(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 16, 0, 1, "(a//4)")
def test_sum_lt_fold(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)")
def test_mod_mul(self):
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")

View File

@ -1,4 +1,4 @@
from typing import List, Tuple, Union, Dict, Set
from typing import List, Tuple, Dict, Any
from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG
# *** image Tensor function replacements ***
@ -140,85 +140,51 @@ def fix_schedule_for_images(schedule:List[ScheduleItem]):
# *** images have weird indexing requirements ***
from tinygrad.shape.symbolic import Node, AndNode, MulNode, Variable, NumNode, ModNode, SumNode, LtNode
from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
# This part is substituting variables by just looking at single var LtNodes in valid
# Basically if var[0-5] < 3 -> var[0-2]
if valid.min == 0:
nodes: List = valid.nodes if isinstance(valid, AndNode) else [valid]
var_dict = {var:[var.min, var.max] for var in valid.vars()}
for nd in nodes:
var_range = var_dict[nd.vars()[0]]
if isinstance(nd.a, MulNode):
if nd.a.b < 0:
var_range[0] = (nd.b // nd.a.b) + 1
elif nd.a.b > 0:
var_range[1] = (nd.b // nd.a.b) - 1 if nd.b % nd.a.b == 0 else nd.b // nd.a.b
elif isinstance(nd.a, Variable):
var_range[1] = nd.b - 1
# We do not allow NumNode because it is constant
# TODO: Remove mx != mn
sub_dict: Dict[Union[Variable, NumNode], Node] = {v:Variable(v.expr, mn, mx) for v, (mn, mx) in var_dict.items() if mx != mn}
valid, idxy = valid.substitute(sub_dict), idxy.substitute(sub_dict)
idx = (idxy // 4) % base_shape[1]
idy = (idxy // (4 * base_shape[1]))
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
idx_vars, idy_vars, val_vars = set(idx.vars()), set(idy.vars()), set(valid.vars())
# Simplify ModNode if possibe # test_padded_conv_transpose2d, Needs much more thinking
if valid.min == 0 and isinstance(idx, ModNode) and isinstance(idx.a, SumNode):
if valid.min == 0 and isinstance(idxy, SumNode):
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
same_dict: Dict[Node, List[Tuple[int, Node]]] = {}
idx_nodes = idx.a.flat_components
val_dict: Dict[Node, Any] = {}
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
for node in nodes:
if not isinstance(node, LtNode) or not isinstance(node.a, SumNode): continue
assert isinstance(node, LtNode)
node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars()
same_sym = [i for (i, var) in idxy_flat_var if var in node_vars]
first, second = sorted(same_sym)[0], sorted(node_flat)[0]
f_b = 1 if isinstance(first, Variable) else first.b
s_b = 1 if isinstance(second, Variable) else second.b
sig = -1 if s_b < 0 else 1
key_node = sig*node.a
if key_node not in val_dict: val_dict[key_node] = [key_node.min, key_node.max, abs(f_b//s_b)]
val_dict[key_node][(sig + 1)//2] = sig*(node.b - 1)
nd_flat, nd_vars = node.a.flat_components, node.vars()
fakes = {}
for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()):
fake_var = Variable("fake_" + str(cnt), mnn, mxn)
fakes[fake_var] = key_node
idxy += multip*(fake_var - key_node)
same = [x for x in idx_nodes if (x.a if isinstance(x, MulNode) else x) in nd_vars]
idx = (idxy // 4) % base_shape[1]
idy = (idxy // (4 * base_shape[1]))
if len(same) != len(nd_vars): continue
fake_rep = {fake: node for fake, node in fakes.items()}
first_b, second_b = nd_flat[0].b if isinstance(nd_flat[0], MulNode) else 1, same[0].b if isinstance(same[0], MulNode) else 1
k, same_sum = second_b//first_b, Variable.sum(same)
idx = idx.substitute(fake_rep)
idy = idy.substitute(fake_rep)
if k*(node.a) == same_sum: same_dict[same_sum] = same_dict.get(same_sum, []) + [(k, node)]
for key in same_dict.keys():
same, mnn, mxn = key.flat_components, key.min, key.max # type: ignore # Same is sumnode because node.a is SumNode
for k, node in same_dict[key]: # TODO: This part may need more thinking
if k < 0: mnn = (-k)*max((-node.b) + 1, min([-lal.b if isinstance(lal, MulNode) else 1 for lal in same]))
else: mxn = (node.b - 1)*k
fake_var = Variable("valid_fake", mnn, mxn)
total = (Variable.sum([x for x in idx_nodes if x not in same]) + fake_var) % idx.b
idx = total.substitute({fake_var: key})
# TODO: If idx has no ModNode we may can remove the valid node, but removing it needs careful thinking
# Simplify SumNodes
# This part just removes valid nodes if node is exactly same as idx or idy
# idx = 3*a + b (+ 5), valid = 3*a + b < 10 # Valid will be removed as idx will go out of bounds
# Check for var intersection, removing valid can affect other index
if valid.min == 0 and not idx_vars.intersection(idy_vars):
nds = valid.nodes if isinstance(valid, AndNode) else [valid]
flats = [id.flat_components for id in (idx, idy) if isinstance(id, SumNode)]
sym_sums = [Variable.sum([i for i in flat if not isinstance(i, NumNode)]) for flat in flats]
ones = [node for sym_sum in sym_sums for node in nds if (node.a == sym_sum) or (-(node.a) == sym_sum)] # type: ignore # AndNode always consists of LtNode
valid = Variable.ands([i for i in nds if i not in ones])
# This is the slow part
# This part is for brute forcing all possible values of idx, idy and valid
# If valid is both 0 and 1 for the same (idx, idy) we can not delete the valid
if getenv("VALIDHACKS", 1) and valid.min == 0 and not isinstance(idx, ModNode):
variables = tuple(val_vars | idy_vars | idx_vars)
val_infer, idx_infer, idy_infer = valid.expand(variables), idx.expand(variables), idy.expand(variables)
val_dict: Dict[int, Set[Tuple[int,int]]] = {0:set(), 1:set()}
for v, x, y in zip(val_infer, idx_infer, idy_infer): val_dict[v.min].add((x.min, y.min))
if not val_dict[1].intersection(val_dict[0]): valid = NumNode(1)
idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), []
for node in nodes:
node_vars = set(node.vars())
if not node_vars & (idx_vars | idy_vars): continue #There is simplified NumNode which can not go outside the bounds
# NOTE: Why does only idy is problematic? and not the idx
if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node)
valid = Variable.ands([i for i in nodes if i not in ones])
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
return (idx, idy), valid

View File

@ -289,17 +289,15 @@ class SumNode(RedNode):
if isinstance(x, NumNode): b -= x.b
else: new_sum.append(x)
lhs = Node.sum(new_sum)
if isinstance(lhs, SumNode):
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if muls:
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = muls[0].b
for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
if b%mul_gcd == 0:
all_others = Variable.sum(others)
if all_others.min >= 0 and all_others.max < mul_gcd:
# TODO: should we divide both by mul_gcd here?
lhs = Variable.sum(muls)
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
if muls:
# NOTE: gcd in python 3.8 takes exactly 2 args
mul_gcd = b
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
all_others = Variable.sum(others)
if all_others.min >= 0 and all_others.max < mul_gcd:
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
return Node.__lt__(lhs, b)
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes])