Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions isic_challenge_scoring/load_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@


def parse_truth_csv(csv_file_stream: TextIO) -> tuple[pd.DataFrame, pd.DataFrame]:
table = pd.read_csv(csv_file_stream, header=0)
table = pd.read_csv(csv_file_stream, header=0, index_col=False)

table.set_index('image', drop=True, inplace=True, verify_integrity=False)
if 'image' in table.columns:
index_name = 'image'
elif 'lesion_id' in table.columns:
index_name = 'lesion_id'
else:
raise KeyError('Missing column in CSV: "image" or "lesion_id".')

table.set_index(index_name, drop=True, inplace=True, verify_integrity=False)

# Support legacy truth files
if 'score_weight' not in table.columns:
Expand Down Expand Up @@ -42,25 +49,31 @@ def parse_csv(csv_file_stream: TextIO, categories: pd.Index) -> pd.DataFrame:
except UnicodeDecodeError:
raise ScoreError('Could not parse CSV: could not decode file as UTF-8.')

if 'image' not in probabilities.columns:
raise ScoreError("Missing column in CSV: 'image'.")
if 'image' in probabilities.columns:
index_name = 'image'
elif 'lesion_id' in probabilities.columns:
index_name = 'lesion_id'
else:
raise ScoreError('Missing column in CSV: "image" or "lesion_id".')

# Pandas represents strings as 'O' (object)
if probabilities['image'].dtype != np.dtype('O'):
if probabilities[index_name].dtype != np.dtype('O'):
# Coercing to 'U' (unicode) ensures that even NaN values are converted;
# however, the resulting type is still 'O'
probabilities['image'] = probabilities['image'].astype(np.dtype('U'))
probabilities[index_name] = probabilities[index_name].astype(np.dtype('U'))

probabilities['image'] = probabilities['image'].str.replace(
probabilities[index_name] = probabilities[index_name].str.replace(
r'\.jpg$', '', case=False, regex=True
)

if not probabilities['image'].is_unique:
duplicate_images = probabilities['image'][probabilities['image'].duplicated()].unique()
if not probabilities[index_name].is_unique:
duplicate_images = probabilities[index_name][
probabilities[index_name].duplicated()
].unique()
raise ScoreError(f'Duplicate image rows detected in CSV: {duplicate_images.tolist()}.')

# The duplicate check is the same as performed by 'verify_integrity'
probabilities.set_index('image', drop=True, inplace=True, verify_integrity=False)
probabilities.set_index(index_name, drop=True, inplace=True, verify_integrity=False)

missing_columns = categories.difference(probabilities.columns)
if not missing_columns.empty:
Expand Down
24 changes: 19 additions & 5 deletions tests/test_load_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,16 @@
from isic_challenge_scoring.types import ScoreError


def test_parse_truth_csv(categories):
@pytest.mark.parametrize(
'index_column',
[
'image',
'lesion_id',
],
)
def test_parse_truth_csv(categories, index_column):
truth_file_stream = io.StringIO(
'image,MEL,NV,BCC,AKIEC,BKL,DF,VASC,score_weight,validation_weight\n'
f'{index_column},MEL,NV,BCC,AKIEC,BKL,DF,VASC,score_weight,validation_weight\n'
'ISIC_0000123,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0\n'
'ISIC_0000124,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0\n'
'ISIC_0000125,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0\n'
Expand Down Expand Up @@ -67,9 +74,16 @@ def test_parse_truth_csv_legacy(categories):
)


def test_parse_csv(categories):
@pytest.mark.parametrize(
'index_column',
[
'image',
'lesion_id',
],
)
def test_parse_csv(categories, index_column):
prediction_file_stream = io.StringIO(
'image,MEL,NV,BCC,AKIEC,BKL,DF,VASC\n'
f'{index_column},MEL,NV,BCC,AKIEC,BKL,DF,VASC\n'
'ISIC_0000123,1.0,0.0,0.0,0.0,0.0,0.0,0.0\n'
'ISIC_0000124.jpg,0.0,1.0,0.0,0.0,0.0,0.0,0.0\n'
'ISIC_0000125.JPG,0.0,0.0,1.0,0.0,0.0,0.0,0.0\n'
Expand Down Expand Up @@ -204,7 +218,7 @@ def test_parse_csv_missing_index(categories):
'MEL,NV,BCC,AKIEC,BKL,DF,VASC\n' '1.0,0.0,0.0,0.0,0.0,0.0,0.0\n'
)

with pytest.raises(ScoreError, match=r"^Missing column in CSV: 'image'\.$"):
with pytest.raises(ScoreError, match=r'^Missing column in CSV: "image" or "lesion_id"\.$'):
load_csv.parse_csv(prediction_file_stream, categories)


Expand Down