support inputs

This commit is contained in:
George Hotz 2022-06-11 11:21:45 -07:00
parent 08de1aa636
commit 8440dbfa5d
3 changed files with 20 additions and 5 deletions

View File

@ -27,6 +27,7 @@ setup(name='tinygrad',
"pytest",
"torch",
"tqdm",
"onnx",
],
},
include_package_data=True)

View File

@ -6,7 +6,7 @@ import onnx
from extra.utils import fetch
from tinygrad.tensor import Tensor
def run_onnx(dat):
def run_onnx(dat, inputs={}):
onnx_model = onnx.load_model(dat)
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
@ -21,7 +21,14 @@ def run_onnx(dat):
# get inputs
for inp in onnx_model.graph.input:
tensors[inp.name] = Tensor.zeros(*shape_to_tuple(inp.type.tensor_type.shape))
shape = shape_to_tuple(inp.type.tensor_type.shape)
if inp.name in inputs:
input_shape = inputs[inp.name].shape
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
tensors[inp.name] = Tensor(inputs[inp.name].astype(np.float32))
else:
print(f"filling {inp.name} shape {shape} with 0")
tensors[inp.name] = Tensor.zeros(*shape)
# get weights and biases
for inp in onnx_model.graph.initializer:
@ -76,7 +83,14 @@ def run_onnx(dat):
class TestOpenpilotModel(unittest.TestCase):
def test(self):
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
out = run_onnx(io.BytesIO(dat))
inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
"desire": np.zeros((1, 8)),
"traffic_convention": np.array([[1., 0.]]),
"initial_state": np.zeros((1, 512))
}
out = run_onnx(io.BytesIO(dat), inputs)
print(out)
if __name__ == "__main__":

View File

@ -260,8 +260,8 @@ class Tensor:
def transpose(self, order=(1,0)):
return self.permute(order=order)
def flatten(self, dim=0):
return self.reshape(shape=tuple(list(self.shape[0:dim]) + [-1]))
def flatten(self, start_dim=0):
return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
def _canonicalize_reduce_axis(self, axis):
if axis is None: axis = range(len(self.shape))