mirror of https://github.com/commaai/tinygrad.git
fix a bunch of tests (#856)
This commit is contained in:
parent
502e33652f
commit
f91f28d9e2
|
@ -137,14 +137,16 @@ def get_run_onnx(onnx_model):
|
|||
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
|
||||
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
|
||||
# TODO: add this to tinygrad? i don't think it's in torch
|
||||
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
|
||||
inp[1] = inp[1].reshape(inp[0].shape)
|
||||
inp_0 = inp[0] if isinstance(inp[0], Tensor) else Tensor(np.array(inp[0], dtype=np.float32), requires_grad=False)
|
||||
inp_1 = inp[1] if isinstance(inp[1], Tensor) else Tensor(np.array(inp[1], dtype=np.float32), requires_grad=False)
|
||||
if (len(inp_0.shape) != len(inp_1.shape)) and (prod(inp_0.shape) == prod(inp_1.shape)):
|
||||
inp_1 = inp_1.reshape(inp_0.shape)
|
||||
# TODO: is this right?
|
||||
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
|
||||
if n.op_type == "Add": ret = inp[0] + inp[1]
|
||||
if n.op_type == "Sub": ret = inp[0] - inp[1]
|
||||
if n.op_type == "Mul": ret = inp[0] * inp[1]
|
||||
if n.op_type == "Pow": ret = inp[0] ** inp[1]
|
||||
if 'broadcast' in opt: inp_1 = inp_1.reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp_0.shape))])
|
||||
if n.op_type == "Add": ret = inp_0 + inp_1
|
||||
if n.op_type == "Sub": ret = inp_0 - inp_1
|
||||
if n.op_type == "Mul": ret = inp_0 * inp_1
|
||||
if n.op_type == "Pow": ret = inp_0 ** inp_1
|
||||
elif n.op_type == "Split":
|
||||
if 'split' not in opt: opt['split'] = [int(x) for x in safe_numpy(inp[1])] # split can be a tensor
|
||||
if 'axis' not in opt: opt['axis'] = 0
|
||||
|
|
Loading…
Reference in New Issue