mirror of https://github.com/commaai/tinygrad.git
batchnorm d(var)/d(mean) = 0 (#4430)
* d(var)/d(mean) = 0 * drop the number in test_schedule!
This commit is contained in:
parent
e2eab9c2b3
commit
c0a048c044
|
@ -49,7 +49,7 @@ class UnsyncedBatchNorm:
|
|||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||
batch_mean = x.mean(axis=(1,3,4))
|
||||
y = (x - batch_mean.reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1]))
|
||||
y = (x - batch_mean.detach().reshape(shape=[batch_mean.shape[0], 1, -1, 1, 1])) # d(var)/d(mean) = 0
|
||||
batch_var = (y*y).mean(axis=(1,3,4))
|
||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
|
||||
|
|
|
@ -208,7 +208,7 @@ class TestSchedule(unittest.TestCase):
|
|||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
# this is too high
|
||||
check_schedule(opt.schedule_step(), 18)
|
||||
check_schedule(opt.schedule_step(), 17)
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
|
|
|
@ -20,7 +20,7 @@ class BatchNorm2d:
|
|||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
# There's "online" algorithms that fix this, like https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_Online_algorithm
|
||||
batch_mean = x.mean(axis=(0,2,3))
|
||||
y = (x - batch_mean.reshape(shape=[1, -1, 1, 1]))
|
||||
y = (x - batch_mean.detach().reshape(shape=[1, -1, 1, 1])) # d(var)/d(mean) = 0
|
||||
batch_var = (y*y).mean(axis=(0,2,3))
|
||||
batch_invstd = batch_var.add(self.eps).pow(-0.5)
|
||||
|
||||
|
|
Loading…
Reference in New Issue