From 7005275c31b3cfa546a7634baede73641d839125 Mon Sep 17 00:00:00 2001 From: Salvador Escobedo Date: Fri, 30 Jan 2026 12:45:47 -0800 Subject: [PATCH] Optimize validation, matrix operations, reference generation, and process handling --- kb_python/ref.py | 68 +++++----- kb_python/utils.py | 283 +++++++++++++++++++---------------------- kb_python/validate.py | 2 +- tests/test_utils.py | 63 +++++++++ tests/test_validate.py | 4 +- 5 files changed, 228 insertions(+), 192 deletions(-) diff --git a/kb_python/ref.py b/kb_python/ref.py index 1a747d66..887447d6 100755 --- a/kb_python/ref.py +++ b/kb_python/ref.py @@ -2,6 +2,7 @@ import itertools import os import tarfile +from collections import defaultdict from typing import Callable, Dict, List, Optional, Tuple, Union import ngs_tools as ngs @@ -71,6 +72,10 @@ def generate_mismatches(name, sequence): lengths = set() features = {} variants = {} + + # Store all original sequences to check for collisions with variants + original_sequences = set() + # Generate all feature barcode variations before saving to check for collisions. for i, row in df_features.iterrows(): # Check that the first column contains the sequence @@ -83,6 +88,8 @@ def generate_mismatches(name, sequence): lengths.add(len(row.sequence)) features[row['name']] = row.sequence + original_sequences.add(row.sequence) + variants[row['name']] = { name: seq for name, seq in generate_mismatches(row['name'], row.sequence) @@ -103,45 +110,36 @@ def generate_mismatches(name, sequence): ','.join(str(l) for l in lengths) # noqa ) ) - # Find & remove collisions between barcode and variants - for feature in variants.keys(): - _variants = variants[feature] - collisions = set(_variants.values()) & set(features.values()) - if collisions: - # Remove collisions + + # Invert variants: sequence -> list of (feature_name, variant_name) + seq_to_variants = defaultdict(list) + for feature_name, feature_variants in variants.items(): + for variant_name, seq in feature_variants.items(): + seq_to_variants[seq].append((feature_name, variant_name)) + + # Process collisions + for seq, variant_list in seq_to_variants.items(): + # 1. Check collision with original barcodes + if seq in original_sequences: logger.warning( - f'Colision detected between variants of feature barcode {feature} ' - 'and feature barcode(s). These variants will be removed.' + f'Collision detected between variants of feature barcode(s) {",".join(set(v[0] for v in variant_list))} ' + f'and original feature barcode {seq}. These variants will be removed.' ) - variants[feature] = { - name: seq - for name, seq in _variants.items() - if seq not in collisions - } - - # Find & remove collisions between variants - for f1, f2 in itertools.combinations(variants.keys(), 2): - v1 = variants[f1] - v2 = variants[f2] - - collisions = set(v1.values()) & set(v2.values()) - if collisions: + for feature_name, variant_name in variant_list: + if variant_name in variants[feature_name]: + del variants[feature_name][variant_name] + continue + + # 2. Check collision between variants of DIFFERENT features + features_involved = set(v[0] for v in variant_list) + if len(features_involved) > 1: logger.warning( - f'Collision(s) detected between variants of feature barcodes {f1} and {f2}: ' - f'{",".join(collisions)}. These variants will be removed.' + f'Collision(s) detected between variants of feature barcodes {",".join(features_involved)}: ' + f'{seq}. These variants will be removed.' ) - - # Remove collisions - variants[f1] = { - name: seq - for name, seq in v1.items() - if seq not in collisions - } - variants[f2] = { - name: seq - for name, seq in v2.items() - if seq not in collisions - } + for feature_name, variant_name in variant_list: + if variant_name in variants[feature_name]: + del variants[feature_name][variant_name] # Write FASTA with ngs.fasta.Fasta(out_path, 'w') as f: diff --git a/kb_python/utils.py b/kb_python/utils.py index 70a644af..0464fccb 100755 --- a/kb_python/utils.py +++ b/kb_python/utils.py @@ -11,6 +11,7 @@ from urllib.request import urlretrieve import anndata +import numpy as np import ngs_tools as ngs import pandas as pd import scipy.io @@ -171,8 +172,8 @@ def reader(pipe, qu, stop_event, name): stderr_reader.start() while p.poll() is None: - while not out_queue.empty(): - name, line = out_queue.get() + try: + name, line = out_queue.get(timeout=0.1) if stream and not quiet: logger.debug(line) out.append(line) @@ -180,8 +181,8 @@ def reader(pipe, qu, stop_event, name): stdout += f'{line}\n' elif name == 'stderr': stderr += f'{line}\n' - else: - time.sleep(0.1) + except queue.Empty: + pass # Stop readers & flush queue stop_event.set() @@ -520,33 +521,30 @@ def collapse_anndata( if not any(adata.var.index.duplicated()): return adata - var_indices = {} - for i, index in enumerate(adata.var.index): - var_indices.setdefault(index, []).append(i) + # Optimized implementation using matrix multiplication + codes, uniques = pd.factorize(adata.var.index) + n_old = len(codes) + n_new = len(uniques) - # Convert all original matrices to csc for fast column operations - X = sparse.csc_matrix(adata.X) - layers = { - layer: sparse.csc_matrix(adata.layers[layer]) - for layer in adata.layers - } - new_index = [] - # lil_matrix is efficient for row-by-row construction - new_X = sparse.lil_matrix((len(var_indices), adata.shape[0])) - new_layers = {layer: new_X.copy() for layer in adata.layers} - for i, (index, indices) in enumerate(var_indices.items()): - new_index.append(index) - new_X[i] = X[:, indices].sum(axis=1).flatten() - for layer in layers.keys(): - new_layers[layer][i] = layers[layer][:, - indices].sum(axis=1).flatten() + row_indices = np.arange(n_old) + col_indices = codes + data = np.ones(n_old) + + # S maps from old columns to new columns (summing duplicates) + S = sparse.coo_matrix((data, (row_indices, col_indices)), shape=(n_old, n_new)).tocsr() + + X = sparse.csr_matrix(adata.X) + new_X = X @ S + + new_layers = {} + for layer, mat in adata.layers.items(): + new_layers[layer] = sparse.csr_matrix(mat) @ S return anndata.AnnData( - X=new_X.T.tocsr(), - layers={layer: new_layers[layer].T.tocsr() - for layer in new_layers}, + X=new_X, + layers=new_layers, obs=adata.obs.copy(), - var=pd.DataFrame(index=pd.Series(new_index, name=adata.var.index.name)), + var=pd.DataFrame(index=pd.Series(uniques, name=adata.var.index.name)), ) @@ -766,146 +764,123 @@ def do_sum_matrices( ) -> str: """Sums up two matrices given two matrix files. + This implementation uses a 1-pass streaming merge to minimize I/O + and keep memory usage constant (O(1)), allowing it to handle matrices + larger than available RAM. + Args: mtx1_path: First matrix file path mtx2_path: Second matrix file path out_path: Output file path mm: Whether to allow multimapping (i.e. decimals) - header_line: The header line if we have it + header_line: The header line if we have it (Used for recursion) Returns: Output file path """ logger.info('Summing matrices into {}'.format(out_path)) + + if not os.path.exists(mtx1_path) or not os.path.exists(mtx2_path): + raise Exception("Input matrix files do not exist.") + n = 0 - header = [] - with open_as_text(mtx1_path, - 'r') as f1, open_as_text(mtx2_path, - 'r') as f2, open(out_path, - 'w') as out: - eof1 = eof2 = pause1 = pause2 = False - nums = [0, 0, 0] - nums1 = nums2 = to_write = None - if header_line: - out.write("%%MatrixMarket matrix coordinate real general\n%\n") - while not eof1 or not eof2: - s1 = f1.readline() if not eof1 and not pause1 else '%' - s2 = f2.readline() if not eof2 and not pause2 else '%' - if not s1: - pause1 = eof1 = True - if not s2: - pause2 = eof2 = True - _nums1 = _nums2 = [] - if not eof1 and s1[0] != '%': - _nums1 = s1.split() - if not mm: - _nums1[0] = int(_nums1[0]) - _nums1[1] = int(_nums1[1]) - _nums1[2] = int(float(_nums1[2])) - else: - _nums1[0] = int(_nums1[0]) - _nums1[1] = int(_nums1[1]) - _nums1[2] = float(_nums1[2]) - if not eof2 and s2[0] != '%': - _nums2 = s2.split() - if not mm: - _nums2[0] = int(_nums2[0]) - _nums2[1] = int(_nums2[1]) - _nums2[2] = int(float(_nums2[2])) + header = None + # We use a temporary file to store the body while we count n (nnz) + temp_dir = os.path.dirname(out_path) + tmp_body_path = None + + try: + tmp_body_path = get_temporary_filename(temp_dir) + with open_as_text(mtx1_path, 'r') as f1, \ + open_as_text(mtx2_path, 'r') as f2, \ + open(tmp_body_path, 'w') as tmp_body: + + eof1 = eof2 = pause1 = pause2 = False + nums1 = nums2 = to_write = None + + while not eof1 or not eof2: + s1 = f1.readline() if not eof1 and not pause1 else '%' + s2 = f2.readline() if not eof2 and not pause2 else '%' + if not s1: + pause1 = eof1 = True + if not s2: + pause2 = eof2 = True + + _nums1 = _nums2 = [] + if not eof1 and s1[0] != '%': + tokens1 = s1.split() + if not mm: + _nums1 = [int(tokens1[0]), int(tokens1[1]), int(float(tokens1[2]))] + else: + _nums1 = [int(tokens1[0]), int(tokens1[1]), float(tokens1[2])] + if not eof2 and s2[0] != '%': + tokens2 = s2.split() + if not mm: + _nums2 = [int(tokens2[0]), int(tokens2[1]), int(float(tokens2[2]))] + else: + _nums2 = [int(tokens2[0]), int(tokens2[1]), float(tokens2[2])] + + if nums1 is not None: + _nums1, nums1 = nums1, None + if nums2 is not None: + _nums2, nums2 = nums2, None + + if eof1 and eof2: + break + elif eof1: + nums, pause2 = _nums2, False + elif eof2: + nums, pause1 = _nums1, False + elif not _nums1 or not _nums2: + # Skip header comments + continue + elif not header: + if (_nums1[0] != _nums2[0] or _nums1[1] != _nums2[1]): + raise Exception("Summing up two matrix files failed: Headers incompatible") + header = [_nums1[0], _nums1[1]] + continue + elif (_nums1[0] > _nums2[0] + or (_nums1[0] == _nums2[0] and _nums1[1] > _nums2[1])): + nums, pause1, pause2, nums1, nums2 = _nums2, True, False, _nums1, None + elif (_nums2[0] > _nums1[0] + or (_nums2[0] == _nums1[0] and _nums2[1] > _nums1[1])): + nums, pause2, pause1, nums2, nums1 = _nums1, True, False, _nums2, None + elif _nums1[0] == _nums2[0] and _nums1[1] == _nums2[1]: + nums, pause1, pause2, nums1, nums2 = _nums1, False, False, None, None + nums[2] += _nums2[2] else: - _nums2[0] = int(_nums2[0]) - _nums2[1] = int(_nums2[1]) - _nums2[2] = float(_nums2[2]) - if nums1 is not None: - _nums1 = nums1 - nums1 = None - if nums2 is not None: - _nums2 = nums2 - nums2 = None - if eof1 and eof2: - # Both mtxs are done - break - elif eof1: - # mtx1 is done - nums = _nums2 - pause2 = False - elif eof2: - # mtx2 is done - nums = _nums1 - pause1 = False - elif eof1 and eof2: - # Both mtxs are done - break - # elif (len(_nums1) != len(_nums2)): - # # We have a problem - # raise Exception("Summing up two matrix files failed") - elif not _nums1 or not _nums2: - # We have something other than a matrix line - continue - elif not header: - # We are at the header line and need to read it in - if (_nums1[0] != _nums2[0] or _nums1[1] != _nums2[1]): - raise Exception( - "Summing up two matrix files failed: Headers incompatible" - ) + raise Exception("Summing up two matrix files failed: Assertion failed") + + if to_write and to_write[0] == nums[0] and to_write[1] == nums[1]: + to_write[2] += nums[2] else: - header = [_nums1[0], _nums1[1]] - if header_line: - out.write(header_line) - continue - elif (_nums1[0] > _nums2[0] - or (_nums1[0] == _nums2[0] and _nums1[1] > _nums2[1])): - # If we're further in mtx1 than mtx2 - nums = _nums2 - pause1 = True - pause2 = False - nums1 = _nums1 - nums2 = None - elif (_nums2[0] > _nums1[0] - or (_nums2[0] == _nums1[0] and _nums2[1] > _nums1[1])): - # If we're further in mtx2 than mtx1 - nums = _nums1 - pause2 = True - pause1 = False - nums2 = _nums2 - nums1 = None - elif _nums1[0] == _nums2[0] and _nums1[1] == _nums2[1]: - # If we're at the same location in mtx1 and mtx2 - nums = _nums1 - nums[2] += _nums2[2] - pause1 = pause2 = False - nums1 = nums2 = None - else: - # Shouldn't happen - raise Exception( - "Summing up two matrix files failed: Assertion failed" - ) - # Write out a line - _nums_prev = to_write - if (_nums_prev and _nums_prev[0] == nums[0] - and _nums_prev[1] == nums[1]): - nums[2] += _nums_prev[2] - pause1 = pause2 = False - to_write = [nums[0], nums[1], nums[2]] - else: - if to_write: - if header_line: - if mm and to_write[2].is_integer(): - to_write[2] = int(to_write[2]) - out.write( - f'{to_write[0]} {to_write[1]} {to_write[2]}\n' - ) - n += 1 - to_write = [nums[0], nums[1], nums[2]] - if to_write: - if header_line: - if mm and to_write[2].is_integer(): - to_write[2] = int(to_write[2]) - out.write(f'{to_write[0]} {to_write[1]} {to_write[2]}\n') - n += 1 - if not header_line: - header_line = f'{header[0]} {header[1]} {n}\n' - do_sum_matrices(mtx1_path, mtx2_path, out_path, mm, header_line) + if to_write: + val = to_write[2] + if not mm: + val = int(val) + tmp_body.write(f'{to_write[0]} {to_write[1]} {val}\n') + n += 1 + to_write = [nums[0], nums[1], nums[2]] + + if to_write: + val = to_write[2] + if not mm: + val = int(val) + tmp_body.write(f'{to_write[0]} {to_write[1]} {val}\n') + n += 1 + + if header is None: + raise Exception(f"Summing up two matrix files failed: Missing header in {mtx1_path} or {mtx2_path}") + + # Final assembly: Prepend header and copy body + with open(out_path, 'w') as out, open(tmp_body_path, 'r') as body: + out.write("%%MatrixMarket matrix coordinate real general\n%\n") + out.write(f"{header[0]} {header[1]} {n}\n") + shutil.copyfileobj(body, out) + finally: + if tmp_body_path and os.path.exists(tmp_body_path): + os.remove(tmp_body_path) + return out_path diff --git a/kb_python/validate.py b/kb_python/validate.py index ecca731f..fe631c8e 100755 --- a/kb_python/validate.py +++ b/kb_python/validate.py @@ -52,7 +52,7 @@ def validate_mtx(path: str): ValidateError: If the file failed verification """ try: - scipy.io.mmread(path) + scipy.io.mminfo(path) except ValueError: raise ValidateError(f'{path} is not a valid matrix market file') diff --git a/tests/test_utils.py b/tests/test_utils.py index d09f4d28..11f1529d 100755 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -340,3 +340,66 @@ def test_create_10x_feature_barcode_map(self): self.assertTrue(os.path.exists(map_path)) with open(map_path, 'r') as f: self.assertIn('\t', f.readline()) + + def test_do_sum_matrices(self): + import scipy.io + import scipy.sparse + + m1 = scipy.sparse.csr_matrix([[1, 2], [3, 4]]) + m2 = scipy.sparse.csr_matrix([[5, 6], [7, 8]]) + + m1_path = os.path.join(self.temp_dir, 'm1.mtx') + m2_path = os.path.join(self.temp_dir, 'm2.mtx') + out_path = os.path.join(self.temp_dir, 'sum.mtx') + + scipy.io.mmwrite(m1_path, m1) + scipy.io.mmwrite(m2_path, m2) + + utils.do_sum_matrices(m1_path, m2_path, out_path) + + m_sum = scipy.io.mmread(out_path) + expected = np.array([[6, 8], [10, 12]]) + + np.testing.assert_array_equal(m_sum.toarray(), expected) + + # Verify robust integer formatting in output + with open(out_path, 'r') as f: + for line in f: + if line.startswith('%'): + continue + parts = line.split() + if len(parts) == 3: + self.assertTrue(parts[0].isdigit()) + self.assertTrue(parts[1].isdigit()) + # Value might be negative, though not in this test + self.assertTrue(parts[2].lstrip('-').isdigit()) + + def test_do_sum_matrices_complex(self): + import scipy.io + import scipy.sparse + + # Test case with: + # - Overlapping coordinates (1,1) + # - Unique coordinates in m1 (1,2) + # - Unique coordinates in m2 (2,1) + # - Sparse structure + m1 = scipy.sparse.coo_matrix(([1, 2], ([0, 0], [0, 1])), shape=(2, 2)) + m2 = scipy.sparse.coo_matrix(([3, 4], ([0, 1], [0, 0])), shape=(2, 2)) + + # m1: [[1, 2], [0, 0]] + # m2: [[3, 0], [4, 0]] + # sum: [[4, 2], [4, 0]] + + m1_path = os.path.join(self.temp_dir, 'm1_complex.mtx') + m2_path = os.path.join(self.temp_dir, 'm2_complex.mtx') + out_path = os.path.join(self.temp_dir, 'sum_complex.mtx') + + scipy.io.mmwrite(m1_path, m1) + scipy.io.mmwrite(m2_path, m2) + + utils.do_sum_matrices(m1_path, m2_path, out_path) + + m_sum = scipy.io.mmread(out_path) + expected = np.array([[4, 2], [4, 0]]) + + np.testing.assert_array_equal(m_sum.toarray(), expected) diff --git a/tests/test_validate.py b/tests/test_validate.py index ed1cf1af..23bcf192 100755 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -47,8 +47,8 @@ def test_validate_mtx(self): validate.validate_mtx(self.matrix_path) def test_validate_mtx_raises_on_error(self): - with mock.patch('kb_python.validate.scipy.io.mmread') as mmread: - mmread.side_effect = ValueError('test') + with mock.patch('kb_python.validate.scipy.io.mminfo') as mminfo: + mminfo.side_effect = ValueError('test') with self.assertRaises(validate.ValidateError): validate.validate_mtx('path')