mirror of https://github.com/commaai/tinygrad.git
Fix naming conflict with huggingface datasets (#1161)
* Rename in files * Move files * Moved to extra/datasets as suggested * Changes to files * Fixed stupid mistake --------- Co-authored-by: terafo <terafo@protonmail.com>
This commit is contained in:
parent
fd66d1ca00
commit
aa60feda48
|
@ -20,14 +20,14 @@ recognize*
|
|||
disassemblers/applegpu
|
||||
disassemblers/cuda_ioctl_sniffer
|
||||
*.prof
|
||||
datasets/cifar-10-python.tar.gz
|
||||
datasets/librispeech/
|
||||
datasets/imagenet/
|
||||
datasets/kits19/
|
||||
datasets/squad/
|
||||
datasets/img_align_celeba*
|
||||
datasets/open-images-v6-mlperf
|
||||
datasets/kits/
|
||||
datasets/COCO/
|
||||
datasets/audio*
|
||||
extra/datasets/cifar-10-python.tar.gz
|
||||
extra/datasets/librispeech/
|
||||
extra/datasets/imagenet/
|
||||
extra/datasets/kits19/
|
||||
extra/datasets/squad/
|
||||
extra/datasets/img_align_celeba*
|
||||
extra/datasets/open-images-v6-mlperf
|
||||
extra/datasets/kits/
|
||||
extra/datasets/COCO/
|
||||
extra/datasets/audio*
|
||||
venv
|
||||
|
|
|
@ -1,51 +0,0 @@
|
|||
# Python version of https://gist.github.com/antoinebrl/7d00d5cb6c95ef194c737392ef7e476a
|
||||
from extra.utils import download_file
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import tarfile, os
|
||||
|
||||
def imagenet_extract(file, path, small=False):
|
||||
with tarfile.open(name=file) as tar:
|
||||
if small: # Show progressbar only for big files
|
||||
for member in tar.getmembers(): tar.extract(path=path, member=member)
|
||||
else:
|
||||
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member)
|
||||
tar.close()
|
||||
|
||||
def imagenet_prepare_val():
|
||||
# Read in the labels file
|
||||
with open(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt", 'r') as f:
|
||||
labels = f.read().splitlines()
|
||||
f.close()
|
||||
# Get a list of images
|
||||
images = os.listdir(Path(__file__).parent.parent / "datasets/imagenet/val")
|
||||
images.sort()
|
||||
# Create folders and move files into those
|
||||
for co,dir in enumerate(labels):
|
||||
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val" / dir, exist_ok=True)
|
||||
os.replace(Path(__file__).parent.parent / "datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "datasets/imagenet/val" / dir / images[co])
|
||||
os.remove(Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
|
||||
|
||||
def imagenet_prepare_train():
|
||||
images = os.listdir(Path(__file__).parent.parent / "datasets/imagenet/train")
|
||||
for co,tarf in enumerate(images):
|
||||
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
|
||||
if Path(Path(__file__).parent.parent / "datasets/imagenet/train" / images[co]).is_file():
|
||||
images[co] = tarf[:-4] # remove .tar from extracted tar files
|
||||
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/train" / images[co], exist_ok=True)
|
||||
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/train" / tarf, Path(__file__).parent.parent / "datasets/imagenet/train" / images[co], small=True)
|
||||
os.remove(Path(__file__).parent.parent / "datasets/imagenet/train" / tarf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/val", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent.parent / "datasets/imagenet/train", exist_ok=True)
|
||||
download_file("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent.parent / "datasets/imagenet/imagenet_class_index.json")
|
||||
download_file("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent.parent / "datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
|
||||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar") # 7GB
|
||||
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "datasets/imagenet/val")
|
||||
imagenet_prepare_val()
|
||||
if os.getenv('IMGNET_TRAIN', None) is not None:
|
||||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar") #138GB!
|
||||
imagenet_extract(Path(__file__).parent.parent / "datasets/imagenet/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "datasets/imagenet/train")
|
||||
imagenet_prepare_train()
|
|
@ -188,7 +188,7 @@ Variable | Possible Value(s) | Description
|
|||
---|---|---
|
||||
BS | [8, 16, 32, 64, 128] | batch size to use
|
||||
|
||||
### datasets/imagenet_download.py
|
||||
### extra/datasets/imagenet_download.py
|
||||
|
||||
Variable | Possible Value(s) | Description
|
||||
---|---|---
|
||||
|
|
|
@ -169,11 +169,11 @@ There is a simpler way to do this just by using `get_parameters(net)` from `tiny
|
|||
The parameters are just listed out explicitly here for clarity.
|
||||
|
||||
Now that we have our network, loss function, and optimizer defined all we are missing is the data to train on!
|
||||
There are a couple of dataset loaders in tinygrad located in [/datasets](/datasets).
|
||||
There are a couple of dataset loaders in tinygrad located in [/extra/datasets](/extra/datasets).
|
||||
We will be using the MNIST dataset loader.
|
||||
|
||||
```python
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
```
|
||||
|
||||
Now we have everything we need to start training our neural network.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# https://siboehm.com/articles/22/CUDA-MMM
|
||||
import time
|
||||
import numpy as np
|
||||
from datasets import fetch_cifar
|
||||
from extra.datasets import fetch_cifar
|
||||
from tinygrad import nn
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.nn import optim
|
||||
|
|
|
@ -5,7 +5,7 @@ import platform
|
|||
from torch import nn
|
||||
from torch import optim
|
||||
|
||||
from datasets import fetch_cifar
|
||||
from extra.datasets import fetch_cifar
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
# allow TF32
|
||||
|
|
|
@ -25,7 +25,7 @@ def eval_resnet():
|
|||
mdljit = TinyJit(mdlrun)
|
||||
|
||||
# evaluation on the mlperf classes of the validation set from imagenet
|
||||
from datasets.imagenet import iterate
|
||||
from extra.datasets.imagenet import iterate
|
||||
from extra.helpers import cross_process
|
||||
|
||||
BS = 64
|
||||
|
@ -56,7 +56,7 @@ def eval_resnet():
|
|||
def eval_unet3d():
|
||||
# UNet3D
|
||||
from models.unet3d import UNet3D
|
||||
from datasets.kits19 import iterate, sliding_window_inference
|
||||
from extra.datasets.kits19 import iterate, sliding_window_inference
|
||||
from examples.mlperf.metrics import get_dice_score
|
||||
mdl = UNet3D()
|
||||
mdl.load_from_pretrained()
|
||||
|
@ -86,7 +86,7 @@ def eval_retinanet():
|
|||
x /= input_std
|
||||
return x
|
||||
|
||||
from datasets.openimages import openimages, iterate
|
||||
from extra.datasets.openimages import openimages, iterate
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
from contextlib import redirect_stdout
|
||||
|
@ -135,7 +135,7 @@ def eval_rnnt():
|
|||
mdl = RNNT()
|
||||
mdl.load_from_pretrained()
|
||||
|
||||
from datasets.librispeech import iterate
|
||||
from extra.datasets.librispeech import iterate
|
||||
from examples.mlperf.metrics import word_error_rate
|
||||
|
||||
LABELS = [" ", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "'"]
|
||||
|
@ -168,7 +168,7 @@ def eval_bert():
|
|||
def run(input_ids, input_mask, segment_ids):
|
||||
return mdl(input_ids, input_mask, segment_ids).realize()
|
||||
|
||||
from datasets.squad import iterate
|
||||
from extra.datasets.squad import iterate
|
||||
from examples.mlperf.helpers import get_bert_qa_prediction
|
||||
from examples.mlperf.metrics import f1_score
|
||||
from transformers import BertTokenizer
|
||||
|
@ -198,7 +198,7 @@ def eval_mrcnn():
|
|||
from tqdm import tqdm
|
||||
from models.mask_rcnn import MaskRCNN
|
||||
from models.resnet import ResNet
|
||||
from datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
||||
from extra.datasets.coco import BASEDIR, images, convert_prediction_to_coco_bbox, convert_prediction_to_coco_mask, accumulate_predictions_for_coco, evaluate_predictions_on_coco, iterate
|
||||
from examples.mask_rcnn import compute_prediction_batched, Image
|
||||
mdl = MaskRCNN(ResNet(50, num_classes=None, stride_in_1x1=True))
|
||||
mdl.load_from_pretrained()
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.state import get_parameters
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.nn import optim
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
class LinearGen:
|
||||
def __init__(self):
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.state import get_parameters
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import BatchNorm2d, optim
|
||||
from tinygrad.helpers import getenv
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
from extra.augment import augment_img
|
||||
from extra.training import train, evaluate, sparse_categorical_crossentropy
|
||||
GPU = getenv("GPU")
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.state import get_parameters
|
|||
from tinygrad.nn import optim
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from datasets import fetch_cifar
|
||||
from extra.datasets import fetch_cifar
|
||||
from models.efficientnet import EfficientNet
|
||||
|
||||
class TinyConvNet:
|
||||
|
@ -46,7 +46,7 @@ if __name__ == "__main__":
|
|||
print(f"training with batch size {BS} for {steps} steps")
|
||||
|
||||
if IMAGENET:
|
||||
from datasets.imagenet import fetch_batch
|
||||
from extra.datasets.imagenet import fetch_batch
|
||||
def loader(q):
|
||||
while 1:
|
||||
try:
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.nn import optim
|
|||
from tinygrad.helpers import getenv
|
||||
from extra.training import train, evaluate
|
||||
from models.resnet import ResNet
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
|
||||
class ComposeTransforms:
|
||||
|
|
|
@ -4,7 +4,7 @@ import os
|
|||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
sys.path.append(os.path.join(os.getcwd(), 'test'))
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
from tqdm import trange
|
||||
|
||||
def augment_img(X, rotate=10, px=3):
|
||||
|
|
|
@ -12,7 +12,7 @@ iou = _mask.iou
|
|||
merge = _mask.merge
|
||||
frPyObjects = _mask.frPyObjects
|
||||
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets" / "COCO"
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "extra" / "datasets" / "COCO"
|
||||
BASEDIR.mkdir(exist_ok=True)
|
||||
|
||||
def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for row in rows}
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
from PIL import Image
|
||||
import functools, pathlib
|
||||
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/imagenet"
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "extra/datasets/imagenet"
|
||||
ci = json.load(open(BASEDIR / "imagenet_class_index.json"))
|
||||
cir = {v[0]: int(k) for k,v in ci.items()}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
# Python version of https://gist.github.com/antoinebrl/7d00d5cb6c95ef194c737392ef7e476a
|
||||
from extra.utils import download_file
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import tarfile, os
|
||||
|
||||
def imagenet_extract(file, path, small=False):
|
||||
with tarfile.open(name=file) as tar:
|
||||
if small: # Show progressbar only for big files
|
||||
for member in tar.getmembers(): tar.extract(path=path, member=member)
|
||||
else:
|
||||
for member in tqdm(iterable=tar.getmembers(), total=len(tar.getmembers())): tar.extract(path=path, member=member)
|
||||
tar.close()
|
||||
|
||||
def imagenet_prepare_val():
|
||||
# Read in the labels file
|
||||
with open(Path(__file__).parent.parent / "extra/datasets/imagenet/imagenet_2012_validation_synset_labels.txt", 'r') as f:
|
||||
labels = f.read().splitlines()
|
||||
f.close()
|
||||
# Get a list of images
|
||||
images = os.listdir(Path(__file__).parent.parent / "extra/datasets/imagenet/val")
|
||||
images.sort()
|
||||
# Create folders and move files into those
|
||||
for co,dir in enumerate(labels):
|
||||
os.makedirs(Path(__file__).parent.parent / "extra/datasets/imagenet/val" / dir, exist_ok=True)
|
||||
os.replace(Path(__file__).parent.parent / "extra/datasets/imagenet/val" / images[co], Path(__file__).parent.parent / "extra/datasets/imagenet/val" / dir / images[co])
|
||||
os.remove(Path(__file__).parent.parent / "extra/datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
|
||||
|
||||
def imagenet_prepare_train():
|
||||
images = os.listdir(Path(__file__).parent.parent / "extra/datasets/imagenet/train")
|
||||
for co,tarf in enumerate(images):
|
||||
# for each tar file found. Create a folder with its name. Extract into that folder. Remove tar file
|
||||
if Path(Path(__file__).parent.parent / "extra/datasets/imagenet/train" / images[co]).is_file():
|
||||
images[co] = tarf[:-4] # remove .tar from extracted tar files
|
||||
os.makedirs(Path(__file__).parent.parent / "extra/datasets/imagenet/train" / images[co], exist_ok=True)
|
||||
imagenet_extract(Path(__file__).parent.parent / "extra/datasets/imagenet/train" / tarf, Path(__file__).parent.parent / "extra/datasets/imagenet/train" / images[co], small=True)
|
||||
os.remove(Path(__file__).parent.parent / "extra/datasets/imagenet/train" / tarf)
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.makedirs(Path(__file__).parent.parent / "extra/datasets/imagenet", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent.parent / "extra/datasets/imagenet/val", exist_ok=True)
|
||||
os.makedirs(Path(__file__).parent.parent / "extra/datasets/imagenet/train", exist_ok=True)
|
||||
download_file("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent.parent / "extra/datasets/imagenet/imagenet_class_index.json")
|
||||
download_file("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent.parent / "extra/datasets/imagenet/imagenet_2012_validation_synset_labels.txt")
|
||||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "extra/datasets/imagenet/ILSVRC2012_img_val.tar") # 7GB
|
||||
imagenet_extract(Path(__file__).parent.parent / "extra/datasets/imagenet/ILSVRC2012_img_val.tar", Path(__file__).parent.parent / "extra/datasets/imagenet/val")
|
||||
imagenet_prepare_val()
|
||||
if os.getenv('IMGNET_TRAIN', None) is not None:
|
||||
download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "extra/datasets/imagenet/ILSVRC2012_img_train.tar") #138GB!
|
||||
imagenet_extract(Path(__file__).parent.parent / "extra/datasets/imagenet/ILSVRC2012_img_train.tar", Path(__file__).parent.parent / "extra/datasets/imagenet/train")
|
||||
imagenet_prepare_train()
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
BASEDIR = Path(__file__).parent.parent.resolve() / "datasets" / "kits19" / "data"
|
||||
BASEDIR = Path(__file__).parent.parent.resolve() / "extra" / "datasets" / "kits19" / "data"
|
||||
|
||||
"""
|
||||
To download the dataset:
|
||||
|
@ -19,7 +19,7 @@ cd kits19
|
|||
pip3 install -r requirements.txt
|
||||
python3 -m starter_code.get_imaging
|
||||
cd ..
|
||||
mv kits datasets
|
||||
mv kits extra/datasets
|
||||
```
|
||||
"""
|
||||
|
|
@ -5,7 +5,7 @@ import librosa
|
|||
import soundfile
|
||||
|
||||
"""
|
||||
The dataset has to be downloaded manually from https://www.openslr.org/12/ and put in `datasets/librispeech`.
|
||||
The dataset has to be downloaded manually from https://www.openslr.org/12/ and put in `extra/datasets/librispeech`.
|
||||
For mlperf validation the dev-clean dataset is used.
|
||||
|
||||
Then all the flacs have to be converted to wav using something like:
|
||||
|
@ -13,9 +13,9 @@ Then all the flacs have to be converted to wav using something like:
|
|||
for file in $(find * | grep flac); do ffmpeg -i $file -ar 16k "$(dirname $file)/$(basename $file .flac).wav"; done
|
||||
```
|
||||
|
||||
Then this [file](https://github.com/mlcommons/inference/blob/master/speech_recognition/rnnt/dev-clean-wav.json) has to also be put in `datasets/librispeech`.
|
||||
Then this [file](https://github.com/mlcommons/inference/blob/master/speech_recognition/rnnt/dev-clean-wav.json) has to also be put in `extra/datasets/librispeech`.
|
||||
"""
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/librispeech"
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "extra/datasets/librispeech"
|
||||
with open(BASEDIR / "dev-clean-wav.json") as f:
|
||||
ci = json.load(f)
|
||||
|
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||
import pandas as pd
|
||||
import concurrent.futures
|
||||
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "datasets/open-images-v6-mlperf"
|
||||
BASEDIR = pathlib.Path(__file__).parent.parent / "extra/datasets/open-images-v6-mlperf"
|
||||
BUCKET_NAME = "open-images-dataset"
|
||||
BBOX_ANNOTATIONS_URL = "https://storage.googleapis.com/openimages/v5/validation-annotations-bbox.csv"
|
||||
MAP_CLASSES_URL = "https://storage.googleapis.com/openimages/v5/class-descriptions-boxable.csv"
|
|
@ -1,6 +1,6 @@
|
|||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.tensor import Tensor
|
||||
from datasets.imagenet import iterate, get_val_files
|
||||
from extra.datasets.imagenet import iterate, get_val_files
|
||||
|
||||
if __name__ == "__main__":
|
||||
#sz = len(get_val_files())
|
|
@ -5,7 +5,7 @@ from transformers import BertTokenizer
|
|||
import numpy as np
|
||||
from extra.utils import download_file
|
||||
|
||||
BASEDIR = Path(__file__).parent.parent / "datasets/squad"
|
||||
BASEDIR = Path(__file__).parent.parent / "extra/datasets/squad"
|
||||
def init_dataset():
|
||||
os.makedirs(BASEDIR, exist_ok=True)
|
||||
download_file("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json")
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python3
|
||||
from examples.hlb_cifar10 import SpeedyResNet, fetch_batch
|
||||
from examples.hlb_cifar10_torch import SpeedyResNet as SpeedyResNetTorch
|
||||
from datasets import fetch_cifar
|
||||
from extra.datasets import fetch_cifar
|
||||
from test.models.test_end2end import compare_tiny_torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -6,7 +6,7 @@ from tinygrad.state import get_parameters
|
|||
from tinygrad.nn.optim import Adam
|
||||
from extra.lr_scheduler import MultiStepLR, ReduceLROnPlateau, CosineAnnealingLR, OneCycleLR
|
||||
from extra.training import train, evaluate
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
np.random.seed(1337)
|
||||
Tensor.manual_seed(1337)
|
||||
|
|
|
@ -5,7 +5,7 @@ import numpy as np
|
|||
from tinygrad.state import get_parameters, get_state_dict
|
||||
from tinygrad.nn import optim, Linear, Conv2d, BatchNorm2d
|
||||
from tinygrad.tensor import Tensor
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
def compare_tiny_torch(model, model_torch, X, Y):
|
||||
Tensor.training = True
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.state import get_parameters
|
|||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.nn import optim, BatchNorm2d
|
||||
from extra.training import train, evaluate
|
||||
from datasets import fetch_mnist
|
||||
from extra.datasets import fetch_mnist
|
||||
|
||||
# load the mnist dataset
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
|
|
Loading…
Reference in New Issue