move state to nn/state (#1619)

This commit is contained in:
George Hotz 2023-08-22 07:36:24 -07:00 committed by GitHub
parent 1e93fd5449
commit 718ced296c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 64 additions and 65 deletions

View File

@ -145,7 +145,7 @@ def sparse_categorical_crossentropy(out, Y, ignore_index=-1):
loss_mask = Y != ignore_index
num_classes = out.shape[-1]
y_counter = Tensor.arange(num_classes, requires_grad=False).unsqueeze(0).expand(Y.numel(), num_classes)
y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0)
y = (y_counter == Y.flatten().reshape(-1, 1)).where(-1.0, 0)
y = y * loss_mask.reshape(-1, 1)
y = y.reshape(*Y.shape, num_classes)
return out.log_softmax().mul(y).sum() / loss_mask.sum()
@ -165,7 +165,7 @@ opt = SGD([net.l1.weight, net.l2.weight], lr=3e-4)
We can see that we are passing in the parameters of our neural network to the optimizer.
This is due to the fact that the optimizer needs to know which parameters to update.
There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.state` which will return a list of all the parameters in the neural network.
There is a simpler way to do this just by using `get_parameters(net)` from `tinygrad.nn.state` which will return a list of all the parameters in the neural network.
The parameters are just listed out explicitly here for clarity.
Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
@ -291,7 +291,7 @@ The standard weight format for tinygrad is [safetensors](https://github.com/hugg
There are functions in [state.py](/tinygrad/state.py) to save and load models to and from this format.
```python
from tinygrad.state import safe_save, safe_load, get_state_dict, load_state_dict
from tinygrad.nn.state import safe_save, safe_load, get_state_dict, load_state_dict
# first we need the state dict of our model
state_dict = get_state_dict(net)

View File

@ -3,7 +3,7 @@ import gc
import time
from tqdm import trange
from models.efficientnet import EfficientNet
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters

View File

@ -1,6 +1,6 @@
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from tinygrad.state import safe_save
from tinygrad.nn.state import safe_save
from extra.utils import fetch
from extra.export_model import export_model
from tinygrad.helpers import getenv

View File

@ -1,7 +1,7 @@
from typing import Optional, Tuple
from numpy.typing import NDArray
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.helpers import getenv

View File

@ -132,7 +132,7 @@ class GPT2:
@staticmethod
def build(model_size="gpt2"):
import tiktoken
from tinygrad.state import torch_load, load_state_dict, get_state_dict
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from extra.utils import fetch_as_file
tokenizer = tiktoken.get_encoding("gpt2")

View File

@ -14,7 +14,7 @@ import random
import numpy as np
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
from tinygrad import nn
from tinygrad.state import get_state_dict
from tinygrad.nn.state import get_state_dict
from tinygrad.nn import optim
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor

View File

@ -251,7 +251,7 @@ class LLaMa:
sp_model = SentencePieceProcessor(model_file=str(tokenizer_path))
assert sp_model.vocab_size() == VOCAB_SIZE
from tinygrad.state import torch_load, load_state_dict
from tinygrad.nn.state import torch_load, load_state_dict
params = MODEL_PARAMS[model_gen][model_size]
model = Transformer(**params["args"], linear=AbsmaxQuantizedLinear) if quantize else Transformer(**params["args"])
weights = concat_weights([torch_load(filename) for filename in [f"{model_path}/{model_size}/consolidated.{i:02d}.pth" for i in range(params["files"])]])

View File

@ -3,7 +3,7 @@ import numpy as np
from tqdm import trange
import torch
from torchvision.utils import make_grid, save_image
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
from tinygrad.nn import optim

View File

@ -2,7 +2,7 @@
#inspired by https://github.com/Matuzas77/MNIST-0.17/blob/master/MNIST_final_solution.ipynb
import sys
import numpy as np
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.nn import BatchNorm2d, optim
from tinygrad.helpers import getenv

View File

@ -2,7 +2,7 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
if __name__ == "__main__":
Tensor.training = True

View File

@ -7,7 +7,7 @@ from typing import Tuple, Optional, Type
from tinygrad import nn
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, getenv
from tinygrad.state import torch_load
from tinygrad.nn.state import torch_load
from examples.vits import ResidualCouplingBlock, PosteriorEncoder, Encoder, ResBlock1, ResBlock2, LRELU_SLOPE, sequence_mask, split, download_if_not_present, get_hparams_from_file, load_checkpoint, weight_norm, HParams
from examples.sovits_helpers import preprocess
import soundfile

View File

@ -12,7 +12,7 @@ from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes, GlobalCounters
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict, get_state_dict
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
class AttnBlock:
def __init__(self, in_channels):

View File

@ -3,7 +3,7 @@ import time
from multiprocessing import Process, Queue
import numpy as np
from tqdm import trange
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor

View File

