The Wayback Machine - https://web.archive.org/web/20220212130938/https://github.com/pytorch/text/commit/325e235fccfefe68cb08637a55b100131f2e0394
Skip to content
Permalink
Browse files
removing dependence on iopath (#1381)
  • Loading branch information
parmeet committed Aug 24, 2021
1 parent e3799a6 commit 325e235fccfefe68cb08637a55b100131f2e0394
@@ -8,7 +8,6 @@ dependencies:
- dataclasses
- nltk
- requests
- iopath
- revtok
- pytest
- pytest-cov
@@ -3,13 +3,11 @@ channels:
dependencies:
- flake8>=3.7.9
- codecov
- pywin32
- pip
- pip:
- dataclasses
- nltk
- requests
- iopath
- revtok
- pytest
- pytest-cov
@@ -1,3 +1,2 @@
sphinx==2.4.4
iopath
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
@@ -180,7 +180,6 @@ setup_pip_pytorch_version() {
# You MUST have populated PYTORCH_VERSION_SUFFIX before hand.
setup_conda_pytorch_constraint() {
CONDA_CHANNEL_FLAGS=${CONDA_CHANNEL_FLAGS:-}
CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c iopath"
if [[ -z "$PYTORCH_VERSION" ]]; then
export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly"
export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | python -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['pytorch'][-1]['version']))")"
@@ -20,7 +20,6 @@ requirements:
run:
- python
- requests
- iopath
- tqdm
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}

@@ -3,7 +3,6 @@ tqdm

# Downloading data and other files
requests
iopath

# Optional NLP tools
nltk
@@ -1,18 +1,6 @@
from typing import List, Optional, Union, IO, Dict, Any
import requests
import os
import logging
import uuid
import re
import shutil
from tqdm import tqdm
from iopath.common.file_io import (
PathHandler,
PathManager,
get_cache_dir,
file_lock,
HTTPURLHandler,
)


def _stream_response(r, chunk_size=16 * 1024):
@@ -54,118 +42,16 @@ def _get_response_from_google_drive(url):
return response, filename


class GoogleDrivePathHandler(PathHandler):
"""
Download URLs and cache them to disk.
"""

MAX_FILENAME_LEN = 250

def __init__(self) -> None:
self.cache_map: Dict[str, str] = {}

def _get_supported_prefixes(self) -> List[str]:
return ["https://drive.google.com"]

def _get_local_path(
self,
path: str,
force: bool = False,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> str:
"""
This implementation downloads the remote resource from google drive and caches it locally.
The resource will only be downloaded if not previously requested.
"""
self._check_kwargs(kwargs)
if (
force
or path not in self.cache_map
or not os.path.exists(self.cache_map[path])
):
logger = logging.getLogger(__name__)
dirname = get_cache_dir(cache_dir)

response, filename = _get_response_from_google_drive(path)
if len(filename) > self.MAX_FILENAME_LEN:
filename = filename[:100] + "_" + uuid.uuid4().hex

cached = os.path.join(dirname, filename)
with file_lock(cached):
if not os.path.isfile(cached):
logger.info("Downloading {} ...".format(path))
with open(cached, 'wb') as f:
for data in _stream_response(response):
f.write(data)
logger.info("URL {} cached in {}".format(path, cached))
self.cache_map[path] = cached
return self.cache_map[path]

def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a google drive path. The resource is first downloaded and cached
locally.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): Not used for this PathHandler.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
assert mode in ("r", "rb"), "{} does not support open with {} mode".format(
self.__class__.__name__, mode
)
assert (
buffering == -1
), f"{self.__class__.__name__} does not support the `buffering` argument"
local_path = self._get_local_path(path, force=False)
return open(local_path, mode)


class CombinedInternalPathhandler(PathHandler):
def __init__(self):
path_manager = PathManager()
path_manager.register_handler(HTTPURLHandler())
path_manager.register_handler(GoogleDrivePathHandler())
self.path_manager = path_manager

def _get_supported_prefixes(self) -> List[str]:
return ["https://", "http://"]

def _get_local_path(
self,
path: str,
force: bool = False,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> str:

destination = kwargs["destination"]

local_path = self.path_manager.get_local_path(path, force)

shutil.move(local_path, destination)

return destination
class DownloadManager:
def get_local_path(self, url, destination):
if 'drive.google.com' not in url:
response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
else:
response, filename = _get_response_from_google_drive(url)

def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
self._check_kwargs(kwargs)
assert mode in ("r", "rb"), "{} does not support open with {} mode".format(
self.__class__.__name__, mode
)
assert (
buffering == -1
), f"{self.__class__.__name__} does not support the `buffering` argument"
local_path = self._get_local_path(path, force=False)
return open(local_path, mode)
with open(destination, 'wb') as f:
for chunk in _stream_response(response):
f.write(chunk)


_DATASET_DOWNLOAD_MANAGER = PathManager()
_DATASET_DOWNLOAD_MANAGER.register_handler(CombinedInternalPathhandler())
_DATASET_DOWNLOAD_MANAGER = DownloadManager()

0 comments on commit 325e235

Please sign in to comment.