2023-10-23 01:56:57 +08:00
|
|
|
import unittest
|
2023-10-24 22:49:22 +08:00
|
|
|
import pickle
|
2024-05-06 02:19:01 +08:00
|
|
|
from tinygrad.helpers import diskcache_get, diskcache_put, diskcache, diskcache_clear
|
2023-10-23 01:56:57 +08:00
|
|
|
|
2023-10-31 04:49:49 +08:00
|
|
|
def remote_get(table,q,k): q.put(diskcache_get(table, k))
|
|
|
|
def remote_put(table,k,v): diskcache_put(table, k, v)
|
2023-10-23 01:56:57 +08:00
|
|
|
|
|
|
|
class DiskCache(unittest.TestCase):
|
|
|
|
def test_putget(self):
|
2023-10-31 04:49:49 +08:00
|
|
|
table = "test_putget"
|
|
|
|
diskcache_put(table, "hello", "world")
|
|
|
|
self.assertEqual(diskcache_get(table, "hello"), "world")
|
|
|
|
diskcache_put(table, "hello", "world2")
|
|
|
|
self.assertEqual(diskcache_get(table, "hello"), "world2")
|
2023-10-23 01:56:57 +08:00
|
|
|
|
|
|
|
def test_putcomplex(self):
|
2023-10-31 04:49:49 +08:00
|
|
|
table = "test_putcomplex"
|
|
|
|
diskcache_put(table, "k", ("complex", 123, "object"))
|
|
|
|
ret = diskcache_get(table, "k")
|
2023-10-23 01:56:57 +08:00
|
|
|
self.assertEqual(ret, ("complex", 123, "object"))
|
|
|
|
|
|
|
|
def test_getotherprocess(self):
|
2023-10-31 04:49:49 +08:00
|
|
|
table = "test_getotherprocess"
|
2023-10-23 01:56:57 +08:00
|
|
|
from multiprocessing import Process, Queue
|
2023-10-31 04:49:49 +08:00
|
|
|
diskcache_put(table, "k", "getme")
|
2023-10-23 01:56:57 +08:00
|
|
|
q = Queue()
|
2023-10-31 04:49:49 +08:00
|
|
|
p = Process(target=remote_get, args=(table,q,"k"))
|
2023-10-23 01:56:57 +08:00
|
|
|
p.start()
|
|
|
|
p.join()
|
|
|
|
self.assertEqual(q.get(), "getme")
|
|
|
|
|
|
|
|
def test_putotherprocess(self):
|
2023-10-31 04:49:49 +08:00
|
|
|
table = "test_putotherprocess"
|
2023-10-23 01:56:57 +08:00
|
|
|
from multiprocessing import Process
|
2023-10-31 04:49:49 +08:00
|
|
|
p = Process(target=remote_put, args=(table,"k", "remote"))
|
2023-10-23 01:56:57 +08:00
|
|
|
p.start()
|
|
|
|
p.join()
|
2023-10-31 04:49:49 +08:00
|
|
|
self.assertEqual(diskcache_get(table, "k"), "remote")
|
2023-10-23 01:56:57 +08:00
|
|
|
|
|
|
|
def test_no_table(self):
|
|
|
|
self.assertIsNone(diskcache_get("faketable", "k"))
|
|
|
|
|
|
|
|
def test_ret(self):
|
2023-10-31 04:49:49 +08:00
|
|
|
table = "test_ret"
|
|
|
|
self.assertEqual(diskcache_put(table, "key", ("vvs",)), ("vvs",))
|
2023-10-23 01:56:57 +08:00
|
|
|
|
|
|
|
def test_non_str_key(self):
|
2023-10-31 04:49:49 +08:00
|
|
|
table = "test_non_str_key"
|
|
|
|
diskcache_put(table, 4, 5)
|
|
|
|
self.assertEqual(diskcache_get(table, 4), 5)
|
|
|
|
self.assertEqual(diskcache_get(table, "4"), 5)
|
2023-10-23 01:56:57 +08:00
|
|
|
|
2024-01-16 06:15:18 +08:00
|
|
|
def test_decorator(self):
|
|
|
|
calls = 0
|
|
|
|
@diskcache
|
|
|
|
def hello(x):
|
|
|
|
nonlocal calls
|
|
|
|
calls += 1
|
|
|
|
return "world"+x
|
|
|
|
self.assertEqual(hello("bob"), "worldbob")
|
|
|
|
self.assertEqual(hello("billy"), "worldbilly")
|
|
|
|
kcalls = calls
|
|
|
|
self.assertEqual(hello("bob"), "worldbob")
|
|
|
|
self.assertEqual(hello("billy"), "worldbilly")
|
|
|
|
self.assertEqual(kcalls, calls)
|
|
|
|
|
2023-10-24 22:49:22 +08:00
|
|
|
def test_dict_key(self):
|
2023-10-31 04:49:49 +08:00
|
|
|
table = "test_dict_key"
|
2023-10-24 22:49:22 +08:00
|
|
|
fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")}
|
|
|
|
fancy_key2 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("cat")}
|
|
|
|
fancy_key3 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("dog")}
|
2023-10-31 04:49:49 +08:00
|
|
|
diskcache_put(table, fancy_key, 5)
|
|
|
|
self.assertEqual(diskcache_get(table, fancy_key), 5)
|
|
|
|
diskcache_put(table, fancy_key2, 8)
|
|
|
|
self.assertEqual(diskcache_get(table, fancy_key2), 8)
|
|
|
|
self.assertEqual(diskcache_get(table, fancy_key), 5)
|
|
|
|
self.assertEqual(diskcache_get(table, fancy_key3), None)
|
2023-10-24 22:49:22 +08:00
|
|
|
|
2024-03-01 05:04:21 +08:00
|
|
|
def test_table_name(self):
|
|
|
|
table = "test_gfx1010:xnack-"
|
|
|
|
diskcache_put(table, "key", "test")
|
|
|
|
self.assertEqual(diskcache_get(table, "key"), "test")
|
|
|
|
|
2024-05-06 02:19:01 +08:00
|
|
|
@unittest.skip("disabled by default because this drops cache table")
|
|
|
|
def test_clear_cache(self):
|
|
|
|
# clear cache to start
|
|
|
|
diskcache_clear()
|
|
|
|
tables = [f"test_clear_cache:{i}" for i in range(3)]
|
|
|
|
for table in tables:
|
|
|
|
# check no entries
|
|
|
|
self.assertIsNone(diskcache_get(table, "k"))
|
|
|
|
for table in tables:
|
|
|
|
diskcache_put(table, "k", "test")
|
|
|
|
# check insertion
|
|
|
|
self.assertEqual(diskcache_get(table, "k"), "test")
|
|
|
|
|
|
|
|
diskcache_clear()
|
|
|
|
for table in tables:
|
|
|
|
# check no entries again
|
|
|
|
self.assertIsNone(diskcache_get(table, "k"))
|
|
|
|
|
|
|
|
# calling multiple times is fine
|
|
|
|
diskcache_clear()
|
|
|
|
diskcache_clear()
|
|
|
|
diskcache_clear()
|
|
|
|
|
2023-10-23 01:56:57 +08:00
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|