2023-07-20 00:05:45 +08:00
|
|
|
import itertools
|
2023-04-17 23:21:46 +08:00
|
|
|
import random
|
2024-10-14 04:42:19 +08:00
|
|
|
from tinygrad import Variable, dtypes
|
|
|
|
from tinygrad.ops import UOp
|
2023-07-20 00:05:45 +08:00
|
|
|
from tinygrad.helpers import DEBUG
|
|
|
|
random.seed(42)
|
2023-04-17 23:21:46 +08:00
|
|
|
|
|
|
|
def add_v(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(0,2)
|
|
|
|
return expr + v[rng], rng
|
|
|
|
|
|
|
|
def div(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(1,9)
|
|
|
|
return expr // rng, rng
|
|
|
|
|
|
|
|
def mul(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(-4,4)
|
|
|
|
return expr * rng, rng
|
|
|
|
|
|
|
|
def mod(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(1,9)
|
|
|
|
return expr % rng, rng
|
|
|
|
|
|
|
|
def add_num(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(-4,4)
|
|
|
|
return expr + rng, rng
|
|
|
|
|
2023-06-28 10:34:53 +08:00
|
|
|
def lt(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(-4,4)
|
|
|
|
return expr < rng, rng
|
|
|
|
|
|
|
|
def ge(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(-4,4)
|
|
|
|
return expr >= rng, rng
|
|
|
|
|
2023-07-24 03:41:22 +08:00
|
|
|
def le(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(-4,4)
|
|
|
|
return expr <= rng, rng
|
|
|
|
|
|
|
|
def gt(expr, rng=None):
|
|
|
|
if rng is None: rng = random.randint(-4,4)
|
|
|
|
return expr > rng, rng
|
|
|
|
|
2024-10-04 16:42:27 +08:00
|
|
|
# NOTE: you have to replace these for this test to pass
|
|
|
|
from tinygrad.ops import python_alu, BinaryOps
|
|
|
|
python_alu[BinaryOps.MOD] = lambda x,y: x%y
|
|
|
|
python_alu[BinaryOps.IDIV] = lambda x,y: x//y
|
|
|
|
|
2023-04-17 23:21:46 +08:00
|
|
|
if __name__ == "__main__":
|
2023-07-20 00:05:45 +08:00
|
|
|
ops = [add_v, div, mul, add_num, mod]
|
|
|
|
for _ in range(1000):
|
|
|
|
upper_bounds = [*list(range(1, 10)), 16, 32, 64, 128, 256]
|
|
|
|
u1 = Variable("v1", 0, random.choice(upper_bounds))
|
|
|
|
u2 = Variable("v2", 0, random.choice(upper_bounds))
|
|
|
|
u3 = Variable("v3", 0, random.choice(upper_bounds))
|
2023-04-17 23:21:46 +08:00
|
|
|
v = [u1,u2,u3]
|
2023-07-20 00:05:45 +08:00
|
|
|
tape = [random.choice(ops) for _ in range(random.randint(2, 30))]
|
2023-07-24 03:41:22 +08:00
|
|
|
# 10% of the time, add one of lt, le, gt, ge
|
|
|
|
if random.random() < 0.1: tape.append(random.choice([lt, le, gt, ge]))
|
2024-10-14 04:42:19 +08:00
|
|
|
expr = UOp.const(dtypes.int, 0)
|
2023-04-17 23:21:46 +08:00
|
|
|
rngs = []
|
|
|
|
for t in tape:
|
|
|
|
expr, rng = t(expr)
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print(t.__name__, rng)
|
2023-04-17 23:21:46 +08:00
|
|
|
rngs.append(rng)
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >=1: print(expr)
|
2024-10-04 14:58:51 +08:00
|
|
|
space = list(itertools.product(range(u1.vmin, u1.vmax+1), range(u2.vmin, u2.vmax+1), range(u3.vmin, u3.vmax+1)))
|
2023-07-20 00:05:45 +08:00
|
|
|
volume = len(space)
|
|
|
|
for (v1, v2, v3) in random.sample(space, min(100, volume)):
|
|
|
|
v = [v1,v2,v3]
|
|
|
|
rn = 0
|
|
|
|
for t,r in zip(tape, rngs): rn, _ = t(rn, r)
|
|
|
|
num = eval(expr.render())
|
2024-10-04 16:42:27 +08:00
|
|
|
if num != rn:
|
|
|
|
unsimplified_num = eval(expr.render(simplify=False))
|
|
|
|
assert unsimplified_num == rn, "UNSIMPLIFIED MISMATCH!"
|
|
|
|
assert num == rn, f"mismatched {expr.render()} at {v1=} {v2=} {v3=} = {num} != {rn}\n{expr.render(simplify=False)}"
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print(f"matched {expr.render()} at {v1=} {v2=} {v3=} = {num} == {rn}")
|