209 lines
10 KiB
Python
209 lines
10 KiB
Python
from __future__ import annotations
|
|
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3
|
|
import numpy as np
|
|
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
|
|
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
|
from typing_extensions import TypeGuard
|
|
|
|
T = TypeVar("T")
|
|
# NOTE: it returns int 1 if x is empty regardless of the type of x
|
|
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
|
|
|
|
# NOTE: helpers is not allowed to import from anything else in tinygrad
|
|
OSX = platform.system() == "Darwin"
|
|
CI = os.getenv("CI", "") != ""
|
|
|
|
def dedup(x): return list(dict.fromkeys(x)) # retains list order
|
|
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
|
|
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
|
|
def all_same(items): return all(x == items[0] for x in items)
|
|
def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
|
|
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
|
|
def ansistrip(s): return re.sub('\x1b\\[(K|.*?m)', '', s)
|
|
def ansilen(s): return len(ansistrip(s))
|
|
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
|
|
def flatten(l:Union[List, Iterator]): return [item for sublist in l for item in sublist]
|
|
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
|
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
|
def merge_dicts(ds:Iterable[Dict]) -> Dict:
|
|
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
|
return {k:v for d in ds for k,v in d.items()}
|
|
def partition(lst, fxn):
|
|
a: list[Any] = []
|
|
b: list[Any] = []
|
|
for s in lst: (a if fxn(s) else b).append(s)
|
|
return a,b
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def getenv(key, default=0): return type(default)(os.getenv(key, default))
|
|
|
|
class Context(contextlib.ContextDecorator):
|
|
stack: ClassVar[List[dict[str, int]]] = [{}]
|
|
def __init__(self, **kwargs): self.kwargs = kwargs
|
|
def __enter__(self):
|
|
Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
|
|
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
|
|
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
|
|
def __exit__(self, *args):
|
|
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
|
|
|
|
class ContextVar:
|
|
_cache: ClassVar[Dict[str, ContextVar]] = {}
|
|
value: int
|
|
def __new__(cls, key, default_value):
|
|
if key in ContextVar._cache: return ContextVar._cache[key]
|
|
instance = ContextVar._cache[key] = super().__new__(cls)
|
|
instance.value = getenv(key, default_value)
|
|
return instance
|
|
def __bool__(self): return bool(self.value)
|
|
def __ge__(self, x): return self.value >= x
|
|
def __gt__(self, x): return self.value > x
|
|
def __lt__(self, x): return self.value < x
|
|
|
|
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
|
|
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
|
|
|
|
class Timing(contextlib.ContextDecorator):
|
|
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
|
|
def __enter__(self): self.st = time.perf_counter_ns()
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.et = time.perf_counter_ns() - self.st
|
|
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
|
|
|
|
# **** tinygrad now supports dtypes! *****
|
|
|
|
class DType(NamedTuple):
|
|
priority: int # this determines when things get upcasted
|
|
itemsize: int
|
|
name: str
|
|
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
|
|
sz: int = 1
|
|
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}"
|
|
|
|
# dependent typing?
|
|
class ImageDType(DType):
|
|
def __new__(cls, priority, itemsize, name, np, shape):
|
|
return super().__new__(cls, priority, itemsize, name, np)
|
|
def __init__(self, priority, itemsize, name, np, shape):
|
|
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
|
|
super().__init__()
|
|
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
|
# TODO: fix this to not need these
|
|
def __hash__(self): return hash((super().__hash__(), self.shape))
|
|
def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
|
|
def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
|
|
|
|
class PtrDType(DType):
|
|
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
|
|
def __repr__(self): return f"ptr.{super().__repr__()}"
|
|
|
|
class dtypes:
|
|
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
|
|
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
|
@staticmethod
|
|
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
|
|
@staticmethod
|
|
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
|
|
@staticmethod
|
|
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
|
|
@staticmethod
|
|
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
|
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
|
|
float16: Final[DType] = DType(0, 2, "half", np.float16)
|
|
half = float16
|
|
float32: Final[DType] = DType(4, 4, "float", np.float32)
|
|
float = float32
|
|
float64: Final[DType] = DType(0, 8, "double", np.float64)
|
|
double = float64
|
|
int8: Final[DType] = DType(0, 1, "char", np.int8)
|
|
int16: Final[DType] = DType(1, 2, "short", np.int16)
|
|
int32: Final[DType] = DType(2, 4, "int", np.int32)
|
|
int64: Final[DType] = DType(3, 8, "long", np.int64)
|
|
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
|
|
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
|
|
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
|
|
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
|
|
|
|
# NOTE: bfloat16 isn't supported in numpy
|
|
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
|
|
|
|
# NOTE: these are internal dtypes, should probably check for that
|
|
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
|
|
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
|
|
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
|
|
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
|
|
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
|
|
|
|
# NOTE: these are image dtypes
|
|
@staticmethod
|
|
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp)
|
|
@staticmethod
|
|
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp)
|
|
|
|
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
|
|
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
|
|
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
|
|
|
|
class GlobalCounters:
|
|
global_ops: ClassVar[int] = 0
|
|
global_mem: ClassVar[int] = 0
|
|
time_sum_s: ClassVar[float] = 0.0
|
|
kernel_count: ClassVar[int] = 0
|
|
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
|
|
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
|
|
@staticmethod
|
|
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
|
|
|
|
# *** universal database cache ***
|
|
|
|
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
|
|
CACHELEVEL = getenv("CACHELEVEL", 2)
|
|
|
|
VERSION = 6
|
|
_db_connection = None
|
|
def db_connection():
|
|
global _db_connection
|
|
if _db_connection is None:
|
|
_db_connection = sqlite3.connect(CACHEDB)
|
|
if DEBUG >= 5: _db_connection.set_trace_callback(print)
|
|
if diskcache_get("meta", "version") != VERSION:
|
|
print("cache is out of date, clearing it")
|
|
os.unlink(CACHEDB)
|
|
_db_connection = sqlite3.connect(CACHEDB)
|
|
if DEBUG >= 5: _db_connection.set_trace_callback(print)
|
|
diskcache_put("meta", "version", VERSION)
|
|
return _db_connection
|
|
|
|
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
|
if isinstance(key, (str,int)): key = {"key": key}
|
|
try:
|
|
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
|
except sqlite3.OperationalError:
|
|
return None # table doesn't exist
|
|
if (val:=res.fetchone()) is not None:
|
|
return pickle.loads(val[0])
|
|
return None
|
|
|
|
_db_tables = set()
|
|
def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
|
if isinstance(key, (str,int)): key = {"key": key}
|
|
conn = db_connection()
|
|
cur = conn.cursor()
|
|
if table not in _db_tables:
|
|
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
|
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
|
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
|
_db_tables.add(table)
|
|
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
|
|
conn.commit()
|
|
cur.close()
|
|
return val
|
|
|
|
def diskcache(func):
|
|
def wrapper(*args, **kwargs) -> bytes:
|
|
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
|
|
if (ret:=diskcache_get(table, key)): return ret
|
|
return diskcache_put(table, key, func(*args, **kwargs))
|
|
setattr(wrapper, "__wrapped__", func)
|
|
return wrapper
|