ods777/tinygrad_repo/extra/onnx.py

210 lines
10 KiB
Python

from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
from typing import List,Dict
from onnx.onnx_pb import AttributeProto, ModelProto, TensorProto, TypeProto
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
# for onnx < 1.13
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
# global numpy cache for parameters
numpy_cache = {}
def safe_numpy(t) -> np.ndarray:
if not isinstance(t, Tensor): return t
global numpy_cache
if t not in numpy_cache:
if DEBUG >= 3: print("numpy cache miss", t)
tmp = t.numpy()
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
assert len(numpy_cache[t].shape) > 0
return numpy_cache[t]
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
def type_parse(type_proto: TypeProto):
ret = []
while True:
attr = type_proto.WhichOneof('value')
if attr == 'tensor_type':
if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): return () # variable type, unable to determine shape
elif not ret:
return tuple([x.dim_value for x in getattr(type_proto, attr).shape.dim])
else:
ret.extend([(x.dim_value,) for x in getattr(type_proto, attr).shape.dim])
return tuple(ret)
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
ret.append(1)
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (1,10,6,7):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.int32_data) > 0:
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
else:
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
else:
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
return ret
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
if a.type == AttributeProto.FLOAT: return float(a.f)
elif a.type == AttributeProto.INT: return int(a.i)
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}")
else: raise Exception(f"can't parse {a.type} {a}")
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
tensors: Dict[str, Tensor] = {}
# get weights and biases
for inp in onnx_model.graph.initializer:
if len(inp.raw_data) > 0:
tensors[inp.name] = buffer_parse(inp)
elif len(inp.float_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.raw_data) == 0:
tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False)
else:
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
print(inp)
raise Exception("no data")
# preparse the attributes
attribute_dict = {}
domain = ""
for num,n in enumerate(onnx_model.graph.node):
attribute_dict[num] = attribute_to_dict(n.attribute)
if n.domain: domain = n.domain
onnx_model_version = onnx_model.opset_import[0].version
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
input_tensors: Dict[str,Tensor] = {}
intermediate_tensors: Dict[str,Tensor] = {}
output_tensor_names = [x.name for x in onnx_model.graph.output]
# get inputs
for inp in onnx_model.graph.input:
if inp.name in tensors: continue
shape = type_parse(inp.type)
if inp.name in inputs:
if isinstance(inputs[inp.name], Tensor):
input_tensors[inp.name] = inputs[inp.name]
elif isinstance(inputs[inp.name], list):
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
else:
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
if shape: # if only input_tensor is not variable type
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
else:
raise Exception(f"no data for {inp.name} with shape {shape}")
def fetch_tensor(x: str):
if x in tensors: return tensors[x]
if x in intermediate_tensors: return intermediate_tensors[x]
if x != str(): return input_tensors[x]
return None
for num,n in enumerate(onnx_model.graph.node):
inp: List[Tensor] = []
if debug >= 3: print("inputs:")
for x in n.input:
t = fetch_tensor(x)
if debug >= 3: print(f"\t{x} - {t}")
inp.append(t)
opt: Dict = attribute_dict[num]
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
# some ops live here because they require some local variables
if n.op_type == "Split": # have to use n.output for cases when num_outputs is absent
axis = opt.get("axis", 0)
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
if split is None:
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
for i in range(inp[0].shape[axis] % len(n.output)):
split[i] += 1
i, ret = 0, []
arg = [(0,x) for x in inp[0].shape]
for s in split:
arg[axis] = (i,i+s)
ret.append(inp[0].shrink(arg=tuple(arg)))
i = i+s
ret = tuple(ret)
elif n.op_type == "Slice": # need to check onnx_model_version
if onnx_model_version < 10:
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
else:
starts, ends = inp[1:3]
axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist()
steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
arg = [(0,x,1) for x in inp[0].shape]
for i, axis in enumerate(axes):
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
starts[i], ends[i] = starts[i] + inp[0].shape[axis] if starts[i] < 0 else starts[i], ends[i] + inp[0].shape[axis] if ends[i] < 0 else ends[i]
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
arg[axis] = (starts[i], ends[i], steps[i])
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
else: ret = inp[0].__getitem__(tuple([slice(s,e,st) for s,e,st in arg]))
elif n.op_type == "Gradient": # need to call backward on intermediate_tensors
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
y = opt["y"]
intermediate_tensors[y].backward()
ret = tuple([t.grad for t in inp])
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= onnx_model_version:
real_fxn = fxn[k]
else:
real_fxn = fxn
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
if debug >= 2: print("outputs:")
for i in range(len(n.output)):
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
intermediate_tensors[n.output[i]] = ret[i]
if num == ONNXLIMIT:
output_tensor_names = n.output
break
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
return run_onnx