mirror of https://github.com/commaai/tinygrad.git
support inputs
This commit is contained in:
parent
08de1aa636
commit
8440dbfa5d
1
setup.py
1
setup.py
|
@ -27,6 +27,7 @@ setup(name='tinygrad',
|
|||
"pytest",
|
||||
"torch",
|
||||
"tqdm",
|
||||
"onnx",
|
||||
],
|
||||
},
|
||||
include_package_data=True)
|
||||
|
|
|
@ -6,7 +6,7 @@ import onnx
|
|||
from extra.utils import fetch
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def run_onnx(dat):
|
||||
def run_onnx(dat, inputs={}):
|
||||
onnx_model = onnx.load_model(dat)
|
||||
|
||||
def shape_to_tuple(s): return tuple(x.dim_value for x in s.dim)
|
||||
|
@ -21,7 +21,14 @@ def run_onnx(dat):
|
|||
|
||||
# get inputs
|
||||
for inp in onnx_model.graph.input:
|
||||
tensors[inp.name] = Tensor.zeros(*shape_to_tuple(inp.type.tensor_type.shape))
|
||||
shape = shape_to_tuple(inp.type.tensor_type.shape)
|
||||
if inp.name in inputs:
|
||||
input_shape = inputs[inp.name].shape
|
||||
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
|
||||
tensors[inp.name] = Tensor(inputs[inp.name].astype(np.float32))
|
||||
else:
|
||||
print(f"filling {inp.name} shape {shape} with 0")
|
||||
tensors[inp.name] = Tensor.zeros(*shape)
|
||||
|
||||
# get weights and biases
|
||||
for inp in onnx_model.graph.initializer:
|
||||
|
@ -76,7 +83,14 @@ def run_onnx(dat):
|
|||
class TestOpenpilotModel(unittest.TestCase):
|
||||
def test(self):
|
||||
dat = fetch("https://github.com/commaai/openpilot/raw/7da48ebdba5e3cf4c0b8078c934bee9a199f0280/selfdrive/modeld/models/supercombo.onnx")
|
||||
out = run_onnx(io.BytesIO(dat))
|
||||
inputs = {
|
||||
"input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
"big_input_imgs": np.random.randn(*(1, 12, 128, 256)),
|
||||
"desire": np.zeros((1, 8)),
|
||||
"traffic_convention": np.array([[1., 0.]]),
|
||||
"initial_state": np.zeros((1, 512))
|
||||
}
|
||||
out = run_onnx(io.BytesIO(dat), inputs)
|
||||
print(out)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -260,8 +260,8 @@ class Tensor:
|
|||
def transpose(self, order=(1,0)):
|
||||
return self.permute(order=order)
|
||||
|
||||
def flatten(self, dim=0):
|
||||
return self.reshape(shape=tuple(list(self.shape[0:dim]) + [-1]))
|
||||
def flatten(self, start_dim=0):
|
||||
return self.reshape(shape=tuple(list(self.shape[0:start_dim]) + [-1]))
|
||||
|
||||
def _canonicalize_reduce_axis(self, axis):
|
||||
if axis is None: axis = range(len(self.shape))
|
||||
|
|
Loading…
Reference in New Issue