wow how did i think that was okay (#2339)

This commit is contained in:
George Hotz 2023-11-16 21:21:11 -08:00 committed by GitHub
parent 8e22c0d95c
commit 652d2de256
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 12 deletions

View File

@ -1,7 +1,8 @@
import sys, sqlite3, pickle
from tinygrad.helpers import CACHEDB
if __name__ == "__main__":
fn = sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache"
fn = sys.argv[1] if len(sys.argv) > 1 else CACHEDB
conn = sqlite3.connect(fn)
cur = conn.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")

View File

@ -182,14 +182,6 @@ def db_connection():
os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 7: _db_connection.set_trace_callback(print)
if diskcache_get("meta", "version") != VERSION:
print("cache is out of date, clearing it")
_db_connection.close()
del _db_connection
os.unlink(CACHEDB)
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 7: _db_connection.set_trace_callback(print)
diskcache_put("meta", "version", VERSION)
return _db_connection
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
@ -198,7 +190,7 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
conn = db_connection()
cur = conn.cursor()
try:
res = cur.execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
except sqlite3.OperationalError:
return None # table doesn't exist
if (val:=res.fetchone()) is not None: return pickle.loads(val[0])
@ -213,9 +205,9 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
if table not in _db_tables:
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
_db_tables.add(table)
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
conn.commit()
cur.close()
return val