fix shape

This commit is contained in:
George Hotz 2023-10-31 11:35:03 -07:00
parent a27c9f9de5
commit 5aaa8a0cc1
2 changed files with 11 additions and 13 deletions

View File

@ -3,6 +3,7 @@ from tinygrad.helpers import prod, dtypes
from extra.onnx import safe_numpy
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.onnx_pb import TensorProto
import os
import numpy as np
import functools
from typing import Union, Tuple, Optional, List, Any
@ -103,7 +104,7 @@ def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([]
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start)), stop=int(safe_numpy(limit)), step=int(safe_numpy(delta))).cast(dtype=start.dtype)
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64)
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])

View File

@ -109,10 +109,6 @@ def thneed_test_onnx(onnx_data, output_fn):
# non thneed
run_onnx = get_run_onnx(onnx_model)
new_tinygrad_out = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).numpy()
for i,(x,y) in enumerate(zip(new_torch_out.flatten().tolist(), new_tinygrad_out.flatten().tolist())):
if abs(x-y) > 100:
print(i, x, y)
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("classic self-test passed!")
else:
@ -142,19 +138,20 @@ if __name__ == "__main__":
schedule, schedule_independent, inputs = get_schedule(onnx_data)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
print(f"{len(schedule_input)} inputs")
schedule = fix_schedule_for_images(schedule)
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
if GRAPH:
for si in schedule_input: log_schedule_item(si)
for si in schedule: log_schedule_item(si)
run_schedule(schedule_independent, disable_logging=True)
run_schedule(schedule_input)
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
schedule = fix_schedule_for_images(schedule)
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
if GRAPH:
for si in schedule_input: log_schedule_item(si)
for si in schedule: log_schedule_item(si)
GlobalCounters.reset()
run_schedule(schedule)
run_schedule(schedule[:])
output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
schedule_to_thneed(schedule, output_fn)