mirror of https://github.com/commaai/tinygrad.git
ngrl stuff
This commit is contained in:
parent
392e57aea7
commit
82ca9c6666
|
@ -88,12 +88,13 @@ def get_run_onnx(onnx_model):
|
|||
elif n.op_type == "Constant": ret = opt['value']
|
||||
elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) for x in inp[1].numpy()])
|
||||
elif n.op_type == "Gather":
|
||||
# TODO: is this correct? seems to work for simple gather ops
|
||||
axis = opt['axis']
|
||||
shape = list(inp[0].shape)
|
||||
assert axis==0, 'untested for other values'
|
||||
indices = [shape[axis]+int(x) if x<0 else int(x) for x in inp[1].numpy()]
|
||||
shape[axis] = 1
|
||||
ret = inp[0][indices[0]].reshape(shape).cat(*[inp[0][x].reshape(shape) for x in indices[1:]], dim=axis)
|
||||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(shape)] for i in indices]
|
||||
ret = inp[0].slice(arg=args[0]).cat(*[inp[0].slice(arg=arg) for arg in args[1:]], dim=axis)
|
||||
ret = ret.reshape([s for i,s in enumerate(shape) if i != axis]) if len(indices) == 1 else ret # squeeze if needed
|
||||
elif n.op_type == "BatchNormalization":
|
||||
invstd = inp[4].add(opt.get('epsilon', 1e-5))**-0.5
|
||||
ret = batch_normalize(inp[0], inp[1], inp[2], inp[3], invstd)
|
||||
|
@ -150,7 +151,7 @@ def get_run_onnx(onnx_model):
|
|||
axis, starts, ends = int(axes.numpy()[0]), int(starts.numpy()[0]), int(ends.numpy()[0])
|
||||
ends = min(ends, inp[0].shape[axis])
|
||||
starts = starts + inp[0].shape[axis] if starts < 0 else starts
|
||||
arg[0] = (starts, ends)
|
||||
arg[axis] = (starts, ends)
|
||||
ret = inp[0].slice(arg=arg)
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
|
@ -162,4 +163,3 @@ def get_run_onnx(onnx_model):
|
|||
|
||||
return {outp.name:intermediate_tensors[outp.name] for outp in onnx_model.graph.output}
|
||||
return run_onnx
|
||||
|
||||
|
|
|
@ -32,9 +32,9 @@ def get_random_input_tensors():
|
|||
np_inputs = {
|
||||
"input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
|
||||
"big_input_imgs": np.random.randn(*(1, 12, 128, 256))*256,
|
||||
"desire": np.zeros((1, 8)),
|
||||
"desire": np.zeros((1,100, 8)),
|
||||
"traffic_convention": np.array([[1., 0.]]),
|
||||
"initial_state": np.random.randn(*(1, 512))
|
||||
"features_buffer": np.random.randn(*(1, 99, 2048))
|
||||
#"initial_state": np.zeros((1, 768))
|
||||
}
|
||||
|
||||
|
@ -300,9 +300,9 @@ def compile(input, output_fn):
|
|||
"local_work_size": [1 for x in args[0]] if args[1] is None else args[1],
|
||||
"num_args": len(args)-2,
|
||||
"args": targs,
|
||||
"args_size": args_size
|
||||
"args_size": args_size
|
||||
})
|
||||
|
||||
|
||||
jdat['outputs'] = [{
|
||||
"buffer_id": struct.pack("Q", tinygrad_out.lazydata.realized.cl.global_id).decode("latin_1"),
|
||||
"size": tinygrad_out.lazydata.realized.cl.size,
|
||||
|
@ -316,7 +316,7 @@ def compile(input, output_fn):
|
|||
"name": k
|
||||
} for k,v in inputs.items()][::-1]
|
||||
print(jdat['inputs'])
|
||||
|
||||
|
||||
print(f"saving {len([x for x in jdat['objects'] if x['needs_load']])} objects")
|
||||
|
||||
print("saving thneed")
|
||||
|
|
Loading…
Reference in New Issue