fix a bunch of tests (#856)

This commit is contained in:
Friedrich Carl Eichenroth 2023-05-30 02:48:26 +02:00 committed by GitHub
parent 502e33652f
commit f91f28d9e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 7 deletions

View File

@ -137,14 +137,16 @@ def get_run_onnx(onnx_model):
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
elif n.op_type in ["Add", "Sub", "Mul", "Pow"]:
# TODO: add this to tinygrad? i don't think it's in torch
if len(inp[0].shape) != len(inp[1].shape) and prod(inp[0].shape) == prod(inp[1].shape):
inp[1] = inp[1].reshape(inp[0].shape)
inp_0 = inp[0] if isinstance(inp[0], Tensor) else Tensor(np.array(inp[0], dtype=np.float32), requires_grad=False)
inp_1 = inp[1] if isinstance(inp[1], Tensor) else Tensor(np.array(inp[1], dtype=np.float32), requires_grad=False)
if (len(inp_0.shape) != len(inp_1.shape)) and (prod(inp_0.shape) == prod(inp_1.shape)):
inp_1 = inp_1.reshape(inp_0.shape)
# TODO: is this right?
if 'broadcast' in opt: inp[1] = inp[1].reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp[0].shape))])
if n.op_type == "Add": ret = inp[0] + inp[1]
if n.op_type == "Sub": ret = inp[0] - inp[1]
if n.op_type == "Mul": ret = inp[0] * inp[1]
if n.op_type == "Pow": ret = inp[0] ** inp[1]
if 'broadcast' in opt: inp_1 = inp_1.reshape([-1 if i == opt['broadcast'] else 1 for i in range(len(inp_0.shape))])
if n.op_type == "Add": ret = inp_0 + inp_1
if n.op_type == "Sub": ret = inp_0 - inp_1
if n.op_type == "Mul": ret = inp_0 * inp_1
if n.op_type == "Pow": ret = inp_0 ** inp_1
elif n.op_type == "Split":
if 'split' not in opt: opt['split'] = [int(x) for x in safe_numpy(inp[1])] # split can be a tensor
if 'axis' not in opt: opt['axis'] = 0