@ -2,7 +2,7 @@
import numpy as np
from PIL import Image
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.helpers import getenv
from extra.training import train, evaluate

View File

@ -2,7 +2,7 @@
import numpy as np
import random
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.nn.optim import Adam
from extra.training import train, evaluate
from models.transformer import Transformer

View File

@ -5,7 +5,7 @@ from typing import List
from extra.utils import download_file
from tinygrad import nn
from tinygrad.helpers import dtypes
from tinygrad.state import torch_load
from tinygrad.nn.state import torch_load
from tinygrad.tensor import Tensor
from unidecode import unidecode

View File

@ -7,7 +7,7 @@ import multiprocessing
import numpy as np
from typing import Optional
from extra.utils import download_file
from tinygrad.state import torch_load, load_state_dict
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv
import tinygrad.nn as nn
from tinygrad.tensor import Tensor

View File

@ -8,7 +8,7 @@ import cv2
from collections import defaultdict
import os
import time, io, sys
from tinygrad.state import safe_load, load_state_dict
from tinygrad.nn.state import safe_load, load_state_dict
#Model architecture from https://github.com/ultralytics/ultralytics/issues/189

View File

@ -2,7 +2,7 @@ from typing import Tuple, Dict, List
from tinygrad.helpers import DType
from tinygrad.tensor import Device, Tensor
from tinygrad.jit import TinyJit
from tinygrad.state import get_state_dict
from tinygrad.nn.state import get_state_dict
import json
def compile_net(run:TinyJit, special_names:Dict[int,str]) -> Tuple[Dict[str,str],List[Tuple[str,List[str],List[int]]],Dict[str,Tuple[int,DType,int]],Dict[str,Tensor]]:

View File

@ -143,7 +143,7 @@ class EfficientNet:
}
from extra.utils import fetch_as_file
from tinygrad.state import torch_load
from tinygrad.nn.state import torch_load
b0 = torch_load(fetch_as_file(model_urls[self.number]))
for k,v in b0.items():
if k.endswith("num_batches_tracked"): continue

View File

@ -7,7 +7,7 @@ from tinygrad import nn
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from extra.utils import get_child, download_file
from tinygrad.state import torch_load
from tinygrad.nn.state import torch_load
from models.resnet import ResNet
from models.retinanet import nms as _box_nms

View File

@ -2,7 +2,7 @@
import unittest, gc
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.state import get_parameters, get_state_dict
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod

View File

@ -36,7 +36,7 @@ from models.convnext import ConvNeXt
from models.efficientnet import EfficientNet
from models.resnet import ResNet18
from models.vit import ViT
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestInferenceMinKernels(unittest.TestCase):

View File

@ -5,7 +5,7 @@ from examples.llama import Transformer, MODEL_PARAMS
from test.test_net_speed import start_profile, stop_profile
from tinygrad.tensor import Tensor
from tinygrad.lazy import Device
from tinygrad.state import get_state_dict
from tinygrad.nn.state import get_state_dict
from tinygrad.ops import Compiled
from tinygrad.helpers import dtypes, prod
from tinygrad.runtime.lib import RawBuffer

View File

