minor cleanups of onnx_ops (#3116)

This commit is contained in:
chenyu 2024-01-14 02:15:24 -05:00 committed by GitHub
parent fb3f8f7597
commit 152ef7fc79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 7 deletions

View File

@ -332,7 +332,7 @@ def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode)
if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
rng = np.random.RandomState(seed)
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
ratio = ratio.item() if isinstance(ratio, Tensor) else ratio
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
return data * mask * (1/(1.0 - ratio)), mask
@ -356,7 +356,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind
weight = (mask * weight).sum(axis=-1)
if ignore_index is not None:
cond = target == ignore_index
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
weight = cond.where(0, weight) if weight is not None else cond.where(0, 1)
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2))
loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight)
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
@ -368,14 +368,17 @@ def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
y = scores.log_softmax(axis=1)
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
loss = (mask * -y).sum(1)
if weights is not None:
weights = weights[labels, ...]
loss = loss * weights
if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum()
elif reduction == "sum": loss = loss.sum()
return loss, y
def ArrayFeatureExtractor(x: Tensor, indices: Tensor):
return x.__getitem__(tuple([slice(None) if i != (x.ndim-1) else indices for i in range(x.ndim)]))
return x[tuple([slice(None) if i != (x.ndim-1) else indices for i in range(x.ndim)])]
def Gather(x: Tensor, indices: Tensor, axis=0):
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
x_sh = list(x.shape)
@ -385,7 +388,7 @@ def Gather(x: Tensor, indices: Tensor, axis=0):
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices]
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
return x.__getitem__(tuple([slice(None) if i != axis else indices for i in range(x.ndim)]))
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
def GatherElements(x: Tensor, indices: Tensor, axis):
indices = (indices < 0).where(x.shape[axis], 0) + indices
@ -541,7 +544,7 @@ def Compress(inp: Tensor, condition: Tensor, axis=None):
con_np = safe_numpy(condition)
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)]))
return inp[tuple([slice(None) if i != axis else con for i in range(inp.ndim)])]
def EyeLike(x: Tensor, dtype=None, k=0):
if dtype is None: dtype = x.dtype