From 7f92000a42d894945ab87554949845dcf94572a8 Mon Sep 17 00:00:00 2001 From: Brian Helba Date: Sat, 19 Jul 2025 15:50:00 -0400 Subject: [PATCH] Add support for `lesion_id` in classification ground truth --- isic_challenge_scoring/load_csv.py | 33 +++++++++++++++++++++--------- tests/test_load_csv.py | 24 +++++++++++++++++----- 2 files changed, 42 insertions(+), 15 deletions(-) diff --git a/isic_challenge_scoring/load_csv.py b/isic_challenge_scoring/load_csv.py index 3bfbf17..16c8b54 100644 --- a/isic_challenge_scoring/load_csv.py +++ b/isic_challenge_scoring/load_csv.py @@ -7,9 +7,16 @@ def parse_truth_csv(csv_file_stream: TextIO) -> tuple[pd.DataFrame, pd.DataFrame]: - table = pd.read_csv(csv_file_stream, header=0) + table = pd.read_csv(csv_file_stream, header=0, index_col=False) - table.set_index('image', drop=True, inplace=True, verify_integrity=False) + if 'image' in table.columns: + index_name = 'image' + elif 'lesion_id' in table.columns: + index_name = 'lesion_id' + else: + raise KeyError('Missing column in CSV: "image" or "lesion_id".') + + table.set_index(index_name, drop=True, inplace=True, verify_integrity=False) # Support legacy truth files if 'score_weight' not in table.columns: @@ -42,25 +49,31 @@ def parse_csv(csv_file_stream: TextIO, categories: pd.Index) -> pd.DataFrame: except UnicodeDecodeError: raise ScoreError('Could not parse CSV: could not decode file as UTF-8.') - if 'image' not in probabilities.columns: - raise ScoreError("Missing column in CSV: 'image'.") + if 'image' in probabilities.columns: + index_name = 'image' + elif 'lesion_id' in probabilities.columns: + index_name = 'lesion_id' + else: + raise ScoreError('Missing column in CSV: "image" or "lesion_id".') # Pandas represents strings as 'O' (object) - if probabilities['image'].dtype != np.dtype('O'): + if probabilities[index_name].dtype != np.dtype('O'): # Coercing to 'U' (unicode) ensures that even NaN values are converted; # however, the resulting type is still 'O' - probabilities['image'] = probabilities['image'].astype(np.dtype('U')) + probabilities[index_name] = probabilities[index_name].astype(np.dtype('U')) - probabilities['image'] = probabilities['image'].str.replace( + probabilities[index_name] = probabilities[index_name].str.replace( r'\.jpg$', '', case=False, regex=True ) - if not probabilities['image'].is_unique: - duplicate_images = probabilities['image'][probabilities['image'].duplicated()].unique() + if not probabilities[index_name].is_unique: + duplicate_images = probabilities[index_name][ + probabilities[index_name].duplicated() + ].unique() raise ScoreError(f'Duplicate image rows detected in CSV: {duplicate_images.tolist()}.') # The duplicate check is the same as performed by 'verify_integrity' - probabilities.set_index('image', drop=True, inplace=True, verify_integrity=False) + probabilities.set_index(index_name, drop=True, inplace=True, verify_integrity=False) missing_columns = categories.difference(probabilities.columns) if not missing_columns.empty: diff --git a/tests/test_load_csv.py b/tests/test_load_csv.py index 653d78f..82424a2 100644 --- a/tests/test_load_csv.py +++ b/tests/test_load_csv.py @@ -7,9 +7,16 @@ from isic_challenge_scoring.types import ScoreError -def test_parse_truth_csv(categories): +@pytest.mark.parametrize( + 'index_column', + [ + 'image', + 'lesion_id', + ], +) +def test_parse_truth_csv(categories, index_column): truth_file_stream = io.StringIO( - 'image,MEL,NV,BCC,AKIEC,BKL,DF,VASC,score_weight,validation_weight\n' + f'{index_column},MEL,NV,BCC,AKIEC,BKL,DF,VASC,score_weight,validation_weight\n' 'ISIC_0000123,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0\n' 'ISIC_0000124,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0\n' 'ISIC_0000125,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0\n' @@ -67,9 +74,16 @@ def test_parse_truth_csv_legacy(categories): ) -def test_parse_csv(categories): +@pytest.mark.parametrize( + 'index_column', + [ + 'image', + 'lesion_id', + ], +) +def test_parse_csv(categories, index_column): prediction_file_stream = io.StringIO( - 'image,MEL,NV,BCC,AKIEC,BKL,DF,VASC\n' + f'{index_column},MEL,NV,BCC,AKIEC,BKL,DF,VASC\n' 'ISIC_0000123,1.0,0.0,0.0,0.0,0.0,0.0,0.0\n' 'ISIC_0000124.jpg,0.0,1.0,0.0,0.0,0.0,0.0,0.0\n' 'ISIC_0000125.JPG,0.0,0.0,1.0,0.0,0.0,0.0,0.0\n' @@ -204,7 +218,7 @@ def test_parse_csv_missing_index(categories): 'MEL,NV,BCC,AKIEC,BKL,DF,VASC\n' '1.0,0.0,0.0,0.0,0.0,0.0,0.0\n' ) - with pytest.raises(ScoreError, match=r"^Missing column in CSV: 'image'\.$"): + with pytest.raises(ScoreError, match=r'^Missing column in CSV: "image" or "lesion_id"\.$'): load_csv.parse_csv(prediction_file_stream, categories)