mirror of https://github.com/commaai/tinygrad.git
move state to nn/state (#1619)
This commit is contained in:
parent
1e93fd5449
commit
718ced296c
|
@ -145,7 +145,7 @@ def sparse_categorical_crossentropy(out, Y, ignore_index=-1):
|
|||
loss_mask = Y != ignore_index
|
||||
num_classes = out.shape[-1]
|
||||
y_counter = Tensor.arange(num_classes, requires_grad=False).unsqueeze(0).expand(Y.numel(), num_classes)
|
||||
y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0)
|
||||
y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0)
|
||||
y = y * loss_mask.reshape(-1, 1)
|
||||
y = y.reshape(*Y.shape, num_classes)
|
||||
return out.log_softmax().mul(y).sum() / loss_mask.sum()
|
||||
|
@ -165,7 +165,7 @@ opt = SGD([net.l1.weight, net.l2.weight], lr=3e-4)
|
|||
|
||||
We can see that we are passing in the parameters of our neural network to the optimizer.
|
||||
This is due to the fact that the optimizer needs to know which parameters to update.
|
||||
There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.state` which will return a list of all the parameters in the neural network.
|
||||
There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.nn.state` which will return a list of all the parameters in the neural network.
|
||||
The parameters are just listed out explicitly here for clarity.
|
||||
|
||||
Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
|
||||
|
@ -291,7 +291,7 @@ The standard weight format for tinygrad is [safetensors](https://github.com/hugg
|
|||
There are functions in [state.py](/tinygrad/state.py) to save and load models to and from this format.
|
||||
|
||||
```python
|
||||
from tinygrad.state import safe_save, safe_load, get_state_dict, load_state_dict
|
||||
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
|
||||
|
||||
# first we need the state dict of our model
|
||||
state_dict = get_state_dict(net)
|
||||
|
|
|
@ -3,7 +3,7 @@ import gc
|
|||
import time
|
||||
from tqdm import trange
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import safe_save
|
||||
from tinygrad.nn.state import safe_save
|
||||
from extra.utils import fetch
|
||||
from extra.export_model import export_model
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional, Tuple
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -132,7 +132,7 @@ class GPT2:
|
|||
@staticmethod
|
||||
def build(model_size="gpt2"):
|
||||
import tiktoken
|
||||
from tinygrad.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from extra.utils import fetch_as_file
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ import random
|
|||
import numpy as np
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from tinygrad import nn
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
|
|
|
@ -251,7 +251,7 @@ class LLaMa:
|
|||
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
|
||||
assert sp_model.vocab_size() == VOCAB_SIZE
|
||||
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
params = MODEL_PARAMS[model_gen][model_size]
|
||||
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
|
||||
weights = concat_weights([torch_load(filename) for filename in [f"{model_path}/{model_size}/consolidated.{i:02d}.pth" for i in range(params["files"])]])
|
||||
|
|
|
@ -3,7 +3,7 @@ import numpy as np
|
|||
from tqdm import trange
|
||||
import torch
|
||||
from torchvision.utils import make_grid, save_image
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn import optim
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
|
||||
import sys
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import BatchNorm2d, optim
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.training = True
|
||||
|
|
|
@ -7,7 +7,7 @@ from typing import Tuple, Optional, Type
|
|||
from tinygrad import nn
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, getenv
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, download_if_not_present, get_hparams_from_file, load_checkpoint, weight_norm, HParams
|
||||
from examples.sovits_helpers import preprocess
|
||||
import soundfile
|
||||
|
|
|
@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.helpers import dtypes, GlobalCounters
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
|
|
|
@ -3,7 +3,7 @@ import time
|
|||
from multiprocessing import Process, Queue
|
||||
import numpy as np
|
||||
from tqdm import trange
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.training import train, evaluate
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import numpy as np
|
||||
import random
|
||||
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.training import train, evaluate
|
||||
from models.transformer import Transformer
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import List
|
|||
from extra.utils import download_file
|
||||
from tinygrad import nn
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from tinygrad.tensor import Tensor
|
||||
from unidecode import unidecode
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import multiprocessing
|
|||
import numpy as np
|
||||
from typing import Optional
|
||||
from extra.utils import download_file
|
||||
from tinygrad.state import torch_load, load_state_dict
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from tinygrad.helpers import getenv
|
||||
import tinygrad.nn as nn
|
||||
from tinygrad.tensor import Tensor
|
||||
|
|
|
@ -8,7 +8,7 @@ import cv2
|
|||
from collections import defaultdict
|
||||
import os
|
||||
import time, io, sys
|
||||
from tinygrad.state import safe_load, load_state_dict
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
|
||||
|
||||
#Model architecture from https://github.com/ultralytics/ultralytics/issues/189
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Tuple, Dict, List
|
|||
from tinygrad.helpers import DType
|
||||
from tinygrad.tensor import Device, Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
import json
|
||||
|
||||
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:
|
||||
|
|
|
@ -143,7 +143,7 @@ class EfficientNet:
|
|||
}
|
||||
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
b0 = torch_load(fetch_as_file(model_urls[self.number]))
|
||||
for k,v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad import nn
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes
|
||||
from extra.utils import get_child, download_file
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from models.resnet import ResNet
|
||||
from models.retinanet import nms as _box_nms
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
import unittest, gc
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import get_parameters, get_state_dict
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
|
|
|
@ -36,7 +36,7 @@ from models.convnext import ConvNeXt
|
|||
from models.efficientnet import EfficientNet
|
||||
from models.resnet import ResNet18
|
||||
from models.vit import ViT
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestInferenceMinKernels(unittest.TestCase):
|
||||
|
|
|
@ -5,7 +5,7 @@ from examples.llama import Transformer, MODEL_PARAMS
|
|||
from test.test_net_speed import start_profile, stop_profile
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
|
|
|
@ -6,29 +6,29 @@ import unittest
|
|||
import io, cv2, os
|
||||
import onnxruntime as ort
|
||||
import ultralytics
|
||||
from tinygrad.state import safe_load, load_state_dict
|
||||
from tinygrad.nn.state import safe_load, load_state_dict
|
||||
|
||||
class TestYOLOv8(unittest.TestCase):
|
||||
def test_all_load_weights(self):
|
||||
for variant in ['n', 's', 'm', 'l', 'x']:
|
||||
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
|
||||
download_file(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors', weights_location)
|
||||
|
||||
depth, width, ratio = get_variant_multiples(variant)
|
||||
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
|
||||
depth, width, ratio = get_variant_multiples(variant)
|
||||
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
state_dict = safe_load(weights_location)
|
||||
load_state_dict(TinyYolov8, state_dict)
|
||||
print(f'successfully loaded weights for yolov{variant}')
|
||||
|
||||
|
||||
def test_predictions(self):
|
||||
test_image_urls = ['https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg', 'https://www.aljazeera.com/wp-content/uploads/2022/10/2022-04-28T192650Z_1186456067_UP1EI4S1I0P14_RTRMADP_3_SOCCER-ENGLAND-MUN-CHE-REPORT.jpg']
|
||||
variant = 'n'
|
||||
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
|
||||
depth, width, ratio = get_variant_multiples(variant)
|
||||
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
depth, width, ratio = get_variant_multiples(variant)
|
||||
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
state_dict = safe_load(weights_location)
|
||||
load_state_dict(TinyYolov8, state_dict)
|
||||
|
||||
|
||||
for i in range(len(test_image_urls)):
|
||||
img_stream = io.BytesIO(fetch(test_image_urls[i]))
|
||||
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
|
||||
|
@ -37,41 +37,40 @@ class TestYOLOv8(unittest.TestCase):
|
|||
post_predictions = postprocess(preds=predictions, img=test_image, orig_imgs=[img])
|
||||
labels = label_predictions(post_predictions)
|
||||
assert labels == {5: 1, 0: 4, 11: 1} if i == 0 else labels == {0: 13, 29: 1, 32: 1}
|
||||
|
||||
|
||||
def test_forward_pass_torch_onnx(self):
|
||||
variant = 'n'
|
||||
weights_location_onnx = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.onnx'
|
||||
weights_location_pt = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.pt'
|
||||
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
|
||||
weights_location_onnx = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.onnx'
|
||||
weights_location_pt = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.pt'
|
||||
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
|
||||
|
||||
download_file(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', weights_location_pt)
|
||||
# the ultralytics export prints a lot of unneccesary things
|
||||
if not os.path.isfile(weights_location_onnx):
|
||||
model = ultralytics.YOLO(model=weights_location_pt, task='Detect')
|
||||
model.export(format="onnx",imgsz=[640, 480])
|
||||
model = ultralytics.YOLO(model=weights_location_pt, task='Detect')
|
||||
model.export(format="onnx",imgsz=[640, 480])
|
||||
|
||||
depth, width, ratio = get_variant_multiples(variant)
|
||||
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
depth, width, ratio = get_variant_multiples(variant)
|
||||
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
|
||||
state_dict = safe_load(weights_location)
|
||||
load_state_dict(TinyYolov8, state_dict)
|
||||
|
||||
|
||||
image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg')).read(), np.uint8)]
|
||||
orig_image = [cv2.imdecode(image_location[0], 1)]
|
||||
|
||||
|
||||
input_image = preprocess(orig_image)
|
||||
|
||||
|
||||
onnx_session = ort.InferenceSession(weights_location_onnx)
|
||||
onnx_input_name = onnx_session.get_inputs()[0].name
|
||||
onnx_output_name = onnx_session.get_outputs()[0].name
|
||||
onnx_output = onnx_session.run([onnx_output_name], {onnx_input_name: input_image.numpy()})
|
||||
|
||||
tiny_output = TinyYolov8(input_image)
|
||||
|
||||
# currently rtol is 0.025 because there is a 1-2% difference in our predictions
|
||||
# because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch.
|
||||
# This difference does not make a difference "visually".
|
||||
|
||||
# currently rtol is 0.025 because there is a 1-2% difference in our predictions
|
||||
# because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch.
|
||||
# This difference does not make a difference "visually".
|
||||
np.testing.assert_allclose(onnx_output[0], tiny_output.numpy(), atol=5e-4, rtol=0.025)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
import unittest
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d, BatchNorm2d, optim
|
||||
|
||||
|
@ -33,7 +33,7 @@ class TestBatchnorm(unittest.TestCase):
|
|||
return self.c2(self.c(x)).relu()
|
||||
lm = LilModel()
|
||||
model_step(lm)
|
||||
|
||||
|
||||
def test_two_conv_bn(self):
|
||||
class LilModel:
|
||||
def __init__(self):
|
||||
|
|
|
@ -2,7 +2,7 @@ import numpy as np
|
|||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR
|
||||
from extra.training import train, evaluate
|
||||
|
|
|
@ -6,9 +6,9 @@ from unittest.mock import patch, MagicMock
|
|||
|
||||
import torch
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.utils import fetch, temp, download_file
|
||||
from tinygrad.state import torch_load
|
||||
from tinygrad.nn.state import torch_load
|
||||
from PIL import Image
|
||||
|
||||
@unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI")
|
||||
|
@ -33,7 +33,7 @@ class TestFetchRelative(unittest.TestCase):
|
|||
os.chdir(self.tempdir.name)
|
||||
with open('test_file.txt', 'x') as f:
|
||||
f.write("12345")
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
os.chdir(self.working_dir)
|
||||
self.tempdir.cleanup()
|
||||
|
@ -41,7 +41,7 @@ class TestFetchRelative(unittest.TestCase):
|
|||
#test ./
|
||||
def test_fetch_relative_dotslash(self):
|
||||
self.assertEqual(b'12345', fetch("./test_file.txt"))
|
||||
|
||||
|
||||
#test ../
|
||||
def test_fetch_relative_dotdotslash(self):
|
||||
os.mkdir('test_file_path')
|
||||
|
@ -92,7 +92,7 @@ class TestUtils(unittest.TestCase):
|
|||
)
|
||||
if isfloat16: model = model.half()
|
||||
|
||||
path = temp(f"test_load_{isfloat16}.pt")
|
||||
path = temp(f"test_load_{isfloat16}.pt")
|
||||
torch.save(model.state_dict(), path)
|
||||
model2 = torch_load(path)
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ def get_question_samp(bsz, seq_len, vocab_size, seed):
|
|||
return in_ids, mask, seg_ids
|
||||
|
||||
def set_equal_weights(mdl, torch_mdl):
|
||||
from tinygrad.state import get_state_dict
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
state, torch_state = get_state_dict(mdl), torch_mdl.state_dict()
|
||||
assert len(state) == len(torch_state)
|
||||
for k, v in state.items():
|
||||
|
|
|
@ -2,7 +2,7 @@ import torch
|
|||
from torch import nn
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters, get_state_dict
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
|
||||
from tinygrad.tensor import Tensor
|
||||
from extra.datasets import fetch_mnist
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import optim, BatchNorm2d
|
||||
from extra.training import train, evaluate
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest, time
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.lazy import Device
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.tensor import Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
|
|
@ -2,7 +2,7 @@ import pathlib
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.helpers import Timing
|
||||
|
|
Loading…
Reference in New Issue