From 1468efd876b2208a005638fa7981d6801feae456 Mon Sep 17 00:00:00 2001 From: dhruvgupta-meesho Date: Mon, 9 Feb 2026 10:47:56 +0530 Subject: [PATCH 1/4] removed some bottlenecks --- .../inference_logging_client/__init__.py | 399 ++++++++---------- .../inference_logging_client/cli.py | 35 +- 2 files changed, 215 insertions(+), 219 deletions(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 89a7319c..95130f17 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -209,6 +209,42 @@ def decode_mplog( return spark.createDataFrame(rows) +def _extract_metadata_byte(metadata_data, json_module, base64_module) -> int: + """Extract metadata byte from JSON array with base64-encoded string. + + Expected format: JSON array with single base64-encoded string, e.g., '["BQ=="]' + """ + if metadata_data is None: + return 0 + # Handle pandas NA/NaN + try: + if hasattr(metadata_data, "isna") and metadata_data.isna(): + return 0 + except (TypeError, ValueError): + pass + if isinstance(metadata_data, str): + try: + parsed = json_module.loads(metadata_data) + if isinstance(parsed, list) and len(parsed) > 0: + decoded = base64_module.b64decode(parsed[0]) + if len(decoded) > 0: + return decoded[0] + except (json_module.JSONDecodeError, ValueError, TypeError): + pass + return 0 + if isinstance(metadata_data, list) and len(metadata_data) > 0: + first_item = metadata_data[0] + if isinstance(first_item, str): + try: + decoded = base64_module.b64decode(first_item) + if len(decoded) > 0: + return decoded[0] + except (ValueError, TypeError): + pass + return 0 + return 0 + + def decode_mplog_dataframe( df: "SparkDataFrame", spark: "SparkSession", @@ -217,6 +253,8 @@ def decode_mplog_dataframe( features_column: str = "features", metadata_column: str = "metadata", mp_config_id_column: str = "mp_config_id", + num_partitions: Optional[int] = None, + max_records_per_batch: Optional[int] = None, ) -> "SparkDataFrame": """ Decode MPLog features from a Spark DataFrame with specific column structure. @@ -227,8 +265,9 @@ def decode_mplog_dataframe( - mp_config_id, parent_entity, tracking_id, user_id - year, month, day, hour - Note: This function collects the DataFrame to the driver for processing. - For very large datasets, consider partitioning and processing in smaller batches. + Processing is done distributed via mapInPandas so that large DataFrames (millions + of rows, multi-MB per row) are not collected to the driver. Each partition is + decoded on workers; only decoded (small) rows are returned. Args: df: Input Spark DataFrame with MPLog data columns @@ -238,6 +277,10 @@ def decode_mplog_dataframe( features_column: Name of the column containing encoded features (default: "features") metadata_column: Name of the column containing metadata byte (default: "metadata") mp_config_id_column: Name of the column containing model proxy config ID (default: "mp_config_id") + num_partitions: Number of partitions for distributed decode. Default 10000 to keep + partition size small when rows are large (3-5 MB each). Increase if rows are small. + max_records_per_batch: Max rows per Arrow batch in mapInPandas. When set (default 200), + applied temporarily during this call to limit memory per batch when rows are large. Returns: Spark DataFrame with decoded features. Each row from input becomes multiple rows @@ -260,11 +303,8 @@ def decode_mplog_dataframe( if inference_host is None: inference_host = os.getenv("INFERENCE_HOST", "http://localhost:8082") - # Track decode errors for summary - decode_errors = [] - - # Check if DataFrame is empty - if df.count() == 0: + # Check if DataFrame is empty (avoid full count: use limit(1)) + if df.limit(1).count() == 0: from pyspark.sql.types import StructType return spark.createDataFrame([], StructType([])) @@ -275,79 +315,28 @@ def decode_mplog_dataframe( if missing_columns: raise ValueError(f"Missing required columns: {missing_columns}") - # Collect to driver for processing - # Note: For large datasets, consider using mapInPandas or processing in partitions - rows = df.collect() + # Only collect distinct (mp_config_id, metadata) to get schema keys - small payload + distinct_df = df.select(mp_config_id_column, metadata_column).distinct() + distinct_rows = distinct_df.collect() - # Pre-fetch schemas for unique (mp_config_id, version) combinations to avoid - # redundant HTTP requests during row iteration. - # Key: (mp_config_id, version) only - host/path intentionally excluded as schemas are canonical schema_cache: dict[tuple[str, int], list[FeatureInfo]] = {} - - def _extract_metadata_byte(metadata_data) -> int: - """Extract metadata byte from JSON array with base64-encoded string. - - Expected format: JSON array with single base64-encoded string, e.g., '["BQ=="]' - """ - if metadata_data is None: - return 0 - - # Handle JSON string format - if isinstance(metadata_data, str): - try: - parsed = json.loads(metadata_data) - if isinstance(parsed, list) and len(parsed) > 0: - decoded = base64.b64decode(parsed[0]) - if len(decoded) > 0: - return decoded[0] - except (json.JSONDecodeError, ValueError, TypeError): - pass - return 0 - - # Handle already-parsed list format - if isinstance(metadata_data, list) and len(metadata_data) > 0: - first_item = metadata_data[0] - if isinstance(first_item, str): - try: - decoded = base64.b64decode(first_item) - if len(decoded) > 0: - return decoded[0] - except (ValueError, TypeError): - pass - return 0 - - return 0 - - # First pass: collect unique (mp_config_id, version) pairs - for row in rows: - # Extract metadata byte to get version + for row in distinct_rows: metadata_data = row[metadata_column] - metadata_byte = _extract_metadata_byte(metadata_data) - + metadata_byte = _extract_metadata_byte(metadata_data, json, base64) _, version, _ = unpack_metadata_byte(metadata_byte) - - # Skip invalid versions if not (0 <= version <= _MAX_SCHEMA_VERSION): continue - - # Extract mp_config_id mp_config_id = row[mp_config_id_column] if mp_config_id is None: continue mp_config_id = str(mp_config_id) - cache_key = (mp_config_id, version) if cache_key not in schema_cache: - # Pre-fetch schema and store in local cache try: schema_cache[cache_key] = get_feature_schema(mp_config_id, version, inference_host) except Exception as e: - # Log warning but don't fail - will be caught again in main loop warnings.warn(f"Failed to pre-fetch schema for {cache_key}: {e}", UserWarning) - all_decoded_rows = [] - - # Metadata columns to preserve row_metadata_columns = [ "prism_ingested_at", "prism_extracted_at", @@ -361,158 +350,144 @@ def _extract_metadata_byte(metadata_data) -> int: "day", "hour", ] - - for idx, row in enumerate(rows): - # Extract features data - features_data = row[features_column] - if features_data is None: - continue - - # Extract metadata byte - metadata_data = row[metadata_column] - metadata_byte = _extract_metadata_byte(metadata_data) - - # Extract version from metadata byte - _, version, _ = unpack_metadata_byte(metadata_byte) - - # Validate version range - if not (0 <= version <= _MAX_SCHEMA_VERSION): - warnings.warn( - f"Row {idx}: Version {version} extracted from metadata is out of valid range (0-{_MAX_SCHEMA_VERSION}). " - f"This may indicate corrupted metadata.", - UserWarning, - ) - continue - - # Extract mp_config_id - mp_config_id = row[mp_config_id_column] - if mp_config_id is None: - continue - mp_config_id = str(mp_config_id) - - # Lookup cached schema - cache_key = (mp_config_id, version) - cached_schema = schema_cache.get(cache_key) - + # Build full output schema: entity_id + metadata cols + all feature names from all schemas + all_feature_names = set() + for feat_list in schema_cache.values(): + for f in feat_list: + all_feature_names.add(f.name) + metadata_cols_in_schema = [c for c in row_metadata_columns if c in df_columns] + from pyspark.sql.types import StringType, StructField, StructType + schema_fields = [StructField("entity_id", StringType(), True)] + for c in metadata_cols_in_schema: + schema_fields.append(StructField(c, StringType(), True)) + for c in sorted(all_feature_names): + schema_fields.append(StructField(c, StringType(), True)) + full_schema = StructType(schema_fields) + all_columns_ordered = ["entity_id"] + metadata_cols_in_schema + sorted(all_feature_names) + + def _safe_get(row, col, default=None): try: - # Parse features JSON (expected format: JSON array of dicts with encoded_features) - if isinstance(features_data, str): - features_list = json.loads(features_data) - else: - features_list = features_data - - if not isinstance(features_list, list): - warnings.warn(f"Row {idx}: features is not a list, skipping", UserWarning) - continue - - # Get entities from row - entities_val = None - if "entities" in df_columns: - entities_raw = row["entities"] - if entities_raw is not None: - if isinstance(entities_raw, str): - try: - entities_val = json.loads(entities_raw) - except (json.JSONDecodeError, ValueError): - entities_val = [entities_raw] - elif isinstance(entities_raw, list): - entities_val = entities_raw - else: - entities_val = [entities_raw] - - # Use cached schema or fetch - feature_schema = cached_schema - if feature_schema is None: - feature_schema = get_feature_schema(mp_config_id, version, inference_host) - - # Determine format type from metadata byte - # unpack_metadata_byte returns (compression_enabled, version, format_type) - _, _, format_type_num = unpack_metadata_byte(metadata_byte) - if format_type_num in FORMAT_TYPE_MAP: - detected_format = FORMAT_TYPE_MAP[format_type_num] - else: - detected_format = Format.PROTO # Default to proto - - # Process parent_entity - parent_entity_val = None - if "parent_entity" in df_columns and row["parent_entity"] is not None: - parent_val = row["parent_entity"] - if isinstance(parent_val, str): - try: - parent_val = json.loads(parent_val) - except (json.JSONDecodeError, ValueError): - parent_val = [parent_val] - if isinstance(parent_val, list): - if len(parent_val) == 1: - parent_entity_val = parent_val[0] - elif len(parent_val) > 1: - parent_entity_val = str(parent_val) - else: - parent_entity_val = None - else: - parent_entity_val = parent_val - - # Process each entity's features - for i, feature_item in enumerate(features_list): - # Get entity_id from entities array or generate synthetic - entity_id = f"entity_{i}" - if entities_val and i < len(entities_val): - entity_id = str(entities_val[i]) - - # Get and decode base64 encoded_features - encoded_features_b64 = feature_item.get("encoded_features", "") - if not encoded_features_b64: + val = row[col] if col in row.index else getattr(row, col, default) + if hasattr(val, "isna") and val.isna(): + return default + return val + except (KeyError, AttributeError): + return default + + def _decode_batch(iterator): + import pandas as pd + for pdf in iterator: + out_rows = [] + for idx, row in pdf.iterrows(): + features_data = _safe_get(row, features_column) + if features_data is None: continue - - try: - encoded_bytes = base64.b64decode(encoded_features_b64) - except (ValueError, TypeError): + metadata_data = _safe_get(row, metadata_column) + metadata_byte = _extract_metadata_byte(metadata_data, json, base64) + _, version, _ = unpack_metadata_byte(metadata_byte) + if not (0 <= version <= _MAX_SCHEMA_VERSION): continue - - if len(encoded_bytes) == 0: + mp_config_id = _safe_get(row, mp_config_id_column) + if mp_config_id is None: continue - - # Attempt decompression if enabled - working_data = encoded_bytes - if decompress: - working_data = _decompress_zstd(encoded_bytes) - - # Decode features based on format type - if detected_format == Format.ARROW: - decoded_features = decode_arrow_features(working_data, feature_schema) - elif detected_format == Format.PARQUET: - decoded_features = decode_parquet_features(working_data, feature_schema) + mp_config_id = str(mp_config_id) + cache_key = (mp_config_id, version) + feature_schema = schema_cache.get(cache_key) + if feature_schema is None: + try: + feature_schema = get_feature_schema(mp_config_id, version, inference_host) + except Exception: + continue + if isinstance(features_data, str): + try: + features_list = json.loads(features_data) + except (json.JSONDecodeError, ValueError, TypeError): + continue else: - # Default to proto format - decoded_features = decode_proto_features(working_data, feature_schema) - - result_row = {"entity_id": entity_id} - result_row.update(decoded_features) - - # Add metadata columns - for col in row_metadata_columns: - if col in df_columns: - result_row[col] = row[col] - - # Set parent_entity - if parent_entity_val is not None: - result_row["parent_entity"] = parent_entity_val - - all_decoded_rows.append(result_row) - - except Exception as e: - decode_errors.append((idx, str(e))) - warnings.warn(f"Failed to decode row {idx}: {e}", UserWarning) - continue - - if not all_decoded_rows: - from pyspark.sql.types import StructType - return spark.createDataFrame([], StructType([])) - - # Create Spark DataFrame from all decoded rows - result_df = spark.createDataFrame(all_decoded_rows) - - # Reorder columns: entity_id first, then metadata columns, then features + features_list = features_data + if not isinstance(features_list, list): + continue + entities_val = None + if "entities" in df_columns: + entities_raw = _safe_get(row, "entities") + if entities_raw is not None: + if isinstance(entities_raw, str): + try: + entities_val = json.loads(entities_raw) + except (json.JSONDecodeError, ValueError): + entities_val = [entities_raw] + elif isinstance(entities_raw, list): + entities_val = entities_raw + else: + entities_val = [entities_raw] + _, _, format_type_num = unpack_metadata_byte(metadata_byte) + detected_format = FORMAT_TYPE_MAP.get(format_type_num, Format.PROTO) + parent_entity_val = None + if "parent_entity" in df_columns: + parent_val = _safe_get(row, "parent_entity") + if parent_val is not None: + if isinstance(parent_val, str): + try: + parent_val = json.loads(parent_val) + except (json.JSONDecodeError, ValueError): + parent_val = [parent_val] + if isinstance(parent_val, list): + parent_entity_val = parent_val[0] if len(parent_val) == 1 else str(parent_val) if len(parent_val) > 1 else None + else: + parent_entity_val = parent_val + for i, feature_item in enumerate(features_list): + if not isinstance(feature_item, dict): + continue + entity_id = str(entities_val[i]) if entities_val and i < len(entities_val) else f"entity_{i}" + encoded_features_b64 = feature_item.get("encoded_features", "") + if not encoded_features_b64: + continue + try: + encoded_bytes = base64.b64decode(encoded_features_b64) + except (ValueError, TypeError): + continue + if len(encoded_bytes) == 0: + continue + working_data = encoded_bytes + if decompress: + working_data = _decompress_zstd(encoded_bytes) + try: + if detected_format == Format.ARROW: + decoded_features = decode_arrow_features(working_data, feature_schema) + elif detected_format == Format.PARQUET: + decoded_features = decode_parquet_features(working_data, feature_schema) + else: + decoded_features = decode_proto_features(working_data, feature_schema) + except Exception: + continue + result_row = {"entity_id": entity_id} + result_row.update(decoded_features) + for col in row_metadata_columns: + if col in df_columns: + result_row[col] = _safe_get(row, col) + if parent_entity_val is not None: + result_row["parent_entity"] = parent_entity_val + # Fill missing schema columns with None + for col in all_columns_ordered: + if col not in result_row: + result_row[col] = None + out_rows.append(result_row) + if out_rows: + out_pdf = pd.DataFrame(out_rows, columns=all_columns_ordered) + yield out_pdf + + n_partitions = num_partitions if num_partitions is not None else 10000 + df_repart = df.repartition(n_partitions) + + batch_limit = max_records_per_batch if max_records_per_batch is not None else 200 + prev_max_records = spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch") + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", str(batch_limit)) + try: + result_df = df_repart.mapInPandas(_decode_batch, full_schema) + finally: + spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", prev_max_records or "10000") + + # Reorder columns: entity_id first, then metadata, then features result_columns = result_df.columns metadata_cols = ["entity_id"] for col in [ @@ -530,8 +505,6 @@ def _extract_metadata_byte(metadata_data) -> int: ]: if col in result_columns: metadata_cols.append(col) - - feature_cols = [col for col in result_columns if col not in metadata_cols] + feature_cols = [c for c in result_columns if c not in metadata_cols] column_order = metadata_cols + feature_cols - return result_df.select(column_order) diff --git a/py-sdk/inference_logging_client/inference_logging_client/cli.py b/py-sdk/inference_logging_client/inference_logging_client/cli.py index 14e64ba0..a5a1ce93 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/cli.py +++ b/py-sdk/inference_logging_client/inference_logging_client/cli.py @@ -2,8 +2,11 @@ import argparse import base64 +import glob import os +import shutil import sys +import tempfile from . import decode_mplog, format_dataframe_floats, get_format_name, get_mplog_metadata from .types import Format @@ -126,12 +129,27 @@ def main(): print(f"Output written to {args.output}") else: if args.json: - # Collect and print as JSON - import json - rows = [row.asDict() for row in df.collect()] - print(json.dumps(rows, indent=2, default=str)) + # Avoid collect() for large DataFrames: write to temp dir then stream to stdout + tmpdir = tempfile.mkdtemp(prefix="inference_logging_client_json_") + try: + df.coalesce(1).write.mode("overwrite").json(tmpdir) + part_files = sorted(glob.glob(os.path.join(tmpdir, "part-*"))) + print("[") + first = True + for path in part_files: + with open(path) as f: + for line in f: + line = line.strip() + if line: + if not first: + print(",") + print(" " + line, end="") + first = False + print("\n]" if not first else "]") + finally: + shutil.rmtree(tmpdir, ignore_errors=True) else: - # Show table + # Show table (only fetches 20 rows, no full collect) df.show(truncate=False) # Get metadata for summary @@ -150,7 +168,12 @@ def main(): print( f"Compression: {'enabled' if metadata.compression_enabled else 'disabled'}", file=sys.stderr ) - print(f"Rows: {df.count()}", file=sys.stderr) + # Avoid full count() for huge DataFrames: use limit(1).count() for empty check only + try: + row_count = df.count() + print(f"Rows: {row_count}", file=sys.stderr) + except Exception: + print("Rows: (count skipped - use --output to write without summary)", file=sys.stderr) print(f"Columns: {len(df.columns)}", file=sys.stderr) col_preview = df.columns[1:5] if len(df.columns) > 1 else [] print( From aeb91eed05825033030f2f3460e8baf0aaa2cc6b Mon Sep 17 00:00:00 2001 From: dhruvgupta-meesho Date: Mon, 9 Feb 2026 13:03:39 +0530 Subject: [PATCH 2/4] small bug --- .../inference_logging_client/__init__.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 95130f17..28e745b0 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -461,7 +461,16 @@ def _decode_batch(iterator): except Exception: continue result_row = {"entity_id": entity_id} - result_row.update(decoded_features) + # Convert all feature values to strings for schema compatibility + for k, v in decoded_features.items(): + if v is None: + result_row[k] = None + elif isinstance(v, (list, tuple)): + result_row[k] = str(v) + elif isinstance(v, bytes): + result_row[k] = v.hex() + else: + result_row[k] = str(v) for col in row_metadata_columns: if col in df_columns: result_row[col] = _safe_get(row, col) From d303727d60b917802b3a4d9d871e5915998c10df Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Mon, 9 Feb 2026 14:41:45 +0530 Subject: [PATCH 3/4] Update pyproject.toml --- py-sdk/inference_logging_client/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index 396ad071..cd327326 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inference-logging-client" -version = "0.1.0" +version = "0.2.1" description = "Decode MPLog feature logs from proto, arrow, or parquet format" readme = "readme.md" requires-python = ">=3.8" From ca346b7126b2161c4f5090e0f269f47b314edb09 Mon Sep 17 00:00:00 2001 From: dhruvgupta-meesho Date: Wed, 11 Feb 2026 16:15:45 +0530 Subject: [PATCH 4/4] preserving fields --- .../inference_logging_client/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 28e745b0..a5150c3f 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -357,9 +357,13 @@ def decode_mplog_dataframe( all_feature_names.add(f.name) metadata_cols_in_schema = [c for c in row_metadata_columns if c in df_columns] from pyspark.sql.types import StringType, StructField, StructType + # Map input column names to their Spark types so we can preserve them in the output + input_field_map = {field.name: field.dataType for field in df.schema.fields} schema_fields = [StructField("entity_id", StringType(), True)] for c in metadata_cols_in_schema: - schema_fields.append(StructField(c, StringType(), True)) + # Preserve the original type (LongType, TimestampType, etc.) + original_type = input_field_map.get(c, StringType()) + schema_fields.append(StructField(c, original_type, True)) for c in sorted(all_feature_names): schema_fields.append(StructField(c, StringType(), True)) full_schema = StructType(schema_fields) @@ -473,6 +477,8 @@ def _decode_batch(iterator): result_row[k] = str(v) for col in row_metadata_columns: if col in df_columns: + # Pass through as-is to preserve original types + # (LongType, TimestampType, etc.) result_row[col] = _safe_get(row, col) if parent_entity_val is not None: result_row["parent_entity"] = parent_entity_val