mirror of https://github.com/commaai/tinygrad.git
fix shape
This commit is contained in:
parent
a27c9f9de5
commit
5aaa8a0cc1
|
@ -3,6 +3,7 @@ from tinygrad.helpers import prod, dtypes
|
|||
from extra.onnx import safe_numpy
|
||||
from onnx.helper import tensor_dtype_to_np_dtype
|
||||
from onnx.onnx_pb import TensorProto
|
||||
import os
|
||||
import numpy as np
|
||||
import functools
|
||||
from typing import Union, Tuple, Optional, List, Any
|
||||
|
@ -103,7 +104,7 @@ def OptionalGetElement(x: Tensor=None): return x if x is not None else Tensor([]
|
|||
|
||||
def Tile(input: Tensor, repeats): return input.repeat([int(x) for x in safe_numpy(repeats)])
|
||||
def Range(start: Tensor, limit, delta): return Tensor.arange(start=int(safe_numpy(start)), stop=int(safe_numpy(limit)), step=int(safe_numpy(delta))).cast(dtype=start.dtype)
|
||||
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int64)
|
||||
def Shape(data: Tensor, end=None, start=0): return Tensor(list(data.shape)[start:end], dtype=dtypes.int32 if os.path.isfile("/TICI") else dtypes.int64) # TODO: really?
|
||||
def Size(data: Tensor): return prod(data if isinstance(data, list) else data.shape)
|
||||
def Flatten(input: Tensor, axis=1): return input.reshape(prod((1,) + input.shape[0:axis]), -1)
|
||||
def Reshape(data: Tensor, shape: Tensor, allowzero=None): return data.reshape([int(x) if x != 0 else data.shape[i] for i,x in enumerate(safe_numpy(shape))])
|
||||
|
|
|
@ -109,10 +109,6 @@ def thneed_test_onnx(onnx_data, output_fn):
|
|||
# non thneed
|
||||
run_onnx = get_run_onnx(onnx_model)
|
||||
new_tinygrad_out = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).numpy()
|
||||
for i,(x,y) in enumerate(zip(new_torch_out.flatten().tolist(), new_tinygrad_out.flatten().tolist())):
|
||||
if abs(x-y) > 100:
|
||||
print(i, x, y)
|
||||
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("classic self-test passed!")
|
||||
else:
|
||||
|
@ -142,19 +138,20 @@ if __name__ == "__main__":
|
|||
schedule, schedule_independent, inputs = get_schedule(onnx_data)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
schedule = fix_schedule_for_images(schedule)
|
||||
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
||||
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
||||
|
||||
if GRAPH:
|
||||
for si in schedule_input: log_schedule_item(si)
|
||||
for si in schedule: log_schedule_item(si)
|
||||
|
||||
run_schedule(schedule_independent, disable_logging=True)
|
||||
run_schedule(schedule_input)
|
||||
with Context(DEBUG=2, BEAM=getenv("LATEBEAM")):
|
||||
schedule = fix_schedule_for_images(schedule)
|
||||
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
||||
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
||||
|
||||
if GRAPH:
|
||||
for si in schedule_input: log_schedule_item(si)
|
||||
for si in schedule: log_schedule_item(si)
|
||||
|
||||
GlobalCounters.reset()
|
||||
run_schedule(schedule)
|
||||
run_schedule(schedule[:])
|
||||
|
||||
output_fn = sys.argv[2] if len(sys.argv) >= 3 else "/tmp/output.thneed"
|
||||
schedule_to_thneed(schedule, output_fn)
|
||||
|
|
Loading…
Reference in New Issue