2023-04-17 23:21:46 +08:00
|
|
|
import random
|
2023-11-08 09:30:53 +08:00
|
|
|
from tinygrad.helpers import DEBUG, getenv
|
2023-04-17 23:21:46 +08:00
|
|
|
from test.unit.test_shapetracker import CheckingShapeTracker
|
|
|
|
|
|
|
|
def do_permute(st):
|
|
|
|
perm = list(range(0, len(st.shape)))
|
|
|
|
random.shuffle(perm)
|
|
|
|
perm = tuple(perm)
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print("st.permute(", perm, ")")
|
2023-04-17 23:21:46 +08:00
|
|
|
st.permute(perm)
|
|
|
|
|
|
|
|
def do_pad(st):
|
|
|
|
c = random.randint(0, len(st.shape)-1)
|
|
|
|
pad = tuple((random.randint(0,2), random.randint(0,2)) if i==c else (0,0) for i in range(len(st.shape)))
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print("st.pad(", pad, ")")
|
2023-04-17 23:21:46 +08:00
|
|
|
st.pad(pad)
|
|
|
|
|
|
|
|
def do_reshape_split_one(st):
|
|
|
|
c = random.randint(0, len(st.shape)-1)
|
|
|
|
poss = [n for n in [1,2,3,4,5] if st.shape[c]%n == 0]
|
|
|
|
spl = random.choice(poss)
|
|
|
|
shp = st.shape[0:c] + (st.shape[c]//spl, spl) + st.shape[c+1:]
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print("st.reshape(", shp, ")")
|
2023-04-17 23:21:46 +08:00
|
|
|
st.reshape(shp)
|
|
|
|
|
|
|
|
def do_reshape_combine_two(st):
|
|
|
|
if len(st.shape) < 2: return
|
|
|
|
c = random.randint(0, len(st.shape)-2)
|
|
|
|
shp = st.shape[:c] + (st.shape[c] * st.shape[c+1], ) + st.shape[c+2:]
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print("st.reshape(", shp, ")")
|
2023-04-17 23:21:46 +08:00
|
|
|
st.reshape(shp)
|
|
|
|
|
|
|
|
def do_shrink(st):
|
|
|
|
c = random.randint(0, len(st.shape)-1)
|
|
|
|
while 1:
|
|
|
|
shrink = tuple((random.randint(0,s), random.randint(0,s)) if i == c else (0,s) for i,s in enumerate(st.shape))
|
|
|
|
if all(x<y for (x,y) in shrink): break
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print("st.shrink(", shrink, ")")
|
2023-04-17 23:21:46 +08:00
|
|
|
st.shrink(shrink)
|
|
|
|
|
|
|
|
def do_stride(st):
|
|
|
|
c = random.randint(0, len(st.shape)-1)
|
|
|
|
stride = tuple(random.choice([-2,-1,2]) if i==c else 1 for i in range(len(st.shape)))
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print("st.stride(", stride, ")")
|
2023-04-17 23:21:46 +08:00
|
|
|
st.stride(stride)
|
|
|
|
|
2023-12-19 08:03:27 +08:00
|
|
|
def do_flip(st):
|
|
|
|
c = random.randint(0, len(st.shape)-1)
|
2023-12-19 14:09:08 +08:00
|
|
|
stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
|
2023-12-19 08:03:27 +08:00
|
|
|
if DEBUG >= 1: print("st.stride(", stride, ")")
|
|
|
|
st.stride(stride)
|
|
|
|
|
2023-04-17 23:21:46 +08:00
|
|
|
def do_expand(st):
|
|
|
|
c = [i for i,s in enumerate(st.shape) if s==1]
|
|
|
|
if len(c) == 0: return
|
|
|
|
c = random.choice(c)
|
|
|
|
expand = tuple(random.choice([2,3,4]) if i==c else s for i,s in enumerate(st.shape))
|
2023-07-20 00:05:45 +08:00
|
|
|
if DEBUG >= 1: print("st.expand(", expand, ")")
|
2023-04-17 23:21:46 +08:00
|
|
|
st.expand(expand)
|
|
|
|
|
2023-12-19 08:03:27 +08:00
|
|
|
shapetracker_ops = [do_permute, do_pad, do_shrink, do_reshape_split_one, do_reshape_combine_two, do_stride, do_expand]
|
|
|
|
|
2023-04-17 23:21:46 +08:00
|
|
|
if __name__ == "__main__":
|
2023-12-19 08:03:27 +08:00
|
|
|
random.seed(42)
|
2023-11-08 09:30:53 +08:00
|
|
|
for _ in range(getenv("CNT", 200)):
|
2023-07-20 00:05:45 +08:00
|
|
|
st = CheckingShapeTracker((random.randint(2, 10), random.randint(2, 10), random.randint(2, 10)))
|
2023-12-19 08:03:27 +08:00
|
|
|
for i in range(8): random.choice(shapetracker_ops)(st)
|
2023-04-17 23:21:46 +08:00
|
|
|
st.assert_same()
|