Files
dragonpilot/tinygrad_repo/test/test_sample.py
Vehicle Researcher dd778596b7 openpilot v0.9.8 release
date: 2025-03-15T21:10:51
master commit: fb7b9c0f9420d228f03362970ebcfb7237095cf3
2025-03-18 10:05:17 -07:00

19 lines
616 B
Python

import unittest
import numpy as np
from tinygrad import Tensor, Variable
class TestSample(unittest.TestCase):
def test_sample(self):
X = Tensor.rand(10000, 50).realize()
BS = 16
idxs = np.random.randint(0, X.shape[0], size=(BS))
# this uncovered a bug with arg sort order
batch = [Variable(f'idx{i}', 0, X.shape[0]-1).bind(s) for i,s in enumerate(idxs.tolist())]
x = Tensor.cat(*[X.shrink(((batch[i], batch[i]+1), None)) for i in range(BS)])
print(idxs)
ret = x.numpy()
base = X.numpy()[idxs]
np.testing.assert_equal(ret, base)
if __name__ == '__main__':
unittest.main()