mirror of https://github.com/commaai/tinygrad.git
Revert hax
This commit is contained in:
parent
8b42a971c6
commit
3a11fa9d71
|
@ -34,7 +34,6 @@ def compile():
|
|||
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||
input_types = {inp.name: np.float32 for inp in onnx_model.graph.input}
|
||||
if 'input_img' in input_shapes:
|
||||
input_shapes['input_img'] = (1, 1812, 1928)
|
||||
input_types['input_img'] = np.uint8
|
||||
else:
|
||||
input_types['input_imgs'] = np.uint8
|
||||
|
@ -42,20 +41,8 @@ def compile():
|
|||
Tensor.manual_seed(100)
|
||||
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
|
||||
print("created tensors")
|
||||
|
||||
# TODO remove this hack from dm
|
||||
if 'input_img' in input_shapes: #DM model
|
||||
def fun_to_jit(kwargs):
|
||||
MODEL_WIDTH = 1440
|
||||
MODEL_HEIGHT = 960
|
||||
v_offset = kwargs['input_img'].shape[1] * 2 // 3 - MODEL_HEIGHT
|
||||
h_offset = (kwargs['input_img'].shape[2] - MODEL_WIDTH) // 2
|
||||
kwargs['input_img'] = kwargs['input_img'][:,v_offset:v_offset+MODEL_HEIGHT, h_offset:h_offset+MODEL_WIDTH].reshape((1,-1))
|
||||
return run_onnx(kwargs)
|
||||
else:
|
||||
fun_to_jit = run_onnx
|
||||
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: fun_to_jit(kwargs), prune=True)
|
||||
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
|
||||
for i in range(3):
|
||||
GlobalCounters.reset()
|
||||
print(f"run {i}")
|
||||
|
|
Loading…
Reference in New Issue