No negative (#632)

* behavior is correct without VALIDHACKS

* simple div and mod

* fix tests

* no negative variables

* alt form is correct

* still correct

* bug in mulnode

* at least validhacks works now

* cleanups

* test validhacks, and to_image_idx

* cache compare key

* tests and __neg__
This commit is contained in:
George Hotz 2023-03-03 16:48:14 -08:00 committed by GitHub
parent 8c475ea86a
commit 7a1d96fd76
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 181 additions and 123 deletions

View File

@ -162,7 +162,8 @@ jobs:
- name: Test openpilot model
run: |
ALLOWED_KERNEL_COUNT=197 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
UNSAFE_FLOAT4=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
# disabled, this test is flaky
testdocker:

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.helpers import prod
from tinygrad.helpers import prod, all_same
from tinygrad.shape import ShapeTracker, View, ZeroView, merge_views
from tinygrad.codegen.gpu import to_image_idx
@ -66,14 +66,7 @@ class CheckingShapeTracker:
class TestImageShapeTracker(unittest.TestCase):
def test_image(self):
base_shape = (64, 1024, 4)
"""
st = ShapeTracker(shape=(8, 64, 128, 3), views=[
View((1, 66, 130, 32, 1, 1), (0, 4096, 32, 1, 0, 0), -4128),
ZeroView((1, 64, 128, 32, 1, 1), ((0, 1), (-1, 65), (-1, 129), (0, 32), (0, 1), (0, 1))),
View((8, 64, 128, 3), (4, 4160, 32, 4160), 0)])
offsets = [0,32,64]
"""
print(base_shape)
new_view = merge_views(
View((1, 66, 130, 32, 1, 1), (0, 4096, 32, 1, 0, 0), -4128),
@ -88,19 +81,18 @@ class TestImageShapeTracker(unittest.TestCase):
offsets = [0,32,64,96]
print(st.shape)
idys = []
for o in offsets:
print("offset:", o)
idxy, valid = st.expr_idxs(o)
print("idxy:", idxy.render())
print("valids:", [x.render() for x in valid.nodes])
out = to_image_idx(base_shape, idxy, True)
print(out)
#idx = (idxy//4)%base_shape[1]
#idy = (idxy//(4*base_shape[1]))%base_shape[0]
#idx, idy = [x.a if isinstance(x, ModNode) and x.a.max < x.b*2 else x for x in (idx, idy)]
idx, idy = to_image_idx(base_shape, idxy, valid, True)
idys.append(idy)
print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
#print("idx:", idx.render())
#print("idy:", idy.render())
# y index shouldn't be changing
assert all_same(idys)
class TestSimplifyingShapeTracker(unittest.TestCase):
def setUp(self):

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python
import unittest
from tinygrad.shape.symbolic import Variable, divn, modn
from tinygrad.shape.symbolic import Variable, NumNode, Node
class TestSymbolic(unittest.TestCase):
def helper_test_variable(self, v, n, m, s):
@ -24,6 +24,48 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0")
self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0")
def test_div_becomes_num(self):
assert isinstance(Variable("a", 2, 3)//2, NumNode)
def test_var_becomes_num(self):
assert isinstance(Variable("a", 2, 2), NumNode)
def test_equality(self):
idx1 = Variable("idx1", 0, 3)
idx2 = Variable("idx2", 0, 3)
assert idx1 == idx1
assert idx1 != idx2
assert idx1*4 == idx1*4
assert idx1*4 != idx1*3
assert idx1*4 != idx1+4
assert idx1*4 != idx2*4
assert idx1+idx2 == idx1+idx2
assert idx1+idx2 == idx2+idx1
assert idx1+idx2 != idx2
def test_factorize(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
def test_factorize_no_mul(self):
a = Variable("a", 0, 8)
self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
def test_neg(self):
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
def test_add_1(self):
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)")
def test_add_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)+Variable.num(1), 1, 9, "(1+a)")
def test_sub_1(self):
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)")
def test_sub_num_1(self):
self.helper_test_variable(Variable("a", 0, 8)-Variable.num(1), -1, 7, "(-1+a)")
def test_mul_0(self):
self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
@ -45,6 +87,9 @@ class TestSymbolic(unittest.TestCase):
def test_div_min_max(self):
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
def test_div_neg_min_max(self):
self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)")
def test_sum_div_min_max(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
@ -57,9 +102,9 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_no_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
@unittest.skip("mod max is wrong")
def test_mod_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b*50)%100)")
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
def test_sum_div_const(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 4, 0, 7, "a")
@ -73,6 +118,9 @@ class TestSymbolic(unittest.TestCase):
def test_mul_mul(self):
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
def test_div_div(self):
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
def test_distribute_mul(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
@ -86,11 +134,12 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
def test_big_mod(self):
self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
# NOTE: we no longer support negative variables
#self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
#self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
def test_gt_remove(self):
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "0")
@ -107,16 +156,14 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable.ands([Variable.num(1), Variable("a", 0, 1)]), 0, 1, "a")
def test_mod_factor_negative(self):
# this is technically wrong, if b is 0 the output will be negative
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((-1+a)%28)")
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((-1+a)%28)")
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
def test_sum_combine_num(self):
self.helper_test_variable(Variable.sum([Variable.num(29), Variable("a", 0, 10), Variable.num(-23)]), 6, 16, "(6+a)")
def test_div_factor(self):
# TODO: this isn't right
self.helper_test_variable(Variable.sum([Variable.num(-44), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
@ -132,7 +179,7 @@ class TestSymbolic(unittest.TestCase):
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken?
# TODO: why are the negative tests broken? (even if we did support negative variables)
#MIN, MAX = -10, 10
MIN, MAX = 0, 10
# one number
@ -150,15 +197,15 @@ class TestSymbolicNumeric(unittest.TestCase):
self.assertLessEqual(v.min, min(values))
self.assertGreaterEqual(v.max, max(values))
def test_mod_4(self): self.helper_test_numeric(lambda x: modn(x, 4))
def test_div_4(self): self.helper_test_numeric(lambda x: divn(x, 4))
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: divn(x+1, 2))
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: modn(x+1, 2))
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: modn(x*2 + 3, 4))
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: divn(x*2 + 3, 4))
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: modn(divn(x*2 + 3, 4), 4))
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
if __name__ == '__main__':
unittest.main()

