improve test_dropout_on_shard (#4912)

tested some basic property, also minor formatting for a few Tensor.training setups
This commit is contained in:
chenyu 2024-06-11 11:36:02 -04:00 committed by GitHub
parent 7f03420d05
commit b886d250fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 8 deletions

View File

@ -104,7 +104,8 @@ class ExternalTestOptim(unittest.TestCase):
def setUp(self):
self.old_training = Tensor.training
Tensor.training = True
def tearDown(self): Tensor.training = self.old_training
def tearDown(self):
Tensor.training = self.old_training
def _test_optim(self, tinygrad_optim, tensorflow_optim, steps, opts, atol, rtol, tiny_sched=None, tf_sched=None, schedopts=None, do_optim=True):
for x,y in zip(step(tinygrad_optim, steps=steps, kwargs=opts, scheduler=tiny_sched, schedopts=schedopts, do_optim=do_optim),

View File

@ -504,11 +504,12 @@ class TestMultiTensor(unittest.TestCase):
t_none.assign(t_zero)
def test_dropout_on_shard(self):
Tensor.training = True
X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5)
output.numpy()
Tensor.training = False
with Tensor.train():
X = Tensor.ones(256).to(devices_2)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
assert 100 < counts[0] < 156, counts[0]
def test_broadcast_const(self):
devices = (d0, d1, d2, d3)

View File

@ -46,7 +46,8 @@ class TestOptim(unittest.TestCase):
def setUp(self):
self.old_training = Tensor.training
Tensor.training = True
def tearDown(self): Tensor.training = self.old_training
def tearDown(self):
Tensor.training = self.old_training
def _test_optim(self, tinygrad_optim, torch_optim, steps, opts, atol, rtol):
for x,y in zip(step(Tensor, tinygrad_optim, steps, **opts),

View File

@ -56,7 +56,8 @@ class TestLrScheduler(unittest.TestCase):
def setUp(self):
self.old_training = Tensor.training
Tensor.training = True
def tearDown(self): Tensor.training = self.old_training
def tearDown(self):
Tensor.training = self.old_training
def _test_lr_scheduler(self, tinygrad_sched, torch_sched, epochs, opts, atol, rtol, adam=True):
accs = opts.pop('accs', None)