forked from mawei/dp
1
0
Fork 0
dp/tinygrad_repo/tinygrad/helpers.py

209 lines
10 KiB
Python

from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
T = TypeVar("T")
# NOTE: it returns int 1 if x is empty regardless of the type of x
def prod(x:Iterable[T]) -> Union[T,int]: return functools.reduce(operator.__mul__, x, 1)
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX = platform.system() == "Darwin"
CI = os.getenv("CI", "") != ""
def dedup(x): return list(dict.fromkeys(x)) # retains list order
def argfix(*x): return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
def argsort(x): return type(x)(sorted(range(len(x)), key=x.__getitem__)) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
def all_same(items): return all(x == items[0] for x in items)
def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]: return all(isinstance(s, int) for s in t)
def colored(st, color, background=False): return f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m" if color is not None else st # replace the termcolor library with one line
def ansistrip(s): return re.sub('\x1b\\[(K|.*?m)', '', s)
def ansilen(s): return len(ansistrip(s))
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l:Union[List, Iterator]): return [item for sublist in l for item in sublist]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def strip_parens(fst): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
def merge_dicts(ds:Iterable[Dict]) -> Dict:
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for d in ds for k,v in d.items()}
def partition(lst, fxn):
a: list[Any] = []
b: list[Any] = []
for s in lst: (a if fxn(s) else b).append(s)
return a,b
@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
class Context(contextlib.ContextDecorator):
stack: ClassVar[List[dict[str, int]]] = [{}]
def __init__(self, **kwargs): self.kwargs = kwargs
def __enter__(self):
Context.stack[-1] = {k:o.value for k,o in ContextVar._cache.items()} # Store current state.
for k,v in self.kwargs.items(): ContextVar._cache[k].value = v # Update to new temporary state.
Context.stack.append(self.kwargs) # Store the temporary state so we know what to undo later.
def __exit__(self, *args):
for k in Context.stack.pop(): ContextVar._cache[k].value = Context.stack[-1].get(k, ContextVar._cache[k].value)
class ContextVar:
_cache: ClassVar[Dict[str, ContextVar]] = {}
value: int
def __new__(cls, key, default_value):
if key in ContextVar._cache: return ContextVar._cache[key]
instance = ContextVar._cache[key] = super().__new__(cls)
instance.value = getenv(key, default_value)
return instance
def __bool__(self): return bool(self.value)
def __ge__(self, x): return self.value >= x
def __gt__(self, x): return self.value > x
def __lt__(self, x): return self.value < x
DEBUG, IMAGE, BEAM, NOOPT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0)
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
class Timing(contextlib.ContextDecorator):
def __init__(self, prefix="", on_exit=None, enabled=True): self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
def __enter__(self): self.st = time.perf_counter_ns()
def __exit__(self, exc_type, exc_val, exc_tb):
self.et = time.perf_counter_ns() - self.st
if self.enabled: print(f"{self.prefix}{self.et*1e-6:.2f} ms"+(self.on_exit(self.et) if self.on_exit else ""))
# **** tinygrad now supports dtypes! *****
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[type] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self]}"
# dependent typing?
class ImageDType(DType):
def __new__(cls, priority, itemsize, name, np, shape):
return super().__new__(cls, priority, itemsize, name, np)
def __init__(self, priority, itemsize, name, np, shape):
self.shape: Tuple[int, ...] = shape # arbitrary arg for the dtype, used in image for the shape
super().__init__()
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# TODO: fix this to not need these
def __hash__(self): return hash((super().__hash__(), self.shape))
def __eq__(self, x): return super().__eq__(x) and self.shape == x.shape
def __ne__(self, x): return super().__ne__(x) or self.shape != x.shape
class PtrDType(DType):
def __new__(cls, dt:DType): return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
def __repr__(self): return f"ptr.{super().__repr__()}"
class dtypes:
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType)-> bool: return x in (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def is_float(x: DType) -> bool: return x in (dtypes.float16, dtypes.float32, dtypes.float64, dtypes._half4, dtypes._float2, dtypes._float4)
@staticmethod
def is_unsigned(x: DType) -> bool: return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
@staticmethod
def from_np(x) -> DType: return DTYPES_DICT[np.dtype(x).name]
@staticmethod
def fields() -> Dict[str, DType]: return DTYPES_DICT
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(0, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(4, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(0, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(0, 1, "char", np.int8)
int16: Final[DType] = DType(1, 2, "short", np.int16)
int32: Final[DType] = DType(2, 4, "int", np.int32)
int64: Final[DType] = DType(3, 8, "long", np.int64)
uint8: Final[DType] = DType(0, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(1, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(2, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(3, 8, "unsigned long", np.uint64)
# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(0, 2, "__bf16", None)
# NOTE: these are internal dtypes, should probably check for that
_int2: Final[DType] = DType(2, 4*2, "int2", None, 2)
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# NOTE: these are image dtypes
@staticmethod
def imageh(shp): return ImageDType(100, 2, "imageh", np.float16, shp)
@staticmethod
def imagef(shp): return ImageDType(100, 4, "imagef", np.float32, shp)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
INVERSE_DTYPES_DICT = {v:k for k,v in DTYPES_DICT.items()}
class GlobalCounters:
global_ops: ClassVar[int] = 0
global_mem: ClassVar[int] = 0
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0
# *** universal database cache ***
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 6
_db_connection = None
def db_connection():
global _db_connection
if _db_connection is None:
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 5: _db_connection.set_trace_callback(print)
if diskcache_get("meta", "version") != VERSION:
print("cache is out of date, clearing it")
os.unlink(CACHEDB)
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 5: _db_connection.set_trace_callback(print)
diskcache_put("meta", "version", VERSION)
return _db_connection
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
if isinstance(key, (str,int)): key = {"key": key}
try:
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
except sqlite3.OperationalError:
return None # table doesn't exist
if (val:=res.fetchone()) is not None:
return pickle.loads(val[0])
return None
_db_tables = set()
def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
if isinstance(key, (str,int)): key = {"key": key}
conn = db_connection()
cur = conn.cursor()
if table not in _db_tables:
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
_db_tables.add(table)
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
conn.commit()
cur.close()
return val
def diskcache(func):
def wrapper(*args, **kwargs) -> bytes:
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
return diskcache_put(table, key, func(*args, **kwargs))
setattr(wrapper, "__wrapped__", func)
return wrapper