View File

@ -3,9 +3,9 @@ from collections import defaultdict
from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner
from tinygrad.codegen.ast import ASTKernel, Token, Types
from tinygrad.shape.symbolic import Node, ModNode, DivNode, render_python
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, Variable, render_python
from tinygrad.shape import ShapeTracker
from tinygrad.helpers import getenv, DEBUG, prod
from tinygrad.helpers import getenv, DEBUG, prod, partition
# div is different in cl than python
render_cl = render_python.copy()
@ -25,11 +25,25 @@ class GPULanguage(NamedTuple):
extra_args : List[str] = []
float4 : Optional[str] = None
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, validhacks=False):
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
idy = (idxy//(4*base_shape[1]))
if validhacks and valid.min == 0:
idx = (idxy//4) + (idy*-base_shape[1])
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
if isinstance(idx, SumNode):
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
assert len(unfactored) <= 1
idx = Variable.sum(idx_nodes)
unfactored = (Variable.sum(unfactored) // base_shape[1])
idy += unfactored
# ugh really...
if idx.min >= base_shape[1]//2:
idx -= base_shape[1]
idy += 1
else:
idx = (idxy//4)%base_shape[1]
idy = (idxy//(4*base_shape[1]))%base_shape[0]
if validhacks: idx, idy = [x.a if isinstance(x, ModNode) and x.a.max < x.b*2 else x for x in (idx, idy)]
return f"(int2)({idx.render(render_cl)}, {idy.render(render_cl)})"
#print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
return idx, idy
class GPUCodegen(ASTKernel):
lang : GPULanguage = GPULanguage()
@ -66,7 +80,8 @@ class GPUCodegen(ASTKernel):
v = Token(f"{self.lang.float4}({','.join([to_store[o+j].tok for j in range(4)])})", Types.FLOAT4)
if hasattr(self.bufs[buf_index]._buf, "IMAGE"):
assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4"
self.kernel.append(f"write_imagef(data{buf_index}, {to_image_idx(self.bufs[buf_index]._base_shape, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid)
self.kernel.append(f"write_imagef(data{buf_index}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
elif v.typ == Types.FLOAT4:
self.kernel.append(f"(({self.lang.buffer_prefix}float4*)data{buf_index})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
else:
@ -97,7 +112,8 @@ class GPUCodegen(ASTKernel):
ldr = const
elif hasattr(self.bufs[buf_index]._buf, "IMAGE"):
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {to_image_idx(self.bufs[buf_index]._base_shape, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid, VALIDHACKS)
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
elif should_upcast and can_merge:
ldr = Token(f"(({self.lang.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
else:

View File

@ -1,48 +1,61 @@
from __future__ import annotations
import math
from typing import List, Dict, Callable, Type
import math, itertools, functools
from typing import List, Dict, Callable, Type, Union
from tinygrad.helpers import partition, all_same
# python has different behavior for negative mod and div than c
def divn(x, a): return x//a if isinstance(x, Node) else int(x/a)
def modn(x, a): return x%a if isinstance(x, Node) else (-((-x)%a) if x < 0 else x%a)
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code is outputs is agnostic, and will never have negative numbers in div or mod
def create_node(typ:Type[Node], *args):
ret = typ(*args)
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {typ} {args}"
if ret.min == ret.max: return NumNode(ret.min)
return ret
class Node:
b: int
min: int
max: int
def render(self, ops=None, ctx=None):
def render(self, ops=None, ctx=None) -> str:
if ops is None: ops = render_python
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx)
assert isinstance(self, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
def __add__(self, b:int): return Variable.sum([self, Variable.num(b)]) if b != 0 else self
def __sub__(self, b:int): return self+-b
def __ge__(self, b:int): return GeNode(self, b)
def __lt__(self, b:int): return LtNode(self, b)
@functools.cached_property
def key(self) -> str: return self.render()
def __repr__(self): return "<"+self.key+">"
def __eq__(self, other:object) -> bool:
if not isinstance(other, Node): return NotImplemented
return self.key == other.key
def __neg__(self): return self*-1
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
def __sub__(self, b:Union[Node, int]): return self+-b
def __ge__(self, b:int): return create_node(GeNode, self, b)
def __lt__(self, b:int): return create_node(LtNode, self, b)
def __mul__(self, b:int):
if b == 0: return NumNode(0)
elif b == 1: return self
if isinstance(self, MulNode): return MulNode(self.a, self.b*b)
# distribute mul into sum
if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes])
return MulNode(self, b)
if isinstance(self, MulNode): return self.a*(self.b*b) # two muls is one mul
if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum
return create_node(MulNode, self, b)
# *** complex ops ***
def __floordiv__(self, b:int):
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
if isinstance(self, MulNode) and modn(self.b, b) == 0: return self.a*divn(self.b, b)
if isinstance(self, MulNode) and modn(b, self.b) == 0: return self.a//divn(b, self.b)
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
if isinstance(self, SumNode):
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
nofactor = []
# ugh, i doubt this is universally right
for x in tmp_nofactor:
if isinstance(x, NumNode):
if modn(x.b, b) != x.b:
factors.append(Variable.num(x.b - modn(x.b, b))) # python does floor division
nofactor.append(Variable.num(modn(x.b, b)))
if (x.b%b) != x.b:
factors.append(Variable.num(x.b - (x.b%b))) # python does floor division
nofactor.append(Variable.num(x.b%b))
else:
nofactor.append(x)
gcd = [math.gcd(x.b, b) if isinstance(x, (MulNode, NumNode)) else None for x in nofactor]
@ -58,24 +71,28 @@ class Node:
for m in muls:
if m > 1 and b%m == 0:
return (self//m)//(b//m)
return DivNode(self, b)
if self.min < 0:
offset = self.min//b
return (self+offset*b)//b - offset
return create_node(DivNode, self, b)
def __mod__(self, b:int):
assert b > 0
if b == 1: return NumNode(0)
if isinstance(self, SumNode):
new_nodes = []
for x in self.nodes:
if isinstance(x, NumNode): new_nodes.append(Variable.num(modn(x.b, b)))
elif isinstance(x, MulNode): new_nodes.append(x.a * modn(x.b, b))
if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b))
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
else: new_nodes.append(x)
a = Variable.sum(new_nodes)
elif isinstance(self, MulNode):
a = self.a * modn(self.b, b)
a = self.a * (self.b%b)
else:
a = self
if a.min >= 0 and a.max < b: return a
if a.min == a.max: return Variable.num(modn(a.min, b))
return ModNode(a, b)
if a.min < 0: return (a + ((a.min//b)*b)) % b
return create_node(ModNode, a, b)
@staticmethod
def num(num:int) -> Node: return NumNode(num)
@ -90,28 +107,35 @@ class Node:
# combine any numbers inside a sum
nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode))
num_sum = sum([x.b for x in num_nodes])
# TODO: these can't be merged due to image indexing. it's not clear which idx to group the offset with
if num_sum >= 0: nodes.append(NumNode(num_sum))
else:
lte_0, rest = partition(num_nodes, lambda x: x.b <= 0)
nodes += [NumNode(x.b) for x in sorted(lte_0, key=lambda x:x.b) if x.b != 0]
if len(rest): nodes += [NumNode(sum([x.b for x in rest]))]
nodes.append(NumNode(sum([x.b for x in num_nodes])))
# combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things)
nodes, mul_nodes = partition(nodes, lambda x: not isinstance(x, MulNode))
mul_nodes += [MulNode(x, 1) for x in nodes]
mul_nodes = sorted(mul_nodes, key=lambda x: x.a.render()) # group by equality (ugh, uses render!)
new_nodes = [k * sum(x.b for x in g) for k, g in itertools.groupby(mul_nodes, key=lambda x: x.a)]
nodes = [x if not isinstance(x, MulNode) or x.b != 1 else x.a for x in new_nodes]
# filter 0s
nodes = [x for x in nodes if x.min != 0 or x.max != 0]
return SumNode(nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(0))
return create_node(SumNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(0))
@staticmethod
def ands(nodes:List[Node]) -> Node:
if any((x.min == 0 and x.max == 0) for x in nodes): return NumNode(0)
# filter 1s
nodes = [x for x in nodes if x.min != x.max]
return AndNode(nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
return create_node(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
# 4 basic node types
class Variable(Node):
def __new__(cls, expr:str, nmin:int, nmax:int):
assert nmin >= 0 and nmin <= nmax
if nmin == nmax: return NumNode(nmin)
return super().__new__(cls)
def __init__(self, expr:str, nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
@ -135,51 +159,29 @@ class RedNode(Node):
class GeNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.min >= b), int(a.max >= b)))
class LtNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.max < b), int(a.min < b)))
class MulNode(OpNode): minmax = staticmethod(lambda a,b: (a.min*b, a.max*b))
class DivNode(OpNode): minmax = staticmethod(lambda a,b: (divn(a.min, b), divn(a.max, b)))
# given a number in the range [amin, amax] (inclusive)
# what are the min and max of that number after modding it by b?
# aka a fast version of:
#values = [modn(rv, b) for rv in range(amin, amax+1)]
#return min(values), max(values)
class MulNode(OpNode): minmax = staticmethod(lambda a,b: (a.min*b, a.max*b) if b >= 0 else (a.max*b, a.min*b))
class DivNode(OpNode):
@staticmethod
def minmax(a, b):
assert a.min >= 0
return a.min//b, a.max//b
# you have 3 included ranges
# range 1 from min1 -> max1 (smaller than a mod)
# range 1 from a.min -> max1 (smaller than a mod)
# range 2 from max1 -> min2
# range 3 from min2 -> max2 (smaller than a mod)
def modrange_negative(amin, amax, b):
assert amin<0 and amax<0
min1, max1 = amin, math.ceil(amin/b)*b
min2, max2 = math.floor(amax/b)*b, amax
if max1 > min2: return (modn(min1, b), modn(max2, b)) # range 2 doesn't exist, min1 -> max2 is smaller than a mod
if max1 < min2: return (-b+1, 0) # range 2 is the full distance
if min2 == max2: return (modn(min1, b), 0) # range 1 is the only valid
return (-b+1, 0) # range 1 and 3 are valid
def modrange_positive(amin, amax, b):
assert amin>=0 and amax>=0
min1, max1 = amin, math.ceil(amin/b)*b
min2, max2 = math.floor(amax/b)*b, amax
if max1 > min2: return (modn(min1, b), modn(max2, b)) # range 2 doesn't exist, min1 -> max2 is smaller than a mod
# range 3 from min2 -> a.max (smaller than a mod)
class ModNode(OpNode):
@staticmethod
def minmax(a, b):
assert a.min >= 0
#values = [x%b for x in range(a.min, a.max+1)]
#return min(values), max(values)
max1, min2 = math.ceil(a.min/b)*b, math.floor(a.max/b)*b
if max1 < min2: return (0, b-1) # range 2 is the full distance
if min1 == max1: return (0, modn(max2, b)) # range 3 is the only valid
if max1 > min2: return (a.min%b, a.max%b) # range 2 doesn't exist, a.min -> a.max is smaller than a mod
if a.min == max1: return (0, a.max%b) # range 3 is the only valid
return (0, b-1) # range 1 and 3 are valid
def modrange(amin, amax, b):
if amin < 0 and amax < 0:
return modrange_negative(amin, amax, b)
if amin >= 0 and amax >= 0:
return modrange_positive(amin, amax, b)
if amin < 0 and amax >= 0:
min1, max1 = modrange_negative(amin, -1, b)
min2, max2 = modrange_positive(0, amax, b)
return min(min1, min2), max(max1, max2)
class ModNode(OpNode): minmax = staticmethod(lambda a,b: modrange(a.min, a.max, b))
# reduce nodes
class SumNode(RedNode): minmax = staticmethod(lambda nodes: (sum([x.min for x in nodes]), sum([x.max for x in nodes])))