batchnorm d(var)/d(mean) = 0 (#4430)

* d(var)/d(mean) = 0

* drop the number in test_schedule!
This commit is contained in:
David Hou 2024-05-04 21:25:45 -07:00 committed by GitHub
parent e2eab9c2b3
commit c0a048c044
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 3 additions and 3 deletions

View File

@ -49,7 +49,7 @@ class UnsyncedBatchNorm:
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(1,3,4))
y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1]))
y = (x - batch_mean.detach().reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) # d(var)/d(mean) = 0
batch_var = (y*y).mean(axis=(1,3,4))
batch_invstd = batch_var.add(self.eps).pow(-0.5)

View File

@ -208,7 +208,7 @@ class TestSchedule(unittest.TestCase):
opt.zero_grad()
img_bn.backward()
# this is too high
check_schedule(opt.schedule_step(), 18)
check_schedule(opt.schedule_step(), 17)
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)

View File

@ -20,7 +20,7 @@ class BatchNorm2d:
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
batch_mean = x.mean(axis=(0,2,3))
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
batch_var = (y*y).mean(axis=(0,2,3))
batch_invstd = batch_var.add(self.eps).pow(-0.5)