diff --git a/accel/opencl/preprocessing.py b/accel/opencl/preprocessing.py index 4cc306aa..120b2ab0 100644 --- a/accel/opencl/preprocessing.py +++ b/accel/opencl/preprocessing.py @@ -62,7 +62,12 @@ def preprocessing_op(x,w,C): w = w.contiguous_op() # early realize on the weights - w.realize().image + bw = w + while getattr(bw, 'op', None) and len(bw.op.src) == 1: + bw = bw.op.src[0] + if bw.realized: + # weights are static + w.realize().image return x,w,C def postprocessing_op(ret, C, C_initial): diff --git a/extra/onnx.py b/extra/onnx.py index efa3cf85..b5d28a9e 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -87,10 +87,7 @@ def get_run_onnx(onnx_model): elif n.op_type == "Sigmoid": ret = inp[0].sigmoid() elif n.op_type == "Tanh": ret = inp[0].tanh() elif n.op_type == "Softmax": ret = inp[0].softmax() - elif n.op_type == "MatMul": - if inp[1].lazydata.realized is not None: - print("WARNING: matmul weights are not static") - ret = inp[0].matmul(inp[1]) + elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1]) # one liners elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha']) elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt.get('min', -3.4e38), opt.get('max', 3.4e38)))) diff --git a/extra/thneed.py b/extra/thneed.py index b37539c6..6c44074a 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -281,7 +281,8 @@ class Thneed: local_cl_cache = [] for prg, args in self.cl_cache: args = list(args) - if args[1] is None and len(args[0]) == 2: + # TODO: WTF is wrong with to_image? + if args[1] is None and len(args[0]) == 2 and 'to_image' not in prg.name: args[1] = [min(MAX_WORKGROUP, args[0][0]), 1] try: e = prg.clprg(CL().cl_queue, *args) diff --git a/openpilot/compile.py b/openpilot/compile.py index 29dbb5e5..b66c551f 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -93,6 +93,7 @@ def compile(dat, output_fn): # real run inputs, np_inputs = get_random_input_tensors(input_shapes) + print("***** REAL RUN *****") tinygrad_out = run_onnx(inputs)['outputs'] # note, since CL.CACHE is enabled, it doesn't actually run the kernels