support inputs

This commit is contained in:
George Hotz 2022-06-11 11:21:45 -07:00
parent 08de1aa636
commit 8440dbfa5d
3 changed files with 20 additions and 5 deletions

View File

@ -27,6 +27,7 @@ setup(name='tinygrad',
"pytest", "pytest",
"torch", "torch",
"tqdm", "tqdm",
"onnx",
], ],
}, },
include_package_data=True) include_package_data=True)

View File

@ -6,7 +6,7 @@ import onnx
from extra.utils import fetch from extra.utils import fetch
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
def run_onnx(dat): def run_onnx(dat, inputs={}):
onnx_model = onnx.load_model(dat) onnx_model = onnx.load_model(dat)
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim) def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
@ -21,7 +21,14 @@ def run_onnx(dat):
# get inputs # get inputs
for inp in onnx_model.graph.input: for inp in onnx_model.graph.input:
tensors[inp.name] = Tensor.zeros(*shape_to_tuple(inp.type.tensor_type.shape)) shape = shape_to_tuple(inp.type.tensor_type.shape)
if inp.name in inputs:
input_shape = inputs[inp.name].shape
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
tensors[inp.name] = Tensor(inputs[inp.name].astype(np.float32))
else:
print(f"filling {inp.name} shape {shape} with 0")
tensors[inp.name] = Tensor.zeros(*shape)
# get weights and biases # get weights and biases
for inp in onnx_model.graph.initializer: for inp in onnx_model.graph.initializer:
@ -76,7 +83,14 @@ def run_onnx(dat):
class TestOpenpilotModel(unittest.TestCase): class TestOpenpilotModel(unittest.TestCase):
def test(self): def test(self):
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx") dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
out = run_onnx(io.BytesIO(dat)) inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
"desire": np.zeros((1, 8)),
"traffic_convention": np.array([[1., 0.]]),
"initial_state": np.zeros((1, 512))
}
out = run_onnx(io.BytesIO(dat), inputs)
print(out) print(out)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -260,8 +260,8 @@ class Tensor:
def transpose(self, order=(1,0)): def transpose(self, order=(1,0)):
return self.permute(order=order) return self.permute(order=order)
def flatten(self, dim=0): def flatten(self, start_dim=0):
return self.reshape(shape=tuple(list(self.shape[0:dim]) + [-1])) return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
def _canonicalize_reduce_axis(self, axis): def _canonicalize_reduce_axis(self, axis):
if axis is None: axis = range(len(self.shape)) if axis is None: axis = range(len(self.shape))