add sharded arange test (#4908)

This commit is contained in:
George Hotz 2024-06-11 10:58:33 +02:00 committed by GitHub
parent 798ea61377
commit 35e53c0809
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 0 deletions

View File

@ -48,6 +48,11 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (128,)
(X + X).realize()
def test_sharded_arange(self):
sharded_arange = Tensor.arange(1000).shard(devices_2, 0)
sharded_arange.realize()
np.testing.assert_equal(sharded_arange.numpy(), np.arange(1000))
def test_shard_no_recompile(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_((d0, d1), 0)