Revert hax

This commit is contained in:
Bruce Wayne 2024-11-08 09:56:23 -08:00
parent 8b42a971c6
commit 3a11fa9d71
1 changed files with 1 additions and 14 deletions

View File

@ -34,7 +34,6 @@ def compile():
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
input_types = {inp.name: np.float32 for inp in onnx_model.graph.input}
if 'input_img' in input_shapes:
input_shapes['input_img'] = (1, 1812, 1928)
input_types['input_img'] = np.uint8
else:
input_types['input_imgs'] = np.uint8
@ -42,20 +41,8 @@ def compile():
Tensor.manual_seed(100)
new_inputs = {k:Tensor.randn(*shp, dtype=_from_np_dtype(input_types[k])).mul(8).realize() for k,shp in sorted(input_shapes.items())}
print("created tensors")
# TODO remove this hack from dm
if 'input_img' in input_shapes: #DM model
def fun_to_jit(kwargs):
MODEL_WIDTH = 1440
MODEL_HEIGHT = 960
v_offset = kwargs['input_img'].shape[1] * 2 // 3 - MODEL_HEIGHT
h_offset = (kwargs['input_img'].shape[2] - MODEL_WIDTH) // 2
kwargs['input_img'] = kwargs['input_img'][:,v_offset:v_offset+MODEL_HEIGHT, h_offset:h_offset+MODEL_WIDTH].reshape((1,-1))
return run_onnx(kwargs)
else:
fun_to_jit = run_onnx
run_onnx_jit = TinyJit(lambda **kwargs: fun_to_jit(kwargs), prune=True)
run_onnx_jit = TinyJit(lambda **kwargs: run_onnx(kwargs), prune=True)
for i in range(3):
GlobalCounters.reset()
print(f"run {i}")