diff --git a/barcodeforge/generate_barcodes.py b/barcodeforge/generate_barcodes.py index f267478..65e20f3 100755 --- a/barcodeforge/generate_barcodes.py +++ b/barcodeforge/generate_barcodes.py @@ -196,7 +196,7 @@ def check_mutation_chain(df_barcodes: pd.DataFrame) -> pd.DataFrame: else: # combining leads to already existing mutation # just add in that mutation - df_barcodes.loc[lin_seq.index, sm[2]] = 1 + df_barcodes.loc[lin_seq.index, sm[2]] += 1 # remove constituent mutations df_barcodes.loc[lin_seq.index, sm[0:2]] -= 1 # drop all unused mutations @@ -207,6 +207,8 @@ def check_mutation_chain(df_barcodes: pd.DataFrame) -> pd.DataFrame: # in case mutation path leads to a return to the reference. df_barcodes = reversion_checking(df_barcodes) seq_muts = identify_chains(df_barcodes) + # The barcode should be a binary sparse matrix + assert df_barcodes.isin([0, 1]).all(axis=None), "Barcode matrix should be binary" return df_barcodes diff --git a/tests/test_generate_barcodes.py b/tests/test_generate_barcodes.py index 0483c58..b078650 100644 --- a/tests/test_generate_barcodes.py +++ b/tests/test_generate_barcodes.py @@ -91,6 +91,35 @@ def test_check_mutation_chain(sample_barcode_data): assert isinstance(chained_df, pd.DataFrame) +def test_check_mutation_chain_repetitve_mutations(): + sample_barcode_data = pd.DataFrame( + {"A225G": [1], "A225T": [1], "C225A": [1], "G225T": [1], "T225C": [2]}, + index=["lineage"], + ) + chained_df = check_mutation_chain(sample_barcode_data.copy()) + df_barcodes_ideal = pd.DataFrame( + {"A225C": [1]}, + index=["lineage"], + ) + pd.testing.assert_frame_equal(chained_df, df_barcodes_ideal) + + +def test_check_mutation_chain_non_binary_values(): + sample_barcode_data = pd.DataFrame( + { + "A225G": [1], + "A225T": [1], + "C225A": [1], + "G225T": [1], + "T225C": [2], + "C123A": [2], + }, + index=["lineage"], + ) + with pytest.raises(AssertionError, match="Barcode matrix should be binary"): + check_mutation_chain(sample_barcode_data.copy()) + + def test_replace_underscore_with_dash(): data = {"value": [1, 2]} df = pd.DataFrame(data, index=["lineage_A", "lineage_B"])