diff --git a/dask_ml/metrics/regression.py b/dask_ml/metrics/regression.py index 0c1b21b59..ccb7e7bca 100644 --- a/dask_ml/metrics/regression.py +++ b/dask_ml/metrics/regression.py @@ -162,18 +162,13 @@ def r2_score( numerator = (weight * (y_true - y_pred) ** 2).sum(axis=0, dtype="f8") denominator = (weight * (y_true - y_true.mean(axis=0)) ** 2).sum(axis=0, dtype="f8") - nonzero_denominator = denominator != 0 - nonzero_numerator = numerator != 0 - valid_score = nonzero_denominator & nonzero_numerator - output_chunks = getattr(y_true, "chunks", [None, None])[1] - output_scores = da.ones([y_true.shape[1]], chunks=output_chunks) - with np.errstate(all="ignore"): - output_scores[valid_score] = 1 - ( - numerator[valid_score] / denominator[valid_score] - ) - output_scores[nonzero_numerator & ~nonzero_denominator] = 0.0 - - result = output_scores.mean(axis=0) + score = da.where( + numerator == 0, + 1.0, + da.where(denominator != 0, 1 - numerator / denominator, 0.0), + ) + + result = score.mean(axis=0) if compute: result = result.compute() return result diff --git a/tests/metrics/test_regression.py b/tests/metrics/test_regression.py index af775e168..d15580c7d 100644 --- a/tests/metrics/test_regression.py +++ b/tests/metrics/test_regression.py @@ -116,3 +116,19 @@ def test_regression_metrics_do_not_support_weighted_multioutput(metric_pairs): with pytest.raises((NotImplementedError, ValueError), match=error_msg): _ = m1(a, b, multioutput=weights) + + +def test_r2_score_with_different_chunk_patterns(): + """Test r2_score with different chunking configurations.""" + # Create arrays with compatible but different chunk patterns + a = da.random.uniform(size=(100,), chunks=25) # 4 chunks + b = da.random.uniform(size=(100,), chunks=20) # 5 chunks + result = dask_ml.metrics.r2_score(a, b) + assert isinstance(result, float) + # Create arrays with different chunk patterns + a_multi = da.random.uniform(size=(100, 3), chunks=(25, 3)) # 4 chunks + b_multi = da.random.uniform(size=(100, 3), chunks=(20, 3)) # 5 chunks + result_multi = dask_ml.metrics.r2_score( + a_multi, b_multi, multioutput="uniform_average" + ) + assert isinstance(result_multi, float)