tinygrad/extra/onnx.py

221 lines
10 KiB
Python

from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, DEBUG, dtypes
from typing import List, Dict
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto # onnx 1.50 uses serialized file (see onnx/onnx-ml.proto) as descriptors
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
# for onnx < 1.13
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
# global numpy cache for parameters
numpy_cache = {}
def safe_numpy(t) -> np.ndarray:
if not isinstance(t, Tensor): return t
global numpy_cache
if t not in numpy_cache:
if DEBUG >= 3: print("numpy cache miss", t)
tmp = t.numpy()
numpy_cache[t] = tmp if len(tmp.shape) else tmp.reshape(1)
assert len(numpy_cache[t].shape) > 0
return numpy_cache[t]
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
def type_parse(type_proto: TypeProto):
ret = []
while True:
attr = type_proto.WhichOneof('value')
if attr == 'tensor_type':
if "dim_value" not in getattr(type_proto, attr).shape.dim.__dir__(): return () # variable type, unable to determine shape
elif not ret:
return tuple([x.dim_value for x in getattr(type_proto, attr).shape.dim])
else:
ret.extend([(x.dim_value,) for x in getattr(type_proto, attr).shape.dim])
return tuple(ret)
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
ret.append(1)
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (1,10,6,7,5):
# TODO: this is shared with below
if len(inp.float_data) > 0:
ret = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
ret = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.int32_data) > 0:
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
else:
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
else:
raise Exception(f"bad data type {inp.name} {inp.dims} {inp.data_type}")
return ret
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
if a.type == AttributeProto.FLOAT: return float(a.f)
elif a.type == AttributeProto.INT: return int(a.i)
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
else: raise Exception(f"can't parse {a.type} {a}")
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
tensors: Dict[str, Tensor] = {}
# get weights and biases
for inp in onnx_model.graph.initializer:
if len(inp.raw_data) > 0:
tensors[inp.name] = buffer_parse(inp)
elif len(inp.float_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.float_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int64_data) > 0:
tensors[inp.name] = Tensor(np.array(inp.int64_data, dtype=np.int64).reshape(inp.dims), requires_grad=False)
elif len(inp.raw_data) == 0:
tensors[inp.name] = Tensor(np.array([], dtype=np.float32), requires_grad=False)
else:
print(inp.name, inp.dims, inp.data_type, len(inp.raw_data))
print(inp)
raise Exception("no data")
# preparse the attributes
attribute_dict = {}
domain = ""
for num,n in enumerate(onnx_model.graph.node):
attribute_dict[num] = attribute_to_dict(n.attribute)
if n.domain: domain = n.domain
onnx_model_version = onnx_model.opset_import[0].version
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
input_tensors: Dict[str,Tensor] = {}
intermediate_tensors: Dict[str,Tensor] = {}
output_tensor_names = [x.name for x in onnx_model.graph.output]
# get inputs
for inp in onnx_model.graph.input:
if inp.name in tensors: continue
shape = type_parse(inp.type)
if inp.name in inputs:
if isinstance(inputs[inp.name], Tensor):
input_tensors[inp.name] = inputs[inp.name]
elif isinstance(inputs[inp.name], list):
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
else:
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
if shape: # if only input_tensor is not variable type
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
else:
raise Exception(f"no data for {inp.name} with shape {shape}")
def fetch_tensor(x: str):
if x in tensors: return tensors[x]
if x in intermediate_tensors: return intermediate_tensors[x]
if x != str(): return input_tensors[x]
return None
for num,n in enumerate(onnx_model.graph.node):
inp: List[Tensor] = []
if debug >= 3: print("inputs:")
for x in n.input:
t = fetch_tensor(x)
if debug >= 3: print(f"\t{x} - {t}")
inp.append(t)
opt: Dict = attribute_dict[num]
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
# NOTE some ops live here because they require access to some local variables
# have to use n.output for cases when num_outputs is absent
if n.op_type in onnx_ops.tensor_methods:
ret = getattr(Tensor, n.op_type.lower())(*inp, **opt)
elif n.op_type == "Split":
axis = opt.get("axis", 0)
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
if split is None:
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
for i in range(inp[0].shape[axis] % len(n.output)):
split[i] += 1
i, ret = 0, []
arg = [(0,x) for x in inp[0].shape]
for s in split:
arg[axis] = (i,i+s)
ret.append(inp[0].shrink(arg=tuple(arg)))
i = i+s
ret = tuple(ret)
# need to check onnx_model_version
elif n.op_type == "Slice":
if onnx_model_version < 10:
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
else:
starts, ends = inp[1:3]
axes = safe_numpy(Tensor.arange(inp[0].ndim, dtype=dtypes.int32) if len(inp) <= 3 else inp[3]).tolist()
steps = safe_numpy(inp[4]) if len(inp) > 4 else [1]*inp[0].ndim
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
arg = [(0,x,1) for x in inp[0].shape]
for i, axis in enumerate(axes):
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
starts[i], ends[i] = starts[i] + inp[0].shape[axis] if starts[i] < 0 else starts[i], ends[i] + inp[0].shape[axis] if ends[i] < 0 else ends[i]
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
arg[axis] = (starts[i], ends[i], steps[i])
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
else: ret = inp[0].__getitem__(tuple([slice(s,e,st) for s,e,st in arg]))
# need to call backward on intermediate_tensors
elif n.op_type == "Gradient":
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
y = opt["y"]
intermediate_tensors[y].backward()
ret = tuple([t.grad for t in inp])
# onnx_ops.py
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= onnx_model_version:
real_fxn = fxn[k]
else:
real_fxn = fxn
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
if debug >= 2: print("outputs:")
for i in range(len(n.output)):
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
intermediate_tensors[n.output[i]] = ret[i]
if num == ONNXLIMIT:
output_tensor_names = n.output
break
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
return run_onnx