Support `gunzip` in `fetch` (#6176)

* init

* update

* clean

* add type

* clean

* fix import order

* shorten variable names
This commit is contained in:
Eitan Turok 2024-08-19 15:04:40 -04:00 committed by GitHub
parent 705b8066ab
commit 8556d0c642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 5 deletions

View File

@ -1,4 +1,4 @@
import unittest import gzip, unittest
from PIL import Image from PIL import Image
from tinygrad.helpers import Context, ContextVar from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, get_contraction, get_shape
@ -161,6 +161,24 @@ class TestFetch(unittest.TestCase):
assert pimg.size == (77, 77), pimg.size assert pimg.size == (77, 77), pimg.size
assert img.parent.name == "images" assert img.parent.name == "images"
def test_fetch_gunzip_valid(self):
# compare fetch(gunzip=True) to fetch(gunzip=False) plus decompressing afterwards
gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.tar.gz'
fp_gz = fetch(gzip_url, gunzip=True)
fp_no_gz = fetch(gzip_url, gunzip=False)
with open(fp_gz, 'rb') as f: content_gz = f.read()
with open(fp_no_gz, 'rb') as f: content_no_gz = gzip.decompress(f.read())
assert fp_gz.stat().st_size > fp_no_gz.stat().st_size
assert isinstance(content_gz, bytes) and isinstance(content_no_gz, bytes)
assert len(content_gz) == len(content_no_gz)
assert content_gz == content_no_gz
def test_fetch_gunzip_invalid(self):
# given a non-gzipped file, fetch(gunzip=True) fails
no_gzip_url: str = 'https://ftp.gnu.org/gnu/gzip/gzip-1.13.zip'
with self.assertRaises(gzip.BadGzipFile):
fetch(no_gzip_url, gunzip=True)
class TestFullyFlatten(unittest.TestCase): class TestFullyFlatten(unittest.TestCase):
def test_fully_flatten(self): def test_fully_flatten(self):
self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2]) self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2])

View File

@ -1,5 +1,5 @@
from __future__ import annotations from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip
import itertools, urllib.request, subprocess, shutil, math, json, contextvars import itertools, urllib.request, subprocess, shutil, math, json, contextvars
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence from typing import Dict, Tuple, Union, List, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable, Sequence
@ -262,19 +262,22 @@ def diskcache(func):
# *** http support *** # *** http support ***
def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, subdir:Optional[str]=None, gunzip:bool=False,
allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
if url.startswith(("/", ".")): return pathlib.Path(url) if url.startswith(("/", ".")): return pathlib.Path(url)
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
else: fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / (name or hashlib.md5(url.encode('utf-8')).hexdigest()) else:
fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (subdir or "") / \
((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
if not fp.is_file() or not allow_caching: if not fp.is_file() or not allow_caching:
with urllib.request.urlopen(url, timeout=10) as r: with urllib.request.urlopen(url, timeout=10) as r:
assert r.status == 200 assert r.status == 200
total_length = int(r.headers.get('content-length', 0)) total_length = int(r.headers.get('content-length', 0))
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI) progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
(path := fp.parent).mkdir(parents=True, exist_ok=True) (path := fp.parent).mkdir(parents=True, exist_ok=True)
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
while chunk := r.read(16384): progress_bar.update(f.write(chunk)) while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
f.close() f.close()
progress_bar.update(close=True) progress_bar.update(close=True)
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}") if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")