mirror of https://github.com/commaai/tinygrad.git
wow how did i think that was okay (#2339)
This commit is contained in:
parent
8e22c0d95c
commit
652d2de256
|
@ -1,7 +1,8 @@
|
|||
import sys, sqlite3, pickle
|
||||
from tinygrad.helpers import CACHEDB
|
||||
|
||||
if __name__ == "__main__":
|
||||
fn = sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache"
|
||||
fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB
|
||||
conn = sqlite3.connect(fn)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
|
|
|
@ -182,14 +182,6 @@ def db_connection():
|
|||
os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
|
||||
_db_connection = sqlite3.connect(CACHEDB)
|
||||
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
||||
if diskcache_get("meta", "version") != VERSION:
|
||||
print("cache is out of date, clearing it")
|
||||
_db_connection.close()
|
||||
del _db_connection
|
||||
os.unlink(CACHEDB)
|
||||
_db_connection = sqlite3.connect(CACHEDB)
|
||||
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
||||
diskcache_put("meta", "version", VERSION)
|
||||
return _db_connection
|
||||
|
||||
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
||||
|
@ -198,7 +190,7 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
|||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
res = cur.execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
||||
res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
||||
except sqlite3.OperationalError:
|
||||
return None # table doesn't exist
|
||||
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
|
||||
|
@ -213,9 +205,9 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
|||
if table not in _db_tables:
|
||||
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
||||
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
||||
_db_tables.add(table)
|
||||
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
|
||||
cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return val
|
||||
|
|
Loading…
Reference in New Issue