added onnx group norm (#614)

This commit is contained in:
Diogo 2023-02-27 11:11:01 -05:00 committed by GitHub
parent edc8fbfff2
commit 07e643431c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -44,6 +44,9 @@ def LayerNormalization(x: Tensor, scale, bias, axis=-1, epsilon=1e-05, stash_typ
mean = x.mean(axis=axis, keepdim=True)
return x.layernorm(axis, epsilon).mul(scale).add(bias), mean, (x.sub(mean)).pow(2).mean(axis=axis, keepdim=True).add(epsilon).sqrt().reciprocal()
def GroupNormalization(x: Tensor, scale: Tensor, bias: Tensor, num_groups, epsilon=1e-05):
return x.reshape(x.shape[0], num_groups, -1).layernorm(axis=-1, eps=epsilon).mul(scale.unsqueeze(-1)).add(bias.unsqueeze(-1)).reshape(x.shape)
def _padding(pads=None, auto_pad="NOTSET"):
assert auto_pad == "NOTSET" # TODO: write this
return (pads[1], pads[3], pads[0], pads[2]) if pads is not None else (0,0,0,0)