tinygrad/extra/onnx.py

220 lines
10 KiB
Python

from __future__ import annotations
from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
import importlib
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import getenv, DEBUG, CI, OSX
from typing import List, Dict
from onnx import AttributeProto, ModelProto, TensorProto, TypeProto
try:
from onnx.helper import tensor_dtype_to_np_dtype
except ImportError:
# for onnx < 1.13
from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
tensor_dtype_to_np_dtype = lambda x: TENSOR_TYPE_TO_NP_TYPE[x]
# global numpy cache for parameters
numpy_cache = {}
def safe_numpy(t) -> np.ndarray:
if not isinstance(t, Tensor): return t
global numpy_cache
if t not in numpy_cache:
if DEBUG >= 3: print("numpy cache miss", t)
tmp = t.numpy()
numpy_cache[t] = tmp
return numpy_cache[t]
# copied from helpers.py
def is_dtype_supported(dtype, device: str = Device.DEFAULT):
if dtype == dtypes.bfloat16: return False
if device in ["WEBGPU", "WEBGL"]: return dtype in [dtypes.float, dtypes.int32, dtypes.uint32]
if dtype == dtypes.half: return not (CI and device in {"GPU", "LLVM", "CUDA"})
if dtype == dtypes.float64: return device != "METAL" and not (OSX and device == "GPU")
return True
# src: onnx/mapping.py
# not supported: STRING = 8 COMPLEX64 = 14, COMPLEX128 = 15
# NOTE: 17, 18, 19, 20 are float8, 10 is half
DTYPE_MAP = {1:dtypes.float, 2:dtypes.uint8, 3:dtypes.int8, 4:dtypes.uint16, 5:dtypes.int16, 6:dtypes.int32, 7:dtypes.int64,
9:dtypes.bool, 10:dtypes.float, 11:dtypes.double, 12:dtypes.uint32, 13:dtypes.uint64, 16:dtypes.bfloat16,
17:dtypes.float, 18:dtypes.float, 19:dtypes.float, 20:dtypes.float}
# TODO: fix buffer_parse to use this and fix get_weight_and_biases to only use buffer_parse
onnx_ops = importlib.import_module('extra.onnx_ops')
ONNXLIMIT = getenv("ONNXLIMIT", -1)
def get_run_onnx(onnx_model: ModelProto):
def type_parse(type_proto: TypeProto):
ret = []
while True:
attr = type_proto.WhichOneof('value')
if attr == 'tensor_type':
if "dim_value" not in type_proto.tensor_type.shape.dim.__dir__(): return () # variable type, unable to determine shape
elif not ret:
return tuple([x.dim_value for x in type_proto.tensor_type.shape.dim])
else:
ret.extend([(x.dim_value,) for x in type_proto.tensor_type.shape.dim])
return tuple(ret)
elif attr == 'sequence_type':
type_proto = getattr(type_proto, attr).elem_type
ret.append(1)
elif attr == 'map_type': raise NotImplementedError(f"map_type is not implemented: {type_proto}")
elif attr == 'opaque_type': raise NotImplementedError(f"opaque_type is not implemented: {type_proto}")
elif attr == 'sparse_tensor_type': raise NotImplementedError(f"sparse_tensor_type is not implemented: {type_proto}")
elif attr == 'optional_type': type_proto = getattr(type_proto, attr).elem_type
else: raise Exception(f"unknown attr: {attr}, {type_proto}")
def buffer_parse(inp: TensorProto) -> Tensor:
if inp.data_type in (8,14,15): raise Exception(f"data type not supported {inp.name} {inp.dims} {inp.data_type}")
dtype = DTYPE_MAP[inp.data_type] if is_dtype_supported(DTYPE_MAP[inp.data_type]) else dtypes.float32
if dat := list(inp.float_data) or list(inp.int32_data) or list(inp.int64_data):
return Tensor(dat, dtype=dtype, requires_grad=False).reshape(tuple(inp.dims))
if len(inp.raw_data) > 0:
return Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).astype(dtype.np).copy(), requires_grad=False).reshape(tuple(inp.dims))
return Tensor(None, requires_grad=False)
def attribute_parse(a: AttributeProto) -> float | int | str | Tensor | tuple[float] | tuple[int]:
# TODO: this is not complete, see onnx/onnx_ml_pb2.pyi for a complete list
if a.type == AttributeProto.FLOAT: return float(a.f)
elif a.type == AttributeProto.INT: return int(a.i)
elif a.type == AttributeProto.STRING: return a.s.decode("utf-8")
elif a.type == AttributeProto.TENSOR: return buffer_parse(a.t) # TENSOR
elif a.type == AttributeProto.FLOATS: return tuple(float(x) for x in a.floats)
elif a.type == AttributeProto.INTS: return tuple(int(x) for x in a.ints)
elif a.type == AttributeProto.STRINGS: return tuple(x.decode("utf-8") for x in a.strings)
elif a.type == AttributeProto.GRAPH: raise Exception(f"graph not implemented: {a.g}\n likely an OP requiring control flow")
else: raise Exception(f"can't parse {a.type} {a}")
def attribute_to_dict(a: RepeatedCompositeFieldContainer[AttributeProto]): return {x.name:attribute_parse(x) for x in a}
tensors: Dict[str, Tensor] = {}
# get weights and biases
for inp in onnx_model.graph.initializer:
tensors[inp.name] = buffer_parse(inp)
# preparse the attributes
attribute_dict = {}
domain = ""
for num,n in enumerate(onnx_model.graph.node):
attribute_dict[num] = attribute_to_dict(n.attribute)
if n.domain: domain = n.domain
onnx_model_version = onnx_model.opset_import[0].version
def run_onnx(inputs={}, debug=0):
debug = getenv("DEBUGONNX") or debug
input_tensors: Dict[str,Tensor] = {}
intermediate_tensors: Dict[str,Tensor] = {}
output_tensor_names = [x.name for x in onnx_model.graph.output]
# get inputs
for inp in onnx_model.graph.input:
if inp.name in tensors: continue
shape = type_parse(inp.type)
if inp.name in inputs:
if isinstance(inputs[inp.name], Tensor):
input_tensors[inp.name] = inputs[inp.name]
elif isinstance(inputs[inp.name], list):
input_tensors[inp.name] = [Tensor(i, requires_grad=False) for i in inputs[inp.name]]
elif domain == "ai.onnx.preview.training": # not sure if in real use the domain is "ai.onnx.preview.training"
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=True) # TODO there isn't a good way to parse which inp requires_grad, some are manually turned off in optimizer ops
else:
input_tensors[inp.name] = Tensor(inputs[inp.name], requires_grad=False)
if shape: # if only input_tensor is not variable type
input_shape = input_tensors[inp.name].shape if isinstance(input_tensors[inp.name], Tensor) else (1, *[i.shape for i in input_tensors[inp.name]])
assert input_shape == shape, f"wrong shape for input {inp.name}, {input_shape} isn't {shape}"
else:
raise Exception(f"no data for {inp.name} with shape {shape}")
def fetch_tensor(x: str):
if x in tensors: return tensors[x]
if x in intermediate_tensors: return intermediate_tensors[x]
if x != "": return input_tensors[x]
return None
for num,n in enumerate(onnx_model.graph.node):
inp: List[Tensor] = []
if debug >= 3: print("inputs:")
for x in n.input:
t = fetch_tensor(x)
if debug >= 3: print(f"\t{x} - {t}")
inp.append(t)
opt: Dict = attribute_dict[num]
if debug >= 1: print(f"{num}: op {n.op_type} shape {[x.shape if isinstance(x, Tensor) else x for x in inp]} opt {opt}")
# NOTE some ops live here because they require access to some local variables
# have to use n.output for cases when num_outputs is absent
if n.op_type in onnx_ops.tensor_methods:
ret = getattr(Tensor, n.op_type.lower())(*inp, **opt)
elif n.op_type == "Split":
axis = opt.get("axis", 0)
split = None if len(inp) == 1 else [int(x) for x in safe_numpy(inp[1])]
if split is None:
split = [inp[0].shape[axis] // len(n.output)] * len(n.output)
for i in range(inp[0].shape[axis] % len(n.output)):
split[i] += 1
i, ret = 0, []
arg = [None] * inp[0].ndim
for s in split:
arg[axis] = (i,i+s)
ret.append(inp[0].shrink(arg=tuple(arg)))
i = i+s
ret = tuple(ret)
# need to check onnx_model_version
elif n.op_type == "Slice":
if onnx_model_version < 10:
axes, ends, starts, steps = list(opt.get("axes", range(inp[0].ndim))), list(opt["ends"]), list(opt["starts"]), [1]*inp[0].ndim
else:
starts, ends = inp[1:3]
axes = safe_numpy(Tensor.arange(inp[0].ndim) if len(inp) <= 3 else inp[3].cast(dtypes.int32)).tolist()
steps = safe_numpy(inp[4].cast(dtypes.int32)).tolist() if len(inp) > 4 else [1]*inp[0].ndim
starts, ends = safe_numpy(starts.ceil().cast(dtypes.int32)).tolist(), safe_numpy(ends.ceil().cast(dtypes.int32)).tolist()
arg = [(0,x,1) for x in inp[0].shape]
for i, axis in enumerate(axes):
axis = int(axis) + inp[0].ndim if axis < 0 else int(axis)
if starts[i] < 0: starts[i] += inp[0].shape[axis]
if ends[i] < 0: ends[i] += inp[0].shape[axis]
starts[i], ends[i] = max(0, min(starts[i], inp[0].shape[axis])), max(0, min(ends[i], inp[0].shape[axis]))
if starts[i] > ends[i] and steps[i] >= 0: steps[i] = -steps[i]
arg[axis] = (starts[i], ends[i], steps[i])
new_shape = tuple((s, e) if st > 0 else (e+1, s+1) for s, e, st in arg)
if any(s==e for s,e in new_shape): ret = inp[0].shrink(new_shape)
else: ret = inp[0][tuple([slice(s,e,st) for s,e,st in arg])]
# need to call backward on intermediate_tensors
elif n.op_type == "Gradient":
assert len(opt["xs"]) == len(inp), f"len(opt['xs']):{len(opt['xs'])}, len(inp):{len(inp)} output and input has to match"
y = opt["y"]
intermediate_tensors[y].backward()
ret = tuple([t.grad for t in inp])
# onnx_ops.py
elif hasattr(onnx_ops, n.op_type):
fxn = getattr(onnx_ops, n.op_type)
if isinstance(fxn, dict):
for k in sorted(fxn.keys()):
if k <= onnx_model_version:
real_fxn = fxn[k]
else:
real_fxn = fxn
ret = real_fxn(*inp, **opt)
else:
print("UNSUPPORTED", n.op_type, n.input, n.output)
raise Exception(f"op_type {n.op_type} not supported")
if not isinstance(ret, tuple): ret = (ret, )
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
if debug >= 2: print([x.shape if isinstance(x, Tensor) else None for x in ret])
if debug >= 2: print("outputs:")
for i in range(len(n.output)):
if debug >= 2: print(f"\t{n.output[i]} - {ret[i]}")
intermediate_tensors[n.output[i]] = ret[i]
if num == ONNXLIMIT:
output_tensor_names = n.output
break
return {outp:intermediate_tensors[outp] for outp in output_tensor_names}
return run_onnx