mirror of https://github.com/commaai/tinygrad.git
minor cleanups of onnx_ops (#3116)
This commit is contained in:
parent
fb3f8f7597
commit
152ef7fc79
|
@ -332,7 +332,7 @@ def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None):
|
|||
if isinstance(training_mode, Tensor) and not training_mode.shape: training_mode = safe_numpy(training_mode)
|
||||
if not training_mode: return data, Tensor.ones(*data.shape, dtype=dtypes.bool) # if mask is requested as output it will contain all True's.
|
||||
rng = np.random.RandomState(seed)
|
||||
ratio = ratio.lazydata.realize().toCPU()[0] if isinstance(ratio, Tensor) else ratio
|
||||
ratio = ratio.item() if isinstance(ratio, Tensor) else ratio
|
||||
mask = Tensor((rng.random(data.shape) >= ratio), requires_grad=False, device=data.device)
|
||||
return data * mask * (1/(1.0 - ratio)), mask
|
||||
|
||||
|
@ -356,7 +356,7 @@ def NegativeLogLikelihoodLoss(x: Tensor, target: Tensor, weight=None, ignore_ind
|
|||
weight = (mask * weight).sum(axis=-1)
|
||||
if ignore_index is not None:
|
||||
cond = target == ignore_index
|
||||
weight = cond.where(0, weight) if weight is not None else cond.where(Tensor.zeros(*target.shape), 1)
|
||||
weight = cond.where(0, weight) if weight is not None else cond.where(0, 1)
|
||||
mask = target[:, None, :] == Tensor.arange(C).reshape([1, C] + [1]*(len(x.shape) -2))
|
||||
loss = -(mask * x).sum(axis=1) * (1 if weight is None else weight)
|
||||
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
||||
|
@ -368,14 +368,17 @@ def SoftmaxCrossEntropyLoss(scores: Tensor, labels: Tensor, weights=None, ignore
|
|||
if ignore_index is not None: labels = (labels == ignore_index).where(C+1, labels)
|
||||
mask = labels.unsqueeze(1) == Tensor.arange(C).reshape(1, C, *[1]*len(s_dimensions))
|
||||
y = scores.log_softmax(axis=1)
|
||||
if weights is not None: weights = weights.__getitem__(tuple([labels, *[slice(None)]*(weights.ndim-1)]))
|
||||
loss = (mask * -y).sum(1) if weights is None else (mask * -y).sum(1) * weights
|
||||
loss = (mask * -y).sum(1)
|
||||
if weights is not None:
|
||||
weights = weights[labels, ...]
|
||||
loss = loss * weights
|
||||
if reduction == "mean": loss = loss.sum() / (loss == 0).where(0, 1).sum() if weights is None else loss.sum() / weights.sum()
|
||||
elif reduction == "sum": loss = loss.sum()
|
||||
return loss, y
|
||||
|
||||
def ArrayFeatureExtractor(x: Tensor, indices: Tensor):
|
||||
return x.__getitem__(tuple([slice(None) if i != (x.ndim-1) else indices for i in range(x.ndim)]))
|
||||
return x[tuple([slice(None) if i != (x.ndim-1) else indices for i in range(x.ndim)])]
|
||||
|
||||
def Gather(x: Tensor, indices: Tensor, axis=0):
|
||||
if indices.numel() < 9: # NOTE lessor kernels for smaller indices but kernel number increases depending on size of indices
|
||||
x_sh = list(x.shape)
|
||||
|
@ -385,7 +388,7 @@ def Gather(x: Tensor, indices: Tensor, axis=0):
|
|||
args = [[(0,x) if j != axis else (i,i+1) for j, x in enumerate(x_sh)] for i in indices]
|
||||
return x.shrink(arg=tuple(args[0])).cat(*[x.shrink(arg=tuple(arg)) for arg in args[1:]], dim=axis).reshape(ret_shape)
|
||||
# NOTE faster gather, fixed number of kernels, but exceeds limited kernels for openpilot
|
||||
return x.__getitem__(tuple([slice(None) if i != axis else indices for i in range(x.ndim)]))
|
||||
return x[tuple([slice(None) if i != axis else indices for i in range(x.ndim)])]
|
||||
|
||||
def GatherElements(x: Tensor, indices: Tensor, axis):
|
||||
indices = (indices < 0).where(x.shape[axis], 0) + indices
|
||||
|
@ -541,7 +544,7 @@ def Compress(inp: Tensor, condition: Tensor, axis=None):
|
|||
|
||||
con_np = safe_numpy(condition)
|
||||
con = Tensor(np.arange(condition.shape[0])[con_np]) # no boolean indexing in Tensor
|
||||
return inp.__getitem__(tuple([slice(None) if i != axis else con for i in range(inp.ndim)]))
|
||||
return inp[tuple([slice(None) if i != axis else con for i in range(inp.ndim)])]
|
||||
|
||||
def EyeLike(x: Tensor, dtype=None, k=0):
|
||||
if dtype is None: dtype = x.dtype
|
||||
|
|
Loading…
Reference in New Issue