mirror of https://github.com/commaai/tinygrad.git
Support `gunzip` in `fetch` (#6176)
* init * update * clean * add type * clean * fix import order * shorten variable names
This commit is contained in:
parent
705b8066ab
commit
8556d0c642
|
@ -1,4 +1,4 @@
|
||||||
import unittest
|
import gzip, unittest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tinygrad.helpers import Context, ContextVar
|
from tinygrad.helpers import Context, ContextVar
|
||||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
|
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
|
||||||
|
@ -161,6 +161,24 @@ class TestFetch(unittest.TestCase):
|
||||||
assert pimg.size == (77, 77), pimg.size
|
assert pimg.size == (77, 77), pimg.size
|
||||||
assert img.parent.name == "images"
|
assert img.parent.name == "images"
|
||||||
|
|
||||||
|
def test_fetch_gunzip_valid(self):
|
||||||
|
# compare fetch(gunzip=True) to fetch(gunzip=False) plus decompressing afterwards
|
||||||
|
gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.tar.gz'
|
||||||
|
fp_gz = fetch(gzip_url, gunzip=True)
|
||||||
|
fp_no_gz = fetch(gzip_url, gunzip=False)
|
||||||
|
with open(fp_gz, 'rb') as f: content_gz = f.read()
|
||||||
|
with open(fp_no_gz, 'rb') as f: content_no_gz = gzip.decompress(f.read())
|
||||||
|
assert fp_gz.stat().st_size > fp_no_gz.stat().st_size
|
||||||
|
assert isinstance(content_gz, bytes) and isinstance(content_no_gz, bytes)
|
||||||
|
assert len(content_gz) == len(content_no_gz)
|
||||||
|
assert content_gz == content_no_gz
|
||||||
|
|
||||||
|
def test_fetch_gunzip_invalid(self):
|
||||||
|
# given a non-gzipped file, fetch(gunzip=True) fails
|
||||||
|
no_gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.zip'
|
||||||
|
with self.assertRaises(gzip.BadGzipFile):
|
||||||
|
fetch(no_gzip_url, gunzip=True)
|
||||||
|
|
||||||
class TestFullyFlatten(unittest.TestCase):
|
class TestFullyFlatten(unittest.TestCase):
|
||||||
def test_fully_flatten(self):
|
def test_fully_flatten(self):
|
||||||
self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2])
|
self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2])
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys
|
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
|
||||||
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
|
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
|
||||||
|
@ -262,19 +262,22 @@ def diskcache(func):
|
||||||
|
|
||||||
# *** http support ***
|
# *** http support ***
|
||||||
|
|
||||||
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
|
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
|
||||||
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
||||||
if url.startswith(("/", ".")): return pathlib.Path(url)
|
if url.startswith(("/", ".")): return pathlib.Path(url)
|
||||||
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
||||||
else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
|
else:
|
||||||
|
fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / \
|
||||||
|
((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
||||||
if not fp.is_file() or not allow_caching:
|
if not fp.is_file() or not allow_caching:
|
||||||
with urllib.request.urlopen(url, timeout=10) as r:
|
with urllib.request.urlopen(url, timeout=10) as r:
|
||||||
assert r.status == 200
|
assert r.status == 200
|
||||||
total_length = int(r.headers.get('content-length', 0))
|
total_length = int(r.headers.get('content-length', 0))
|
||||||
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
||||||
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
||||||
|
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
||||||
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
||||||
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
|
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
|
||||||
f.close()
|
f.close()
|
||||||
progress_bar.update(close=True)
|
progress_bar.update(close=True)
|
||||||
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
|
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
|
||||||
|
|
Loading…
Reference in New Issue