openpilot compile cleanups

This commit is contained in:
George Hotz 2023-02-20 09:16:03 -08:00
parent ea13504f35
commit 8b0082540b
2 changed files with 11 additions and 35 deletions

View File

@ -6,6 +6,8 @@ if os.getenv("OPT", None) is None:
os.environ['OPT'] = '99'
if os.getenv("GPU", None) is None:
os.environ['GPU'] = '1'
if os.getenv("IMAGE", None) is None:
os.environ['IMAGE'] = '2'
from tinygrad.helpers import getenv
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
@ -26,34 +28,9 @@ OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/6c5693e965b9c63f8678
np.random.seed(1337)
def get_random_input_tensors(input_shapes):
np_inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
"desire": np.zeros((1,100, 8)),
"traffic_convention": np.array([[1., 0.]]),
#"features_buffer": np.random.randn(*(1, 99, 128))
"features_buffer": np.random.randn(*input_shapes['features_buffer'])
#"initial_state": np.zeros((1, 768))
}
if getenv("ZERO_OUT", 0):
np_inputs = {k:v*0 for k,v in np_inputs.items()}
for k,v in np_inputs.items():
assert v.shape == input_shapes[k], f"{k} shape mismatch, {v.shape} {input_shapes[k]}"
#import pickle
#frames, big_frames, last_state, frame_inputs, policy_outs = pickle.load(open("openpilot/test/frame_0.pkl", "rb"))
#np_inputs["input_imgs"] = frames
#np_inputs["big_input_imgs"] = big_frames
#np_inputs["initial_state"] = last_state[0]
#for i,k in enumerate(np_inputs.keys()):
# dat = open("/home/batman/openpilot/xx/ml_tools/snpe/compile_test_data/dlc_input_%d" % i, "rb").read()
# np_inputs[k] = np.frombuffer(dat, np.float32).reshape(np_inputs[k].shape)
np_inputs = {k:v.astype(np.float32) for k,v in np_inputs.items()}
inputs = {k:Tensor(v.astype(np.float32), requires_grad=False) for k,v in np_inputs.items()}
for _,v in inputs.items(): v.realize()
# this 16 is a random scale factor
inputs = {k:Tensor.randn(*shp, requires_grad=False)*16 for k,shp in input_shapes.items()}
np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
return inputs, np_inputs
from extra.jit import TinyJit
@ -72,10 +49,7 @@ def compile(dat, output_fn):
onnx_model = onnx.load(io.BytesIO(dat))
run_onnx = get_run_onnx(onnx_model)
input_shapes = {}
for inp in onnx_model.graph.input:
input_shapes[inp.name] = tuple(x.dim_value for x in inp.type.tensor_type.shape.dim)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
inputs, np_inputs = get_random_input_tensors(input_shapes)
# run twice to trigger the JIT
@ -90,6 +64,7 @@ def compile(dat, output_fn):
for prg,args in model_exec.jit_cache:
real_clprg = prg.clprg
used_ops += real_clprg.op_estimate
# replace clprg with a fake program to log to cl_cache
prg.clprg = lambda *args: cl_cache.append((real_clprg, args))
prg(*args)
@ -111,7 +86,7 @@ def compile(dat, output_fn):
CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True)
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
# float32 only (fix this)
# testing is float32 only (fix this)
FLOAT16 = getenv("FLOAT16", 0)
if FLOAT16 == 0:
try:

View File

@ -3,7 +3,7 @@ import atexit
import itertools
import networkx as nx # type: ignore
from collections import defaultdict
from typing import Dict, List
from typing import Dict, List, Optional
from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType, LazyOp, get_buffers, get_lazyops
from tinygrad.helpers import getenv
@ -38,7 +38,8 @@ def get_sop(op : List[Op]):
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
return str(len(op))
def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : bool = GRAPH):
def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : Optional[bool] = None):
if show_graph is None: show_graph = GRAPH
if not DEBUG and not show_graph: return
op : List[Op] = [x.op for x in get_lazyops(ast)]
inp : List[DeviceBuffer] = get_buffers(ast)