diff --git a/openpilot/compile.py b/openpilot/compile.py index 43463a86..093a1938 100644 --- a/openpilot/compile.py +++ b/openpilot/compile.py @@ -6,6 +6,8 @@ if os.getenv("OPT", None) is None: os.environ['OPT'] = '99' if os.getenv("GPU", None) is None: os.environ['GPU'] = '1' +if os.getenv("IMAGE", None) is None: + os.environ['IMAGE'] = '2' from tinygrad.helpers import getenv ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0) @@ -26,34 +28,9 @@ OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/6c5693e965b9c63f8678 np.random.seed(1337) def get_random_input_tensors(input_shapes): - np_inputs = { - "input_imgs": np.random.randn(*(1, 12, 128, 256))*256, - "big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256, - "desire": np.zeros((1,100, 8)), - "traffic_convention": np.array([[1., 0.]]), - #"features_buffer": np.random.randn(*(1, 99, 128)) - "features_buffer": np.random.randn(*input_shapes['features_buffer']) - #"initial_state": np.zeros((1, 768)) - } - if getenv("ZERO_OUT", 0): - np_inputs = {k:v*0 for k,v in np_inputs.items()} - - for k,v in np_inputs.items(): - assert v.shape == input_shapes[k], f"{k} shape mismatch, {v.shape} {input_shapes[k]}" - - #import pickle - #frames, big_frames, last_state, frame_inputs, policy_outs = pickle.load(open("openpilot/test/frame_0.pkl", "rb")) - #np_inputs["input_imgs"] = frames - #np_inputs["big_input_imgs"] = big_frames - #np_inputs["initial_state"] = last_state[0] - - #for i,k in enumerate(np_inputs.keys()): - # dat = open("/home/batman/openpilot/xx/ml_tools/snpe/compile_test_data/dlc_input_%d" % i, "rb").read() - # np_inputs[k] = np.frombuffer(dat, np.float32).reshape(np_inputs[k].shape) - - np_inputs = {k:v.astype(np.float32) for k,v in np_inputs.items()} - inputs = {k:Tensor(v.astype(np.float32), requires_grad=False) for k,v in np_inputs.items()} - for _,v in inputs.items(): v.realize() + # this 16 is a random scale factor + inputs = {k:Tensor.randn(*shp, requires_grad=False)*16 for k,shp in input_shapes.items()} + np_inputs = {k:v.realize().numpy() for k,v in inputs.items()} return inputs, np_inputs from extra.jit import TinyJit @@ -72,10 +49,7 @@ def compile(dat, output_fn): onnx_model = onnx.load(io.BytesIO(dat)) run_onnx = get_run_onnx(onnx_model) - - input_shapes = {} - for inp in onnx_model.graph.input: - input_shapes[inp.name] = tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) + input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} inputs, np_inputs = get_random_input_tensors(input_shapes) # run twice to trigger the JIT @@ -90,6 +64,7 @@ def compile(dat, output_fn): for prg,args in model_exec.jit_cache: real_clprg = prg.clprg used_ops += real_clprg.op_estimate + # replace clprg with a fake program to log to cl_cache prg.clprg = lambda *args: cl_cache.append((real_clprg, args)) prg(*args) @@ -111,7 +86,7 @@ def compile(dat, output_fn): CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True) np.testing.assert_allclose(thneed_out, tinygrad_out.numpy()) - # float32 only (fix this) + # testing is float32 only (fix this) FLOAT16 = getenv("FLOAT16", 0) if FLOAT16 == 0: try: diff --git a/tinygrad/graph.py b/tinygrad/graph.py index f5e57040..87c24aef 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -3,7 +3,7 @@ import atexit import itertools import networkx as nx # type: ignore from collections import defaultdict -from typing import Dict, List +from typing import Dict, List, Optional from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType, LazyOp, get_buffers, get_lazyops from tinygrad.helpers import getenv @@ -38,7 +38,8 @@ def get_sop(op : List[Op]): if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:2] for y in op][::-1]) return str(len(op)) -def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : bool = GRAPH): +def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : Optional[bool] = None): + if show_graph is None: show_graph = GRAPH if not DEBUG and not show_graph: return op : List[Op] = [x.op for x in get_lazyops(ast)] inp : List[DeviceBuffer] = get_buffers(ast)