move lazy to engine [pr] (#6886)

* move lazy to engine [pr]

* engine.lazy
This commit is contained in:
George Hotz 2024-10-04 23:19:26 +08:00 committed by GitHub
parent 6b063450df
commit 4df5c7a4ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 20 additions and 20 deletions

View File

@ -76,7 +76,7 @@ assert out.as_buffer().cast('I')[0] == 5
print("******** third, the LazyBuffer ***********")
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule

View File

@ -11,7 +11,7 @@ There is a good [bunch of tutorials](https://mesozoic-egg.github.io/tinygrad-not
Everything in [Tensor](../tensor/index.md) is syntactic sugar around [function.py](function.md), where the forwards and backwards passes are implemented for the different functions. There's about 25 of them, implemented using about 20 basic ops. Those basic ops go on to construct a graph of:
::: tinygrad.lazy.LazyBuffer
::: tinygrad.engine.lazy.LazyBuffer
options:
show_source: false

View File

@ -1,7 +1,7 @@
# TODO: move the GRAPH and DEBUG stuff to here
import gc
from tinygrad.helpers import prod
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.device import Buffer
from tinygrad import Tensor, GlobalCounters

View File

@ -6,7 +6,7 @@ from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel
from tinygrad.helpers import Context, getenv, to_function_name
from tinygrad.engine.schedule import _get_output_groups, _lower_lazybuffer
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.ops import UOp, UOps
if __name__ == "__main__":

View File

@ -1,6 +1,6 @@
import time
from tinygrad import Tensor, Device, GlobalCounters, TinyJit
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.ops import ReduceOps
from tinygrad.multi import MultiLazyBuffer, all_reduce
from tinygrad.engine.schedule import create_schedule

View File

@ -4,7 +4,7 @@ from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
from tinygrad.device import Buffer
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
from tinygrad.ops import MetaOps
from tinygrad.tensor import Tensor, _to_np_dtype

View File

@ -5,7 +5,7 @@ from typing import DefaultDict, List, Set, Tuple
from test.external.process_replay.helpers import print_diff
from tinygrad.engine.schedule import LBScheduleItem, ScheduleItem
from tinygrad.helpers import CI, DEBUG, Context, ContextVar, colored, diskcache_put, fetch, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item
CAPTURING_PROCESS_REPLAY = ContextVar("CAPTURING_PROCESS_REPLAY", getenv("RUN_PROCESS_REPLAY"))

View File

@ -4,7 +4,7 @@ from test.external.process_replay.diff_schedule import CAPTURING_PROCESS_REPLAY,
from tinygrad import Tensor, nn
from tinygrad.helpers import Context
from tinygrad.engine.schedule import _graph_schedule
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
class TestDiffSchedule(unittest.TestCase):
def setUp(self):

View File

@ -9,7 +9,7 @@ from tinygrad.dtype import dtypes
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
from tinygrad.lazy import Buffer, create_lazybuffer
from tinygrad.engine.lazy import Buffer, create_lazybuffer
from tinygrad.device import Device
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.engine.realize import CompiledRunner
@ -32,7 +32,7 @@ def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(
# In general, it is also optional to write a backward function, just your backward pass won't work without it
from tinygrad.ops import MetaOps
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.tensor import Function
class ATan2(Function):

View File

@ -3,7 +3,7 @@ import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.ops import UOps
from tinygrad.lazy import LazyBuffer, MetaOps
from tinygrad.engine.lazy import LazyBuffer, MetaOps
from tinygrad.engine.schedule import create_schedule
class TestLazyBuffer(unittest.TestCase):

View File

@ -10,7 +10,7 @@ from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.ops import BinaryOps, UOp, UOps
from tinygrad.renderer import Program
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
for x in inputs: x.realize()

View File

@ -20,7 +20,7 @@ from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, reduceop_fusor, st_fixup
from tinygrad.engine.realize import CompiledRunner, run_schedule
from test.helpers import ast_const, is_dtype_supported, Context, timeit
from tinygrad.lazy import LazyBuffer, view_supported_devices
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
from extra.models.llama import precompute_freqs_cis
class KernelCountException(Exception): pass

View File

@ -1,7 +1,7 @@
import unittest
from tinygrad import Device, dtypes, Tensor
from tinygrad.device import Buffer
from tinygrad.lazy import view_supported_devices
from tinygrad.engine.lazy import view_supported_devices
@unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported")
class TestSubBuffer(unittest.TestCase):

View File

@ -4,7 +4,7 @@ from typing import List, Any, DefaultDict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, UOps, UOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
with contextlib.suppress(ImportError): import networkx as nx

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from typing import TypeVar, Generic, Callable, List, Tuple, Union, Dict, cast, Optional, Any
import functools, itertools, collections
from tinygrad.tensor import Tensor
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
from tinygrad.device import Buffer, Compiled, Device
from tinygrad.dtype import DType

View File

@ -7,7 +7,7 @@ from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV
GlobalCounters, all_same, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata, unwrap
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from tinygrad.shape.view import View, strides_for_shape

View File

@ -5,7 +5,7 @@ from tinygrad.helpers import argsort
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import ReduceOps, resolve
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.shape.symbolic import sint
class Contiguous(Function):

View File

@ -4,7 +4,7 @@ import functools, itertools, operator
from tinygrad.helpers import all_same, all_int, dedup, prod, DEBUG, RING, getenv
from tinygrad.dtype import DType
from tinygrad.ops import REDUCE_ALU, BinaryOps, MetaOps, UnaryOps, TernaryOps, ReduceOps, MathTrait
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.shape.shapetracker import sint
def all_reduce(op: ReduceOps, lbs: List[LazyBuffer]) -> List[LazyBuffer]:

View File

@ -8,11 +8,11 @@ from collections import defaultdict
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.engine.lazy import LazyBuffer
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars