diff --git a/extra/onnx.py b/extra/onnx.py index cae81bd2..814b6580 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -4,8 +4,7 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE from tinygrad.tensor import Tensor from tinygrad.helpers import prod from tinygrad.nn import batch_normalize - -MAX_CONVS = int(os.getenv("MAX_CONVS", -1)) +from tinygrad.ops import DEBUG def get_run_onnx(onnx_model): def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim) @@ -38,6 +37,8 @@ def get_run_onnx(onnx_model): print(inp.name, inp.dims, inp.data_type, len(inp.raw_data)) print(inp) raise Exception("no data") + if DEBUG >= 1: + print("realize", inp.name) tensors[inp.name].realize() def run_onnx(inputs={}, debug=False): @@ -60,7 +61,6 @@ def get_run_onnx(onnx_model): else: raise Exception(f"no data for {inp.name} with shape {shape}") - conv_count = 0 for num,n in enumerate(onnx_model.graph.node): if debug: print(f"{num}: op {n.op_type}") inp = [tensors[x] if x in tensors else (intermediate_tensors[x] if x in intermediate_tensors else input_tensors[x]) for x in n.input] @@ -71,7 +71,9 @@ def get_run_onnx(onnx_model): elif n.op_type == "Sigmoid": ret = inp[0].sigmoid() elif n.op_type == "Tanh": ret = inp[0].tanh() elif n.op_type == "Softmax": ret = inp[0].softmax() - elif n.op_type == "MatMul": ret = inp[0].matmul(inp[1]) + elif n.op_type == "MatMul": + assert inp[1].lazydata.realized is not None + ret = inp[0].matmul(inp[1]) # one liners elif n.op_type == "Elu": ret = inp[0].elu(alpha=opt['alpha']) elif n.op_type == "Clip": ret = inp[0].clip(*(inp[1:] if len(inp) > 1 else (opt.get('min', -3.4e38), opt.get('max', 3.4e38)))) @@ -110,10 +112,6 @@ def get_run_onnx(onnx_model): else: x = x.pad2d((opt['pads'][0], opt['pads'][2], opt['pads'][1], opt['pads'][3])) ret = x.conv2d(w, b, stride=opt['strides'], groups=opt.get('group', 1)) - conv_count += 1 - if conv_count == MAX_CONVS: - ret.numpy() - break elif n.op_type in ["Add", "Sub", "Mul"]: # TODO: add this to tinygrad? i don't think it's in torch if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape): diff --git a/extra/thneed.py b/extra/thneed.py index 18e6f19b..b37539c6 100644 --- a/extra/thneed.py +++ b/extra/thneed.py @@ -5,7 +5,7 @@ import struct import json import traceback import numpy as np -from tinygrad.llops.ops_gpu import CL, CLProgram, CLBuffer +from tinygrad.llops.ops_gpu import CL, CLProgram from tinygrad.helpers import prod import pyopencl as cl import networkx as nx @@ -185,7 +185,7 @@ class Thneed: }) if needs_load: data = np.empty(a.size//4, dtype=np.float32) - CL.enqueue_copy(data, a, is_blocking=True) + cl.enqueue_copy(CL().cl_queue, data, a, is_blocking=True) weights.append(data.tobytes()) elif isinstance(a, cl.Image): needs_load = a in self.buffers_to_save @@ -195,7 +195,7 @@ class Thneed: buf = cl.Buffer(CL().cl_ctx, cl.mem_flags.READ_WRITE, size=size * (2 if FLOAT16 else 1)) # zero out the buffer - CL.enqueue_copy(buf, b'\x00'*buf.size, is_blocking=True) + cl.enqueue_copy(CL().cl_queue, buf, b'\x00'*buf.size, is_blocking=True) CLProgram("from_image_strided", """ __kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) { @@ -215,7 +215,7 @@ class Thneed: if needs_load: data = np.empty(size//(2 if FLOAT16 else 4), dtype=np.float32) - CL.enqueue_copy(data, buf, is_blocking=True) + cl.enqueue_copy(CL().cl_queue, data, buf, is_blocking=True) if FLOAT16: data = data.astype(np.float16) weights.append(data.tobytes()) else: diff --git a/openpilot/compile.py b/openpilot/compile.py index 573f89bb..29dbb5e5 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -85,11 +85,11 @@ def compile(dat, output_fn): et = time.monotonic() print(f"ran openpilot model in {(et-st)*1000.0:.2f} ms, waited {(mt2-mt)*1000.0:.2f} ms for realize, {(et-mt2)*1000.0:.2f} ms for GPU queue") - # realize all non GCed tensors (fix for batchnorm folding) - import gc - gc.collect() - for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]: - x.realize() + # realize all non GCed tensors (fix for batchnorm folding) + import gc + gc.collect() + for x in [x for x in gc.get_objects() if isinstance(x, Tensor)]: + x.realize() # real run inputs, np_inputs = get_random_input_tensors(input_shapes) @@ -108,6 +108,9 @@ def compile(dat, output_fn): CL.CACHE = None t.optimize_local_workgroup() + # save thneed (before run) + t.save(output_fn) + print(f"buffers to save: {len(t.buffers_to_save)}, outputs: {t.outputs}") t.run() @@ -116,9 +119,6 @@ def compile(dat, output_fn): CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True) np.testing.assert_allclose(thneed_out, tinygrad_out.numpy()) - # save thneed - t.save(output_fn) - # float32 only (fix this) FLOAT16 = int(os.getenv("FLOAT16", 0)) if FLOAT16 == 0: @@ -132,6 +132,20 @@ def compile(dat, output_fn): _, new_np_inputs = get_random_input_tensors(input_shapes) new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy() + # try old thneed with a different input + for k,v in t.inputs.items(): + CL.enqueue_copy(v, new_np_inputs[k], is_blocking=True) + + t.run() + old_thneed_out = np.empty((t.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) + CL.enqueue_copy(old_thneed_out, t.outputs[0], is_blocking=True) + + # compare thneed (rerun) with torch + np.testing.assert_allclose(new_torch_out, old_thneed_out, atol=1e-4, rtol=1e-2) + + # load thneed and try that + _, new_np_inputs = get_random_input_tensors(input_shapes) + new_torch_out = run_onnx_torch(onnx_model, new_np_inputs).numpy() nt = Thneed() nt.load(output_fn) @@ -142,6 +156,8 @@ def compile(dat, output_fn): nt.run() new_thneed_out = np.empty((nt.outputs[0].size//4,), dtype=np.float32).reshape(tinygrad_out.shape) CL.enqueue_copy(new_thneed_out, nt.outputs[0], is_blocking=True) + + # compare torch to thneed np.testing.assert_allclose(new_torch_out, new_thneed_out, atol=1e-4, rtol=1e-2) print("thneed self-test passed!") except ModuleNotFoundError: