ngrl stuff

This commit is contained in:
Yassine Yousfi 2022-10-04 09:35:37 -07:00
parent 392e57aea7
commit 82ca9c6666
2 changed files with 10 additions and 10 deletions

View File

@ -88,12 +88,13 @@ def get_run_onnx(onnx_model):
elif n.op_type == "Constant": ret = opt['value']
elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) for x in inp[1].numpy()])
elif n.op_type == "Gather":
# TODO: is this correct? seems to work for simple gather ops
axis = opt['axis']
shape = list(inp[0].shape)
assert axis==0, 'untested for other values'
indices = [shape[axis]+int(x) if x<0 else int(x) for x in inp[1].numpy()]
shape[axis] = 1
ret = inp[0][indices[0]].reshape(shape).cat(*[inp[0][x].reshape(shape) for x in indices[1:]], dim=axis)
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices]
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
elif n.op_type == "BatchNormalization":
invstd = inp[4].add(opt.get('epsilon', 1e-5))**-0.5
ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], invstd)
@ -150,7 +151,7 @@ def get_run_onnx(onnx_model):
axis, starts, ends = int(axes.numpy()[0]), int(starts.numpy()[0]), int(ends.numpy()[0])
ends = min(ends, inp[0].shape[axis])
starts = starts + inp[0].shape[axis] if starts < 0 else starts
arg[0] = (starts, ends)
arg[axis] = (starts, ends)
ret = inp[0].slice(arg=arg)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
@ -162,4 +163,3 @@ def get_run_onnx(onnx_model):
return {outp.name:intermediate_tensors[outp.name] for outp in onnx_model.graph.output}
return run_onnx

View File

@ -32,9 +32,9 @@ def get_random_input_tensors():
np_inputs = {
"input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
"desire": np.zeros((1, 8)),
"desire": np.zeros((1,100, 8)),
"traffic_convention": np.array([[1., 0.]]),
"initial_state": np.random.randn(*(1, 512))
"features_buffer": np.random.randn(*(1, 99, 2048))
#"initial_state": np.zeros((1, 768))
}
@ -300,9 +300,9 @@ def compile(input, output_fn):
"local_work_size": [1 for x in args[0]] if args[1] is None else args[1],
"num_args": len(args)-2,
"args": targs,
"args_size": args_size
"args_size": args_size
})
jdat['outputs'] = [{
"buffer_id": struct.pack("Q", tinygrad_out.lazydata.realized.cl.global_id).decode("latin_1"),
"size": tinygrad_out.lazydata.realized.cl.size,
@ -316,7 +316,7 @@ def compile(input, output_fn):
"name": k
} for k,v in inputs.items()][::-1]
print(jdat['inputs'])
print(f"saving {len([x for x in jdat['objects'] if x['needs_load']])} objects")
print("saving thneed")