2023-11-08 09:30:53 +08:00
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from tinygrad.tensor import Tensor
|
|
|
|
from tinygrad.shape.symbolic import Variable
|
|
|
|
|
|
|
|
class TestSample(unittest.TestCase):
|
|
|
|
def test_sample(self):
|
|
|
|
X = Tensor.rand(10000, 50).realize()
|
|
|
|
BS = 16
|
|
|
|
idxs = np.random.randint(0, X.shape[0], size=(BS))
|
|
|
|
# this uncovered a bug with arg sort order
|
|
|
|
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
|
2023-12-04 06:20:27 +08:00
|
|
|
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
|
2023-11-08 09:30:53 +08:00
|
|
|
print(idxs)
|
|
|
|
ret = x.numpy()
|
|
|
|
base = X.numpy()[idxs]
|
|
|
|
np.testing.assert_equal(ret, base)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|