diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 611ef547..005551bb 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -1,14 +1,13 @@ import unittest, math from tinygrad import Tensor, Device, dtypes from tinygrad.engine.schedule import create_schedule -from tinygrad.features.multi import MultiLazyBuffer from tinygrad.helpers import CI from tinygrad.ops import BufferOps import numpy as np def _check_ast_count(desired_count:int, t:Tensor): # NOTE: this has side effect because everything can be scheduled only once - schedule = create_schedule(t.lazydata.lbs if isinstance(t.lazydata, MultiLazyBuffer) else [t.lazydata]) + schedule = create_schedule(t.lazydata.lbs) asts = [s for s in schedule if s.ast[0].op is BufferOps.STORE] assert len(asts) == desired_count diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 6fff66a8..35be3fe7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -138,7 +138,7 @@ class Tensor: @staticmethod def corealize(lst:Iterable[Tensor]): - run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) + run_schedule(*create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))) def realize(self) -> Tensor: Tensor.corealize([self])