mirror of https://github.com/commaai/tinygrad.git
openpilot compile cleanups
This commit is contained in:
parent
ea13504f35
commit
8b0082540b
|
@ -6,6 +6,8 @@ if os.getenv("OPT", None) is None:
|
||||||
os.environ['OPT'] = '99'
|
os.environ['OPT'] = '99'
|
||||||
if os.getenv("GPU", None) is None:
|
if os.getenv("GPU", None) is None:
|
||||||
os.environ['GPU'] = '1'
|
os.environ['GPU'] = '1'
|
||||||
|
if os.getenv("IMAGE", None) is None:
|
||||||
|
os.environ['IMAGE'] = '2'
|
||||||
|
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
|
ALLOWED_KERNEL_COUNT = getenv("ALLOWED_KERNEL_COUNT", 0)
|
||||||
|
@ -26,34 +28,9 @@ OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/6c5693e965b9c63f8678
|
||||||
|
|
||||||
np.random.seed(1337)
|
np.random.seed(1337)
|
||||||
def get_random_input_tensors(input_shapes):
|
def get_random_input_tensors(input_shapes):
|
||||||
np_inputs = {
|
# this 16 is a random scale factor
|
||||||
"input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
|
inputs = {k:Tensor.randn(*shp, requires_grad=False)*16 for k,shp in input_shapes.items()}
|
||||||
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
|
np_inputs = {k:v.realize().numpy() for k,v in inputs.items()}
|
||||||
"desire": np.zeros((1,100, 8)),
|
|
||||||
"traffic_convention": np.array([[1., 0.]]),
|
|
||||||
#"features_buffer": np.random.randn(*(1, 99, 128))
|
|
||||||
"features_buffer": np.random.randn(*input_shapes['features_buffer'])
|
|
||||||
#"initial_state": np.zeros((1, 768))
|
|
||||||
}
|
|
||||||
if getenv("ZERO_OUT", 0):
|
|
||||||
np_inputs = {k:v*0 for k,v in np_inputs.items()}
|
|
||||||
|
|
||||||
for k,v in np_inputs.items():
|
|
||||||
assert v.shape == input_shapes[k], f"{k} shape mismatch, {v.shape} {input_shapes[k]}"
|
|
||||||
|
|
||||||
#import pickle
|
|
||||||
#frames, big_frames, last_state, frame_inputs, policy_outs = pickle.load(open("openpilot/test/frame_0.pkl", "rb"))
|
|
||||||
#np_inputs["input_imgs"] = frames
|
|
||||||
#np_inputs["big_input_imgs"] = big_frames
|
|
||||||
#np_inputs["initial_state"] = last_state[0]
|
|
||||||
|
|
||||||
#for i,k in enumerate(np_inputs.keys()):
|
|
||||||
# dat = open("/home/batman/openpilot/xx/ml_tools/snpe/compile_test_data/dlc_input_%d" % i, "rb").read()
|
|
||||||
# np_inputs[k] = np.frombuffer(dat, np.float32).reshape(np_inputs[k].shape)
|
|
||||||
|
|
||||||
np_inputs = {k:v.astype(np.float32) for k,v in np_inputs.items()}
|
|
||||||
inputs = {k:Tensor(v.astype(np.float32), requires_grad=False) for k,v in np_inputs.items()}
|
|
||||||
for _,v in inputs.items(): v.realize()
|
|
||||||
return inputs, np_inputs
|
return inputs, np_inputs
|
||||||
|
|
||||||
from extra.jit import TinyJit
|
from extra.jit import TinyJit
|
||||||
|
@ -72,10 +49,7 @@ def compile(dat, output_fn):
|
||||||
|
|
||||||
onnx_model = onnx.load(io.BytesIO(dat))
|
onnx_model = onnx.load(io.BytesIO(dat))
|
||||||
run_onnx = get_run_onnx(onnx_model)
|
run_onnx = get_run_onnx(onnx_model)
|
||||||
|
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
|
||||||
input_shapes = {}
|
|
||||||
for inp in onnx_model.graph.input:
|
|
||||||
input_shapes[inp.name] = tuple(x.dim_value for x in inp.type.tensor_type.shape.dim)
|
|
||||||
|
|
||||||
inputs, np_inputs = get_random_input_tensors(input_shapes)
|
inputs, np_inputs = get_random_input_tensors(input_shapes)
|
||||||
# run twice to trigger the JIT
|
# run twice to trigger the JIT
|
||||||
|
@ -90,6 +64,7 @@ def compile(dat, output_fn):
|
||||||
for prg,args in model_exec.jit_cache:
|
for prg,args in model_exec.jit_cache:
|
||||||
real_clprg = prg.clprg
|
real_clprg = prg.clprg
|
||||||
used_ops += real_clprg.op_estimate
|
used_ops += real_clprg.op_estimate
|
||||||
|
# replace clprg with a fake program to log to cl_cache
|
||||||
prg.clprg = lambda *args: cl_cache.append((real_clprg, args))
|
prg.clprg = lambda *args: cl_cache.append((real_clprg, args))
|
||||||
prg(*args)
|
prg(*args)
|
||||||
|
|
||||||
|
@ -111,7 +86,7 @@ def compile(dat, output_fn):
|
||||||
CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True)
|
CL.enqueue_copy(thneed_out, t.outputs[0], is_blocking=True)
|
||||||
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
|
np.testing.assert_allclose(thneed_out, tinygrad_out.numpy())
|
||||||
|
|
||||||
# float32 only (fix this)
|
# testing is float32 only (fix this)
|
||||||
FLOAT16 = getenv("FLOAT16", 0)
|
FLOAT16 = getenv("FLOAT16", 0)
|
||||||
if FLOAT16 == 0:
|
if FLOAT16 == 0:
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import atexit
|
||||||
import itertools
|
import itertools
|
||||||
import networkx as nx # type: ignore
|
import networkx as nx # type: ignore
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List
|
from typing import Dict, List, Optional
|
||||||
from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType, LazyOp, get_buffers, get_lazyops
|
from tinygrad.ops import DeviceBuffer, DEBUG, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, Op, OpType, LazyOp, get_buffers, get_lazyops
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
|
|
||||||
|
@ -38,7 +38,8 @@ def get_sop(op : List[Op]):
|
||||||
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
|
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:2] for y in op][::-1])
|
||||||
return str(len(op))
|
return str(len(op))
|
||||||
|
|
||||||
def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : bool = GRAPH):
|
def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : Optional[bool] = None):
|
||||||
|
if show_graph is None: show_graph = GRAPH
|
||||||
if not DEBUG and not show_graph: return
|
if not DEBUG and not show_graph: return
|
||||||
op : List[Op] = [x.op for x in get_lazyops(ast)]
|
op : List[Op] = [x.op for x in get_lazyops(ast)]
|
||||||
inp : List[DeviceBuffer] = get_buffers(ast)
|
inp : List[DeviceBuffer] = get_buffers(ast)
|
||||||
|
|
Loading…
Reference in New Issue