mirror of https://github.com/commaai/tinygrad.git
25 lines
890 B
Python
25 lines
890 B
Python
import unittest
|
|
from tinygrad.helpers import CI
|
|
from examples.mamba import Mamba, generate
|
|
from transformers import AutoTokenizer
|
|
|
|
PROMPT = 'Why is gravity '
|
|
TOKENIZER = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
|
|
|
|
@unittest.skipIf(CI, "model is slow for CI")
|
|
class TestMamba(unittest.TestCase):
|
|
def test_mamba_130M(self):
|
|
OUT_130M = '''Why is gravity \nnot a good idea?\n\nA:'''
|
|
model = Mamba.from_pretrained('130m')
|
|
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
|
self.assertEqual(OUT_130M, tinyoutput)
|
|
del model
|
|
def test_mamba_370M(self):
|
|
OUT_370M = '''Why is gravity \nso important?\nBecause it's the only'''
|
|
model = Mamba.from_pretrained('370m')
|
|
tinyoutput = generate(model, TOKENIZER, PROMPT, n_tokens_to_gen=10)
|
|
self.assertEqual(OUT_370M, tinyoutput)
|
|
del model
|
|
if __name__ == '__main__':
|
|
unittest.main()
|