mirror of https://github.com/commaai/tinygrad.git
No negative (#632)
* behavior is correct without VALIDHACKS * simple div and mod * fix tests * no negative variables * alt form is correct * still correct * bug in mulnode * at least validhacks works now * cleanups * test validhacks, and to_image_idx * cache compare key * tests and __neg__
This commit is contained in:
parent
8c475ea86a
commit
7a1d96fd76
|
@ -162,7 +162,8 @@ jobs:
|
|||
- name: Test openpilot model
|
||||
run: |
|
||||
ALLOWED_KERNEL_COUNT=197 FLOAT16=1 VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
UNSAFE_FLOAT4=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
VALIDHACKS=1 DEBUGCL=1 GPU=1 IMAGE=2 python3 openpilot/compile.py
|
||||
|
||||
# disabled, this test is flaky
|
||||
testdocker:
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.helpers import prod, all_same
|
||||
from tinygrad.shape import ShapeTracker, View, ZeroView, merge_views
|
||||
from tinygrad.codegen.gpu import to_image_idx
|
||||
|
||||
|
@ -66,14 +66,7 @@ class CheckingShapeTracker:
|
|||
class TestImageShapeTracker(unittest.TestCase):
|
||||
def test_image(self):
|
||||
base_shape = (64, 1024, 4)
|
||||
|
||||
"""
|
||||
st = ShapeTracker(shape=(8, 64, 128, 3), views=[
|
||||
View((1, 66, 130, 32, 1, 1), (0, 4096, 32, 1, 0, 0), -4128),
|
||||
ZeroView((1, 64, 128, 32, 1, 1), ((0, 1), (-1, 65), (-1, 129), (0, 32), (0, 1), (0, 1))),
|
||||
View((8, 64, 128, 3), (4, 4160, 32, 4160), 0)])
|
||||
offsets = [0,32,64]
|
||||
"""
|
||||
print(base_shape)
|
||||
|
||||
new_view = merge_views(
|
||||
View((1, 66, 130, 32, 1, 1), (0, 4096, 32, 1, 0, 0), -4128),
|
||||
|
@ -88,19 +81,18 @@ class TestImageShapeTracker(unittest.TestCase):
|
|||
offsets = [0,32,64,96]
|
||||
|
||||
print(st.shape)
|
||||
idys = []
|
||||
for o in offsets:
|
||||
print("offset:", o)
|
||||
idxy, valid = st.expr_idxs(o)
|
||||
print("idxy:", idxy.render())
|
||||
print("valids:", [x.render() for x in valid.nodes])
|
||||
out = to_image_idx(base_shape, idxy, True)
|
||||
print(out)
|
||||
#idx = (idxy//4)%base_shape[1]
|
||||
#idy = (idxy//(4*base_shape[1]))%base_shape[0]
|
||||
#idx, idy = [x.a if isinstance(x, ModNode) and x.a.max < x.b*2 else x for x in (idx, idy)]
|
||||
idx, idy = to_image_idx(base_shape, idxy, valid, True)
|
||||
idys.append(idy)
|
||||
print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
||||
|
||||
#print("idx:", idx.render())
|
||||
#print("idy:", idy.render())
|
||||
# y index shouldn't be changing
|
||||
assert all_same(idys)
|
||||
|
||||
class TestSimplifyingShapeTracker(unittest.TestCase):
|
||||
def setUp(self):
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import Variable, divn, modn
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, Node
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
def helper_test_variable(self, v, n, m, s):
|
||||
|
@ -24,6 +24,48 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0")
|
||||
self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0")
|
||||
|
||||
def test_div_becomes_num(self):
|
||||
assert isinstance(Variable("a", 2, 3)//2, NumNode)
|
||||
|
||||
def test_var_becomes_num(self):
|
||||
assert isinstance(Variable("a", 2, 2), NumNode)
|
||||
|
||||
def test_equality(self):
|
||||
idx1 = Variable("idx1", 0, 3)
|
||||
idx2 = Variable("idx2", 0, 3)
|
||||
assert idx1 == idx1
|
||||
assert idx1 != idx2
|
||||
assert idx1*4 == idx1*4
|
||||
assert idx1*4 != idx1*3
|
||||
assert idx1*4 != idx1+4
|
||||
assert idx1*4 != idx2*4
|
||||
assert idx1+idx2 == idx1+idx2
|
||||
assert idx1+idx2 == idx2+idx1
|
||||
assert idx1+idx2 != idx2
|
||||
|
||||
def test_factorize(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a*2+a*3, 0, 8*5, "(a*5)")
|
||||
|
||||
def test_factorize_no_mul(self):
|
||||
a = Variable("a", 0, 8)
|
||||
self.helper_test_variable(a+a*3, 0, 8*4, "(a*4)")
|
||||
|
||||
def test_neg(self):
|
||||
self.helper_test_variable(-Variable("a", 0, 8), -8, 0, "(a*-1)")
|
||||
|
||||
def test_add_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+1, 1, 9, "(1+a)")
|
||||
|
||||
def test_add_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)+Variable.num(1), 1, 9, "(1+a)")
|
||||
|
||||
def test_sub_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-1, -1, 7, "(-1+a)")
|
||||
|
||||
def test_sub_num_1(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)-Variable.num(1), -1, 7, "(-1+a)")
|
||||
|
||||
def test_mul_0(self):
|
||||
self.helper_test_variable(Variable("a", 0, 8)*0, 0, 0, "0")
|
||||
|
||||
|
@ -45,6 +87,9 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_div_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // 2, 0, 3, "(a//2)")
|
||||
|
||||
def test_div_neg_min_max(self):
|
||||
self.helper_test_variable(Variable("a", 0, 7) // -2, -3, 0, "((a//2)*-1)")
|
||||
|
||||
def test_sum_div_min_max(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7), Variable("b", 0, 3)]) // 2, 0, 5, "((a+b)//2)")
|
||||
|
||||
|
@ -57,9 +102,9 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_sum_div_no_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
|
||||
|
||||
@unittest.skip("mod max is wrong")
|
||||
def test_mod_factor(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 50, "((b*50)%100)")
|
||||
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
|
||||
def test_sum_div_const(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*4, Variable.num(3)]) // 4, 0, 7, "a")
|
||||
|
@ -73,6 +118,9 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_mul_mul(self):
|
||||
self.helper_test_variable((Variable("a", 0, 5)*10)*9, 0, 5*10*9, "(a*90)")
|
||||
|
||||
def test_div_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 1800)//10)//9, 0, 20, "(a//90)")
|
||||
|
||||
def test_distribute_mul(self):
|
||||
self.helper_test_variable(Variable.sum([Variable("a", 0, 3), Variable("b", 0, 5)])*3, 0, 24, "((a*3)+(b*3))")
|
||||
|
||||
|
@ -86,11 +134,12 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(Variable("a", 0, 6)%100, 0, 6, "a")
|
||||
|
||||
def test_big_mod(self):
|
||||
self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
|
||||
# NOTE: we no longer support negative variables
|
||||
#self.helper_test_variable(Variable("a", -20, 20)%10, -9, 9, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -20, 0)%10, -9, 0, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -20, 1)%10, -9, 1, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", 0, 20)%10, 0, 9, "(a%10)")
|
||||
self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
#self.helper_test_variable(Variable("a", -1, 20)%10, -1, 9, "(a%10)")
|
||||
|
||||
def test_gt_remove(self):
|
||||
self.helper_test_variable(Variable("a", 0, 6) >= 25, 0, 0, "0")
|
||||
|
@ -107,16 +156,14 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(Variable.ands([Variable.num(1), Variable("a", 0, 1)]), 0, 1, "a")
|
||||
|
||||
def test_mod_factor_negative(self):
|
||||
# this is technically wrong, if b is 0 the output will be negative
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, -1, 9, "((-1+a)%28)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, -1, 27, "((-1+a)%28)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 10), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-29), Variable("a", 0, 100), Variable("b", 0, 10)*28]) % 28, 0, 27, "((27+a)%28)")
|
||||
|
||||
def test_sum_combine_num(self):
|
||||
self.helper_test_variable(Variable.sum([Variable.num(29), Variable("a", 0, 10), Variable.num(-23)]), 6, 16, "(6+a)")
|
||||
|
||||
def test_div_factor(self):
|
||||
# TODO: this isn't right
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-44), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
|
||||
|
||||
def test_mul_div(self):
|
||||
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
|
||||
|
@ -132,7 +179,7 @@ class TestSymbolic(unittest.TestCase):
|
|||
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken?
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
#MIN, MAX = -10, 10
|
||||
MIN, MAX = 0, 10
|
||||
# one number
|
||||
|
@ -150,15 +197,15 @@ class TestSymbolicNumeric(unittest.TestCase):
|
|||
self.assertLessEqual(v.min, min(values))
|
||||
self.assertGreaterEqual(v.max, max(values))
|
||||
|
||||
def test_mod_4(self): self.helper_test_numeric(lambda x: modn(x, 4))
|
||||
def test_div_4(self): self.helper_test_numeric(lambda x: divn(x, 4))
|
||||
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: divn(x+1, 2))
|
||||
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: modn(x+1, 2))
|
||||
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
|
||||
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
|
||||
def test_plus_1_div_2(self): self.helper_test_numeric(lambda x: (x+1)//2)
|
||||
def test_plus_1_mod_2(self): self.helper_test_numeric(lambda x: (x+1)%2)
|
||||
def test_times_2(self): self.helper_test_numeric(lambda x: x*2)
|
||||
def test_times_2_plus_3(self): self.helper_test_numeric(lambda x: x*2 + 3)
|
||||
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: modn(x*2 + 3, 4))
|
||||
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: divn(x*2 + 3, 4))
|
||||
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: modn(divn(x*2 + 3, 4), 4))
|
||||
def test_times_2_plus_3_mod_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)%4)
|
||||
def test_times_2_plus_3_div_4(self): self.helper_test_numeric(lambda x: (x*2 + 3)//4)
|
||||
def test_times_2_plus_3_div_4_mod_4(self): self.helper_test_numeric(lambda x: ((x*2 + 3)//4)%4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -3,9 +3,9 @@ from collections import defaultdict
|
|||
from typing import Optional, List, Tuple, Dict, Set, Final, NamedTuple
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LazyOp, Op, ASTRunner
|
||||
from tinygrad.codegen.ast import ASTKernel, Token, Types
|
||||
from tinygrad.shape.symbolic import Node, ModNode, DivNode, render_python
|
||||
from tinygrad.shape.symbolic import Node, MulNode, DivNode, SumNode, Variable, render_python
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.helpers import getenv, DEBUG, prod
|
||||
from tinygrad.helpers import getenv, DEBUG, prod, partition
|
||||
|
||||
# div is different in cl than python
|
||||
render_cl = render_python.copy()
|
||||
|
@ -25,11 +25,25 @@ class GPULanguage(NamedTuple):
|
|||
extra_args : List[str] = []
|
||||
float4 : Optional[str] = None
|
||||
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, validhacks=False):
|
||||
def to_image_idx(base_shape:Tuple[int, ...], idxy:Node, valid:Node, validhacks=False) -> Tuple[Node, Node]:
|
||||
idy = (idxy//(4*base_shape[1]))
|
||||
if validhacks and valid.min == 0:
|
||||
idx = (idxy//4) + (idy*-base_shape[1])
|
||||
# find the ones in idx that didn't factorize and remove them (TODO: this is not universal)
|
||||
if isinstance(idx, SumNode):
|
||||
unfactored, idx_nodes = partition(idx.nodes, lambda x: isinstance(x, MulNode) and x.b == -base_shape[1])
|
||||
assert len(unfactored) <= 1
|
||||
idx = Variable.sum(idx_nodes)
|
||||
unfactored = (Variable.sum(unfactored) // base_shape[1])
|
||||
idy += unfactored
|
||||
# ugh really...
|
||||
if idx.min >= base_shape[1]//2:
|
||||
idx -= base_shape[1]
|
||||
idy += 1
|
||||
else:
|
||||
idx = (idxy//4)%base_shape[1]
|
||||
idy = (idxy//(4*base_shape[1]))%base_shape[0]
|
||||
if validhacks: idx, idy = [x.a if isinstance(x, ModNode) and x.a.max < x.b*2 else x for x in (idx, idy)]
|
||||
return f"(int2)({idx.render(render_cl)}, {idy.render(render_cl)})"
|
||||
#print(base_shape, idx.min, idx.max, idy.min, idy.max, idx, idy)
|
||||
return idx, idy
|
||||
|
||||
class GPUCodegen(ASTKernel):
|
||||
lang : GPULanguage = GPULanguage()
|
||||
|
@ -66,7 +80,8 @@ class GPUCodegen(ASTKernel):
|
|||
v = Token(f"{self.lang.float4}({','.join([to_store[o+j].tok for j in range(4)])})", Types.FLOAT4)
|
||||
if hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
||||
assert v.typ == Types.FLOAT4, "Image requires upcasting to FLOAT4"
|
||||
self.kernel.append(f"write_imagef(data{buf_index}, {to_image_idx(self.bufs[buf_index]._base_shape, idxy)}, {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
|
||||
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid)
|
||||
self.kernel.append(f"write_imagef(data{buf_index}, (int2)({idx.render(render_cl)}, {idy.render(render_cl)}), {v.tok}); /* {self.bufs[buf_index]._base_shape} */\n")
|
||||
elif v.typ == Types.FLOAT4:
|
||||
self.kernel.append(f"(({self.lang.buffer_prefix}float4*)data{buf_index})[{(idxy//4).render(render_cl)}] = {v.tok};\n")
|
||||
else:
|
||||
|
@ -97,7 +112,8 @@ class GPUCodegen(ASTKernel):
|
|||
ldr = const
|
||||
elif hasattr(self.bufs[buf_index]._buf, "IMAGE"):
|
||||
assert should_upcast and can_merge, f"Image requires upcasting to FLOAT4 {self.buftokens[buf_index]}"
|
||||
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, {to_image_idx(self.bufs[buf_index]._base_shape, idxy, VALIDHACKS)}) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
|
||||
idx, idy = to_image_idx(self.bufs[buf_index]._base_shape, idxy, valid, VALIDHACKS)
|
||||
ldr = Token(f"read_imagef({self.buftokens[buf_index].tok}, smp, (int2)({idx.render(render_cl)}, {idy.render(render_cl)})) /* {self.bufs[buf_index]._base_shape} */", Types.FLOAT4)
|
||||
elif should_upcast and can_merge:
|
||||
ldr = Token(f"(({self.lang.buffer_prefix}float4*){self.buftokens[buf_index].tok})[{(idxy//4).render(render_cl)}]", Types.FLOAT4)
|
||||
else:
|
||||
|
|
|
@ -1,48 +1,61 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
from typing import List, Dict, Callable, Type
|
||||
import math, itertools, functools
|
||||
from typing import List, Dict, Callable, Type, Union
|
||||
from tinygrad.helpers import partition, all_same
|
||||
|
||||
# python has different behavior for negative mod and div than c
|
||||
def divn(x, a): return x//a if isinstance(x, Node) else int(x/a)
|
||||
def modn(x, a): return x%a if isinstance(x, Node) else (-((-x)%a) if x < 0 else x%a)
|
||||
# NOTE: Python has different behavior for negative mod and floor div than c
|
||||
# symbolic matches the Python behavior, but the code is outputs is agnostic, and will never have negative numbers in div or mod
|
||||
|
||||
def create_node(typ:Type[Node], *args):
|
||||
ret = typ(*args)
|
||||
assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {typ} {args}"
|
||||
if ret.min == ret.max: return NumNode(ret.min)
|
||||
return ret
|
||||
|
||||
class Node:
|
||||
b: int
|
||||
min: int
|
||||
max: int
|
||||
def render(self, ops=None, ctx=None):
|
||||
def render(self, ops=None, ctx=None) -> str:
|
||||
if ops is None: ops = render_python
|
||||
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx)
|
||||
assert isinstance(self, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def __add__(self, b:int): return Variable.sum([self, Variable.num(b)]) if b != 0 else self
|
||||
def __sub__(self, b:int): return self+-b
|
||||
def __ge__(self, b:int): return GeNode(self, b)
|
||||
def __lt__(self, b:int): return LtNode(self, b)
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render()
|
||||
def __repr__(self): return "<"+self.key+">"
|
||||
def __eq__(self, other:object) -> bool:
|
||||
if not isinstance(other, Node): return NotImplemented
|
||||
return self.key == other.key
|
||||
def __neg__(self): return self*-1
|
||||
def __add__(self, b:Union[Node, int]): return Variable.sum([self, b if isinstance(b, Node) else Variable.num(b)])
|
||||
def __sub__(self, b:Union[Node, int]): return self+-b
|
||||
def __ge__(self, b:int): return create_node(GeNode, self, b)
|
||||
def __lt__(self, b:int): return create_node(LtNode, self, b)
|
||||
def __mul__(self, b:int):
|
||||
if b == 0: return NumNode(0)
|
||||
elif b == 1: return self
|
||||
if isinstance(self, MulNode): return MulNode(self.a, self.b*b)
|
||||
# distribute mul into sum
|
||||
if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes])
|
||||
return MulNode(self, b)
|
||||
if isinstance(self, MulNode): return self.a*(self.b*b) # two muls is one mul
|
||||
if isinstance(self, SumNode): return Variable.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
return create_node(MulNode, self, b)
|
||||
|
||||
# *** complex ops ***
|
||||
|
||||
def __floordiv__(self, b:int):
|
||||
assert b != 0
|
||||
if b < 0: return (self//-b)*-1
|
||||
if b == 1: return self
|
||||
if isinstance(self, MulNode) and modn(self.b, b) == 0: return self.a*divn(self.b, b)
|
||||
if isinstance(self, MulNode) and modn(b, self.b) == 0: return self.a//divn(b, self.b)
|
||||
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
|
||||
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
|
||||
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
|
||||
if isinstance(self, SumNode):
|
||||
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
|
||||
nofactor = []
|
||||
# ugh, i doubt this is universally right
|
||||
for x in tmp_nofactor:
|
||||
if isinstance(x, NumNode):
|
||||
if modn(x.b, b) != x.b:
|
||||
factors.append(Variable.num(x.b - modn(x.b, b))) # python does floor division
|
||||
nofactor.append(Variable.num(modn(x.b, b)))
|
||||
if (x.b%b) != x.b:
|
||||
factors.append(Variable.num(x.b - (x.b%b))) # python does floor division
|
||||
nofactor.append(Variable.num(x.b%b))
|
||||
else:
|
||||
nofactor.append(x)
|
||||
gcd = [math.gcd(x.b, b) if isinstance(x, (MulNode, NumNode)) else None for x in nofactor]
|
||||
|
@ -58,24 +71,28 @@ class Node:
|
|||
for m in muls:
|
||||
if m > 1 and b%m == 0:
|
||||
return (self//m)//(b//m)
|
||||
return DivNode(self, b)
|
||||
if self.min < 0:
|
||||
offset = self.min//b
|
||||
return (self+offset*b)//b - offset
|
||||
return create_node(DivNode, self, b)
|
||||
|
||||
def __mod__(self, b:int):
|
||||
assert b > 0
|
||||
if b == 1: return NumNode(0)
|
||||
if isinstance(self, SumNode):
|
||||
new_nodes = []
|
||||
for x in self.nodes:
|
||||
if isinstance(x, NumNode): new_nodes.append(Variable.num(modn(x.b, b)))
|
||||
elif isinstance(x, MulNode): new_nodes.append(x.a * modn(x.b, b))
|
||||
if isinstance(x, NumNode): new_nodes.append(Variable.num(x.b%b))
|
||||
elif isinstance(x, MulNode): new_nodes.append(x.a * (x.b%b))
|
||||
else: new_nodes.append(x)
|
||||
a = Variable.sum(new_nodes)
|
||||
elif isinstance(self, MulNode):
|
||||
a = self.a * modn(self.b, b)
|
||||
a = self.a * (self.b%b)
|
||||
else:
|
||||
a = self
|
||||
if a.min >= 0 and a.max < b: return a
|
||||
if a.min == a.max: return Variable.num(modn(a.min, b))
|
||||
return ModNode(a, b)
|
||||
if a.min < 0: return (a + ((a.min//b)*b)) % b
|
||||
return create_node(ModNode, a, b)
|
||||
|
||||
@staticmethod
|
||||
def num(num:int) -> Node: return NumNode(num)
|
||||
|
@ -90,28 +107,35 @@ class Node:
|
|||
|
||||
# combine any numbers inside a sum
|
||||
nodes, num_nodes = partition(nodes, lambda x: not isinstance(x, NumNode))
|
||||
num_sum = sum([x.b for x in num_nodes])
|
||||
# TODO: these can't be merged due to image indexing. it's not clear which idx to group the offset with
|
||||
if num_sum >= 0: nodes.append(NumNode(num_sum))
|
||||
else:
|
||||
lte_0, rest = partition(num_nodes, lambda x: x.b <= 0)
|
||||
nodes += [NumNode(x.b) for x in sorted(lte_0, key=lambda x:x.b) if x.b != 0]
|
||||
if len(rest): nodes += [NumNode(sum([x.b for x in rest]))]
|
||||
nodes.append(NumNode(sum([x.b for x in num_nodes])))
|
||||
|
||||
# combine any MulNodes that factorize (big hack sticking the MulNode(x, 1) on things)
|
||||
nodes, mul_nodes = partition(nodes, lambda x: not isinstance(x, MulNode))
|
||||
mul_nodes += [MulNode(x, 1) for x in nodes]
|
||||
mul_nodes = sorted(mul_nodes, key=lambda x: x.a.render()) # group by equality (ugh, uses render!)
|
||||
new_nodes = [k * sum(x.b for x in g) for k, g in itertools.groupby(mul_nodes, key=lambda x: x.a)]
|
||||
nodes = [x if not isinstance(x, MulNode) or x.b != 1 else x.a for x in new_nodes]
|
||||
|
||||
# filter 0s
|
||||
nodes = [x for x in nodes if x.min != 0 or x.max != 0]
|
||||
return SumNode(nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(0))
|
||||
return create_node(SumNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(0))
|
||||
|
||||
@staticmethod
|
||||
def ands(nodes:List[Node]) -> Node:
|
||||
if any((x.min == 0 and x.max == 0) for x in nodes): return NumNode(0)
|
||||
|
||||
# filter 1s
|
||||
nodes = [x for x in nodes if x.min != x.max]
|
||||
return AndNode(nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
||||
return create_node(AndNode, nodes) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
|
||||
|
||||
# 4 basic node types
|
||||
|
||||
class Variable(Node):
|
||||
def __new__(cls, expr:str, nmin:int, nmax:int):
|
||||
assert nmin >= 0 and nmin <= nmax
|
||||
if nmin == nmax: return NumNode(nmin)
|
||||
return super().__new__(cls)
|
||||
|
||||
def __init__(self, expr:str, nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
|
||||
|
@ -135,51 +159,29 @@ class RedNode(Node):
|
|||
|
||||
class GeNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.min >= b), int(a.max >= b)))
|
||||
class LtNode(OpNode): minmax = staticmethod(lambda a,b: (int(a.max < b), int(a.min < b)))
|
||||
class MulNode(OpNode): minmax = staticmethod(lambda a,b: (a.min*b, a.max*b))
|
||||
class DivNode(OpNode): minmax = staticmethod(lambda a,b: (divn(a.min, b), divn(a.max, b)))
|
||||
|
||||
# given a number in the range [amin, amax] (inclusive)
|
||||
# what are the min and max of that number after modding it by b?
|
||||
|
||||
# aka a fast version of:
|
||||
#values = [modn(rv, b) for rv in range(amin, amax+1)]
|
||||
#return min(values), max(values)
|
||||
class MulNode(OpNode): minmax = staticmethod(lambda a,b: (a.min*b, a.max*b) if b >= 0 else (a.max*b, a.min*b))
|
||||
class DivNode(OpNode):
|
||||
@staticmethod
|
||||
def minmax(a, b):
|
||||
assert a.min >= 0
|
||||
return a.min//b, a.max//b
|
||||
|
||||
# you have 3 included ranges
|
||||
# range 1 from min1 -> max1 (smaller than a mod)
|
||||
# range 1 from a.min -> max1 (smaller than a mod)
|
||||
# range 2 from max1 -> min2
|
||||
# range 3 from min2 -> max2 (smaller than a mod)
|
||||
|
||||
def modrange_negative(amin, amax, b):
|
||||
assert amin<0 and amax<0
|
||||
min1, max1 = amin, math.ceil(amin/b)*b
|
||||
min2, max2 = math.floor(amax/b)*b, amax
|
||||
if max1 > min2: return (modn(min1, b), modn(max2, b)) # range 2 doesn't exist, min1 -> max2 is smaller than a mod
|
||||
if max1 < min2: return (-b+1, 0) # range 2 is the full distance
|
||||
if min2 == max2: return (modn(min1, b), 0) # range 1 is the only valid
|
||||
return (-b+1, 0) # range 1 and 3 are valid
|
||||
|
||||
def modrange_positive(amin, amax, b):
|
||||
assert amin>=0 and amax>=0
|
||||
min1, max1 = amin, math.ceil(amin/b)*b
|
||||
min2, max2 = math.floor(amax/b)*b, amax
|
||||
if max1 > min2: return (modn(min1, b), modn(max2, b)) # range 2 doesn't exist, min1 -> max2 is smaller than a mod
|
||||
# range 3 from min2 -> a.max (smaller than a mod)
|
||||
class ModNode(OpNode):
|
||||
@staticmethod
|
||||
def minmax(a, b):
|
||||
assert a.min >= 0
|
||||
#values = [x%b for x in range(a.min, a.max+1)]
|
||||
#return min(values), max(values)
|
||||
max1, min2 = math.ceil(a.min/b)*b, math.floor(a.max/b)*b
|
||||
if max1 < min2: return (0, b-1) # range 2 is the full distance
|
||||
if min1 == max1: return (0, modn(max2, b)) # range 3 is the only valid
|
||||
if max1 > min2: return (a.min%b, a.max%b) # range 2 doesn't exist, a.min -> a.max is smaller than a mod
|
||||
if a.min == max1: return (0, a.max%b) # range 3 is the only valid
|
||||
return (0, b-1) # range 1 and 3 are valid
|
||||
|
||||
def modrange(amin, amax, b):
|
||||
if amin < 0 and amax < 0:
|
||||
return modrange_negative(amin, amax, b)
|
||||
if amin >= 0 and amax >= 0:
|
||||
return modrange_positive(amin, amax, b)
|
||||
if amin < 0 and amax >= 0:
|
||||
min1, max1 = modrange_negative(amin, -1, b)
|
||||
min2, max2 = modrange_positive(0, amax, b)
|
||||
return min(min1, min2), max(max1, max2)
|
||||
|
||||
class ModNode(OpNode): minmax = staticmethod(lambda a,b: modrange(a.min, a.max, b))
|
||||
|
||||
# reduce nodes
|
||||
|
||||
class SumNode(RedNode): minmax = staticmethod(lambda nodes: (sum([x.min for x in nodes]), sum([x.max for x in nodes])))
|
||||
|
|
Loading…
Reference in New Issue