Support `gunzip` in `fetch` (#6176)

* init

* update

* clean

* add type

* clean

* fix import order

* shorten variable names
This commit is contained in:
Eitan Turok 2024-08-19 15:04:40 -04:00 committed by GitHub
parent 705b8066ab
commit 8556d0c642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 5 deletions

View File

@ -1,4 +1,4 @@
import unittest
import gzip, unittest
from PIL import Image
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
@ -161,6 +161,24 @@ class TestFetch(unittest.TestCase):
assert pimg.size == (77, 77), pimg.size
assert img.parent.name == "images"
def test_fetch_gunzip_valid(self):
# compare fetch(gunzip=True) to fetch(gunzip=False) plus decompressing afterwards
gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.tar.gz'
fp_gz = fetch(gzip_url, gunzip=True)
fp_no_gz = fetch(gzip_url, gunzip=False)
with open(fp_gz, 'rb') as f: content_gz = f.read()
with open(fp_no_gz, 'rb') as f: content_no_gz = gzip.decompress(f.read())
assert fp_gz.stat().st_size > fp_no_gz.stat().st_size
assert isinstance(content_gz, bytes) and isinstance(content_no_gz, bytes)
assert len(content_gz) == len(content_no_gz)
assert content_gz == content_no_gz
def test_fetch_gunzip_invalid(self):
# given a non-gzipped file, fetch(gunzip=True) fails
no_gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.zip'
with self.assertRaises(gzip.BadGzipFile):
fetch(no_gzip_url, gunzip=True)
class TestFullyFlatten(unittest.TestCase):
def test_fully_flatten(self):
self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2])

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
import itertools, urllib.request, subprocess, shutil, math, json, contextvars
from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
@ -262,19 +262,22 @@ def diskcache(func):
# *** http support ***
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None,
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
if url.startswith(("/", ".")): return pathlib.Path(url)
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest())
else:
fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / \
((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
if not fp.is_file() or not allow_caching:
with urllib.request.urlopen(url, timeout=10) as r:
assert r.status == 200
total_length = int(r.headers.get('content-length', 0))
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
(path := fp.parent).mkdir(parents=True, exist_ok=True)
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
f.close()
progress_bar.update(close=True)
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")