mirror of https://github.com/commaai/tinygrad.git
support inputs
This commit is contained in:
parent
08de1aa636
commit
8440dbfa5d
1
setup.py
1
setup.py
|
@ -27,6 +27,7 @@ setup(name='tinygrad',
|
||||||
"pytest",
|
"pytest",
|
||||||
"torch",
|
"torch",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
|
"onnx",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
include_package_data=True)
|
include_package_data=True)
|
||||||
|
|
|
@ -6,7 +6,7 @@ import onnx
|
||||||
from extra.utils import fetch
|
from extra.utils import fetch
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
|
||||||
def run_onnx(dat):
|
def run_onnx(dat, inputs={}):
|
||||||
onnx_model = onnx.load_model(dat)
|
onnx_model = onnx.load_model(dat)
|
||||||
|
|
||||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||||
|
@ -21,7 +21,14 @@ def run_onnx(dat):
|
||||||
|
|
||||||
# get inputs
|
# get inputs
|
||||||
for inp in onnx_model.graph.input:
|
for inp in onnx_model.graph.input:
|
||||||
tensors[inp.name] = Tensor.zeros(*shape_to_tuple(inp.type.tensor_type.shape))
|
shape = shape_to_tuple(inp.type.tensor_type.shape)
|
||||||
|
if inp.name in inputs:
|
||||||
|
input_shape = inputs[inp.name].shape
|
||||||
|
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||||
|
tensors[inp.name] = Tensor(inputs[inp.name].astype(np.float32))
|
||||||
|
else:
|
||||||
|
print(f"filling {inp.name} shape {shape} with 0")
|
||||||
|
tensors[inp.name] = Tensor.zeros(*shape)
|
||||||
|
|
||||||
# get weights and biases
|
# get weights and biases
|
||||||
for inp in onnx_model.graph.initializer:
|
for inp in onnx_model.graph.initializer:
|
||||||
|
@ -76,7 +83,14 @@ def run_onnx(dat):
|
||||||
class TestOpenpilotModel(unittest.TestCase):
|
class TestOpenpilotModel(unittest.TestCase):
|
||||||
def test(self):
|
def test(self):
|
||||||
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
|
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
|
||||||
out = run_onnx(io.BytesIO(dat))
|
inputs = {
|
||||||
|
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||||
|
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||||
|
"desire": np.zeros((1, 8)),
|
||||||
|
"traffic_convention": np.array([[1., 0.]]),
|
||||||
|
"initial_state": np.zeros((1, 512))
|
||||||
|
}
|
||||||
|
out = run_onnx(io.BytesIO(dat), inputs)
|
||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
@ -260,8 +260,8 @@ class Tensor:
|
||||||
def transpose(self, order=(1,0)):
|
def transpose(self, order=(1,0)):
|
||||||
return self.permute(order=order)
|
return self.permute(order=order)
|
||||||
|
|
||||||
def flatten(self, dim=0):
|
def flatten(self, start_dim=0):
|
||||||
return self.reshape(shape=tuple(list(self.shape[0:dim]) + [-1]))
|
return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
|
||||||
|
|
||||||
def _canonicalize_reduce_axis(self, axis):
|
def _canonicalize_reduce_axis(self, axis):
|
||||||
if axis is None: axis = range(len(self.shape))
|
if axis is None: axis = range(len(self.shape))
|
||||||
|
|
Loading…
Reference in New Issue