move memory_planner to memory.py [pr] (#7079)

This commit is contained in:
George Hotz 2024-10-16 10:04:35 +08:00 committed by GitHub
parent bddba5897a
commit 26df50cf43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 61 additions and 53 deletions

View File

@ -6,7 +6,8 @@ Device.DEFAULT = "CLANG"
from train_gpt2 import GPT, GPTConfig
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_kernel, memory_planner, run_schedule
from tinygrad.engine.realize import get_kernel, run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.ops import MetaOps, UOps
TIMING = getenv("TIMING")

View File

@ -16,7 +16,8 @@ from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.device import Buffer
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG, tqdm
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner, memory_planner
from tinygrad.engine.realize import run_schedule, lower_schedule, ExecItem, CompiledRunner
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule
from tinygrad.ops import UOps
from tinygrad.tensor import _to_np_dtype

View File

@ -8,7 +8,8 @@ from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
from dataclasses import dataclass
from weakref import WeakKeyDictionary

51
tinygrad/engine/memory.py Normal file
View File

@ -0,0 +1,51 @@
from typing import List, Union, Tuple, Dict
from collections import defaultdict
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG
from tinygrad.ops import UOps
# **************** memory planning ****************
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance = {}, {}
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
def find_replace_buffer(buf, st, en):
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
else: assigned[buf] = assigned.get(buf, buf)
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
f"{len(ak)} -> {len(av)} bufs")
return assigned
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]

View File

@ -1,9 +1,7 @@
from typing import List, Dict, Optional, cast, Generator, Tuple, Union
from typing import List, Dict, Optional, cast, Generator, Tuple
import time, pprint
from collections import defaultdict
from dataclasses import dataclass, replace
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
from tinygrad.helpers import NO_MEMORY_PLANNER
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA
from tinygrad.ops import UOps, UOp, Variable, sym_infer, sint
from tinygrad.dtype import dtypes
from tinygrad.device import Device, Buffer
@ -216,48 +214,3 @@ def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, i
for ei in lower_schedule(schedule):
if len(capturing) and CAPTURING: capturing[0].add(ei)
ei.run(var_vals, do_update_stats=do_update_stats)
# **************** memory planning ****************
def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]]], noopt_buffers=None, debug_prefix="") -> Dict[Buffer, Buffer]:
if NO_MEMORY_PLANNER: return {}
first_appearance, last_appearance = {}, {}
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf.base not in first_appearance: first_appearance[buf.base] = i
last_appearance[buf.base] = i
# Sort buffers by size in descending order, prioritizing largest buffers for allocation first.
# Track free segments, each containing (start, stop, and buffer that could be reused on this segment).
free_segs: Dict[Tuple, List[Tuple[int, int, Buffer]]] = defaultdict(list) # Dict[buffer key, Tuple[start, end, buffer to reuse on the seg]]
def find_replace_buffer(buf, st, en):
key = (buf.device, buf.dtype, buf.options) + ((buf.nbytes,) if not hasattr(Device[buf.device].allocator, "offset") else tuple())
default_buf = (0, len(buffers) - 1, buf) # will return the buffer itself if the replace one is not found.
seg_st, seg_en, seg_buf = next((free_segs[key].pop(i) for i,(sst,sen,_) in enumerate(free_segs[key]) if sst <= st and en <= sen), default_buf)
free_segs[key] += [(seg_st, st - 1, seg_buf)] if st - 1 >= seg_st else []
free_segs[key] += [(en + 1, seg_en, seg_buf)] if seg_en >= en + 1 else []
return seg_buf if seg_buf.nbytes == buf.nbytes else Buffer(buf.device, buf.size, buf.dtype, base=seg_buf)
buffer_requests = sorted([(first_appearance[buf], last_appearance[buf], buf) for buf in first_appearance.keys()], key=lambda x: -x[2].nbytes)
assigned = {buf:find_replace_buffer(buf, st, en) for st, en, buf in buffer_requests}
for i,u in enumerate(buffers):
for buf in u:
if buf.is_allocated() or buf.lb_refcount > 0 or (noopt_buffers is not None and buf.base in noopt_buffers): continue
if buf._base is not None: assigned[buf] = Buffer(buf.device, buf.size, buf.dtype, base=assigned.get(buf.base, buf.base).base, offset=buf.offset)
else: assigned[buf] = assigned.get(buf, buf)
if DEBUG >= 1 and len(ak:=dedup(x for x in assigned.keys() if x._base is None)) != len(av:=dedup(x for x in assigned.values() if x._base is None)):
print(debug_prefix+f"memory reduced from {sum([x.nbytes for x in ak])/1e6:.2f} MB -> {sum([x.nbytes for x in av])/1e6:.2f} MB,",
f"{len(ak)} -> {len(av)} bufs")
return assigned
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not UOps.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]

View File

@ -12,7 +12,8 @@ from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps, sint, Variable
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
# **** start with two base classes, Tensor and Function ****