From e0831472a0e9b3733c06ab83bf2baef5ce965ab2 Mon Sep 17 00:00:00 2001 From: beobest2 Date: Fri, 8 Oct 2021 17:50:12 +0900 Subject: [PATCH] Implement Index.putmask --- .../reference/pyspark.pandas/indexing.rst | 1 + python/pyspark/pandas/indexes/base.py | 117 +++++++++++++++++- python/pyspark/pandas/indexes/multi.py | 7 ++ python/pyspark/pandas/missing/indexes.py | 1 - .../pyspark/pandas/tests/indexes/test_base.py | 35 ++++++ python/pyspark/pandas/typedef/typehints.py | 2 +- 6 files changed, 160 insertions(+), 3 deletions(-) diff --git a/python/docs/source/reference/pyspark.pandas/indexing.rst b/python/docs/source/reference/pyspark.pandas/indexing.rst index 4168b6712bf3f..924fa7d958051 100644 --- a/python/docs/source/reference/pyspark.pandas/indexing.rst +++ b/python/docs/source/reference/pyspark.pandas/indexing.rst @@ -83,6 +83,7 @@ Modifying and computations Index.min Index.max Index.map + Index.putmask Index.rename Index.repeat Index.take diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index 369934af13858..dd924efb5076a 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -36,8 +36,16 @@ from pandas.api.types import CategoricalDtype, is_hashable from pandas._libs import lib +from pyspark.sql.functions import pandas_udf from pyspark.sql import functions as F, Column -from pyspark.sql.types import FractionalType, IntegralType, TimestampType, TimestampNTZType +from pyspark.sql.types import ( + FractionalType, + IntegralType, + TimestampType, + TimestampNTZType, + BooleanType, + StringType, +) from pyspark import pandas as ps # For running doctests and reference resolution in PyCharm. from pyspark.pandas._typing import Dtype, Label, Name, Scalar @@ -46,6 +54,7 @@ from pyspark.pandas.frame import DataFrame from pyspark.pandas.missing.indexes import MissingPandasLikeIndex from pyspark.pandas.series import Series, first_series +from pyspark.pandas.typedef import infer_pd_series_spark_type from pyspark.pandas.spark import functions as SF from pyspark.pandas.spark.accessors import SparkIndexMethods from pyspark.pandas.utils import ( @@ -1927,6 +1936,112 @@ def argmin(self) -> int: .first()[0] ) + def putmask( + self, mask: Union[Series, "Index", List, Tuple], value: Union[Series, "Index", List, Tuple] + ) -> "Index": + """ + Return a new Index of the values set with the mask. + .. note:: this API can be pretty expensive since it is based on + a global sequence internally. + Parameters + ---------- + mask : array-like + Boolean mask array. It has to be the same shape as the index. + value : array-like + Value to put into the index where mask is True. + Returns + ------- + Index + Examples + ------- + >>> psidx = ps.Index([1, 2, 3, 4, 5]) + >>> psidx + Int64Index([1, 2, 3, 4, 5], dtype='int64') + >>> psidx.putmask(psidx > 3, 100).sort_values() + Int64Index([1, 2, 3, 100, 100], dtype='int64') + >>> psidx.putmask(psidx > 3, ps.Index([100, 200, 300, 400, 500])).sort_values() + Int64Index([1, 2, 3, 400, 500], dtype='int64') + """ + scol_name = self._internal.index_spark_column_names[0] + sdf = self._internal.spark_frame.select(self.spark.column) + + dist_sequence_col_name = verify_temp_column_name(sdf, "__distributed_sequence_column__") + sdf = InternalFrame.attach_distributed_sequence_column( + sdf, column_name=dist_sequence_col_name + ) + + replace_col = verify_temp_column_name(sdf, "__replace_column__") + masking_col = verify_temp_column_name(sdf, "__masking_column__") + + if isinstance(value, (list, tuple, Index, Series)): + if isinstance(value, (list, tuple)): + pd_value = pd.Series(value) + elif isinstance(value, (Series, Index)): + pd_value = value.to_pandas() + + if self.size != pd_value.size: + # TODO: We can't support different size of value for now. + raise ValueError("value and data must be the same size") + + replace_return_type = infer_pd_series_spark_type(pd_value, pd_value.dtype) + + @pandas_udf( + returnType=replace_return_type if replace_return_type else StringType() + ) # type: ignore + def replace_pandas_udf(sequence: pd.Series) -> pd.Series: + return pd_value[sequence] + + sdf = sdf.withColumn(replace_col, replace_pandas_udf(dist_sequence_col_name)) + else: + sdf = sdf.withColumn(replace_col, F.lit(value)) + + if isinstance(mask, (list, tuple)): + pandas_mask = pd.Series(mask) + elif isinstance(mask, (Index, Series)): + pandas_mask = mask.to_pandas() + else: + raise TypeError("Mask data doesn't support type " "{0}".format(type(mask).__name__)) + + if self.size != pandas_mask.size: + raise ValueError("mask and data must be the same size") + + @pandas_udf(returnType=BooleanType()) # type: ignore + def masking_pandas_udf(sequence: pd.Series) -> pd.Series: + return pandas_mask[sequence] + + sdf = sdf.withColumn(masking_col, masking_pandas_udf(dist_sequence_col_name)) + + # spark_frame here looks like below + # +-------------------------------+-----------------+------------------+------------------+ + # |__distributed_sequence_column__|__index_level_0__|__replace_column__|__masking_column__| + # +-------------------------------+-----------------+------------------+------------------+ + # | 0| a| 100| true| + # | 3| d| 400| false| + # | 1| b| 200| true| + # | 2| c| 300| false| + # | 4| e| 500| false| + # +-------------------------------+-----------------+------------------+------------------+ + + cond = F.when(scol_for(sdf, masking_col), scol_for(sdf, replace_col)).otherwise( + scol_for(sdf, scol_name) + ) + sdf = sdf.select(cond.alias(scol_name)) + + if sdf.schema[scol_name].nullable != self._internal.index_fields[0].nullable: + sdf.schema[scol_name].nullable = self._internal.index_fields[0].nullable + sdf = sdf.sql_ctx.createDataFrame(sdf.rdd, sdf.schema) + + internal = InternalFrame( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, col) for col in self._internal.index_spark_column_names + ], + index_names=self._internal.index_names, + index_fields=self._internal.index_fields, + ) + + return DataFrame(internal).index + def set_names( self, names: Union[Name, List[Name]], diff --git a/python/pyspark/pandas/indexes/multi.py b/python/pyspark/pandas/indexes/multi.py index 896ea2af27643..5a9524569112a 100644 --- a/python/pyspark/pandas/indexes/multi.py +++ b/python/pyspark/pandas/indexes/multi.py @@ -1236,6 +1236,13 @@ def map( ) -> "Index": return MissingPandasLikeMultiIndex.map(self, mapper, na_action) + def putmask( + self, + mask: Union[Series, "Index", List, Tuple] = None, + value: Union[Series, "Index", List, Tuple] = None, + ) -> "Index": + return MissingPandasLikeMultiIndex.putmask(self, mask, value) + def _test() -> None: import os diff --git a/python/pyspark/pandas/missing/indexes.py b/python/pyspark/pandas/missing/indexes.py index 4170aa70f7d4c..2bed9681afedb 100644 --- a/python/pyspark/pandas/missing/indexes.py +++ b/python/pyspark/pandas/missing/indexes.py @@ -53,7 +53,6 @@ class MissingPandasLikeIndex(object): groupby = _unsupported_function("groupby") is_ = _unsupported_function("is_") join = _unsupported_function("join") - putmask = _unsupported_function("putmask") ravel = _unsupported_function("ravel") reindex = _unsupported_function("reindex") searchsorted = _unsupported_function("searchsorted") diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 40039983c4c11..c332d7e001ddc 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -2388,6 +2388,41 @@ def test_map(self): lambda: psidx.map({1: 1, 2: 2.0, 3: "three"}), ) + def test_putmask(self): + pidx = pd.Index(["a", "b", "c", "d", "e"]) + psidx = ps.from_pandas(pidx) + + self.assert_eq( + psidx.putmask(psidx < "c", "k").sort_values(), + pidx.putmask(pidx < "c", "k").sort_values(), + ) + self.assert_eq( + psidx.putmask(psidx < "c", ["g", "h", "i", "j", "k"]).sort_values(), + pidx.putmask(pidx < "c", ["g", "h", "i", "j", "k"]).sort_values(), + ) + self.assert_eq( + psidx.putmask(psidx < "c", ("g", "h", "i", "j", "k")).sort_values(), + pidx.putmask(pidx < "c", ("g", "h", "i", "j", "k")).sort_values(), + ) + self.assert_eq( + psidx.putmask(psidx < "c", ps.Index(["g", "h", "i", "j", "k"])).sort_values(), + pidx.putmask(pidx < "c", pd.Index(["g", "h", "i", "j", "k"])).sort_values(), + ) + self.assert_eq( + psidx.putmask(psidx < "c", ps.Series(["g", "h", "i", "j", "k"])).sort_values(), + pidx.putmask(pidx < "c", pd.Series(["g", "h", "i", "j", "k"])).sort_values(), + ) + + self.assertRaises( + ValueError, + lambda: psidx.putmask(psidx < "c", ps.Series(["g", "h"])), + ) + + self.assertRaises( + ValueError, + lambda: psidx.putmask([True, False], ps.Series(["g", "h", "i", "j", "k"])), + ) + def test_multiindex_equal_levels(self): pmidx1 = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")]) pmidx2 = pd.MultiIndex.from_tuples([("b", "y"), ("a", "x"), ("c", "z")]) diff --git a/python/pyspark/pandas/typedef/typehints.py b/python/pyspark/pandas/typedef/typehints.py index 288273a2de9b2..c3714e4119524 100644 --- a/python/pyspark/pandas/typedef/typehints.py +++ b/python/pyspark/pandas/typedef/typehints.py @@ -356,7 +356,7 @@ def infer_pd_series_spark_type( if dtype == np.dtype("object"): if len(pser) == 0 or pser.isnull().all(): return types.NullType() - elif hasattr(pser.iloc[0], "__UDT__"): + elif hasattr(pser, "iloc") and hasattr(pser.iloc[0], "__UDT__"): return pser.iloc[0].__UDT__ else: return from_arrow_type(pa.Array.from_pandas(pser).type, prefer_timestamp_ntz)