mirror of https://github.com/commaai/tinygrad.git
O(1) VALIDHACKS (#2072)
* first refactoring * O(1) validhacks * O(1) validhacks * Some cleaning * mypy * flake8 * Trim trim * flake8 * clean * less chaotic * less chaotic * flake8 * Symbolic, SumNode include mulnode for gcd * fix tests * smal optim * revert * clean * clean * flake8 * small fix * Add symbolic test
This commit is contained in:
parent
30933d5bd0
commit
776605f2fc
|
@ -26,16 +26,15 @@ class TestSymbolic(unittest.TestCase):
|
|||
|
||||
def test_ge_divides(self):
|
||||
expr = (Variable("idx", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512
|
||||
self.helper_test_variable(expr, 0, 1, "((idx*4)<512)")
|
||||
self.helper_test_variable(expr//4, 0, 1, "(idx<128)")
|
||||
self.helper_test_variable(expr, 0, 1, "(idx<128)")
|
||||
|
||||
def test_ge_divides_and(self):
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512])
|
||||
self.helper_test_variable(expr//4, 0, 1, "((idx1<128) and (idx2<128))")
|
||||
self.helper_test_variable(expr, 0, 1, "((idx1<128) and (idx2<128))")
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3)) < 512,
|
||||
(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7)) < 512])
|
||||
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and (idx1<128))")
|
||||
self.helper_test_variable(expr//4, 0, 1, "((((FLOAT8_INDEX//4)+idx2)<128) and ((idx1//4)<32))")
|
||||
|
||||
def test_lt_factors(self):
|
||||
expr = Variable.ands([(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256)) < 512])
|
||||
|
@ -140,6 +139,10 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_sum_div_const_big(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 16, 0, 1, "(a//4)")
|
||||
|
||||
def test_sum_lt_fold(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 3)]) < 16, 0, 1, "(a<4)")
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7) * 4, Variable("b", 0, 4)]) < 16, 0, 1, "(((a*4)+b)<16)")
|
||||
|
||||
def test_mod_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)%9, 0, 5, "a")
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Tuple, Union, Dict, Set
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from tinygrad.helpers import ImageDType, prod, IMAGE, getenv, dtypes, DEBUG
|
||||
|
||||
# *** image Tensor function replacements ***
|
||||
|
@ -140,85 +140,51 @@ def fix_schedule_for_images(schedule:List[ScheduleItem]):
|
|||
|
||||
# *** images have weird indexing requirements ***
|
||||
|
||||
from tinygrad.shape.symbolic import Node, AndNode, MulNode, Variable, NumNode, ModNode, SumNode, LtNode
|
||||
from tinygrad.shape.symbolic import Node, AndNode, Variable, NumNode, SumNode, LtNode
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node) -> Tuple[Tuple[Node, Node], Node]:
|
||||
# This part is substituting variables by just looking at single var LtNodes in valid
|
||||
# Basically if var[0-5] < 3 -> var[0-2]
|
||||
if valid.min == 0:
|
||||
nodes: List = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
var_dict = {var:[var.min, var.max] for var in valid.vars()}
|
||||
|
||||
for nd in nodes:
|
||||
var_range = var_dict[nd.vars()[0]]
|
||||
if isinstance(nd.a, MulNode):
|
||||
if nd.a.b < 0:
|
||||
var_range[0] = (nd.b // nd.a.b) + 1
|
||||
elif nd.a.b > 0:
|
||||
var_range[1] = (nd.b // nd.a.b) - 1 if nd.b % nd.a.b == 0 else nd.b // nd.a.b
|
||||
elif isinstance(nd.a, Variable):
|
||||
var_range[1] = nd.b - 1
|
||||
# We do not allow NumNode because it is constant
|
||||
# TODO: Remove mx != mn
|
||||
sub_dict: Dict[Union[Variable, NumNode], Node] = {v:Variable(v.expr, mn, mx) for v, (mn, mx) in var_dict.items() if mx != mn}
|
||||
valid, idxy = valid.substitute(sub_dict), idxy.substitute(sub_dict)
|
||||
idx = (idxy // 4) % base_shape[1]
|
||||
idy = (idxy // (4 * base_shape[1]))
|
||||
|
||||
idx, idy = (idxy // 4) % base_shape[1], (idxy // (4 * base_shape[1]))
|
||||
idx_vars, idy_vars, val_vars = set(idx.vars()), set(idy.vars()), set(valid.vars())
|
||||
|
||||
# Simplify ModNode if possibe # test_padded_conv_transpose2d, Needs much more thinking
|
||||
if valid.min == 0 and isinstance(idx, ModNode) and isinstance(idx.a, SumNode):
|
||||
if valid.min == 0 and isinstance(idxy, SumNode):
|
||||
nodes = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
same_dict: Dict[Node, List[Tuple[int, Node]]] = {}
|
||||
idx_nodes = idx.a.flat_components
|
||||
val_dict: Dict[Node, Any] = {}
|
||||
idxy_flat_var = [(i, i.vars()[0]) for i in idxy.flat_components if not isinstance(i, NumNode)]
|
||||
|
||||
for node in nodes:
|
||||
if not isinstance(node, LtNode) or not isinstance(node.a, SumNode): continue
|
||||
assert isinstance(node, LtNode)
|
||||
node_flat, node_vars = node.a.flat_components if isinstance(node.a, SumNode) else [node.a], node.vars()
|
||||
same_sym = [i for (i, var) in idxy_flat_var if var in node_vars]
|
||||
first, second = sorted(same_sym)[0], sorted(node_flat)[0]
|
||||
f_b = 1 if isinstance(first, Variable) else first.b
|
||||
s_b = 1 if isinstance(second, Variable) else second.b
|
||||
sig = -1 if s_b < 0 else 1
|
||||
key_node = sig*node.a
|
||||
if key_node not in val_dict: val_dict[key_node] = [key_node.min, key_node.max, abs(f_b//s_b)]
|
||||
val_dict[key_node][(sig + 1)//2] = sig*(node.b - 1)
|
||||
|
||||
nd_flat, nd_vars = node.a.flat_components, node.vars()
|
||||
fakes = {}
|
||||
for cnt, (key_node, (mnn, mxn, multip)) in enumerate(val_dict.items()):
|
||||
fake_var = Variable("fake_" + str(cnt), mnn, mxn)
|
||||
fakes[fake_var] = key_node
|
||||
idxy += multip*(fake_var - key_node)
|
||||
|
||||
same = [x for x in idx_nodes if (x.a if isinstance(x, MulNode) else x) in nd_vars]
|
||||
idx = (idxy // 4) % base_shape[1]
|
||||
idy = (idxy // (4 * base_shape[1]))
|
||||
|
||||
if len(same) != len(nd_vars): continue
|
||||
fake_rep = {fake: node for fake, node in fakes.items()}
|
||||
|
||||
first_b, second_b = nd_flat[0].b if isinstance(nd_flat[0], MulNode) else 1, same[0].b if isinstance(same[0], MulNode) else 1
|
||||
k, same_sum = second_b//first_b, Variable.sum(same)
|
||||
idx = idx.substitute(fake_rep)
|
||||
idy = idy.substitute(fake_rep)
|
||||
|
||||
if k*(node.a) == same_sum: same_dict[same_sum] = same_dict.get(same_sum, []) + [(k, node)]
|
||||
|
||||
for key in same_dict.keys():
|
||||
same, mnn, mxn = key.flat_components, key.min, key.max # type: ignore # Same is sumnode because node.a is SumNode
|
||||
for k, node in same_dict[key]: # TODO: This part may need more thinking
|
||||
if k < 0: mnn = (-k)*max((-node.b) + 1, min([-lal.b if isinstance(lal, MulNode) else 1 for lal in same]))
|
||||
else: mxn = (node.b - 1)*k
|
||||
|
||||
fake_var = Variable("valid_fake", mnn, mxn)
|
||||
total = (Variable.sum([x for x in idx_nodes if x not in same]) + fake_var) % idx.b
|
||||
idx = total.substitute({fake_var: key})
|
||||
# TODO: If idx has no ModNode we may can remove the valid node, but removing it needs careful thinking
|
||||
|
||||
# Simplify SumNodes
|
||||
# This part just removes valid nodes if node is exactly same as idx or idy
|
||||
# idx = 3*a + b (+ 5), valid = 3*a + b < 10 # Valid will be removed as idx will go out of bounds
|
||||
# Check for var intersection, removing valid can affect other index
|
||||
if valid.min == 0 and not idx_vars.intersection(idy_vars):
|
||||
nds = valid.nodes if isinstance(valid, AndNode) else [valid]
|
||||
flats = [id.flat_components for id in (idx, idy) if isinstance(id, SumNode)]
|
||||
sym_sums = [Variable.sum([i for i in flat if not isinstance(i, NumNode)]) for flat in flats]
|
||||
ones = [node for sym_sum in sym_sums for node in nds if (node.a == sym_sum) or (-(node.a) == sym_sum)] # type: ignore # AndNode always consists of LtNode
|
||||
valid = Variable.ands([i for i in nds if i not in ones])
|
||||
|
||||
# This is the slow part
|
||||
# This part is for brute forcing all possible values of idx, idy and valid
|
||||
# If valid is both 0 and 1 for the same (idx, idy) we can not delete the valid
|
||||
if getenv("VALIDHACKS", 1) and valid.min == 0 and not isinstance(idx, ModNode):
|
||||
variables = tuple(val_vars | idy_vars | idx_vars)
|
||||
val_infer, idx_infer, idy_infer = valid.expand(variables), idx.expand(variables), idy.expand(variables)
|
||||
val_dict: Dict[int, Set[Tuple[int,int]]] = {0:set(), 1:set()}
|
||||
|
||||
for v, x, y in zip(val_infer, idx_infer, idy_infer): val_dict[v.min].add((x.min, y.min))
|
||||
|
||||
if not val_dict[1].intersection(val_dict[0]): valid = NumNode(1)
|
||||
idy_vars, idx_vars, ones = set(idy.vars()), set(idx.vars()), []
|
||||
for node in nodes:
|
||||
node_vars = set(node.vars())
|
||||
if not node_vars & (idx_vars | idy_vars): continue #There is simplified NumNode which can not go outside the bounds
|
||||
# NOTE: Why does only idy is problematic? and not the idx
|
||||
if idy_vars == node_vars or idy_vars & node_vars == set(): ones.append(node)
|
||||
valid = Variable.ands([i for i in nodes if i not in ones])
|
||||
|
||||
if DEBUG>=5: print("to_image_idx", base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
||||
return (idx, idy), valid
|
||||
|
|
|
@ -289,17 +289,15 @@ class SumNode(RedNode):
|
|||
if isinstance(x, NumNode): b -= x.b
|
||||
else: new_sum.append(x)
|
||||
lhs = Node.sum(new_sum)
|
||||
if isinstance(lhs, SumNode):
|
||||
muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
if muls:
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = muls[0].b
|
||||
for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
|
||||
if b%mul_gcd == 0:
|
||||
all_others = Variable.sum(others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
# TODO: should we divide both by mul_gcd here?
|
||||
lhs = Variable.sum(muls)
|
||||
nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
|
||||
muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
|
||||
if muls:
|
||||
# NOTE: gcd in python 3.8 takes exactly 2 args
|
||||
mul_gcd = b
|
||||
for x in muls: mul_gcd = gcd(mul_gcd, x.b) # type: ignore # mypy cannot tell x.b is int here
|
||||
all_others = Variable.sum(others)
|
||||
if all_others.min >= 0 and all_others.max < mul_gcd:
|
||||
lhs, b = Variable.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
|
||||
return Node.__lt__(lhs, b)
|
||||
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: return Variable.sum([node.substitute(var_vals) for node in self.nodes])
|
||||
|
|
Loading…
Reference in New Issue