diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index b6ca20f5b..11b726623 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -238,6 +238,14 @@ class Config: # is hit first will stop the retry loop. experimental_files_ext_cloud_api_max_retries: int = 3 + # Storage-proxy hostname used for data plane file operations. + # When running inside a Databricks cluster/notebook, the storage proxy can handle + # file operations directly without presigned URLs. + files_ext_storage_proxy_hostname: str = "http://storage-proxy.databricks.com" + + # Timeout in seconds for the storage-proxy health check probe. + files_ext_storage_proxy_probe_timeout: float = 3.0 + def __init__( self, *, diff --git a/databricks/sdk/mixins/files.py b/databricks/sdk/mixins/files.py index 35fc841a0..35d4848b8 100644 --- a/databricks/sdk/mixins/files.py +++ b/databricks/sdk/mixins/files.py @@ -2,6 +2,7 @@ import base64 import datetime +import json import logging import math import os @@ -748,6 +749,156 @@ class DownloadFileResult: """Result of a download to file operation. Currently empty, but can be extended in the future.""" +@dataclass +class _UploadUrl: + """Result of URL resolution for an upload operation.""" + + url: str + headers: dict[str, str] + session: requests.Session + + +class _RequestBuilder(ABC): + """Builds HTTP requests for file upload operations. + + Implementations encapsulate the difference between direct storage-proxy + requests and presigned URL requests. + """ + + @abstractmethod + def build_upload_part(self, path: str, session_token: str, part_number: int, expire_time: str) -> _UploadUrl: + """Build a request for uploading one multipart part.""" + ... + + @abstractmethod + def build_resumable_upload_url(self, path: str, session_token: str) -> _UploadUrl: + """Build a request for resumable upload chunks.""" + ... + + @abstractmethod + def build_abort_url(self, path: str, session_token: str, expire_time: str) -> _UploadUrl: + """Build a request for aborting an upload.""" + ... + + +class _StorageProxyRequestBuilder(_RequestBuilder): + """Builds direct requests to the storage proxy. + + All URLs point directly to the proxy hostname. No presigned URL API calls + are needed. Uses an SDK-authenticated session. + """ + + def __init__(self, hostname: str, session: requests.Session): + self._hostname = hostname + self._session = session + + def build_upload_part(self, path: str, session_token: str, part_number: int, expire_time: str) -> _UploadUrl: + escaped = _escape_multi_segment_path_parameter(path) + base = f"{self._hostname}/api/2.0/fs/files{escaped}" + query = parse.urlencode({"session_token": session_token, "upload_type": "multipart", "part_number": part_number}) + return _UploadUrl( + url=f"{base}?{query}", + headers={"Content-Type": "application/octet-stream"}, + session=self._session, + ) + + def build_resumable_upload_url(self, path: str, session_token: str) -> _UploadUrl: + escaped = _escape_multi_segment_path_parameter(path) + base = f"{self._hostname}/api/2.0/fs/files{escaped}" + query = parse.urlencode({"session_token": session_token, "upload_type": "resumable"}) + return _UploadUrl( + url=f"{base}?{query}", + headers={"Content-Type": "application/octet-stream"}, + session=self._session, + ) + + def build_abort_url(self, path: str, session_token: str, expire_time: str) -> _UploadUrl: + escaped = _escape_multi_segment_path_parameter(path) + base = f"{self._hostname}/api/2.0/fs/files{escaped}" + query = parse.urlencode({"action": "abort-upload", "session_token": session_token}) + return _UploadUrl( + url=f"{base}?{query}", + headers={"Content-Type": "application/json"}, + session=self._session, + ) + + +class _PresignedUrlRequestBuilder(_RequestBuilder): + """Builds requests using presigned URLs from GIG. + + Coordination API calls (create-upload-part-urls, create-resumable-upload-url, + etc.) go to GIG. Data transfer goes to cloud storage via presigned URLs + using an unauthenticated session. + """ + + def __init__(self, api, cloud_session: requests.Session): + self._api = api + self._cloud_session = cloud_session + + def build_upload_part(self, path: str, session_token: str, part_number: int, expire_time: str) -> _UploadUrl: + body = { + "path": path, + "session_token": session_token, + "start_part_number": part_number, + "count": 1, + "expire_time": expire_time, + } + response = self._api.do( + "POST", + "/api/2.0/fs/create-upload-part-urls", + headers={"Content-Type": "application/json"}, + body=body, + ) + upload_part_urls = response.get("upload_part_urls", []) + if len(upload_part_urls) == 0: + raise ValueError(f"Unexpected server response: {response}") + part_info = upload_part_urls[0] + # Validate required fields. Accessing via [] ensures KeyError on missing keys. + url = part_info["url"] + _ = part_info["part_number"] + headers: dict[str, str] = {"Content-Type": "application/octet-stream"} + for h in part_info.get("headers", []): + headers[h["name"]] = h["value"] + return _UploadUrl(url=url, headers=headers, session=self._cloud_session) + + def build_resumable_upload_url(self, path: str, session_token: str) -> _UploadUrl: + body = {"path": path, "session_token": session_token} + response = self._api.do( + "POST", + "/api/2.0/fs/create-resumable-upload-url", + headers={"Content-Type": "application/json"}, + body=body, + ) + url_node = response.get("resumable_upload_url") + if not url_node: + raise ValueError(f"Unexpected server response: {response}") + url = url_node.get("url") + if not url: + raise ValueError(f"Unexpected server response: {response}") + headers: dict[str, str] = {"Content-Type": "application/octet-stream"} + for h in url_node.get("headers", []): + headers[h["name"]] = h["value"] + return _UploadUrl(url=url, headers=headers, session=self._cloud_session) + + def build_abort_url(self, path: str, session_token: str, expire_time: str) -> _UploadUrl: + body = { + "path": path, + "session_token": session_token, + "expire_time": expire_time, + } + response = self._api.do( + "POST", + "/api/2.0/fs/create-abort-upload-url", + headers={"Content-Type": "application/json"}, + body=body, + ) + abort_node = response["abort_upload_url"] + headers: dict[str, str] = {"Content-Type": "application/octet-stream"} + for h in abort_node.get("headers", []): + headers[h["name"]] = h["value"] + return _UploadUrl(url=abort_node["url"], headers=headers, session=self._cloud_session) + + class FilesExt(files.FilesAPI): __doc__ = files.FilesAPI.__doc__ @@ -777,6 +928,109 @@ def __init__(self, api_client, config: Config): super().__init__(api_client) self._config = config.copy() self._multipart_upload_read_ahead_bytes = 1 + self._dp_hostname_available: Optional[bool] = None + self._storage_proxy_session: Optional[requests.Session] = None + + def _probe_storage_proxy(self) -> bool: + """Probes the storage proxy to check if it is reachable. + + Makes a GET request to the probe endpoint using SDK auth. The result + is cached in self._dp_hostname_available after the first call. + """ + proxy_host = self._config.files_ext_storage_proxy_hostname + probe_url = f"{proxy_host}/api/2.0/fs/files/DatabricksInternal/Probes/ping" + try: + headers = self._config.authenticate() + session = self._create_cloud_provider_session() + response = session.request( + "GET", + probe_url, + headers=headers, + timeout=self._config.files_ext_storage_proxy_probe_timeout, + ) + return response.status_code == 200 + except Exception: + return False + + def _create_storage_proxy_session(self) -> requests.Session: + """Creates an HTTP session with SDK auth for storage-proxy requests. + + Unlike _create_cloud_provider_session (which has no auth), this session + includes the SDK authentication callback so that every request to the + storage proxy carries valid credentials. + """ + session = requests.Session() + config = self._config + + def authenticate(r: requests.PreparedRequest) -> requests.PreparedRequest: + auth_headers = config.authenticate() + r.headers.update(auth_headers) + return r + + session.auth = authenticate + http_adapter = requests.adapters.HTTPAdapter( + config.max_connection_pools or 20, + config.max_connections_per_pool or 20, + pool_block=True, + ) + session.mount("https://", http_adapter) + session.mount("http://", http_adapter) + return session + + def _get_hostname(self) -> tuple[str, _RequestBuilder]: + """Returns the optimal hostname and request builder for file operations. + + Probes the storage proxy on first call and caches the result. If the + proxy is reachable, returns the proxy hostname with a direct request + builder. Otherwise returns the workspace hostname with a presigned URL + request builder. + """ + if self._dp_hostname_available is None: + self._dp_hostname_available = self._probe_storage_proxy() + if self._dp_hostname_available: + self._storage_proxy_session = self._create_storage_proxy_session() + _LOG.info("Storage proxy is available, will use it for file operations.") + else: + _LOG.info("Storage proxy is not available, will use presigned URLs.") + + if self._dp_hostname_available: + builder = _StorageProxyRequestBuilder( + self._config.files_ext_storage_proxy_hostname, + self._storage_proxy_session, + ) + return self._config.files_ext_storage_proxy_hostname, builder + else: + cloud_session = self._create_cloud_provider_session() + builder = _PresignedUrlRequestBuilder(self._api, cloud_session) + return self._config.host, builder + + def _get_url(self, purpose: str, url_context: dict) -> _UploadUrl: + """Creates a URL for a file upload operation. + + Delegates to the appropriate request builder based on whether the + storage proxy is available. + """ + _hostname, builder = self._get_hostname() + if purpose == "uploadPart": + return builder.build_upload_part( + path=url_context["path"], + session_token=url_context["session_token"], + part_number=url_context["part_number"], + expire_time=url_context.get("expire_time", ""), + ) + elif purpose == "resumableUrl": + return builder.build_resumable_upload_url( + path=url_context["path"], + session_token=url_context["session_token"], + ) + elif purpose == "abort": + return builder.build_abort_url( + path=url_context["path"], + session_token=url_context["session_token"], + expire_time=url_context.get("expire_time", ""), + ) + else: + raise ValueError(f"Unknown purpose: {purpose}") def download( self, @@ -1242,20 +1496,26 @@ def _upload_single_thread_with_known_size(self, ctx: _UploadContext, contents: B return self._single_thread_multipart_upload(ctx, contents) def _single_thread_single_shot_upload(self, ctx: _UploadContext, contents: BinaryIO) -> None: - """Upload a file with a known size.""" + """Upload a file using a single PUT request.""" _LOG.debug(f"Using single-shot upload for input stream") - return super().upload(file_path=ctx.target_path, contents=contents, overwrite=ctx.overwrite) + hostname, _ = self._get_hostname() + escaped = _escape_multi_segment_path_parameter(ctx.target_path) + query: dict = {} + if ctx.overwrite is not None: + query["overwrite"] = ctx.overwrite + headers = {"Content-Type": "application/octet-stream"} + self._api.do("PUT", url=f"{hostname}/api/2.0/fs/files{escaped}", query=query, headers=headers, data=contents) def _initiate_multipart_upload(self, ctx: _UploadContext) -> dict: """Initiate a multipart upload and return the response.""" - query = {"action": "initiate-upload"} + hostname, _ = self._get_hostname() + query: dict = {"action": "initiate-upload"} if ctx.overwrite is not None: query["overwrite"] = ctx.overwrite + escaped = _escape_multi_segment_path_parameter(ctx.target_path) # Method _api.do() takes care of retrying and will raise an exception in case of failure. - initiate_upload_response = self._api.do( - "POST", f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}", query=query - ) + initiate_upload_response = self._api.do("POST", url=f"{hostname}/api/2.0/fs/files{escaped}", query=query) return initiate_upload_response def _single_thread_multipart_upload(self, ctx: _UploadContext, contents: BinaryIO) -> None: @@ -1272,16 +1532,15 @@ def _single_thread_multipart_upload(self, ctx: _UploadContext, contents: BinaryI initiate_upload_response = self._initiate_multipart_upload(ctx) if initiate_upload_response.get("multipart_upload"): - cloud_provider_session = self._create_cloud_provider_session() session_token = initiate_upload_response["multipart_upload"].get("session_token") if not session_token: raise ValueError(f"Unexpected server response: {initiate_upload_response}") try: - self._perform_multipart_upload(ctx, contents, session_token, pre_read_buffer, cloud_provider_session) + self._perform_multipart_upload(ctx, contents, session_token, pre_read_buffer) except FallbackToUploadUsingFilesApi as e: try: - self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + self._abort_multipart_upload(ctx, session_token) except BaseException as ex: # Ignore abort exceptions as it is a best-effort. _LOG.warning(f"Failed to abort upload: {ex}") @@ -1294,7 +1553,7 @@ def _single_thread_multipart_upload(self, ctx: _UploadContext, contents: BinaryI except Exception as e: _LOG.info(f"Aborting multipart upload on error: {e}") try: - self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + self._abort_multipart_upload(ctx, session_token) except BaseException as ex: # Ignore abort exceptions as it is a best-effort. _LOG.warning(f"Failed to abort upload: {ex}") @@ -1303,11 +1562,10 @@ def _single_thread_multipart_upload(self, ctx: _UploadContext, contents: BinaryI raise e from None elif initiate_upload_response.get("resumable_upload"): - cloud_provider_session = self._create_cloud_provider_session() session_token = initiate_upload_response["resumable_upload"]["session_token"] try: - self._perform_resumable_upload(ctx, contents, session_token, pre_read_buffer, cloud_provider_session) + self._perform_resumable_upload(ctx, contents, session_token, pre_read_buffer) except FallbackToUploadUsingFilesApi as e: _LOG.info(f"Falling back to single-shot upload with Files API: {e}") # Concatenate the buffered part and the rest of the stream. @@ -1329,14 +1587,13 @@ def _parallel_upload_from_stream(self, ctx: _UploadContext, contents: BinaryIO) return self._single_thread_multipart_upload(ctx, contents) elif initiate_upload_response.get("multipart_upload"): session_token = initiate_upload_response["multipart_upload"].get("session_token") - cloud_provider_session = self._create_cloud_provider_session() if not session_token: raise ValueError(f"Unexpected server response: {initiate_upload_response}") try: - self._parallel_multipart_upload_from_stream(ctx, session_token, contents, cloud_provider_session) + self._parallel_multipart_upload_from_stream(ctx, session_token, contents) except FallbackToUploadUsingFilesApi as e: try: - self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + self._abort_multipart_upload(ctx, session_token) except Exception as abort_ex: _LOG.warning(f"Failed to abort upload: {abort_ex}") _LOG.info(f"Falling back to single-shot upload with Files API: {e}") @@ -1346,7 +1603,7 @@ def _parallel_upload_from_stream(self, ctx: _UploadContext, contents: BinaryIO) except Exception as e: _LOG.info(f"Aborting multipart upload on error: {e}") try: - self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + self._abort_multipart_upload(ctx, session_token) except Exception as abort_ex: _LOG.warning(f"Failed to abort upload: {abort_ex}") finally: @@ -1368,7 +1625,6 @@ def _parallel_upload_from_file( initiate_upload_response = self._initiate_multipart_upload(ctx) if initiate_upload_response.get("multipart_upload"): - cloud_provider_session = self._create_cloud_provider_session() session_token = initiate_upload_response["multipart_upload"].get("session_token") if not session_token: raise ValueError(f"Unexpected server response: {initiate_upload_response}") @@ -1376,7 +1632,7 @@ def _parallel_upload_from_file( self._parallel_multipart_upload_from_file(ctx, session_token) except FallbackToUploadUsingFilesApi as e: try: - self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + self._abort_multipart_upload(ctx, session_token) except Exception as abort_ex: _LOG.warning(f"Failed to abort upload: {abort_ex}") @@ -1388,7 +1644,7 @@ def _parallel_upload_from_file( except Exception as e: _LOG.info(f"Aborting multipart upload on error: {e}") try: - self._abort_multipart_upload(ctx, session_token, cloud_provider_session) + self._abort_multipart_upload(ctx, session_token) except Exception as abort_ex: _LOG.warning(f"Failed to abort upload: {abort_ex}") finally: @@ -1420,8 +1676,6 @@ def _parallel_multipart_upload_from_file( part_size = ctx.part_size num_parts = (file_size + part_size - 1) // part_size _LOG.debug(f"Uploading file of size {file_size} bytes in {num_parts} parts using {ctx.parallelism} threads") - cloud_provider_session = self._create_cloud_provider_session() - # Upload one part to verify the upload can proceed. with open(ctx.source_file_path, "rb") as f: f.seek(0) @@ -1430,7 +1684,6 @@ def _parallel_multipart_upload_from_file( try: etag = self._do_upload_one_part( ctx, - cloud_provider_session, 1, 0, first_part_size, @@ -1453,7 +1706,7 @@ def _parallel_multipart_upload_from_file( workers = [ Thread( target=self._upload_file_consumer, - args=(cloud_provider_session, task_queue, etags_result_queue, exception_queue, aborted), + args=(task_queue, etags_result_queue, exception_queue, aborted), ) for _ in range(ctx.parallelism) ] @@ -1492,7 +1745,6 @@ def _parallel_multipart_upload_from_stream( ctx: _UploadContext, session_token: str, content: BinaryIO, - cloud_provider_session: requests.Session, ) -> None: task_queue = Queue(maxsize=ctx.parallelism) # Limit queue size to control memory usage @@ -1510,7 +1762,6 @@ def _parallel_multipart_upload_from_stream( try: etag = self._do_upload_one_part( ctx, - cloud_provider_session, 1, 0, len(pre_read_buffer), @@ -1583,25 +1834,21 @@ def producer() -> None: self._complete_multipart_upload(ctx, etags, session_token) def _complete_multipart_upload(self, ctx, etags, session_token): + hostname, _ = self._get_hostname() + escaped = _escape_multi_segment_path_parameter(ctx.target_path) query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token} headers = {"Content-Type": "application/json"} - body: dict = {} - parts = [] - for part_number, etag in sorted(etags.items()): - part = {"part_number": part_number, "etag": etag} - parts.append(part) - body["parts"] = parts + parts = [{"part_number": pn, "etag": et} for pn, et in sorted(etags.items())] self._api.do( "POST", - f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}", + url=f"{hostname}/api/2.0/fs/files{escaped}", query=query, headers=headers, - body=body, + body={"parts": parts}, ) def _upload_file_consumer( self, - cloud_provider_session: requests.Session, task_queue: Queue[FilesExt._MultipartUploadPart], etags_queue: Queue[tuple[int, str]], exception_queue: Queue[Exception], @@ -1620,7 +1867,6 @@ def _upload_file_consumer( part_content = BytesIO(f.read(part.part_size)) etag = self._do_upload_one_part( part.ctx, - cloud_provider_session, part.part_index, part.part_offset, part.part_size, @@ -1642,7 +1888,6 @@ def _upload_stream_consumer( all_produced: Event, aborted: Event, ) -> None: - cloud_provider_session = self._create_cloud_provider_session() while not aborted.is_set(): try: part, content = task_queue.get(block=False, timeout=0.1) @@ -1654,7 +1899,6 @@ def _upload_stream_consumer( try: etag = self._do_upload_one_part( part.ctx, - cloud_provider_session, part.part_index, part.part_offset, part.part_size, @@ -1671,7 +1915,6 @@ def _upload_stream_consumer( def _do_upload_one_part( self, ctx: _UploadContext, - cloud_provider_session: requests.Session, part_index: int, part_offset: int, part_size: int, @@ -1683,22 +1926,31 @@ def _do_upload_one_part( # Try to upload the part, retrying if the upload URL expires. while True: - body: dict = { - "path": ctx.target_path, - "session_token": session_token, - "start_part_number": part_index, - "count": 1, - "expire_time": self._get_upload_url_expire_time(), - } - - headers = {"Content-Type": "application/json"} - - # Requesting URLs for the same set of parts is an idempotent operation and is safe to retry. try: - # The _api.do() method handles retries and will raise an exception in case of failure. - upload_part_urls_response = self._api.do( - "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body + upload_info = self._get_url( + "uploadPart", + { + "path": ctx.target_path, + "session_token": session_token, + "part_number": part_index, + "expire_time": self._get_upload_url_expire_time(), + }, ) + except (KeyError, ValueError) as e: + # JSONDecodeError is a subclass of ValueError; it indicates a server + # communication failure that should be eligible for fallback, not a + # response structure problem. + if isinstance(e, json.JSONDecodeError): + if is_first_part: + raise FallbackToUploadUsingFilesApi( + None, + f"Failed to obtain upload URL for part {part_index}: {e}, falling back to single shot upload", + ) + else: + raise + # Other ValueError/KeyError indicate an invalid response structure + # and should propagate as a hard failure. + raise except Exception as e: if is_first_part: raise FallbackToUploadUsingFilesApi( @@ -1708,28 +1960,16 @@ def _do_upload_one_part( else: raise e - upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) - if len(upload_part_urls) == 0: - raise ValueError(f"Unexpected server response: {upload_part_urls_response}") - upload_part_url = upload_part_urls[0] - url = upload_part_url["url"] - required_headers = upload_part_url.get("headers", []) - assert part_index == upload_part_url["part_number"] - - headers: dict = {"Content-Type": "application/octet-stream"} - for h in required_headers: - headers[h["name"]] = h["value"] - _LOG.debug(f"Uploading part {part_index}: [{part_offset}, {part_offset + part_size - 1}]") def rewind() -> None: part_content.seek(0, os.SEEK_SET) def perform_upload() -> requests.Response: - return cloud_provider_session.request( + return upload_info.session.request( "PUT", - url, - headers=headers, + upload_info.url, + headers=upload_info.headers, data=part_content, timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) @@ -1763,11 +2003,11 @@ def _perform_multipart_upload( input_stream: BinaryIO, session_token: str, pre_read_buffer: bytes, - cloud_provider_session: requests.Session, ) -> None: - """ - Performs multipart upload using presigned URLs on AWS and Azure: - https://docs.aws.amazon.com/AmazonS3/latest/userguide/mpuoverview.html + """Performs multipart upload on AWS and Azure. + + Uses _get_url() to obtain per-part upload URLs, which abstracts away + whether the upload goes through the storage proxy or via presigned URLs. """ current_part_number = 1 etags: dict = {} @@ -1777,44 +2017,42 @@ def _perform_multipart_upload( # provide each chunk size up front. In case of a non-seekable input stream we need # to buffer a chunk before uploading to know its size. This also allows us to rewind # the stream before retrying on request failure. - # AWS signed chunked upload: https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-streaming.html - # https://learn.microsoft.com/en-us/azure/storage/blobs/storage-blobs-tune-upload-download-python#buffering-during-uploads chunk_offset = 0 - - # This buffer is expected to contain at least multipart_upload_chunk_size bytes. - # Note that initially buffer can be bigger (from pre_read_buffer). buffer = pre_read_buffer - retry_count = 0 - eof = False - while not eof: - # If needed, buffer the next chunk. + + while True: buffer = FilesExt._fill_buffer(buffer, ctx.part_size, input_stream) if len(buffer) == 0: - # End of stream, no need to request the next block of upload URLs. break - _LOG.debug( - f"Multipart upload: requesting next {ctx.batch_size} upload URLs starting from part {current_part_number}" - ) + actual_chunk_length = min(len(buffer), ctx.part_size) - body: dict = { - "path": ctx.target_path, - "session_token": session_token, - "start_part_number": current_part_number, - "count": ctx.batch_size, - "expire_time": self._get_upload_url_expire_time(), - } - - headers = {"Content-Type": "application/json"} - - # Requesting URLs for the same set of parts is an idempotent operation, safe to retry. try: - # Method _api.do() takes care of retrying and will raise an exception in case of failure. - upload_part_urls_response = self._api.do( - "POST", "/api/2.0/fs/create-upload-part-urls", headers=headers, body=body + upload_info = self._get_url( + "uploadPart", + { + "path": ctx.target_path, + "session_token": session_token, + "part_number": current_part_number, + "expire_time": self._get_upload_url_expire_time(), + }, ) + except (KeyError, ValueError) as e: + # JSONDecodeError is a subclass of ValueError; it indicates a server + # communication failure that should be eligible for fallback, not a + # response structure problem. + if isinstance(e, json.JSONDecodeError): + if chunk_offset == 0: + raise FallbackToUploadUsingFilesApi( + buffer, f"Failed to obtain upload URLs: {e}, falling back to single shot upload" + ) from e + else: + raise + # Other ValueError/KeyError indicate an invalid response structure + # and should propagate as a hard failure. + raise except Exception as e: if chunk_offset == 0: raise FallbackToUploadUsingFilesApi( @@ -1823,113 +2061,59 @@ def _perform_multipart_upload( else: raise e - upload_part_urls = upload_part_urls_response.get("upload_part_urls", []) - if len(upload_part_urls) == 0: - raise ValueError(f"Unexpected server response: {upload_part_urls_response}") - - for upload_part_url in upload_part_urls: - buffer = FilesExt._fill_buffer(buffer, ctx.part_size, input_stream) - actual_buffer_length = len(buffer) - if actual_buffer_length == 0: - eof = True - break + _LOG.debug( + f"Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]" + ) - url = upload_part_url["url"] - required_headers = upload_part_url.get("headers", []) - assert current_part_number == upload_part_url["part_number"] + chunk = BytesIO(buffer[:actual_chunk_length]) - headers: dict = {"Content-Type": "application/octet-stream"} - for h in required_headers: - headers[h["name"]] = h["value"] + def rewind(): + chunk.seek(0, os.SEEK_SET) - actual_chunk_length = min(actual_buffer_length, ctx.part_size) - _LOG.debug( - f"Uploading part {current_part_number}: [{chunk_offset}, {chunk_offset + actual_chunk_length - 1}]" + def perform(): + return upload_info.session.request( + "PUT", + upload_info.url, + headers=upload_info.headers, + data=chunk, + timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) - chunk = BytesIO(buffer[:actual_chunk_length]) - - def rewind(): - chunk.seek(0, os.SEEK_SET) + upload_response = self._retry_cloud_idempotent_operation(perform, rewind) - def perform(): - return cloud_provider_session.request( - "PUT", - url, - headers=headers, - data=chunk, - timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, - ) - - upload_response = self._retry_cloud_idempotent_operation(perform, rewind) - - if upload_response.status_code in (200, 201): - # Chunk upload successful - - chunk_offset += actual_chunk_length - - etag = upload_response.headers.get("ETag", "") - etags[current_part_number] = etag - - # Discard uploaded bytes - buffer = buffer[actual_chunk_length:] - - # Reset retry count when progressing along the stream - retry_count = 0 + if upload_response.status_code in (200, 201): + chunk_offset += actual_chunk_length + etag = upload_response.headers.get("ETag", "") + etags[current_part_number] = etag + buffer = buffer[actual_chunk_length:] + retry_count = 0 - elif FilesExt._is_url_expired_response(upload_response): - if retry_count < self._config.files_ext_multipart_upload_max_retries: - retry_count += 1 - _LOG.debug("Upload URL expired") - # Preserve the buffer so we'll upload the current part again using next upload URL - else: - # don't confuse user with unrelated "Permission denied" error. - raise ValueError(f"Unsuccessful chunk upload: upload URL expired") - - elif upload_response.status_code == 403 and chunk_offset == 0: - # We got 403 failure when uploading the very first chunk (we can't tell if it is Azure for sure yet). - # This might happen due to Azure firewall enabled for the customer bucket. - # Let's fallback to using Files API which might be allowlisted to upload, passing - # currently buffered (but not yet uploaded) part of the stream. - raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}") - elif chunk_offset == 0: - # We got an upload failure when uploading the very first chunk. - # Let's fallback to using Files API which might be more reliable in this case, - # passing currently buffered (but not yet uploaded) part of the stream. - raise FallbackToUploadUsingFilesApi( - buffer, - f"Unsuccessful chunk upload: {upload_response.status_code}, falling back to single shot upload", - ) + elif FilesExt._is_url_expired_response(upload_response): + if retry_count < self._config.files_ext_multipart_upload_max_retries: + retry_count += 1 + _LOG.debug("Upload URL expired") + # Preserve the buffer so we upload the current part again with a new URL. + continue else: - message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" - _LOG.warning(message) - mapped_error = _error_mapper(upload_response, {}) - raise mapped_error or ValueError(message) + raise ValueError(f"Unsuccessful chunk upload: upload URL expired") + + elif upload_response.status_code == 403 and chunk_offset == 0: + raise FallbackToUploadUsingFilesApi(buffer, f"Direct upload forbidden: {upload_response.content}") + elif chunk_offset == 0: + raise FallbackToUploadUsingFilesApi( + buffer, + f"Unsuccessful chunk upload: {upload_response.status_code}, falling back to single shot upload", + ) + else: + message = f"Unsuccessful chunk upload. Response status: {upload_response.status_code}, body: {upload_response.content}" + _LOG.warning(message) + mapped_error = _error_mapper(upload_response, {}) + raise mapped_error or ValueError(message) - current_part_number += 1 + current_part_number += 1 _LOG.debug(f"Completing multipart upload after uploading {len(etags)} parts of up to {ctx.part_size} bytes") - - query = {"action": "complete-upload", "upload_type": "multipart", "session_token": session_token} - headers = {"Content-Type": "application/json"} - body: dict = {} - - parts = [] - for etag in sorted(etags.items()): - part = {"part_number": etag[0], "etag": etag[1]} - parts.append(part) - - body["parts"] = parts - - # Completing upload is an idempotent operation, safe to retry. - # Method _api.do() takes care of retrying and will raise an exception in case of failure. - self._api.do( - "POST", - f"/api/2.0/fs/files{_escape_multi_segment_path_parameter(ctx.target_path)}", - query=query, - headers=headers, - body=body, - ) + self._complete_multipart_upload(ctx, etags, session_token) @staticmethod def _fill_buffer(buffer: bytes, desired_min_size: int, input_stream: BinaryIO) -> bytes: @@ -1999,7 +2183,6 @@ def _perform_resumable_upload( input_stream: BinaryIO, session_token: str, pre_read_buffer: bytes, - cloud_provider_session: requests.Session, ) -> None: """ Performs resumable upload on GCP: https://cloud.google.com/storage/docs/performing-resumable-uploads @@ -2028,29 +2211,29 @@ def _perform_resumable_upload( # On the contrary, in multipart upload we can decide to complete upload *after* # last chunk has been sent. - body: dict = {"path": ctx.target_path, "session_token": session_token} - - headers = {"Content-Type": "application/json"} - try: - # Method _api.do() takes care of retrying and will raise an exception in case of failure. - resumable_upload_url_response = self._api.do( - "POST", "/api/2.0/fs/create-resumable-upload-url", headers=headers, body=body - ) + upload_info = self._get_url("resumableUrl", { + "path": ctx.target_path, + "session_token": session_token, + }) + except (KeyError, ValueError) as e: + # JSONDecodeError is a subclass of ValueError; it indicates a server + # communication failure that should be eligible for fallback. + if isinstance(e, json.JSONDecodeError): + raise FallbackToUploadUsingFilesApi( + pre_read_buffer, f"Failed to obtain resumable upload URL: {e}, falling back to single shot upload" + ) from e + # Other ValueError/KeyError indicate an invalid response structure + # and should propagate as a hard failure. + raise except Exception as e: raise FallbackToUploadUsingFilesApi( pre_read_buffer, f"Failed to obtain resumable upload URL: {e}, falling back to single shot upload" ) from e - resumable_upload_url_node = resumable_upload_url_response.get("resumable_upload_url") - if not resumable_upload_url_node: - raise ValueError(f"Unexpected server response: {resumable_upload_url_response}") - - resumable_upload_url = resumable_upload_url_node.get("url") - if not resumable_upload_url: - raise ValueError(f"Unexpected server response: {resumable_upload_url_response}") - - required_headers = resumable_upload_url_node.get("headers", []) + resumable_upload_url = upload_info.url + resumable_session = upload_info.session + base_headers = upload_info.headers try: # We will buffer this many bytes: one chunk + read-ahead block. @@ -2083,9 +2266,7 @@ def _perform_resumable_upload( actual_chunk_length = ctx.part_size file_size = "*" - headers: dict = {"Content-Type": "application/octet-stream"} - for h in required_headers: - headers[h["name"]] = h["value"] + headers: dict = dict(base_headers) chunk_last_byte_offset = chunk_offset + actual_chunk_length - 1 content_range_header = f"bytes {chunk_offset}-{chunk_last_byte_offset}/{file_size}" @@ -2094,7 +2275,7 @@ def _perform_resumable_upload( def retrieve_upload_status() -> Optional[requests.Response]: def perform(): - return cloud_provider_session.request( + return resumable_session.request( "PUT", resumable_upload_url, headers={"Content-Range": "bytes */*"}, @@ -2109,7 +2290,7 @@ def perform(): return None try: - upload_response = cloud_provider_session.request( + upload_response = resumable_session.request( "PUT", resumable_upload_url, headers=headers, @@ -2196,7 +2377,7 @@ def perform(): except Exception as e: _LOG.info(f"Aborting resumable upload on error: {e}") try: - self._abort_resumable_upload(resumable_upload_url, required_headers, cloud_provider_session) + self._abort_resumable_upload(resumable_upload_url, base_headers, resumable_session) except BaseException as ex: _LOG.warning(f"Failed to abort upload: {ex}") # ignore, abort is a best-effort @@ -2238,34 +2419,19 @@ def _get_download_url_expire_time(self) -> str: current_time, self._config.files_ext_presigned_download_url_expiration_duration ) - def _abort_multipart_upload( - self, ctx: _UploadContext, session_token: str, cloud_provider_session: requests.Session - ) -> None: + def _abort_multipart_upload(self, ctx: _UploadContext, session_token: str) -> None: """Aborts ongoing multipart upload session to clean up incomplete file.""" - body: dict = { + abort_info = self._get_url("abort", { "path": ctx.target_path, "session_token": session_token, "expire_time": self._get_upload_url_expire_time(), - } - - headers = {"Content-Type": "application/json"} - - # Method _api.do() takes care of retrying and will raise an exception in case of failure. - abort_url_response = self._api.do("POST", "/api/2.0/fs/create-abort-upload-url", headers=headers, body=body) - - abort_upload_url_node = abort_url_response["abort_upload_url"] - abort_url = abort_upload_url_node["url"] - required_headers = abort_upload_url_node.get("headers", []) - - headers: dict = {"Content-Type": "application/octet-stream"} - for h in required_headers: - headers[h["name"]] = h["value"] + }) def perform() -> requests.Response: - return cloud_provider_session.request( + return abort_info.session.request( "DELETE", - abort_url, - headers=headers, + abort_info.url, + headers=abort_info.headers, data=b"", timeout=self._config.files_ext_network_transfer_inactivity_timeout_seconds, ) @@ -2276,15 +2442,12 @@ def perform() -> requests.Response: raise ValueError(abort_response) def _abort_resumable_upload( - self, resumable_upload_url: str, required_headers: list, cloud_provider_session: requests.Session + self, resumable_upload_url: str, headers: dict[str, str], session: requests.Session ) -> None: """Aborts ongoing resumable upload session to clean up incomplete file.""" - headers: dict = {} - for h in required_headers: - headers[h["name"]] = h["value"] def perform() -> requests.Response: - return cloud_provider_session.request( + return session.request( "DELETE", resumable_upload_url, headers=headers, diff --git a/tests/test_files.py b/tests/test_files.py index e8861bb9f..33ff16671 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -23,6 +23,7 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config +from databricks.sdk.credentials_provider import credentials_strategy from databricks.sdk.environments import Cloud, DatabricksEnvironment from databricks.sdk.errors.platform import (AlreadyExists, BadRequest, InternalError, NotImplemented, @@ -2847,3 +2848,273 @@ def fast_random_bytes(n: int, chunk_size: int = 1024) -> bytes: chunk = os.urandom(chunk_size) # Repeat it until we reach n bytes return (chunk * (n // chunk_size + 1))[:n] + + +# --------------------------------------------------------------------------- +# Storage-proxy upload tests +# --------------------------------------------------------------------------- + +STORAGE_PROXY_HOST = "http://storage-proxy.databricks.com" + + +def _make_response( + request: requests.Request, status_code: int = 200, body: str = "", headers: Optional[Dict[str, str]] = None +) -> requests.Response: + """Creates a mock response with the request reference set (needed by SDK logger).""" + resp = requests.Response() + resp.status_code = status_code + resp._content = body.encode() if isinstance(body, str) else body + resp.request = request + resp.url = request.url + if headers: + for k, v in headers.items(): + resp.headers[k] = v + return resp + + +class StorageProxyUploadServerState: + """Tracks server state for storage-proxy multipart uploads.""" + + def __init__(self): + self.session_token = f"sp-token-{random.randrange(10000)}" + self.uploaded_parts: Dict[int, bytes] = {} + self.completed = False + + def save_part(self, part_number: int, data: bytes) -> str: + etag = f"etag-sp-{part_number}" + self.uploaded_parts[part_number] = data + return etag + + def get_assembled_content(self) -> Optional[bytes]: + if not self.completed: + return None + parts = sorted(self.uploaded_parts.items()) + return b"".join(data for _, data in parts) + + +def _make_storage_proxy_config() -> Config: + """Creates a Config that will route uploads through the storage proxy.""" + clock = FakeClock() + + @credentials_strategy("pat", []) + def pat_credentials(_: any): + return lambda: {"Authorization": "Bearer test-pat-token"} + + config = Config( + host="http://localhost", + credentials_strategy=pat_credentials, + clock=clock, + ) + config.files_ext_storage_proxy_hostname = STORAGE_PROXY_HOST + return config + + +def test_storage_proxy_probe_success(): + """Verify that a successful probe enables the storage-proxy path.""" + config = _make_storage_proxy_config() + w = WorkspaceClient(config=config) + + with requests_mock.Mocker() as m: + # Probe succeeds. + m.get( + f"{STORAGE_PROXY_HOST}/api/2.0/fs/files/DatabricksInternal/Probes/ping", + status_code=200, + ) + # Single-shot upload goes to storage proxy. + m.put( + requests_mock.ANY, + status_code=200, + ) + + data = b"hello world" + w.files.upload("/test.txt", io.BytesIO(data), overwrite=True) + + # The PUT should have gone to the storage proxy, not localhost. + put_requests = [h for h in m.request_history if h.method == "PUT"] + assert len(put_requests) == 1 + assert put_requests[0].url.startswith(STORAGE_PROXY_HOST) + + +def test_storage_proxy_probe_failure_falls_back(): + """Verify that a failed probe falls back to presigned URLs via localhost.""" + config = _make_storage_proxy_config() + w = WorkspaceClient(config=config) + + with requests_mock.Mocker() as m: + # Probe fails. + m.get( + f"{STORAGE_PROXY_HOST}/api/2.0/fs/files/DatabricksInternal/Probes/ping", + status_code=500, + ) + # Single-shot upload goes to localhost (GIG direct). + m.put( + requests_mock.ANY, + status_code=200, + ) + + data = b"hello world" + w.files.upload("/test.txt", io.BytesIO(data), overwrite=True) + + # The PUT should have gone to localhost, not the storage proxy. + put_requests = [h for h in m.request_history if h.method == "PUT"] + assert len(put_requests) == 1 + assert put_requests[0].url.startswith("http://localhost") + + +def test_storage_proxy_multipart_upload(): + """End-to-end multipart upload through the storage proxy.""" + config = _make_storage_proxy_config() + config.files_ext_multipart_upload_min_stream_size = 1 # Force multipart. + w = WorkspaceClient(config=config) + + server_state = StorageProxyUploadServerState() + file_content = fast_random_bytes(2 * 1024 * 1024) + part_size = 1024 * 1024 # 1 MiB parts. + + with requests_mock.Mocker() as session_mock: + + def custom_matcher(request: requests.Request) -> Optional[requests.Response]: + parsed_url = urlparse(request.url) + query = parse_qs(parsed_url.query) + + # Probe endpoint. + if ( + request.url.startswith(STORAGE_PROXY_HOST) + and parsed_url.path == "/api/2.0/fs/files/DatabricksInternal/Probes/ping" + and request.method == "GET" + ): + return _make_response(request) + + # Initiate upload on the storage proxy. + if ( + request.url.startswith(STORAGE_PROXY_HOST) + and parsed_url.path == "/api/2.0/fs/files/test.txt" + and query.get("action") == ["initiate-upload"] + and request.method == "POST" + ): + assert "Authorization" in request.headers + body = json.dumps({"multipart_upload": {"session_token": server_state.session_token}}) + return _make_response(request, body=body, headers={"Content-Type": "application/json"}) + + # Part upload on the storage proxy. + if ( + request.url.startswith(STORAGE_PROXY_HOST) + and parsed_url.path == "/api/2.0/fs/files/test.txt" + and query.get("upload_type") == ["multipart"] + and request.method == "PUT" + ): + assert "Authorization" in request.headers + part_number = int(query["part_number"][0]) + data = request.body.read() if hasattr(request.body, "read") else request.body + etag = server_state.save_part(part_number, data) + return _make_response(request, headers={"ETag": etag}) + + # Complete upload on the storage proxy. + if ( + request.url.startswith(STORAGE_PROXY_HOST) + and parsed_url.path == "/api/2.0/fs/files/test.txt" + and query.get("action") == ["complete-upload"] + and request.method == "POST" + ): + assert "Authorization" in request.headers + server_state.completed = True + return _make_response(request) + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + w.files.upload( + "/test.txt", + io.BytesIO(file_content), + overwrite=True, + part_size=part_size, + ) + + # Verify the assembled content matches. + assembled = server_state.get_assembled_content() + assert assembled == file_content + + # Verify no calls went to create-upload-part-urls (presigned URL path). + presigned_calls = [ + h for h in session_mock.request_history if "create-upload-part-urls" in (h.url or "") + ] + assert len(presigned_calls) == 0, "Storage proxy path should not call create-upload-part-urls." + + +def test_storage_proxy_single_shot_upload(): + """Single-shot upload through the storage proxy (small file).""" + config = _make_storage_proxy_config() + w = WorkspaceClient(config=config) + + file_content = b"small file content" + uploaded_content = None + + with requests_mock.Mocker() as session_mock: + + def custom_matcher(request: requests.Request) -> Optional[requests.Response]: + nonlocal uploaded_content + parsed_url = urlparse(request.url) + + # Probe endpoint. + if ( + request.url.startswith(STORAGE_PROXY_HOST) + and parsed_url.path == "/api/2.0/fs/files/DatabricksInternal/Probes/ping" + and request.method == "GET" + ): + return _make_response(request) + + # Single-shot PUT to storage proxy. + if ( + request.url.startswith(STORAGE_PROXY_HOST) + and parsed_url.path == "/api/2.0/fs/files/test.txt" + and request.method == "PUT" + ): + assert "Authorization" in request.headers + uploaded_content = request.body.read() if hasattr(request.body, "read") else request.body + return _make_response(request) + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + w.files.upload("/test.txt", io.BytesIO(file_content), overwrite=True) + + assert uploaded_content == file_content + + +def test_storage_proxy_probe_cached(): + """Verify that the probe result is cached across uploads.""" + config = _make_storage_proxy_config() + w = WorkspaceClient(config=config) + probe_count = 0 + + with requests_mock.Mocker() as session_mock: + + def custom_matcher(request: requests.Request) -> Optional[requests.Response]: + nonlocal probe_count + parsed_url = urlparse(request.url) + + # Probe endpoint. + if ( + request.url.startswith(STORAGE_PROXY_HOST) + and parsed_url.path == "/api/2.0/fs/files/DatabricksInternal/Probes/ping" + and request.method == "GET" + ): + probe_count += 1 + return _make_response(request) + + # Single-shot PUT. + if request.method == "PUT": + return _make_response(request) + + return None + + session_mock.add_matcher(matcher=custom_matcher) + + # Upload twice. + w.files.upload("/test1.txt", io.BytesIO(b"first"), overwrite=True) + w.files.upload("/test2.txt", io.BytesIO(b"second"), overwrite=True) + + # Probe should have been called only once. + assert probe_count == 1, f"Probe was called {probe_count} times, expected 1."