@ -6,29 +6,29 @@ import unittest
import io, cv2, os
import onnxruntime as ort
import ultralytics
from tinygrad.state import safe_load, load_state_dict
from tinygrad.nn.state import safe_load, load_state_dict
class TestYOLOv8(unittest.TestCase):
def test_all_load_weights(self):
for variant in ['n', 's', 'm', 'l', 'x']:
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
download_file(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors', weights_location)
depth, width, ratio = get_variant_multiples(variant)
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
depth, width, ratio = get_variant_multiples(variant)
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
state_dict = safe_load(weights_location)
load_state_dict(TinyYolov8, state_dict)
print(f'successfully loaded weights for yolov{variant}')
def test_predictions(self):
test_image_urls = ['https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg', 'https://www.aljazeera.com/wp-content/uploads/2022/10/2022-04-28T192650Z_1186456067_UP1EI4S1I0P14_RTRMADP_3_SOCCER-ENGLAND-MUN-CHE-REPORT.jpg']
variant = 'n'
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
depth, width, ratio = get_variant_multiples(variant)
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
depth, width, ratio = get_variant_multiples(variant)
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
state_dict = safe_load(weights_location)
load_state_dict(TinyYolov8, state_dict)
for i in range(len(test_image_urls)):
img_stream = io.BytesIO(fetch(test_image_urls[i]))
img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1)
@ -37,41 +37,40 @@ class TestYOLOv8(unittest.TestCase):
post_predictions = postprocess(preds=predictions, img=test_image, orig_imgs=[img])
labels = label_predictions(post_predictions)
assert labels == {5: 1, 0: 4, 11: 1} if i == 0 else labels == {0: 13, 29: 1, 32: 1}
def test_forward_pass_torch_onnx(self):
variant = 'n'
weights_location_onnx = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.onnx'
weights_location_pt = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.pt'
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
weights_location_onnx = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.onnx'
weights_location_pt = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.pt'
weights_location = Path(__file__).parent.parent.parent / "weights" / f'yolov8{variant}.safetensors'
download_file(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', weights_location_pt)
# the ultralytics export prints a lot of unneccesary things
if not os.path.isfile(weights_location_onnx):
model = ultralytics.YOLO(model=weights_location_pt, task='Detect')
model.export(format="onnx",imgsz=[640, 480])
model = ultralytics.YOLO(model=weights_location_pt, task='Detect')
model.export(format="onnx",imgsz=[640, 480])
depth, width, ratio = get_variant_multiples(variant)
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
depth, width, ratio = get_variant_multiples(variant)
TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80)
state_dict = safe_load(weights_location)
load_state_dict(TinyYolov8, state_dict)
image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg')).read(), np.uint8)]
orig_image = [cv2.imdecode(image_location[0], 1)]
input_image = preprocess(orig_image)
onnx_session = ort.InferenceSession(weights_location_onnx)
onnx_input_name = onnx_session.get_inputs()[0].name
onnx_output_name = onnx_session.get_outputs()[0].name
onnx_output = onnx_session.run([onnx_output_name], {onnx_input_name: input_image.numpy()})
tiny_output = TinyYolov8(input_image)
# currently rtol is 0.025 because there is a 1-2% difference in our predictions
# because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch.
# This difference does not make a difference "visually".
# currently rtol is 0.025 because there is a 1-2% difference in our predictions
# because of the zero padding in SPPF module (line 280) maxpooling layers rather than the -infinity in torch.
# This difference does not make a difference "visually".
np.testing.assert_allclose(onnx_output[0], tiny_output.numpy(), atol=5e-4, rtol=0.025)
if __name__ == '__main__':
unittest.main()

View File

@ -1,5 +1,5 @@
import unittest
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d, BatchNorm2d, optim
@ -33,7 +33,7 @@ class TestBatchnorm(unittest.TestCase):
return self.c2(self.c(x)).relu()
lm = LilModel()
model_step(lm)
def test_two_conv_bn(self):
class LilModel:
def __init__(self):

View File

@ -2,7 +2,7 @@ import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.nn.optim import Adam
from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR
from extra.training import train, evaluate

View File

@ -6,9 +6,9 @@ from unittest.mock import patch, MagicMock
import torch
import numpy as np
from tinygrad.helpers import getenv
from tinygrad.helpers import getenv
from extra.utils import fetch, temp, download_file
from tinygrad.state import torch_load
from tinygrad.nn.state import torch_load
from PIL import Image
@unittest.skipIf(getenv("CI", "") != "", "no internet tests in CI")
@ -33,7 +33,7 @@ class TestFetchRelative(unittest.TestCase):
os.chdir(self.tempdir.name)
with open('test_file.txt', 'x') as f:
f.write("12345")
def tearDown(self):
os.chdir(self.working_dir)
self.tempdir.cleanup()
@ -41,7 +41,7 @@ class TestFetchRelative(unittest.TestCase):
#test ./
def test_fetch_relative_dotslash(self):
self.assertEqual(b'12345', fetch("./test_file.txt"))
#test ../
def test_fetch_relative_dotdotslash(self):
os.mkdir('test_file_path')
@ -92,7 +92,7 @@ class TestUtils(unittest.TestCase):
)
if isfloat16: model = model.half()
path = temp(f"test_load_{isfloat16}.pt")
path = temp(f"test_load_{isfloat16}.pt")
torch.save(model.state_dict(), path)
model2 = torch_load(path)

View File

@ -13,7 +13,7 @@ def get_question_samp(bsz, seq_len, vocab_size, seed):
return in_ids, mask, seg_ids
def set_equal_weights(mdl, torch_mdl):
from tinygrad.state import get_state_dict
from tinygrad.nn.state import get_state_dict
state, torch_state = get_state_dict(mdl), torch_mdl.state_dict()
assert len(state) == len(torch_state)
for k, v in state.items():

View File

@ -2,7 +2,7 @@ import torch
from torch import nn
import unittest
import numpy as np
from tinygrad.state import get_parameters, get_state_dict
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
from tinygrad.tensor import Tensor
from extra.datasets import fetch_mnist

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import optim, BatchNorm2d
from extra.training import train, evaluate

View File

@ -1,7 +1,7 @@
import unittest, time
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
from tinygrad.lazy import Device

View File

@ -1,7 +1,7 @@
import unittest
import time
import numpy as np
from tinygrad.state import get_parameters
from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad.tensor import Device
from tinygrad.helpers import getenv

View File

@ -2,7 +2,7 @@ import pathlib
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import dtypes
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.helpers import Timing