diff --git a/streaming/base/format/base/writer.py b/streaming/base/format/base/writer.py index 7cc3add3d..f64f1b65a 100644 --- a/streaming/base/format/base/writer.py +++ b/streaming/base/format/base/writer.py @@ -64,6 +64,9 @@ class Writer(ABC): file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. retry (int): Number of times to retry uploading a file to a remote location. Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ format: str = '' # Name of the format (like "mds", "csv", "json", etc). @@ -100,7 +103,8 @@ def __init__(self, # Validate keyword arguments invalid_kwargs = [ - arg for arg in kwargs.keys() if arg not in ('progress_bar', 'max_workers', 'retry') + arg for arg in kwargs.keys() + if arg not in ('progress_bar', 'max_workers', 'retry', 'exist_ok') ] if invalid_kwargs: raise ValueError(f'Invalid Writer argument(s): {invalid_kwargs} ') @@ -120,6 +124,18 @@ def __init__(self, kwargs.get('retry', 2)) self.local = self.cloud_writer.local self.remote = self.cloud_writer.remote + + if os.path.exists(self.local) and len(os.listdir(self.local)) != 0: + if kwargs.get('exist_ok', False): + logger.warning(f'Directory {self.local} exists and not empty since you provided ' + + f'`exist_ok=True`.') + else: + raise FileExistsError(f'Directory is not empty: {self.local}. If you still want ' + + f'to use this directory without emptying the content, ' + + f'please provide `exist_ok=True`.') + # Create the local directory if it does not exist. + os.makedirs(self.local, exist_ok=True) + # `max_workers`: The maximum number of threads that can be executed in parallel. # One thread is responsible for uploading one shard file to a remote location. self.executor = ThreadPoolExecutor(max_workers=kwargs.get('max_workers', None)) @@ -380,6 +396,9 @@ class JointWriter(Writer): file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. retry (int): Number of times to retry uploading a file to a remote location. Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ def __init__(self, @@ -466,6 +485,9 @@ class SplitWriter(Writer): file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. retry (int): Number of times to retry uploading a file to a remote location. Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ extra_bytes_per_shard = 0 diff --git a/streaming/base/format/json/writer.py b/streaming/base/format/json/writer.py index aae9d1d28..d7a0ac6e3 100644 --- a/streaming/base/format/json/writer.py +++ b/streaming/base/format/json/writer.py @@ -45,6 +45,11 @@ class JSONWriter(SplitWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + retry (int): Number of times to retry uploading a file to a remote location. + Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ format = 'json' diff --git a/streaming/base/format/mds/writer.py b/streaming/base/format/mds/writer.py index e82fc02a8..ee2c0f78f 100644 --- a/streaming/base/format/mds/writer.py +++ b/streaming/base/format/mds/writer.py @@ -45,6 +45,11 @@ class MDSWriter(JointWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + retry (int): Number of times to retry uploading a file to a remote location. + Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ format = 'mds' diff --git a/streaming/base/format/xsv/writer.py b/streaming/base/format/xsv/writer.py index 2888597b2..b5cce2721 100644 --- a/streaming/base/format/xsv/writer.py +++ b/streaming/base/format/xsv/writer.py @@ -46,6 +46,11 @@ class XSVWriter(SplitWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + retry (int): Number of times to retry uploading a file to a remote location. + Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ format = 'xsv' @@ -164,6 +169,11 @@ class CSVWriter(XSVWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + retry (int): Number of times to retry uploading a file to a remote location. + Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ format = 'csv' @@ -230,6 +240,11 @@ class TSVWriter(XSVWriter): max_workers (int): Maximum number of threads used to upload output dataset files in parallel to a remote location. One thread is responsible for uploading one shard file to a remote location. Default to ``min(32, (os.cpu_count() or 1) + 4)``. + retry (int): Number of times to retry uploading a file to a remote location. + Default to ``2``. + exist_ok (bool): If the local directory exists and not empty, whether to overwrite + the content or raise an error. `False` raises an error. `True` overwrites the + content. Defaults to `False`. """ format = 'tsv' diff --git a/streaming/base/storage/upload.py b/streaming/base/storage/upload.py index dacb5747a..8457d5ac5 100644 --- a/streaming/base/storage/upload.py +++ b/streaming/base/storage/upload.py @@ -57,8 +57,7 @@ def get(cls, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> Any: + retry: int = 2) -> Any: """Instantiate a cloud provider uploader or a local uploader based on remote path. Args: @@ -75,8 +74,6 @@ def get(cls, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. Returns: CloudUploader: An instance of sub-class. @@ -89,8 +86,8 @@ def get(cls, prefix = os.path.join(path.parts[0], path.parts[1]) if prefix == 'dbfs:/Volumes': provider_prefix = prefix - return getattr(sys.modules[__name__], - UPLOADERS[provider_prefix])(out, keep_local, progress_bar, retry, exist_ok) + return getattr(sys.modules[__name__], UPLOADERS[provider_prefix])(out, keep_local, + progress_bar, retry) def _validate(self, out: Union[str, Tuple[str, str]]) -> None: """Validate the `out` argument. @@ -124,8 +121,7 @@ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: + retry: int = 2) -> None: """Initialize and validate local and remote path. Args: @@ -142,8 +138,6 @@ def __init__(self, progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. Raises: FileExistsError: Local directory must be empty. @@ -166,16 +160,6 @@ def __init__(self, self.local = out[0] self.remote = out[1] - if os.path.exists(self.local) and len(os.listdir(self.local)) != 0: - if not exist_ok: - raise FileExistsError(f'Directory is not empty: {self.local}') - else: - logger.warning( - f'Directory {self.local} exists and not empty. But continue to mkdir since exist_ok is set to be True.' - ) - - os.makedirs(self.local, exist_ok=True) - def upload_file(self, filename: str): """Upload file from local instance to remote instance. @@ -225,17 +209,14 @@ class S3Uploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) import boto3 from botocore.config import Config @@ -346,17 +327,14 @@ class GCSUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) if 'GCS_KEY' in os.environ and 'GCS_SECRET' in os.environ: import boto3 @@ -494,17 +472,14 @@ class OCIUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) import oci @@ -631,17 +606,14 @@ class AzureUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) from azure.storage.blob import BlobServiceClient @@ -719,17 +691,14 @@ class AzureDataLakeUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) from azure.storage.filedatalake import DataLakeServiceClient @@ -804,17 +773,14 @@ class DatabricksUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) self.client = self._create_workspace_client() def _create_workspace_client(self): @@ -843,17 +809,14 @@ class DatabricksUnityCatalogUploader(DatabricksUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) def upload_file(self, filename: str): """Upload file from local instance to Databricks Unity Catalog. @@ -892,17 +855,14 @@ class DBFSUploader(DatabricksUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) self.dbfs_path = self.remote.lstrip('dbfs:') # pyright: ignore self.check_folder_exists() @@ -962,17 +922,14 @@ class LocalUploader(CloudUploader): progress_bar (bool): Display TQDM progress bars for uploading output dataset files to a remote location. Default to ``False``. retry (int): Number of times to retry uploading a file. Defaults to ``2``. - exist_ok (bool): When exist_ok = False, raise error if the local part of ``out`` already - exists and has contents. Defaults to ``False``. """ def __init__(self, out: Union[str, Tuple[str, str]], keep_local: bool = False, progress_bar: bool = False, - retry: int = 2, - exist_ok: bool = False) -> None: - super().__init__(out, keep_local, progress_bar, retry, exist_ok) + retry: int = 2) -> None: + super().__init__(out, keep_local, progress_bar, retry) # Create remote directory if it doesn't exist if self.remote: os.makedirs(self.remote, exist_ok=True) diff --git a/streaming/base/util.py b/streaming/base/util.py index e86876ee1..d92dd61e4 100644 --- a/streaming/base/util.py +++ b/streaming/base/util.py @@ -284,7 +284,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], # This is the index json file name, e.g., it is index.json as of 0.6.0 index_basename = get_index_basename() - cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + cu = CloudUploader.get(out, keep_local=True) # Remove duplicates, and strip '/' from right if any index_file_urls = list(OrderedDict.fromkeys(index_file_urls)) @@ -297,7 +297,7 @@ def _merge_index_from_list(index_file_urls: List[Union[str, Tuple[str, str]]], # Prepare a temp folder to download index.json from remote if necessary. Removed in the end. with tempfile.TemporaryDirectory() as temp_root: - logging.warning(f'A temporary folder {temp_root} is created to store index files') + logging.debug(f'A temporary folder {temp_root} is created to store index files') # Copy files to a temporary directory. Download if necessary partitions = [] @@ -394,10 +394,10 @@ def not_merged_index(index_file_path: str, out: str): logger.warning('No MDS dataset folder specified, no index merged') return - cu = CloudUploader.get(out, exist_ok=True, keep_local=True) + cu = CloudUploader.get(out, keep_local=True) local_index_files = [] - cl = CloudUploader.get(cu.local, exist_ok=True, keep_local=True) + cl = CloudUploader.get(cu.local, keep_local=True) for file in cl.list_objects(): if file.endswith('.json') and not_merged_index(file, cu.local): local_index_files.append(file) diff --git a/tests/test_upload.py b/tests/test_upload.py index f7406cf1d..a50489b3e 100644 --- a/tests/test_upload.py +++ b/tests/test_upload.py @@ -80,23 +80,9 @@ def test_invalid_out_parameter_type(self, out: Any): with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = CloudUploader.get(out=out) - def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): - with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): - local, _ = local_remote_dir - os.makedirs(local, exist_ok=True) - local_file_path = os.path.join(local, 'file.txt') - # Creating an empty file at specified location - with open(local_file_path, 'w') as _: - pass - _ = CloudUploader.get(out=local) - - def test_local_directory_is_created(self, local_remote_dir: Tuple[str, str]): - local, _ = local_remote_dir - _ = CloudUploader(out=local) - assert os.path.exists(local) - def test_delete_local_file(self, local_remote_dir: Tuple[str, str]): local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) local_file_path = os.path.join(local, 'file.txt') cw = CloudUploader.get(out=local) # Creating an empty file at specified location @@ -117,7 +103,7 @@ def test_check_bucket_exists_exception(self, out: str): def test_list_objects_from_local_gets_called(self, mocked_requests: Mock, remote_local_dir: Any): mock_remote_dir, _ = remote_local_dir() - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu = CloudUploader.get(mock_remote_dir, keep_local=True) cu.list_objects() mocked_requests.assert_called_once() @@ -142,21 +128,12 @@ def test_invalid_remote_list(self, out: Any): with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = S3Uploader(out=out) - def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): - with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): - local, _ = local_remote_dir - os.makedirs(local, exist_ok=True) - local_file_path = os.path.join(local, 'file.txt') - # Creating an empty file at specified location - with open(local_file_path, 'w') as _: - pass - _ = S3Uploader(out=local) - @pytest.mark.usefixtures('s3_client', 's3_test') def test_upload_file(self, local_remote_dir: Tuple[str, str]): with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: filename = tmp.name.split(os.sep)[-1] local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) remote = 's3://streaming-test-bucket/path' local_file_path = os.path.join(local, filename) s3w = S3Uploader(out=(local, remote)) @@ -170,6 +147,7 @@ def test_upload_file_to_r2(self, local_remote_dir: Tuple[str, str]): with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: filename = tmp.name.split(os.sep)[-1] local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) remote = 's3://streaming-test-bucket/path' local_file_path = os.path.join(local, filename) s3w = S3Uploader(out=(local, remote)) @@ -194,14 +172,14 @@ def test_list_objects_from_s3(self, remote_local_dir: Any): client = boto3.client('s3', region_name='us-east-1') client.put_object(Bucket=MY_BUCKET, Key=os.path.join(MY_PREFIX, file_name), Body='') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu = CloudUploader.get(mock_remote_dir, keep_local=True) objs = cu.list_objects(mock_remote_dir) assert isinstance(objs, list) @pytest.mark.usefixtures('s3_client', 's3_test', 'remote_local_dir') def test_clienterror_exception(self, remote_local_dir: Any): mock_remote_dir, _ = remote_local_dir(cloud_prefix='s3://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu = CloudUploader.get(mock_remote_dir, keep_local=True) objs = cu.list_objects() if objs: assert (len(objs) == 0) @@ -210,7 +188,7 @@ def test_clienterror_exception(self, remote_local_dir: Any): def test_invalid_cloud_prefix(self, remote_local_dir: Any): with pytest.raises(ValueError): mock_remote_dir, _ = remote_local_dir(cloud_prefix='s9://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu = CloudUploader.get(mock_remote_dir, keep_local=True) _ = cu.list_objects() @pytest.mark.usefixtures('s3_client', 's3_test') @@ -218,6 +196,7 @@ def test_extra_args(self, local_remote_dir: Tuple[str, str]): with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: filename = tmp.name.split(os.sep)[-1] local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) remote = 's3://streaming-test-bucket/path' local_file_path = os.path.join(local, filename) s3w = S3Uploader(out=(local, remote)) @@ -274,22 +253,12 @@ def test_invalid_remote_list(self, out: Any): with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = GCSUploader(out=out) - @pytest.mark.usefixtures('gcs_hmac_credentials') - def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): - with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): - local, _ = local_remote_dir - os.makedirs(local, exist_ok=True) - local_file_path = os.path.join(local, 'file.txt') - # Creating an empty file at specified location - with open(local_file_path, 'w') as _: - pass - _ = GCSUploader(out=local) - @pytest.mark.usefixtures('gcs_hmac_client', 'gcs_test') def test_upload_file(self, local_remote_dir: Tuple[str, str]): with tempfile.NamedTemporaryFile(delete=True, suffix='.txt') as tmp: filename = tmp.name.split(os.sep)[-1] local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) remote = 'gs://streaming-test-bucket/path' local_file_path = os.path.join(local, filename) gcsw = GCSUploader(out=(local, remote)) @@ -349,14 +318,14 @@ def test_no_authentication(self, out: str): def test_invalid_cloud_prefix(self, remote_local_dir: Any): with pytest.raises(ValueError): mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs9://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu = CloudUploader.get(mock_remote_dir, keep_local=True) _ = cu.list_objects() def test_no_credentials_error(self, remote_local_dir: Any): """Ensure we raise a value error correctly if we have no credentials available.""" with pytest.raises(ValueError): mock_remote_dir, _ = remote_local_dir(cloud_prefix='gs://') - cu = CloudUploader.get(mock_remote_dir, exist_ok=True, keep_local=True) + cu = CloudUploader.get(mock_remote_dir, keep_local=True) _ = cu.list_objects() @@ -381,16 +350,6 @@ def test_invalid_remote_list(self, out: Any): with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = AzureUploader(out=out) - def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): - with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): - local, _ = local_remote_dir - os.makedirs(local, exist_ok=True) - local_file_path = os.path.join(local, 'file.txt') - # Creating an empty file at specified location - with open(local_file_path, 'w') as _: - pass - _ = AzureUploader(out=local) - class TestAzureDataLakeUploader: @@ -414,16 +373,6 @@ def test_invalid_remote_list(self, out: Any): with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = AzureDataLakeUploader(out=out) - def test_local_directory_is_empty(self, local_remote_dir: Tuple[str, str]): - with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): - local, _ = local_remote_dir - os.makedirs(local, exist_ok=True) - local_file_path = os.path.join(local, 'file.txt') - # Creating an empty file at specified location - with open(local_file_path, 'w') as _: - pass - _ = AzureDataLakeUploader(out=local) - class TestDatabricksUnityCatalogUploader: @@ -443,19 +392,6 @@ def test_invalid_remote_list(self, mock_create_client: Mock, out: Any): with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = DatabricksUnityCatalogUploader(out=out) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') - def test_local_directory_is_empty(self, mock_create_client: Mock, - local_remote_dir: Tuple[str, str]): - mock_create_client.side_effect = None - with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): - local, _ = local_remote_dir - os.makedirs(local, exist_ok=True) - local_file_path = os.path.join(local, 'file.txt') - # Creating an empty file at specified location - with open(local_file_path, 'w') as _: - pass - _ = DatabricksUnityCatalogUploader(out=local) - class TestDBFSUploader: @@ -474,24 +410,12 @@ def test_invalid_remote_list(self, mock_create_client: Mock, out: Any): with pytest.raises(ValueError, match=f'Invalid Cloud provider prefix.*'): _ = DBFSUploader(out=out) - @patch('streaming.base.storage.upload.DatabricksUploader._create_workspace_client') - def test_local_directory_is_empty(self, mock_create_client: Mock, - local_remote_dir: Tuple[str, str]): - with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): - mock_create_client.side_effect = None - local, _ = local_remote_dir - os.makedirs(local, exist_ok=True) - local_file_path = os.path.join(local, 'file.txt') - # Creating an empty file at specified location - with open(local_file_path, 'w') as _: - pass - _ = DBFSUploader(out=local) - class TestLocalUploader: def test_upload_file(self, local_remote_dir: Tuple[str, str]): local, remote = local_remote_dir + os.makedirs(local, exist_ok=True) filename = 'file.txt' local_file_path = os.path.join(local, filename) remote_file_path = os.path.join(remote, filename) @@ -511,6 +435,7 @@ def test_instantiation_remote_none(self, local_remote_dir: Tuple[str, str]): def test_upload_file_remote_none(self, local_remote_dir: Tuple[str, str]): local, remote = local_remote_dir + os.makedirs(local, exist_ok=True) filename = 'file.txt' local_file_path = os.path.join(local, filename) remote_file_path = os.path.join(remote, filename) @@ -523,6 +448,7 @@ def test_upload_file_remote_none(self, local_remote_dir: Tuple[str, str]): def test_upload_file_from_local_to_remote(self, local_remote_dir: Tuple[str, str]): local, remote = local_remote_dir + os.makedirs(local, exist_ok=True) filename = 'file.txt' local_file_path = os.path.join(local, filename) remote_file_path = os.path.join(remote, filename) diff --git a/tests/test_util.py b/tests/test_util.py index e59f75911..ec886379f 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -137,13 +137,13 @@ def integrity_check(out: Union[str, Tuple[str, str]], def get_expected(mds_root: str): n_shard_files = 0 - cu = CloudUploader.get(mds_root, exist_ok=True, keep_local=True) + cu = CloudUploader.get(mds_root, keep_local=True) for o in cu.list_objects(): if o.endswith('.mds'): n_shard_files += 1 return n_shard_files - cu = CloudUploader.get(out, keep_local=True, exist_ok=True) + cu = CloudUploader.get(out, keep_local=True) with tempfile.TemporaryDirectory() as temp_dir: if cu.remote: @@ -210,7 +210,7 @@ def not_merged_index(index_file_path: str, out: str): mds_kwargs = {'out': mds_out, 'columns': {'id': 'int', 'name': 'str'}, 'keep_local': True} dataframeToMDS(df, merge_index=False, mds_kwargs=mds_kwargs) - local_cu = CloudUploader.get(local, exist_ok=True, keep_local=True) + local_cu = CloudUploader.get(local, keep_local=True) local_index_files = [ o for o in local_cu.list_objects() if o.endswith('.json') and not_merged_index(o, local) ] diff --git a/tests/test_writer.py b/tests/test_writer.py index 188a6b40b..f0d215bc6 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -256,3 +256,28 @@ def test_dataset_iter_determinism(self, local_remote_dir: Tuple[str, str], num_s # Ensure sample iterator is deterministic for before, after in zip(dataset, mds_dataset): assert before == after + + +def test_dir_exist_and_non_empty(local_remote_dir: Tuple[str, str]): + local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) + local_file_path = os.path.join(local, 'file.txt') + # Creating an empty file at specified location + with open(local_file_path, 'w') as _: + pass + columns = {'tokens': 'bytes'} + with pytest.raises(FileExistsError, match=f'Directory is not empty.*'): + _ = MDSWriter(out=local, columns=columns) + + +def test_dir_exist_and_non_empty_and_overwrite(caplog: Any, local_remote_dir: Tuple[str, str]): + caplog.set_level(logging.WARNING) + local, _ = local_remote_dir + os.makedirs(local, exist_ok=True) + local_file_path = os.path.join(local, 'file.txt') + # Creating an empty file at specified location + with open(local_file_path, 'w') as _: + pass + columns = {'tokens': 'bytes'} + _ = MDSWriter(out=local, columns=columns, exist_ok=True) + assert 'exists and not empty since you' in caplog.text