mirror of https://github.com/commaai/tinygrad.git
added onnx group norm (#614)
This commit is contained in:
parent
edc8fbfff2
commit
07e643431c
|
@ -44,6 +44,9 @@ def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_typ
|
|||
mean = x.mean(axis=axis, keepdim=True)
|
||||
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
|
||||
|
||||
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
|
||||
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
|
||||
|
||||
def _padding(pads=None, auto_pad="NOTSET"):
|
||||
assert auto_pad == "NOTSET" # TODO: write this
|
||||
return (pads[1], pads[3], pads[0], pads[2]) if pads is not None else (0,0,0,0)
|
||||
|
|
Loading…
Reference in New Issue