From cd31ffa53880fd8aa7c57c9145615e0f619b039f Mon Sep 17 00:00:00 2001 From: dhruvgupta-meesho Date: Tue, 3 Feb 2026 15:28:58 +0530 Subject: [PATCH 1/4] added pyspark logic --- .../inference_logging_client/__init__.py | 379 ++++-- .../inference_logging_client/cli.py | 171 ++- .../inference_logging_client/decoder.py | 202 +-- .../inference_logging_client/exceptions.py | 6 + .../inference_logging_client/formats.py | 159 ++- .../inference_logging_client/io.py | 142 +- .../inference_logging_client/types.py | 7 +- .../inference_logging_client/utils.py | 163 ++- .../inference_logging_client/pyproject.toml | 2 +- py-sdk/inference_logging_client/readme.md | 1210 +++++++++++++++-- 10 files changed, 1844 insertions(+), 597 deletions(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 8ed3f529..447e6482 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -4,39 +4,42 @@ This package provides functionality to: 1. Decode MPLog feature logs from various encoding formats (proto, arrow, parquet) 2. Fetch feature schemas from inference API -3. Convert decoded logs to pandas DataFrames +3. Convert decoded logs to Spark DataFrames Main functions: - - decode_mplog: Decode MPLog bytes to a DataFrame - - decode_mplog_dataframe: Decode MPLog features from a DataFrame + - decode_mplog: Decode MPLog bytes to a Spark DataFrame + - decode_mplog_dataframe: Decode MPLog features from a Spark DataFrame - get_mplog_metadata: Extract metadata from MPLog bytes """ import warnings -from typing import Optional +from typing import TYPE_CHECKING, Optional -import pandas as pd +if TYPE_CHECKING: + from pyspark.sql import DataFrame as SparkDataFrame + from pyspark.sql import SparkSession # Check for zstandard availability at import time for clear error messages try: import zstandard as zstd + _ZSTD_AVAILABLE = True except ImportError: _ZSTD_AVAILABLE = False zstd = None -from .types import Format, FeatureInfo, DecodedMPLog, FORMAT_TYPE_MAP -from .io import get_feature_schema, parse_mplog_protobuf, get_mplog_metadata, clear_schema_cache -from .formats import decode_proto_format, decode_arrow_format, decode_parquet_format -from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte from .exceptions import ( - InferenceLoggingError, - SchemaFetchError, - SchemaNotFoundError, DecodeError, FormatError, + InferenceLoggingError, ProtobufError, + SchemaFetchError, + SchemaNotFoundError, ) +from .formats import decode_arrow_format, decode_parquet_format, decode_proto_format +from .io import clear_schema_cache, get_feature_schema, get_mplog_metadata, parse_mplog_protobuf +from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format +from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte __version__ = "0.1.0" @@ -49,6 +52,7 @@ "get_mplog_metadata", "get_feature_schema", "clear_schema_cache", + "format_dataframe_floats", "Format", "FeatureInfo", "DecodedMPLog", @@ -66,18 +70,18 @@ def _decompress_zstd(data: bytes) -> bytes: """Decompress zstd-compressed data. - + Args: data: Potentially zstd-compressed bytes - + Returns: Decompressed bytes, or original data if not compressed or zstd unavailable - + Raises: ImportError: If data is zstd-compressed but zstandard is not installed """ # Check for zstd magic number: 0x28 0xB5 0x2F 0xFD - if len(data) >= 4 and data[:4] == b'\x28\xB5\x2F\xFD': + if len(data) >= 4 and data[:4] == b"\x28\xb5\x2f\xfd": if not _ZSTD_AVAILABLE: raise ImportError( "Data appears to be zstd-compressed but the 'zstandard' package is not installed. " @@ -92,60 +96,65 @@ def decode_mplog( log_data: bytes, model_proxy_id: str, version: int, + spark: "SparkSession", format_type: Optional[Format] = None, inference_host: Optional[str] = None, decompress: bool = True, - schema: Optional[list] = None -) -> pd.DataFrame: + schema: Optional[list] = None, +) -> "SparkDataFrame": """ - Main function to decode MPLog bytes to a DataFrame. - + Main function to decode MPLog bytes to a Spark DataFrame. + Args: log_data: The MPLog bytes (possibly compressed) model_proxy_id: The model proxy config ID version: The schema version (0-15) + spark: The SparkSession to use for creating DataFrames format_type: The encoding format (proto, arrow, parquet). If None, auto-detect from metadata. inference_host: The inference service host URL. If None, reads from INFERENCE_HOST env var. decompress: Whether to attempt zstd decompression schema: Optional pre-fetched schema (list of FeatureInfo). If provided, skips schema fetch. - + Returns: - pandas DataFrame with entity_id as first column and features as remaining columns - + Spark DataFrame with entity_id as first column and features as remaining columns + Raises: ValueError: If version is out of valid range (0-15) ImportError: If data is zstd-compressed but zstandard is not installed FormatError: If format is unsupported or data cannot be parsed - + Example: + >>> from pyspark.sql import SparkSession >>> import inference_logging_client + >>> spark = SparkSession.builder.appName("decode").getOrCreate() >>> with open("log.bin", "rb") as f: ... data = f.read() >>> df = inference_logging_client.decode_mplog( ... log_data=data, ... model_proxy_id="my-model", - ... version=1 + ... version=1, + ... spark=spark ... ) - >>> print(df.head()) + >>> df.show() """ import os - + # Validate version range if not (0 <= version <= _MAX_SCHEMA_VERSION): raise ValueError( f"Version {version} is out of valid range (0-{_MAX_SCHEMA_VERSION}). " f"Version is encoded in 4 bits of the metadata byte." ) - + # Read from environment variable if not provided if inference_host is None: inference_host = os.getenv("INFERENCE_HOST", "http://localhost:8082") - + # Attempt decompression if enabled working_data = log_data if decompress: working_data = _decompress_zstd(log_data) - + # If format_type is None, parse the protobuf to get format from metadata detected_format = format_type if detected_format is None: @@ -156,11 +165,11 @@ def decode_mplog( else: # Default to proto if format type is unknown detected_format = Format.PROTO - + # Use provided schema or fetch from inference service if schema is None: schema = get_feature_schema(model_proxy_id, version, inference_host) - + # Decode based on format if detected_format == Format.PROTO: entity_ids, decoded_rows = decode_proto_format(working_data, schema) @@ -170,90 +179,110 @@ def decode_mplog( entity_ids, decoded_rows = decode_parquet_format(working_data, schema) else: raise FormatError(f"Unsupported format: {detected_format}") - + if not decoded_rows: - # Return empty DataFrame with correct columns - columns = ["entity_id"] + [f.name for f in schema] - return pd.DataFrame(columns=columns) - - # Build DataFrame - df = pd.DataFrame(decoded_rows) - - # Insert entity_id as first column - df.insert(0, "entity_id", entity_ids) - - return df + # Return empty DataFrame with correct schema + from pyspark.sql.types import StringType, StructField, StructType + + # Build empty schema with entity_id + feature columns + fields = [StructField("entity_id", StringType(), True)] + for f in schema: + fields.append(StructField(f.name, StringType(), True)) + empty_schema = StructType(fields) + return spark.createDataFrame([], empty_schema) + + # Build rows with entity_id as first field + rows = [] + for entity_id, row_data in zip(entity_ids, decoded_rows): + row = {"entity_id": entity_id} + row.update(row_data) + rows.append(row) + + # Create Spark DataFrame from list of dicts + return spark.createDataFrame(rows) def decode_mplog_dataframe( - df: pd.DataFrame, + df: "SparkDataFrame", + spark: "SparkSession", inference_host: Optional[str] = None, decompress: bool = True, features_column: str = "features", metadata_column: str = "metadata", - mp_config_id_column: str = "mp_config_id" -) -> pd.DataFrame: + mp_config_id_column: str = "mp_config_id", +) -> "SparkDataFrame": """ - Decode MPLog features from a DataFrame with specific column structure. - + Decode MPLog features from a Spark DataFrame with specific column structure. + Expected DataFrame columns: - prism_ingested_at, prism_extracted_at, created_at - entities, features, metadata - mp_config_id, parent_entity, tracking_id, user_id - year, month, day, hour - + + Note: This function collects the DataFrame to the driver for processing. + For very large datasets, consider partitioning and processing in smaller batches. + Args: - df: Input DataFrame with MPLog data columns + df: Input Spark DataFrame with MPLog data columns + spark: The SparkSession to use for creating the result DataFrame inference_host: The inference service host URL. If None, reads from INFERENCE_HOST env var. decompress: Whether to attempt zstd decompression features_column: Name of the column containing encoded features (default: "features") metadata_column: Name of the column containing metadata byte (default: "metadata") mp_config_id_column: Name of the column containing model proxy config ID (default: "mp_config_id") - + Returns: - pandas DataFrame with decoded features. Each row from input becomes multiple rows + Spark DataFrame with decoded features. Each row from input becomes multiple rows (one per entity) with entity_id as first column and features as remaining columns. Original row metadata (prism_ingested_at, mp_config_id, etc.) is preserved. - + Example: - >>> import pandas as pd + >>> from pyspark.sql import SparkSession >>> import inference_logging_client - >>> df = pd.read_parquet("logs.parquet") - >>> decoded_df = inference_logging_client.decode_mplog_dataframe(df) - >>> print(decoded_df.head()) + >>> spark = SparkSession.builder.appName("decode").getOrCreate() + >>> df = spark.read.parquet("logs.parquet") + >>> decoded_df = inference_logging_client.decode_mplog_dataframe(df, spark) + >>> decoded_df.show() """ - import os - import sys - import json import base64 - + import json + import os + # Read from environment variable if not provided if inference_host is None: inference_host = os.getenv("INFERENCE_HOST", "http://localhost:8082") - + # Track decode errors for summary decode_errors = [] - - if df.empty: - return pd.DataFrame() - + + # Check if DataFrame is empty + if df.count() == 0: + from pyspark.sql.types import StructType + return spark.createDataFrame([], StructType([])) + # Validate required columns required_columns = [features_column, metadata_column, mp_config_id_column] - missing_columns = [col for col in required_columns if col not in df.columns] + df_columns = df.columns + missing_columns = [col for col in required_columns if col not in df_columns] if missing_columns: raise ValueError(f"Missing required columns: {missing_columns}") - + + # Collect to driver for processing + # Note: For large datasets, consider using mapInPandas or processing in partitions + rows = df.collect() + # Pre-fetch schemas for unique (mp_config_id, version) combinations to avoid # redundant HTTP requests during row iteration. # Key: (mp_config_id, version) only - host/path intentionally excluded as schemas are canonical schema_cache: dict[tuple[str, int], list[FeatureInfo]] = {} - + # First pass: collect unique (mp_config_id, version) pairs - for idx, row in df.iterrows(): + for row in rows: # Extract metadata byte to get version metadata_data = row[metadata_column] metadata_byte = 0 - if not pd.isna(metadata_data): + if metadata_data is not None: if isinstance(metadata_data, (int, float)): metadata_byte = int(metadata_data) elif isinstance(metadata_data, bytes) and len(metadata_data) > 0: @@ -265,19 +294,19 @@ def decode_mplog_dataframe( metadata_byte = int(metadata_data) except ValueError: pass - + _, version, _ = unpack_metadata_byte(metadata_byte) - + # Skip invalid versions if not (0 <= version <= _MAX_SCHEMA_VERSION): continue - + # Extract mp_config_id mp_config_id = row[mp_config_id_column] - if pd.isna(mp_config_id): + if mp_config_id is None: continue mp_config_id = str(mp_config_id) - + cache_key = (mp_config_id, version) if cache_key not in schema_cache: # Pre-fetch schema and store in local cache @@ -286,15 +315,15 @@ def decode_mplog_dataframe( except Exception as e: # Log warning but don't fail - will be caught again in main loop warnings.warn(f"Failed to pre-fetch schema for {cache_key}: {e}", UserWarning) - + all_decoded_rows = [] - - for idx, row in df.iterrows(): + + for idx, row in enumerate(rows): # Extract features bytes features_data = row[features_column] - if pd.isna(features_data): + if features_data is None: continue - + # Convert features to bytes (handle base64, hex, or raw bytes) features_bytes = None if isinstance(features_data, bytes): @@ -309,19 +338,19 @@ def decode_mplog_dataframe( features_bytes = bytes.fromhex(features_data) except Exception: # Try UTF-8 encoding - features_bytes = features_data.encode('utf-8') + features_bytes = features_data.encode("utf-8") elif isinstance(features_data, (bytearray, memoryview)): features_bytes = bytes(features_data) else: continue - + if features_bytes is None or len(features_bytes) == 0: continue - + # Extract metadata byte metadata_data = row[metadata_column] metadata_byte = 0 - if not pd.isna(metadata_data): + if metadata_data is not None: if isinstance(metadata_data, (int, float)): metadata_byte = int(metadata_data) elif isinstance(metadata_data, bytes) and len(metadata_data) > 0: @@ -333,72 +362,91 @@ def decode_mplog_dataframe( metadata_byte = int(metadata_data) except ValueError: pass - + # Extract version from metadata byte _, version, _ = unpack_metadata_byte(metadata_byte) - + # Validate version range if not (0 <= version <= _MAX_SCHEMA_VERSION): warnings.warn( f"Row {idx}: Version {version} extracted from metadata is out of valid range (0-{_MAX_SCHEMA_VERSION}). " f"This may indicate corrupted metadata.", - UserWarning + UserWarning, ) continue - + # Extract mp_config_id mp_config_id = row[mp_config_id_column] - if pd.isna(mp_config_id): + if mp_config_id is None: continue mp_config_id = str(mp_config_id) - + # Lookup cached schema cache_key = (mp_config_id, version) cached_schema = schema_cache.get(cache_key) - - # Decode this row's features using cached schema + + # Decode this row's features try: - decoded_df = decode_mplog( - log_data=features_bytes, - model_proxy_id=mp_config_id, - version=version, - format_type=None, # Auto-detect from metadata - inference_host=inference_host, - decompress=decompress, - schema=cached_schema # Pass cached schema to avoid redundant fetches - ) - + # Attempt decompression if enabled + working_data = features_bytes + if decompress: + working_data = _decompress_zstd(features_bytes) + + # Parse protobuf to get format from metadata + parsed = parse_mplog_protobuf(working_data) + if parsed.format_type in FORMAT_TYPE_MAP: + detected_format = FORMAT_TYPE_MAP[parsed.format_type] + else: + detected_format = Format.PROTO + + # Use cached schema or fetch + feature_schema = cached_schema + if feature_schema is None: + feature_schema = get_feature_schema(mp_config_id, version, inference_host) + + # Decode based on format + if detected_format == Format.PROTO: + entity_ids, decoded_feature_rows = decode_proto_format(working_data, feature_schema) + elif detected_format == Format.ARROW: + entity_ids, decoded_feature_rows = decode_arrow_format(working_data, feature_schema) + elif detected_format == Format.PARQUET: + entity_ids, decoded_feature_rows = decode_parquet_format(working_data, feature_schema) + else: + raise FormatError(f"Unsupported format: {detected_format}") + # Add original row metadata to each decoded entity row - if not decoded_df.empty: - # Preserve original metadata columns + if decoded_feature_rows: + # Metadata columns to preserve metadata_columns = [ - "prism_ingested_at", "prism_extracted_at", "created_at", - "mp_config_id", "parent_entity", "tracking_id", "user_id", - "year", "month", "day", "hour" + "prism_ingested_at", + "prism_extracted_at", + "created_at", + "mp_config_id", + "parent_entity", + "tracking_id", + "user_id", + "year", + "month", + "day", + "hour", ] - - for col in metadata_columns: - if col in df.columns: - decoded_df[col] = row[col] - - # Update entity_id from entities column if available and matches count - if "entities" in df.columns and not pd.isna(row["entities"]): - # entities might be a list or string representation + + # Get entities from row if available + entities_val = None + if "entities" in df_columns: entities_val = row["entities"] - if isinstance(entities_val, str): - try: - entities_val = json.loads(entities_val) - except (json.JSONDecodeError, ValueError): + if entities_val is not None: + if isinstance(entities_val, str): + try: + entities_val = json.loads(entities_val) + except (json.JSONDecodeError, ValueError): + entities_val = [entities_val] + elif not isinstance(entities_val, list): entities_val = [entities_val] - elif not isinstance(entities_val, list): - entities_val = [entities_val] - - # Match entities with decoded rows (only if counts match) - if len(entities_val) == len(decoded_df): - decoded_df["entity_id"] = entities_val - - # Add parent_entity if it exists - if "parent_entity" in df.columns and not pd.isna(row["parent_entity"]): + + # Process parent_entity + parent_entity_val = None + if "parent_entity" in df_columns and row["parent_entity"] is not None: parent_val = row["parent_entity"] if isinstance(parent_val, str): try: @@ -406,38 +454,67 @@ def decode_mplog_dataframe( except (json.JSONDecodeError, ValueError): parent_val = [parent_val] if isinstance(parent_val, list): - # If list, use first element or join if multiple if len(parent_val) == 1: - decoded_df["parent_entity"] = parent_val[0] + parent_entity_val = parent_val[0] elif len(parent_val) > 1: - decoded_df["parent_entity"] = str(parent_val) + parent_entity_val = str(parent_val) else: - decoded_df["parent_entity"] = None + parent_entity_val = None else: - decoded_df["parent_entity"] = parent_val - - all_decoded_rows.append(decoded_df) + parent_entity_val = parent_val + + for i, (entity_id, feature_row) in enumerate(zip(entity_ids, decoded_feature_rows)): + result_row = {"entity_id": entity_id} + result_row.update(feature_row) + + # Add metadata columns + for col in metadata_columns: + if col in df_columns: + result_row[col] = row[col] + + # Override entity_id from entities column if available and matches count + if entities_val and len(entities_val) == len(entity_ids): + result_row["entity_id"] = entities_val[i] + + # Set parent_entity + if parent_entity_val is not None: + result_row["parent_entity"] = parent_entity_val + + all_decoded_rows.append(result_row) + except Exception as e: # Track error but continue processing other rows decode_errors.append((idx, str(e))) warnings.warn(f"Failed to decode row {idx}: {e}", UserWarning) continue - + if not all_decoded_rows: - return pd.DataFrame() - - # Combine all decoded DataFrames - result_df = pd.concat(all_decoded_rows, ignore_index=True) - + from pyspark.sql.types import StructType + return spark.createDataFrame([], StructType([])) + + # Create Spark DataFrame from all decoded rows + result_df = spark.createDataFrame(all_decoded_rows) + # Reorder columns: entity_id first, then metadata columns, then features + result_columns = result_df.columns metadata_cols = ["entity_id"] - for col in ["prism_ingested_at", "prism_extracted_at", "created_at", - "mp_config_id", "parent_entity", "tracking_id", "user_id", - "year", "month", "day", "hour"]: - if col in result_df.columns: + for col in [ + "prism_ingested_at", + "prism_extracted_at", + "created_at", + "mp_config_id", + "parent_entity", + "tracking_id", + "user_id", + "year", + "month", + "day", + "hour", + ]: + if col in result_columns: metadata_cols.append(col) - - feature_cols = [col for col in result_df.columns if col not in metadata_cols] + + feature_cols = [col for col in result_columns if col not in metadata_cols] column_order = metadata_cols + feature_cols - - return result_df[column_order] + + return result_df.select(column_order) diff --git a/py-sdk/inference_logging_client/inference_logging_client/cli.py b/py-sdk/inference_logging_client/inference_logging_client/cli.py index 43f1eeb7..14e64ba0 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/cli.py +++ b/py-sdk/inference_logging_client/inference_logging_client/cli.py @@ -1,13 +1,11 @@ """Command-line interface for inference-logging-client.""" +import argparse +import base64 import os import sys -import base64 -import argparse -import pandas as pd - -from . import decode_mplog, get_mplog_metadata, get_format_name, format_dataframe_floats +from . import decode_mplog, format_dataframe_floats, get_format_name, get_mplog_metadata from .types import Format @@ -29,43 +27,50 @@ def main(): # Output to CSV inference-logging-client --model-proxy-id my-model --version 1 input.bin -o output.csv - """ + """, ) - + parser.add_argument("input", help="Input file containing MPLog bytes (or - for stdin)") - parser.add_argument("--model-proxy-id", "-m", required=True, - help="Model proxy config ID") - parser.add_argument("--version", "-v", type=int, required=True, - help="Schema version") - parser.add_argument("--format", "-f", choices=["proto", "arrow", "parquet", "auto"], - default="auto", - help="Encoding format (default: auto-detect from metadata)") - parser.add_argument("--inference-host", default=None, - help="Inference service host URL (default: reads from INFERENCE_HOST env var or http://localhost:8082)") - parser.add_argument("--hex", action="store_true", - help="Input is hex-encoded string") - parser.add_argument("--base64", action="store_true", - help="Input is base64-encoded string") - parser.add_argument("--no-decompress", action="store_true", - help="Skip automatic zstd decompression") - parser.add_argument("--output", "-o", - help="Output file (CSV format, default: print to stdout)") - parser.add_argument("--json", action="store_true", - help="Output as JSON instead of CSV") - + parser.add_argument("--model-proxy-id", "-m", required=True, help="Model proxy config ID") + parser.add_argument("--version", "-v", type=int, required=True, help="Schema version") + parser.add_argument( + "--format", + "-f", + choices=["proto", "arrow", "parquet", "auto"], + default="auto", + help="Encoding format (default: auto-detect from metadata)", + ) + parser.add_argument( + "--inference-host", + default=None, + help="Inference service host URL (default: reads from INFERENCE_HOST env var or http://localhost:8082)", + ) + parser.add_argument("--hex", action="store_true", help="Input is hex-encoded string") + parser.add_argument("--base64", action="store_true", help="Input is base64-encoded string") + parser.add_argument( + "--no-decompress", action="store_true", help="Skip automatic zstd decompression" + ) + parser.add_argument("--output", "-o", help="Output file (CSV format, default: print to stdout)") + parser.add_argument("--json", action="store_true", help="Output as JSON instead of CSV") + parser.add_argument( + "--spark-master", + default="local[*]", + help="Spark master URL (default: local[*])", + ) + args = parser.parse_args() - + # Read input if args.input == "-": data = sys.stdin.buffer.read() else: with open(args.input, "rb") as f: data = f.read() - + # Decode input format if args.hex: try: - data = bytes.fromhex(data.decode('utf-8').strip()) + data = bytes.fromhex(data.decode("utf-8").strip()) except ValueError as e: print(f"Error: Invalid hex input: {e}", file=sys.stderr) sys.exit(1) @@ -75,59 +80,89 @@ def main(): except Exception as e: print(f"Error: Invalid base64 input: {e}", file=sys.stderr) sys.exit(1) - + # Parse format (None for auto-detection) format_type = None if args.format == "auto" else Format(args.format) - + # Get inference host from argument or environment variable inference_host = args.inference_host or os.getenv("INFERENCE_HOST", "http://localhost:8082") - - # Decode MPLog + + # Create SparkSession for CLI + from pyspark.sql import SparkSession + + spark = SparkSession.builder \ + .appName("inference-logging-client") \ + .master(args.spark_master) \ + .config("spark.ui.showConsoleProgress", "false") \ + .config("spark.driver.memory", "2g") \ + .getOrCreate() + + # Suppress Spark logging for cleaner CLI output + spark.sparkContext.setLogLevel("ERROR") + try: + # Decode MPLog df = decode_mplog( log_data=data, model_proxy_id=args.model_proxy_id, version=args.version, + spark=spark, format_type=format_type, inference_host=inference_host, - decompress=not args.no_decompress + decompress=not args.no_decompress, + ) + + # Format floats before output + df = format_dataframe_floats(df) + + # Output + if args.output: + if args.json: + # Write as JSON + df.coalesce(1).write.mode("overwrite").json(args.output) + else: + # Write as CSV + df.coalesce(1).write.mode("overwrite").option("header", "true").csv(args.output) + print(f"Output written to {args.output}") + else: + if args.json: + # Collect and print as JSON + import json + rows = [row.asDict() for row in df.collect()] + print(json.dumps(rows, indent=2, default=str)) + else: + # Show table + df.show(truncate=False) + + # Get metadata for summary + metadata = get_mplog_metadata(data, decompress=not args.no_decompress) + detected_format_name = get_format_name(metadata.format_type) + + # Print summary + print("\n--- Summary ---", file=sys.stderr) + print( + f"Format: {detected_format_name} (from metadata)" + if args.format == "auto" + else f"Format: {args.format}", + file=sys.stderr, + ) + print(f"Version: {metadata.version}", file=sys.stderr) + print( + f"Compression: {'enabled' if metadata.compression_enabled else 'disabled'}", file=sys.stderr + ) + print(f"Rows: {df.count()}", file=sys.stderr) + print(f"Columns: {len(df.columns)}", file=sys.stderr) + col_preview = df.columns[1:5] if len(df.columns) > 1 else [] + print( + f"Features: {', '.join(col_preview)}{'...' if len(df.columns) > 5 else ''}", + file=sys.stderr, ) except Exception as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) - - # Format floats before output - df = format_dataframe_floats(df) - - # Output - if args.output: - if args.json: - df.to_json(args.output, orient="records", indent=2) - else: - df.to_csv(args.output, index=False) - print(f"Output written to {args.output}") - else: - if args.json: - print(df.to_json(orient="records", indent=2)) - else: - # Pretty print for terminal - pd.set_option('display.max_columns', None) - pd.set_option('display.width', None) - pd.set_option('display.max_colwidth', 50) - print(df.to_string(index=False)) - - # Get metadata for summary - metadata = get_mplog_metadata(data, decompress=not args.no_decompress) - detected_format_name = get_format_name(metadata.format_type) - - # Print summary - print(f"\n--- Summary ---", file=sys.stderr) - print(f"Format: {detected_format_name} (from metadata)" if args.format == "auto" else f"Format: {args.format}", file=sys.stderr) - print(f"Version: {metadata.version}", file=sys.stderr) - print(f"Compression: {'enabled' if metadata.compression_enabled else 'disabled'}", file=sys.stderr) - print(f"Rows: {len(df)}", file=sys.stderr) - print(f"Columns: {len(df.columns)}", file=sys.stderr) - print(f"Features: {', '.join(df.columns[1:5])}{'...' if len(df.columns) > 5 else ''}", file=sys.stderr) + finally: + # Stop SparkSession + spark.stop() if __name__ == "__main__": diff --git a/py-sdk/inference_logging_client/inference_logging_client/decoder.py b/py-sdk/inference_logging_client/inference_logging_client/decoder.py index ee68bc9e..32e55341 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/decoder.py +++ b/py-sdk/inference_logging_client/inference_logging_client/decoder.py @@ -5,86 +5,90 @@ from typing import Any from .utils import ( - normalize_type, is_sized_type, get_scalar_size, - format_float, SCALAR_TYPE_SIZES, SIZED_TYPES + normalize_type, + is_sized_type, + get_scalar_size, + format_float, + SCALAR_TYPE_SIZES, + SIZED_TYPES, ) class ByteReader: """Helper class to read bytes sequentially.""" - + def __init__(self, data: bytes): self.data = data self.pos = 0 - + def read(self, n: int) -> bytes: if self.pos + n > len(self.data): raise ValueError(f"Not enough bytes: need {n}, have {len(self.data) - self.pos}") - result = self.data[self.pos:self.pos + n] + result = self.data[self.pos : self.pos + n] self.pos += n return result - + def read_uint8(self) -> int: return self.read(1)[0] - + def read_int8(self) -> int: - return struct.unpack(' int: - return struct.unpack(' int: - return struct.unpack(' int: - return struct.unpack(' int: - return struct.unpack(' int: - return struct.unpack(' int: - return struct.unpack(' float: """Read IEEE 754 half-precision (FP16) float.""" bits = self.read_uint16() - + # Extract IEEE 754 FP16 components sign = (bits >> 15) & 0x1 exponent = (bits >> 10) & 0x1F mantissa = bits & 0x3FF - + if exponent == 0: if mantissa == 0: return -0.0 if sign else 0.0 # Subnormal - return ((-1) ** sign) * (2 ** -14) * (mantissa / 1024.0) + return ((-1) ** sign) * (2**-14) * (mantissa / 1024.0) elif exponent == 31: if mantissa == 0: - return float('-inf') if sign else float('inf') - return float('nan') + return float("-inf") if sign else float("inf") + return float("nan") else: return ((-1) ** sign) * (2 ** (exponent - 15)) * (1.0 + mantissa / 1024.0) - + def read_float32(self) -> float: - return struct.unpack(' float: - return struct.unpack(' int: return len(self.data) - self.pos - + def has_more(self) -> bool: return self.pos < len(self.data) def read_varint(reader: ByteReader) -> int: """Read a protobuf varint. - + Raises: ValueError: If varint exceeds maximum allowed size (malformed data) """ @@ -95,7 +99,7 @@ def read_varint(reader: ByteReader) -> int: if shift > max_shift: raise ValueError("Malformed varint: exceeds maximum size") byte = reader.read_uint8() - result |= (byte & 0x7f) << shift + result |= (byte & 0x7F) << shift if not (byte & 0x80): break shift += 7 @@ -120,35 +124,35 @@ def skip_field(reader: ByteReader, wire_type: int): def decode_ieee754_fp16(value_bytes: bytes) -> float: """ Decode IEEE 754 half-precision (FP16) to float. - + IEEE 754 FP16 format: - 1 bit sign - 5 bits exponent (bias 15) - 10 bits mantissa - + This is the format used by feature stores for vector data. """ if len(value_bytes) != 2: return 0.0 - bits = struct.unpack('> 15) & 0x1 exponent = (bits >> 10) & 0x1F mantissa = bits & 0x3FF - + if exponent == 0: # Subnormal or zero if mantissa == 0: return -0.0 if sign else 0.0 # Subnormal number - result = ((-1) ** sign) * (2 ** -14) * (mantissa / 1024.0) + result = ((-1) ** sign) * (2**-14) * (mantissa / 1024.0) return format_float(result) elif exponent == 31: # Infinity or NaN if mantissa == 0: - return float('-inf') if sign else float('inf') - return float('nan') + return float("-inf") if sign else float("inf") + return float("nan") else: # Normal number result = ((-1) ** sign) * (2 ** (exponent - 15)) * (1.0 + mantissa / 1024.0) @@ -158,27 +162,27 @@ def decode_ieee754_fp16(value_bytes: bytes) -> float: def decode_scalar_value(value_bytes: bytes, feature_type: str) -> Any: """Decode a scalar value from bytes based on feature type.""" normalized = normalize_type(feature_type) - + if len(value_bytes) == 0: return None - + try: if normalized in {"INT8", "I8"}: - return struct.unpack(' Any: return decode_ieee754_fp16(value_bytes) return None elif normalized in {"FP32", "FLOAT32", "F32", "FLOAT"}: - result = struct.unpack(' Any: def decode_binary_vector(value_bytes: bytes, feature_type: str) -> list | None: """ Decode a binary-encoded vector based on element type. - + Binary vectors are packed element bytes in sequence. - + Returns: list: Decoded vector elements (may be empty for zero-length input). None: If the vector type is unsupported/unknown. """ normalized = normalize_type(feature_type) - + if len(value_bytes) == 0: return [] - + result = [] - + # Determine element type and size if "FP16" in normalized or "FLOAT16" in normalized: # FP16 vector: 2 bytes per element, IEEE 754 half-precision format elem_size = 2 for i in range(0, len(value_bytes), elem_size): if i + elem_size <= len(value_bytes): - elem_bytes = value_bytes[i:i + elem_size] + elem_bytes = value_bytes[i : i + elem_size] result.append(decode_ieee754_fp16(elem_bytes)) - + elif "FP32" in normalized or "FLOAT32" in normalized: # FP32 vector: 4 bytes per element elem_size = 4 for i in range(0, len(value_bytes), elem_size): if i + elem_size <= len(value_bytes): - result.append(format_float(struct.unpack(' bool: def decode_vector_or_string(value_bytes: bytes, feature_type: str) -> Any: """Decode a vector or string value from bytes.""" normalized = normalize_type(feature_type) - + # Handle STRING type (including empty bytes -> empty string) if normalized in {"STRING", "STR"}: if len(value_bytes) == 0: return "" try: - return value_bytes.decode('utf-8') + return value_bytes.decode("utf-8") except UnicodeDecodeError: return value_bytes.hex() - + if normalized in {"BYTES"}: # BYTES type: first 2 bytes are the length prefix, remaining bytes are the string content if len(value_bytes) < 2: return None # Read 2-byte little-endian length prefix - length = struct.unpack(' empty vector if len(value_bytes) == 0: return [] - + # For vectors, check if it's JSON or binary encoded # JSON vectors start with '[' (0x5b) if is_likely_json(value_bytes): try: - return json.loads(value_bytes.decode('utf-8')) + return json.loads(value_bytes.decode("utf-8")) except (json.JSONDecodeError, UnicodeDecodeError): pass - + # Try binary vector decoding decoded = decode_binary_vector(value_bytes, feature_type) if decoded is not None: return decoded - + # Fallback to hex if binary decoding returned None (unsupported type) return value_bytes.hex() - + # For non-vector sized types, try JSON decode first if is_likely_json(value_bytes): try: - return json.loads(value_bytes.decode('utf-8')) + return json.loads(value_bytes.decode("utf-8")) except (json.JSONDecodeError, UnicodeDecodeError): pass - + # Return hex representation for unknown binary data return value_bytes.hex() @@ -382,14 +386,14 @@ def decode_feature_value(value_bytes: bytes, feature_type: str) -> Any: """Decode a feature value based on its type.""" if value_bytes is None: return None - + # For sized types (VECTOR/STRING), delegate even for empty bytes # decode_vector_or_string handles empty bytes appropriately per type if is_sized_type(feature_type): return decode_vector_or_string(value_bytes, feature_type) - + # For non-sized scalar types, empty bytes means absent value if len(value_bytes) == 0: return None - + return decode_scalar_value(value_bytes, feature_type) diff --git a/py-sdk/inference_logging_client/inference_logging_client/exceptions.py b/py-sdk/inference_logging_client/inference_logging_client/exceptions.py index 9f40c2af..682f1206 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/exceptions.py +++ b/py-sdk/inference_logging_client/inference_logging_client/exceptions.py @@ -3,29 +3,35 @@ class InferenceLoggingError(Exception): """Base exception for inference-logging-client errors.""" + pass class SchemaFetchError(InferenceLoggingError): """Raised when fetching schema from inference service fails.""" + pass class SchemaNotFoundError(InferenceLoggingError): """Raised when no features are found in schema response.""" + pass class DecodeError(InferenceLoggingError): """Raised when decoding feature data fails.""" + pass class FormatError(InferenceLoggingError): """Raised when there's an issue with the data format.""" + pass class ProtobufError(InferenceLoggingError): """Raised when parsing protobuf data fails.""" + pass diff --git a/py-sdk/inference_logging_client/inference_logging_client/formats.py b/py-sdk/inference_logging_client/inference_logging_client/formats.py index 858a1438..bb662dcd 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/formats.py +++ b/py-sdk/inference_logging_client/inference_logging_client/formats.py @@ -19,7 +19,7 @@ def decode_proto_features(encoded_bytes: bytes, schema: list[FeatureInfo]) -> dict[str, Any]: """ Decode proto-encoded features for a single entity. - + Proto encoding format: - First byte: generated flag (1 = no generated values, 0 = has generated values) - For each feature in schema order: @@ -28,39 +28,39 @@ def decode_proto_features(encoded_bytes: bytes, schema: list[FeatureInfo]) -> di """ if len(encoded_bytes) == 0: return {f.name: None for f in schema} - + reader = ByteReader(encoded_bytes) result = {} - + # Read generated flag (first byte) # The generated flag indicates: 1 = no generated values, 0 = has generated values # Currently unused but reserved for future feature generation tracking if reader.remaining() < 1: return {f.name: None for f in schema} - + _generated_flag = reader.read_uint8() # Prefixed with _ to indicate intentionally unused - + for feature in schema: if not reader.has_more(): result[feature.name] = None continue - + try: if is_sized_type(feature.feature_type): # Read 2-byte size prefix if reader.remaining() < 2: result[feature.name] = None continue - + size = reader.read_uint16() if size == 0: result[feature.name] = None continue - + if reader.remaining() < size: result[feature.name] = None continue - + value_bytes = reader.read(size) result[feature.name] = decode_vector_or_string(value_bytes, feature.feature_type) else: @@ -69,178 +69,184 @@ def decode_proto_features(encoded_bytes: bytes, schema: list[FeatureInfo]) -> di if size is None: result[feature.name] = None continue - + if reader.remaining() < size: result[feature.name] = None continue - + value_bytes = reader.read(size) result[feature.name] = decode_scalar_value(value_bytes, feature.feature_type) except Exception as e: result[feature.name] = f"" - + return result -def decode_proto_format(mplog_data: bytes, schema: list[FeatureInfo]) -> tuple[list[str], list[dict[str, Any]]]: +def decode_proto_format( + mplog_data: bytes, schema: list[FeatureInfo] +) -> tuple[list[str], list[dict[str, Any]]]: """ Decode proto format MPLog. - + Returns: Tuple of (entity_ids, list of decoded feature dicts) """ parsed = parse_mplog_protobuf(mplog_data) - + # Create a copy to avoid mutating the parsed object's entities list entity_ids = list(parsed.entities) - encoded_features_list = getattr(parsed, '_encoded_features', []) - + encoded_features_list = getattr(parsed, "_encoded_features", []) + decoded_rows = [] for i, encoded_bytes in enumerate(encoded_features_list): decoded = decode_proto_features(encoded_bytes, schema) decoded_rows.append(decoded) - + # Ensure entity_ids matches decoded_rows count original_entity_count = len(entity_ids) while len(entity_ids) < len(decoded_rows): entity_ids.append(f"entity_{len(entity_ids)}") - + if original_entity_count != len(decoded_rows): warnings.warn( f"Entity count mismatch: {original_entity_count} entity IDs for {len(decoded_rows)} rows. " f"Generated synthetic IDs for missing entities.", - UserWarning + UserWarning, ) - - return entity_ids[:len(decoded_rows)], decoded_rows + return entity_ids[: len(decoded_rows)], decoded_rows -def decode_arrow_format(mplog_data: bytes, schema: list[FeatureInfo]) -> tuple[list[str], list[dict[str, Any]]]: + +def decode_arrow_format( + mplog_data: bytes, schema: list[FeatureInfo] +) -> tuple[list[str], list[dict[str, Any]]]: """ Decode Arrow IPC format MPLog. - + Arrow encoding: - MPLog protobuf wrapper containing Arrow IPC bytes in encoded_features - Arrow record has binary columns named by feature index ("0", "1", ...) - Each column contains raw feature value bytes - + Returns: Tuple of (entity_ids, list of decoded feature dicts) """ parsed = parse_mplog_protobuf(mplog_data) - + # Create a copy to avoid mutating the parsed object's entities list entity_ids = list(parsed.entities) - encoded_features_list = getattr(parsed, '_encoded_features', []) - + encoded_features_list = getattr(parsed, "_encoded_features", []) + if not encoded_features_list: return [], [] - + # Warn if multiple blobs exist (only first is used) if len(encoded_features_list) > 1: warnings.warn( f"Arrow format contains {len(encoded_features_list)} encoded feature blobs, " f"but only the first will be processed. This may indicate unexpected data.", - UserWarning + UserWarning, ) - + # Arrow format stores all entities in a single IPC blob arrow_bytes = encoded_features_list[0] - + if len(arrow_bytes) == 0: return entity_ids, [] - + try: reader = pa.ipc.open_stream(io.BytesIO(arrow_bytes)) table = reader.read_all() except Exception as e: raise FormatError(f"Failed to read Arrow IPC data: {e}") - + num_rows = table.num_rows decoded_rows = [] - + for row_idx in range(num_rows): row_data = {} for feature in schema: col_name = str(feature.index) - + if col_name not in table.column_names: row_data[feature.name] = None continue - + column = table.column(col_name) - + if column.is_null()[row_idx].as_py(): row_data[feature.name] = None continue - + # Get binary value value_bytes = column[row_idx].as_py() if value_bytes is None or len(value_bytes) == 0: row_data[feature.name] = None else: row_data[feature.name] = decode_feature_value(value_bytes, feature.feature_type) - + decoded_rows.append(row_data) - + # Ensure entity_ids matches decoded_rows count original_entity_count = len(entity_ids) while len(entity_ids) < len(decoded_rows): entity_ids.append(f"entity_{len(entity_ids)}") - + if original_entity_count != len(decoded_rows): warnings.warn( f"Entity count mismatch: {original_entity_count} entity IDs for {len(decoded_rows)} rows. " f"Generated synthetic IDs for missing entities.", - UserWarning + UserWarning, ) - - return entity_ids[:len(decoded_rows)], decoded_rows + + return entity_ids[: len(decoded_rows)], decoded_rows -def decode_parquet_format(mplog_data: bytes, schema: list[FeatureInfo]) -> tuple[list[str], list[dict[str, Any]]]: +def decode_parquet_format( + mplog_data: bytes, schema: list[FeatureInfo] +) -> tuple[list[str], list[dict[str, Any]]]: """ Decode Parquet format MPLog. - + Parquet encoding: - MPLog protobuf wrapper containing Parquet bytes in encoded_features - Parquet file has a Features column (map[int][]byte) - Each row represents an entity - + Returns: Tuple of (entity_ids, list of decoded feature dicts) """ parsed = parse_mplog_protobuf(mplog_data) - + # Create a copy to avoid mutating the parsed object's entities list entity_ids = list(parsed.entities) - encoded_features_list = getattr(parsed, '_encoded_features', []) - + encoded_features_list = getattr(parsed, "_encoded_features", []) + if not encoded_features_list: return [], [] - + # Warn if multiple blobs exist (only first is used) if len(encoded_features_list) > 1: warnings.warn( f"Parquet format contains {len(encoded_features_list)} encoded feature blobs, " f"but only the first will be processed. This may indicate unexpected data.", - UserWarning + UserWarning, ) - + # Parquet format stores all entities in a single blob parquet_bytes = encoded_features_list[0] - + if len(parquet_bytes) == 0: return entity_ids, [] - + try: table = pq.read_table(io.BytesIO(parquet_bytes)) except Exception as e: raise FormatError(f"Failed to read Parquet data: {e}") - + num_rows = table.num_rows decoded_rows = [] - + # Parquet schema uses a map column named "Features" if "Features" not in table.column_names: # Fallback: try reading as columnar format (similar to Arrow) @@ -257,18 +263,20 @@ def decode_parquet_format(mplog_data: bytes, schema: list[FeatureInfo]) -> tuple if value_bytes is None or len(value_bytes) == 0: row_data[feature.name] = None else: - row_data[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + row_data[feature.name] = decode_feature_value( + value_bytes, feature.feature_type + ) else: row_data[feature.name] = None decoded_rows.append(row_data) else: # Features column format features_col = table.column("Features") - + for row_idx in range(num_rows): row_data = {} feature_data = features_col[row_idx].as_py() - + if feature_data is None: for feature in schema: row_data[feature.name] = None @@ -279,7 +287,9 @@ def decode_parquet_format(mplog_data: bytes, schema: list[FeatureInfo]) -> tuple if value_bytes is None or len(value_bytes) == 0: row_data[feature.name] = None else: - row_data[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + row_data[feature.name] = decode_feature_value( + value_bytes, feature.feature_type + ) elif isinstance(feature_data, list): # List-based format: list of (key, value) tuples or list of bytes if len(feature_data) > 0: @@ -291,16 +301,23 @@ def decode_parquet_format(mplog_data: bytes, schema: list[FeatureInfo]) -> tuple if value_bytes is None or len(value_bytes) == 0: row_data[feature.name] = None else: - row_data[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + row_data[feature.name] = decode_feature_value( + value_bytes, feature.feature_type + ) else: # List of bytes indexed by position for feature in schema: if feature.index < len(feature_data): value_bytes = feature_data[feature.index] - if value_bytes is None or (isinstance(value_bytes, (bytes, bytearray)) and len(value_bytes) == 0): + if value_bytes is None or ( + isinstance(value_bytes, (bytes, bytearray)) + and len(value_bytes) == 0 + ): row_data[feature.name] = None else: - row_data[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + row_data[feature.name] = decode_feature_value( + value_bytes, feature.feature_type + ) else: row_data[feature.name] = None else: @@ -310,19 +327,19 @@ def decode_parquet_format(mplog_data: bytes, schema: list[FeatureInfo]) -> tuple # Unknown format for feature in schema: row_data[feature.name] = None - + decoded_rows.append(row_data) - + # Ensure entity_ids matches decoded_rows count original_entity_count = len(entity_ids) while len(entity_ids) < len(decoded_rows): entity_ids.append(f"entity_{len(entity_ids)}") - + if original_entity_count != len(decoded_rows): warnings.warn( f"Entity count mismatch: {original_entity_count} entity IDs for {len(decoded_rows)} rows. " f"Generated synthetic IDs for missing entities.", - UserWarning + UserWarning, ) - - return entity_ids[:len(decoded_rows)], decoded_rows + + return entity_ids[: len(decoded_rows)], decoded_rows diff --git a/py-sdk/inference_logging_client/inference_logging_client/io.py b/py-sdk/inference_logging_client/inference_logging_client/io.py index 800bb50f..0a88d2c0 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/io.py +++ b/py-sdk/inference_logging_client/inference_logging_client/io.py @@ -30,57 +30,69 @@ def _fetch_schema_with_retry(url: str, max_retries: int = _MAX_RETRIES) -> dict: """Fetch schema from URL with exponential backoff retry. - + Args: url: The URL to fetch from max_retries: Maximum number of retry attempts - + Returns: Parsed JSON response as dict - + Raises: SchemaFetchError: If all retries fail """ last_exception = None - + for attempt in range(max_retries): try: - req = urllib.request.Request(url, headers={ - "Content-Type": "application/json", - "User-Agent": "inference-logging-client/0.1.0" - }) + req = urllib.request.Request( + url, + headers={ + "Content-Type": "application/json", + "User-Agent": "inference-logging-client/0.1.0", + }, + ) with urllib.request.urlopen(req, timeout=30) as response: - return json.loads(response.read().decode('utf-8')) + return json.loads(response.read().decode("utf-8")) except urllib.error.HTTPError as e: - error_body = e.read().decode('utf-8') if e.fp else str(e) + error_body = e.read().decode("utf-8") if e.fp else str(e) # Don't retry on client errors (4xx) if 400 <= e.code < 500: raise SchemaFetchError(f"HTTP Error {e.code} from inference service: {error_body}") - last_exception = SchemaFetchError(f"HTTP Error {e.code} from inference service: {error_body}") + last_exception = SchemaFetchError( + f"HTTP Error {e.code} from inference service: {error_body}" + ) except urllib.error.URLError as e: - last_exception = SchemaFetchError(f"URL Error connecting to inference service: {e.reason}") + last_exception = SchemaFetchError( + f"URL Error connecting to inference service: {e.reason}" + ) except Exception as e: last_exception = SchemaFetchError(f"Unexpected error fetching schema: {e}") - + # Exponential backoff before retry if attempt < max_retries - 1: - sleep_time = _RETRY_BACKOFF_BASE * (2 ** attempt) + sleep_time = _RETRY_BACKOFF_BASE * (2**attempt) time.sleep(sleep_time) - + raise last_exception -def get_feature_schema(model_config_id: str, version: int, inference_host: Optional[str] = None, api_path: Optional[str] = None) -> list[FeatureInfo]: +def get_feature_schema( + model_config_id: str, + version: int, + inference_host: Optional[str] = None, + api_path: Optional[str] = None, +) -> list[FeatureInfo]: """Fetch feature schema from inference API with caching. - + Results are cached per (model_config_id, version) combination only. The inference_host and api_path are NOT part of the cache key because schemas are canonical for a given model config and version, regardless of which host serves them. This avoids redundant API calls for the same model proxy and version across different environments. - + Cache is thread-safe and has a maximum size limit with LRU eviction. - + Args: model_config_id: The model proxy config ID version: The schema version @@ -88,7 +100,7 @@ def get_feature_schema(model_config_id: str, version: int, inference_host: Optio Note: Not included in cache key - schemas are cached by (model_config_id, version) only. api_path: API path. If None, reads from INFERENCE_PATH env var or uses default. Note: Not included in cache key - schemas are cached by (model_config_id, version) only. - + Raises: SchemaFetchError: If schema fetch fails after retries SchemaNotFoundError: If no features found in schema @@ -96,49 +108,52 @@ def get_feature_schema(model_config_id: str, version: int, inference_host: Optio # Read from environment variables if not provided if inference_host is None: inference_host = os.getenv("INFERENCE_HOST", "http://localhost:8082") - + if api_path is None: - api_path = os.getenv("INFERENCE_PATH", "/api/v1/inference/mp-config-registry/get_feature_schema") - + api_path = os.getenv( + "INFERENCE_PATH", "/api/v1/inference/mp-config-registry/get_feature_schema" + ) + # Cache key is (model_config_id, version) only - host/path intentionally excluded # because schemas are canonical for a given model config and version cache_key = (model_config_id, version) - + # Thread-safe cache lookup with _schema_cache_lock: if cache_key in _schema_cache: # Move to end (most recently used) _schema_cache.move_to_end(cache_key) return _schema_cache[cache_key] - + base_url = f"{inference_host}{api_path}" - params = urllib.parse.urlencode({ - "model_config_id": model_config_id, - "version": str(version) - }) + params = urllib.parse.urlencode({"model_config_id": model_config_id, "version": str(version)}) url = f"{base_url}?{params}" - + # Fetch with retry data = _fetch_schema_with_retry(url) - + features = [] for idx, component in enumerate(data.get("data", [])): - features.append(FeatureInfo( - name=component["feature_name"], - feature_type=component["feature_type"].upper(), - index=idx - )) - + features.append( + FeatureInfo( + name=component["feature_name"], + feature_type=component["feature_type"].upper(), + index=idx, + ) + ) + if not features: - raise SchemaNotFoundError(f"No features found in schema for model_config_id={model_config_id}, version={version}") - + raise SchemaNotFoundError( + f"No features found in schema for model_config_id={model_config_id}, version={version}" + ) + # Thread-safe cache update with LRU eviction with _schema_cache_lock: # Evict oldest entries if at capacity while len(_schema_cache) >= _SCHEMA_CACHE_MAX_SIZE: _schema_cache.popitem(last=False) _schema_cache[cache_key] = features - + return features @@ -152,30 +167,30 @@ def parse_mplog_protobuf(data: bytes) -> DecodedMPLog: """Parse the outer MPLog protobuf message.""" result = DecodedMPLog() encoded_features_list: list[bytes] = [] - + reader = ByteReader(data) - + while reader.has_more(): tag = read_varint(reader) field_number = tag >> 3 wire_type = tag & 0x7 - + if field_number == 1 and wire_type == 2: # user_id length = read_varint(reader) - result.user_id = reader.read(length).decode('utf-8') - + result.user_id = reader.read(length).decode("utf-8") + elif field_number == 2 and wire_type == 2: # tracking_id length = read_varint(reader) - result.tracking_id = reader.read(length).decode('utf-8') - + result.tracking_id = reader.read(length).decode("utf-8") + elif field_number == 3 and wire_type == 2: # mp_config_id length = read_varint(reader) - result.model_proxy_config_id = reader.read(length).decode('utf-8') - + result.model_proxy_config_id = reader.read(length).decode("utf-8") + elif field_number == 4 and wire_type == 2: # entities (repeated string) length = read_varint(reader) - result.entities.append(reader.read(length).decode('utf-8')) - + result.entities.append(reader.read(length).decode("utf-8")) + elif field_number == 5 and wire_type == 2: # features (repeated perEntityFeatures) length = read_varint(reader) feature_bytes = reader.read(length) @@ -190,21 +205,23 @@ def parse_mplog_protobuf(data: bytes) -> DecodedMPLog: encoded_features_list.append(per_entity_reader.read(enc_len)) else: skip_field(per_entity_reader, inner_wire) - + elif field_number == 6 and wire_type == 2: # metadata length = read_varint(reader) metadata_bytes = reader.read(length) if len(metadata_bytes) > 0: result.metadata_byte = metadata_bytes[0] - result.compression_enabled, result.version, result.format_type = unpack_metadata_byte(result.metadata_byte) - + result.compression_enabled, result.version, result.format_type = ( + unpack_metadata_byte(result.metadata_byte) + ) + elif field_number == 7 and wire_type == 2: # parent_entity (repeated string) length = read_varint(reader) - result.parent_entity.append(reader.read(length).decode('utf-8')) - + result.parent_entity.append(reader.read(length).decode("utf-8")) + else: skip_field(reader, wire_type) - + # Attach encoded features for later processing result._encoded_features = encoded_features_list return result @@ -213,23 +230,24 @@ def parse_mplog_protobuf(data: bytes) -> DecodedMPLog: def get_mplog_metadata(log_data: bytes, decompress: bool = True) -> DecodedMPLog: """ Extract metadata from MPLog bytes without full decoding. - + Args: log_data: The MPLog bytes (possibly compressed) decompress: Whether to attempt zstd decompression - + Returns: DecodedMPLog with metadata fields populated - + Raises: ImportError: If data is zstd-compressed but zstandard is not installed """ working_data = log_data if decompress: # Check for zstd magic number: 0x28 0xB5 0x2F 0xFD - if len(log_data) >= 4 and log_data[:4] == b'\x28\xB5\x2F\xFD': + if len(log_data) >= 4 and log_data[:4] == b"\x28\xb5\x2f\xfd": try: import zstandard as zstd + decompressor = zstd.ZstdDecompressor() working_data = decompressor.decompress(log_data) except ImportError: @@ -237,5 +255,5 @@ def get_mplog_metadata(log_data: bytes, decompress: bool = True) -> DecodedMPLog "Data appears to be zstd-compressed but the 'zstandard' package is not installed. " "Install it with: pip install zstandard" ) - + return parse_mplog_protobuf(working_data) diff --git a/py-sdk/inference_logging_client/inference_logging_client/types.py b/py-sdk/inference_logging_client/inference_logging_client/types.py index 3ab5e078..7ab0c170 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/types.py +++ b/py-sdk/inference_logging_client/inference_logging_client/types.py @@ -6,14 +6,15 @@ class Format(Enum): """Supported log formats.""" + PROTO = "proto" ARROW = "arrow" PARQUET = "parquet" # Format type constants matching Go encoder (bits 6-7 of metadata byte) -FORMAT_TYPE_PROTO = 0 # 00 -FORMAT_TYPE_ARROW = 1 # 01 +FORMAT_TYPE_PROTO = 0 # 00 +FORMAT_TYPE_ARROW = 1 # 01 FORMAT_TYPE_PARQUET = 2 # 10 # Mapping from format type int to Format enum @@ -27,6 +28,7 @@ class Format(Enum): @dataclass class FeatureInfo: """Feature schema information.""" + name: str feature_type: str index: int @@ -35,6 +37,7 @@ class FeatureInfo: @dataclass class DecodedMPLog: """Container for decoded MPLog data.""" + user_id: str = "" tracking_id: str = "" model_proxy_config_id: str = "" diff --git a/py-sdk/inference_logging_client/inference_logging_client/utils.py b/py-sdk/inference_logging_client/inference_logging_client/utils.py index caebdf39..ac2fb445 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/utils.py +++ b/py-sdk/inference_logging_client/inference_logging_client/utils.py @@ -4,58 +4,112 @@ # Feature type to size mapping for scalar types SCALAR_TYPE_SIZES = { - "INT8": 1, "I8": 1, - "INT16": 2, "I16": 2, "SHORT": 2, - "INT32": 4, "I32": 4, "INT": 4, - "INT64": 8, "I64": 8, "LONG": 8, - "UINT8": 1, "U8": 1, - "UINT16": 2, "U16": 2, - "UINT32": 4, "U32": 4, - "UINT64": 8, "U64": 8, - "FP8E5M2": 1, "FP8E4M3": 1, - "FP16": 2, "FLOAT16": 2, "F16": 2, - "FP32": 4, "FLOAT32": 4, "F32": 4, "FLOAT": 4, - "FP64": 8, "FLOAT64": 8, "F64": 8, "DOUBLE": 8, - "BOOL": 1, "BOOLEAN": 1, + "INT8": 1, + "I8": 1, + "INT16": 2, + "I16": 2, + "SHORT": 2, + "INT32": 4, + "I32": 4, + "INT": 4, + "INT64": 8, + "I64": 8, + "LONG": 8, + "UINT8": 1, + "U8": 1, + "UINT16": 2, + "U16": 2, + "UINT32": 4, + "U32": 4, + "UINT64": 8, + "U64": 8, + "FP8E5M2": 1, + "FP8E4M3": 1, + "FP16": 2, + "FLOAT16": 2, + "F16": 2, + "FP32": 4, + "FLOAT32": 4, + "F32": 4, + "FLOAT": 4, + "FP64": 8, + "FLOAT64": 8, + "F64": 8, + "DOUBLE": 8, + "BOOL": 1, + "BOOLEAN": 1, } # Types that use 2-byte size prefix SIZED_TYPES = { - "STRING", "STR", + "STRING", + "STR", "BYTES", # Vector types - "FP8E5M2VECTOR", "FP8E4M3VECTOR", - "FP16VECTOR", "FLOAT16VECTOR", - "FP32VECTOR", "FLOAT32VECTOR", - "FP64VECTOR", "FLOAT64VECTOR", - "INT8VECTOR", "INT16VECTOR", "INT32VECTOR", "INT64VECTOR", - "UINT8VECTOR", "UINT16VECTOR", "UINT32VECTOR", "UINT64VECTOR", - "STRINGVECTOR", "BOOLVECTOR", + "FP8E5M2VECTOR", + "FP8E4M3VECTOR", + "FP16VECTOR", + "FLOAT16VECTOR", + "FP32VECTOR", + "FLOAT32VECTOR", + "FP64VECTOR", + "FLOAT64VECTOR", + "INT8VECTOR", + "INT16VECTOR", + "INT32VECTOR", + "INT64VECTOR", + "UINT8VECTOR", + "UINT16VECTOR", + "UINT32VECTOR", + "UINT64VECTOR", + "STRINGVECTOR", + "BOOLVECTOR", # With underscore - "VECTOR_FP8E5M2", "VECTOR_FP8E4M3", - "VECTOR_FP16", "VECTOR_FLOAT16", - "VECTOR_FP32", "VECTOR_FLOAT32", - "VECTOR_FP64", "VECTOR_FLOAT64", - "VECTOR_INT8", "VECTOR_INT16", "VECTOR_INT32", "VECTOR_INT64", - "VECTOR_UINT8", "VECTOR_UINT16", "VECTOR_UINT32", "VECTOR_UINT64", - "VECTOR_STRING", "VECTOR_BOOL", + "VECTOR_FP8E5M2", + "VECTOR_FP8E4M3", + "VECTOR_FP16", + "VECTOR_FLOAT16", + "VECTOR_FP32", + "VECTOR_FLOAT32", + "VECTOR_FP64", + "VECTOR_FLOAT64", + "VECTOR_INT8", + "VECTOR_INT16", + "VECTOR_INT32", + "VECTOR_INT64", + "VECTOR_UINT8", + "VECTOR_UINT16", + "VECTOR_UINT32", + "VECTOR_UINT64", + "VECTOR_STRING", + "VECTOR_BOOL", # DataType prefix variants "DATATYPESTRING", "DATATYPEBYTES", - "DATATYPEFP8E5M2VECTOR", "DATATYPEFP8E4M3VECTOR", - "DATATYPEFP16VECTOR", "DATATYPEFP32VECTOR", "DATATYPEFP64VECTOR", - "DATATYPEINT8VECTOR", "DATATYPEINT16VECTOR", "DATATYPEINT32VECTOR", "DATATYPEINT64VECTOR", - "DATATYPEUINT8VECTOR", "DATATYPEUINT16VECTOR", "DATATYPEUINT32VECTOR", "DATATYPEUINT64VECTOR", - "DATATYPESTRINGVECTOR", "DATATYPEBOOLVECTOR", + "DATATYPEFP8E5M2VECTOR", + "DATATYPEFP8E4M3VECTOR", + "DATATYPEFP16VECTOR", + "DATATYPEFP32VECTOR", + "DATATYPEFP64VECTOR", + "DATATYPEINT8VECTOR", + "DATATYPEINT16VECTOR", + "DATATYPEINT32VECTOR", + "DATATYPEINT64VECTOR", + "DATATYPEUINT8VECTOR", + "DATATYPEUINT16VECTOR", + "DATATYPEUINT32VECTOR", + "DATATYPEUINT64VECTOR", + "DATATYPESTRINGVECTOR", + "DATATYPEBOOLVECTOR", } def normalize_type(feature_type: str) -> str: """Normalize feature type string for consistent comparison. - + Args: feature_type: The feature type string to normalize - + Returns: Normalized uppercase string with underscores and DATATYPE prefix removed. Returns empty string if feature_type is None. @@ -68,9 +122,11 @@ def normalize_type(feature_type: str) -> str: def is_sized_type(feature_type: str) -> bool: """Check if the feature type requires a 2-byte size prefix.""" normalized = normalize_type(feature_type) - return (normalized in {"STRING", "STR"} or - "VECTOR" in normalized.upper() or - normalized in SIZED_TYPES) + return ( + normalized in {"STRING", "STR"} + or "VECTOR" in normalized.upper() + or normalized in SIZED_TYPES + ) def get_scalar_size(feature_type: str) -> Optional[int]: @@ -82,11 +138,12 @@ def get_scalar_size(feature_type: str) -> Optional[int]: def format_float(value: float) -> float: """ Format float to 6 decimal places without scientific notation. - + Returns the float value formatted to 6 decimal places. For special values (inf, -inf, nan), returns them as-is. For regular floats, rounds to 6 decimals. """ import math + if math.isnan(value) or math.isinf(value): return value # Round to 6 decimal places and convert back to float @@ -96,15 +153,26 @@ def format_float(value: float) -> float: def format_dataframe_floats(df): """ - Format all float columns in DataFrame to 6 decimal places. + Format all float columns in Spark DataFrame to 6 decimal places. This ensures no scientific notation in output. + + Args: + df: A Spark DataFrame + + Returns: + Spark DataFrame with float columns rounded to 6 decimal places """ - import pandas as pd - df = df.copy() - for col in df.columns: - if df[col].dtype == 'float64' or df[col].dtype == 'float32': - df[col] = df[col].apply(lambda x: format_float(x) if pd.notna(x) else x) - return df + from pyspark.sql import functions as spark_funcs + from pyspark.sql.types import DoubleType, FloatType + + result_df = df + for field in df.schema.fields: + if isinstance(field.dataType, (FloatType, DoubleType)): + result_df = result_df.withColumn( + field.name, + spark_funcs.round(spark_funcs.col(field.name), 6) + ) + return result_df def unpack_metadata_byte(metadata_byte: int) -> tuple[bool, int, int]: @@ -124,7 +192,8 @@ def unpack_metadata_byte(metadata_byte: int) -> tuple[bool, int, int]: def get_format_name(format_type: int) -> str: """Get human-readable format name from format type int.""" - from .types import FORMAT_TYPE_PROTO, FORMAT_TYPE_ARROW, FORMAT_TYPE_PARQUET + from .types import FORMAT_TYPE_ARROW, FORMAT_TYPE_PARQUET, FORMAT_TYPE_PROTO + names = { FORMAT_TYPE_PROTO: "proto", FORMAT_TYPE_ARROW: "arrow", diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index 9b4ceeee..396ad071 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -28,7 +28,7 @@ classifiers = [ ] dependencies = [ - "pandas>=1.3.0", + "pyspark>=3.3.0", "pyarrow>=5.0.0", "zstandard>=0.15.0", ] diff --git a/py-sdk/inference_logging_client/readme.md b/py-sdk/inference_logging_client/readme.md index d077a777..b773af05 100644 --- a/py-sdk/inference_logging_client/readme.md +++ b/py-sdk/inference_logging_client/readme.md @@ -1,182 +1,1194 @@ # Inference Logging Client -A Python package for decoding MPLog feature logs from proto, arrow, or parquet format. +A Python SDK for decoding MPLog feature logs from proto, arrow, or parquet format. This client enables you to decode binary-encoded feature data from machine learning inference logging pipelines into Spark DataFrames. -## Features +--- -- Decode MPLog feature logs from multiple encoding formats: - - **Proto**: Custom binary encoding with generated flag + sequential features - - **Arrow**: Arrow IPC format with binary columns - - **Parquet**: Parquet format with feature map -- Automatic format detection from metadata -- Support for zstd compression -- Fetch feature schemas from inference API -- Convert decoded logs to pandas DataFrames -- Command-line interface for easy usage +## Table of Contents + +- [Overview](#overview) +- [Installation](#installation) +- [Quick Start](#quick-start) +- [Configuration](#configuration) +- [Core API Reference](#core-api-reference) + - [decode_mplog()](#decode_mplog) + - [decode_mplog_dataframe()](#decode_mplog_dataframe) + - [get_mplog_metadata()](#get_mplog_metadata) + - [get_feature_schema()](#get_feature_schema) + - [clear_schema_cache()](#clear_schema_cache) +- [Data Types](#data-types) + - [Format Enum](#format-enum) + - [FeatureInfo](#featureinfo) + - [DecodedMPLog](#decodedmplog) +- [Supported Feature Types](#supported-feature-types) +- [Encoding Formats Explained](#encoding-formats-explained) +- [Exception Handling](#exception-handling) +- [Command Line Interface](#command-line-interface) +- [Advanced Usage Examples](#advanced-usage-examples) +- [Architecture & Internals](#architecture--internals) +- [Troubleshooting](#troubleshooting) +- [Development](#development) + +--- + +## Overview + +The Inference Logging Client is designed to decode MPLog (Model Proxy Log) feature data that has been encoded for efficient storage and transmission. It supports three encoding formats: + +| Format | Description | Use Case | +|--------|-------------|----------| +| **Proto** | Custom binary encoding with generated flag + sequential features | Default, most compact | +| **Arrow** | Arrow IPC format with binary columns | Columnar analytics | +| **Parquet** | Parquet format with feature map | Long-term storage | + +### Key Features + +- **Multi-format support**: Decode Proto, Arrow, and Parquet encoded logs +- **Automatic format detection**: Detects encoding format from metadata byte +- **Zstd compression support**: Automatic decompression of zstd-compressed data +- **Schema fetching**: Retrieves feature schemas from inference API with caching +- **Spark integration**: Returns data as PySpark DataFrames +- **CLI tool**: Command-line interface for quick decoding +- **Thread-safe caching**: LRU cache for schemas with thread-safe access + +--- ## Installation +### From PyPI + ```bash pip install inference-logging-client ``` +### From Source + +```bash +cd py-sdk/inference_logging_client +pip install -e . +``` + +### With Development Dependencies + +```bash +pip install -e ".[dev]" +``` + +### Dependencies + +| Package | Version | Purpose | +|---------|---------|---------| +| pyspark | >=3.3.0 | Spark DataFrame operations | +| pyarrow | >=5.0.0 | Arrow/Parquet format support | +| zstandard | >=0.15.0 | Zstd decompression | + +--- + ## Quick Start -### Python API +### Basic Decoding from Bytes + +```python +from pyspark.sql import SparkSession +import inference_logging_client + +# Create SparkSession +spark = SparkSession.builder \ + .appName("inference-decode") \ + .getOrCreate() + +# Read binary MPLog data +with open("inference_log.bin", "rb") as f: + data = f.read() + +# Decode to Spark DataFrame +df = inference_logging_client.decode_mplog( + log_data=data, + model_proxy_id="product-ranking-model", + version=1, + spark=spark +) + +# View the results +df.show() +# entity_id feature_price feature_category embedding_vector +# 0 prod_123 29.99 5 [0.1, 0.2, ...] +# 1 prod_456 49.99 3 [0.3, 0.4, ...] + +# Stop SparkSession when done +spark.stop() +``` + +### Decoding from a Spark DataFrame + +```python +from pyspark.sql import SparkSession +import inference_logging_client + +# Create SparkSession +spark = SparkSession.builder \ + .appName("inference-decode") \ + .getOrCreate() + +# Read parquet file containing MPLog data +df = spark.read.parquet("inference_logs.parquet") + +# Expected columns: features, metadata, mp_config_id, entities, ... +print(df.columns) +# ['prism_ingested_at', 'features', 'metadata', 'mp_config_id', 'entities', ...] + +# Decode features from each row +decoded_df = inference_logging_client.decode_mplog_dataframe(df, spark) + +decoded_df.show() +# entity_id prism_ingested_at mp_config_id feature_1 feature_2 +# 0 user_123 2024-01-15 10:30 my-model 42 3.14 +# 1 user_456 2024-01-15 10:30 my-model 17 2.71 + +spark.stop() +``` + +--- + +## Configuration + +### Environment Variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `INFERENCE_HOST` | `http://localhost:8082` | Inference service base URL | +| `INFERENCE_PATH` | `/api/v1/inference/mp-config-registry/get_feature_schema` | Schema fetch API path | + +### Setting Environment Variables + +```bash +export INFERENCE_HOST="https://inference.prod.example.com" +export INFERENCE_PATH="/api/v1/inference/mp-config-registry/get_feature_schema" +``` + +### Programmatic Configuration ```python +from pyspark.sql import SparkSession import inference_logging_client -# Decode MPLog from bytes +spark = SparkSession.builder.appName("decode").getOrCreate() + +# Pass host directly to functions +df = inference_logging_client.decode_mplog( + log_data=data, + model_proxy_id="my-model", + version=1, + spark=spark, + inference_host="https://inference.staging.example.com" +) +``` + +--- + +## Core API Reference + +### decode_mplog() + +Main function to decode MPLog bytes to a Spark DataFrame. + +```python +def decode_mplog( + log_data: bytes, + model_proxy_id: str, + version: int, + spark: SparkSession, + format_type: Optional[Format] = None, + inference_host: Optional[str] = None, + decompress: bool = True, + schema: Optional[list] = None +) -> pyspark.sql.DataFrame: +``` + +#### Parameters + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `log_data` | `bytes` | Yes | - | The MPLog bytes (possibly zstd compressed) | +| `model_proxy_id` | `str` | Yes | - | The model proxy config ID for schema lookup | +| `version` | `int` | Yes | - | The schema version (0-15) | +| `spark` | `SparkSession` | Yes | - | The SparkSession to use for creating DataFrames | +| `format_type` | `Format` | No | `None` | Encoding format. If None, auto-detects from metadata | +| `inference_host` | `str` | No | `None` | Inference service URL. Falls back to `INFERENCE_HOST` env | +| `decompress` | `bool` | No | `True` | Whether to attempt zstd decompression | +| `schema` | `list` | No | `None` | Pre-fetched schema to skip API call | + +#### Returns + +`pyspark.sql.DataFrame` with: +- First column: `entity_id` - identifier for each entity +- Remaining columns: decoded feature values + +#### Exceptions + +| Exception | When Raised | +|-----------|-------------| +| `ValueError` | Version out of range (0-15) | +| `ImportError` | Data is zstd-compressed but `zstandard` not installed | +| `FormatError` | Unsupported format or parse error | +| `SchemaFetchError` | Failed to fetch schema from API | +| `SchemaNotFoundError` | No features in schema response | + +#### Example: Basic Usage + +```python +from pyspark.sql import SparkSession +import inference_logging_client + +spark = SparkSession.builder.appName("decode").getOrCreate() + with open("log.bin", "rb") as f: data = f.read() df = inference_logging_client.decode_mplog( log_data=data, + model_proxy_id="recommendation-model", + version=2, + spark=spark +) + +print(f"Decoded {df.count()} entities with {len(df.columns) - 1} features") +``` + +#### Example: Explicit Format + +```python +from pyspark.sql import SparkSession +from inference_logging_client import decode_mplog, Format + +spark = SparkSession.builder.appName("decode").getOrCreate() + +df = decode_mplog( + log_data=arrow_encoded_data, model_proxy_id="my-model", - version=1 + version=1, + spark=spark, + format_type=Format.ARROW # Skip auto-detection +) +``` + +#### Example: Pre-fetched Schema (Performance Optimization) + +```python +from pyspark.sql import SparkSession +from inference_logging_client import decode_mplog, get_feature_schema + +spark = SparkSession.builder.appName("decode").getOrCreate() + +# Fetch schema once +schema = get_feature_schema("my-model", 1, "https://inference.example.com") + +# Decode multiple logs with same schema +for log_bytes in batch_of_logs: + df = decode_mplog( + log_data=log_bytes, + model_proxy_id="my-model", + version=1, + spark=spark, + schema=schema # Reuse cached schema + ) + process(df) +``` + +--- + +### decode_mplog_dataframe() + +Decode MPLog features from a Spark DataFrame containing encoded feature data. + +```python +def decode_mplog_dataframe( + df: pyspark.sql.DataFrame, + spark: SparkSession, + inference_host: Optional[str] = None, + decompress: bool = True, + features_column: str = "features", + metadata_column: str = "metadata", + mp_config_id_column: str = "mp_config_id" +) -> pyspark.sql.DataFrame: +``` + +#### Parameters + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `df` | `pyspark.sql.DataFrame` | Yes | - | Input Spark DataFrame with MPLog columns | +| `spark` | `SparkSession` | Yes | - | The SparkSession to use | +| `inference_host` | `str` | No | `None` | Inference service URL | +| `decompress` | `bool` | No | `True` | Attempt zstd decompression | +| `features_column` | `str` | No | `"features"` | Column containing encoded features | +| `metadata_column` | `str` | No | `"metadata"` | Column containing metadata byte | +| `mp_config_id_column` | `str` | No | `"mp_config_id"` | Column containing model proxy ID | + +#### Expected Input DataFrame Columns + +| Column | Type | Required | Description | +|--------|------|----------|-------------| +| `features` | `bytes/str` | Yes | Encoded feature bytes (raw, base64, or hex) | +| `metadata` | `int/bytes` | Yes | Metadata byte for version/format detection | +| `mp_config_id` | `str` | Yes | Model proxy config ID | +| `entities` | `list/str` | No | Entity IDs (JSON list or single value) | +| `prism_ingested_at` | `datetime` | No | Preserved in output | +| `prism_extracted_at` | `datetime` | No | Preserved in output | +| `created_at` | `datetime` | No | Preserved in output | +| `parent_entity` | `str/list` | No | Preserved in output | +| `tracking_id` | `str` | No | Preserved in output | +| `user_id` | `str` | No | Preserved in output | +| `year`, `month`, `day`, `hour` | `int` | No | Partition columns, preserved | + +#### Returns + +`pyspark.sql.DataFrame` with: +- `entity_id`: Entity identifier (one row per entity) +- Metadata columns (if present in input) +- Decoded feature columns + +#### Example: Processing Parquet Logs + +```python +from pyspark.sql import SparkSession +import inference_logging_client + +spark = SparkSession.builder.appName("decode").getOrCreate() + +# Read from data lake +df = spark.read.parquet("s3://bucket/inference-logs/dt=2024-01-15/") + +# Decode all rows +decoded = inference_logging_client.decode_mplog_dataframe(df, spark) + +# Analyze features +decoded.groupBy('mp_config_id').avg('feature_score').show() +``` + +#### Example: Custom Column Names + +```python +from pyspark.sql import SparkSession +import inference_logging_client + +spark = SparkSession.builder.appName("decode").getOrCreate() + +# Your DataFrame has different column names +df = spark.read.csv("custom_logs.csv", header=True) + +decoded = inference_logging_client.decode_mplog_dataframe( + df, + spark, + features_column="encoded_data", # Custom name + metadata_column="meta", # Custom name + mp_config_id_column="model_id" # Custom name ) +``` + +--- + +### get_mplog_metadata() + +Extract metadata from MPLog bytes without full decoding. Useful for inspecting format and version. -print(df.head()) +```python +def get_mplog_metadata( + log_data: bytes, + decompress: bool = True +) -> DecodedMPLog: ``` -### Decode from DataFrame +#### Parameters + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `log_data` | `bytes` | Yes | - | The MPLog bytes | +| `decompress` | `bool` | No | `True` | Attempt zstd decompression | + +#### Returns + +`DecodedMPLog` dataclass with: +- `user_id`: User identifier +- `tracking_id`: Request tracking ID +- `model_proxy_config_id`: Model proxy config ID +- `entities`: List of entity IDs +- `parent_entity`: List of parent entity IDs +- `metadata_byte`: Raw metadata byte +- `compression_enabled`: Whether compression was enabled +- `version`: Schema version (0-15) +- `format_type`: Format type int (0=proto, 1=arrow, 2=parquet) + +#### Example: Inspect Log Before Decoding ```python -import pandas as pd import inference_logging_client -# Read DataFrame with MPLog columns -df = pd.read_parquet("logs.parquet") +with open("unknown_log.bin", "rb") as f: + data = f.read() -# Decode features -decoded_df = inference_logging_client.decode_mplog_dataframe(df) +metadata = inference_logging_client.get_mplog_metadata(data) -print(decoded_df.head()) +print(f"Model: {metadata.model_proxy_config_id}") +print(f"Version: {metadata.version}") +print(f"Format: {inference_logging_client.get_format_name(metadata.format_type)}") +print(f"Compression: {'enabled' if metadata.compression_enabled else 'disabled'}") +print(f"Entities: {len(metadata.entities)}") ``` -### Command Line Interface +--- -```bash -# Decode with auto-detection -inference-logging-client --model-proxy-id my-model --version 1 input.bin +### get_feature_schema() + +Fetch feature schema from the inference API with automatic caching. + +```python +def get_feature_schema( + model_config_id: str, + version: int, + inference_host: Optional[str] = None, + api_path: Optional[str] = None +) -> list[FeatureInfo]: +``` + +#### Parameters + +| Parameter | Type | Required | Default | Description | +|-----------|------|----------|---------|-------------| +| `model_config_id` | `str` | Yes | - | Model proxy config ID | +| `version` | `int` | Yes | - | Schema version | +| `inference_host` | `str` | No | `None` | Inference service URL | +| `api_path` | `str` | No | `None` | API path for schema endpoint | + +#### Returns + +`list[FeatureInfo]`: List of feature definitions with: +- `name`: Feature name +- `feature_type`: Feature data type (e.g., "FP32", "INT64", "FP32VECTOR") +- `index`: Feature index in encoding order -# Specify format explicitly -inference-logging-client --model-proxy-id my-model --version 1 --format proto input.bin +#### Caching Behavior -# Output to CSV -inference-logging-client --model-proxy-id my-model --version 1 input.bin -o output.csv +- Schemas are cached by `(model_config_id, version)` tuple +- Cache is thread-safe (uses threading.Lock) +- Maximum 100 cached schemas (LRU eviction) +- Host/path are NOT part of cache key (schemas are canonical) -# Decode from stdin (base64) -echo "BASE64_DATA" | inference-logging-client --model-proxy-id my-model --version 1 --base64 - +#### Example: Manual Schema Fetch + +```python +from inference_logging_client import get_feature_schema + +schema = get_feature_schema( + model_config_id="product-ranking", + version=3, + inference_host="https://inference.example.com" +) + +for feature in schema: + print(f" {feature.index}: {feature.name} ({feature.feature_type})") ``` -## Configuration +--- -The package uses environment variables for configuration: +### clear_schema_cache() -- `INFERENCE_HOST`: Inference service host URL (default: `http://localhost:8082`) -- `INFERENCE_PATH`: API path for schema fetching (default: `/api/v1/inference/mp-config-registry/get_feature_schema`) +Clear the internal schema cache. Useful for testing or when schemas have changed. -## API Reference +```python +def clear_schema_cache() -> None: +``` -### `decode_mplog()` +#### Example -Decode MPLog bytes to a pandas DataFrame. +```python +from inference_logging_client import clear_schema_cache, get_feature_schema -**Parameters:** -- `log_data` (bytes): The MPLog bytes (possibly compressed) -- `model_proxy_id` (str): The model proxy config ID -- `version` (int): The schema version -- `format_type` (Format, optional): The encoding format. If None, auto-detect from metadata. -- `inference_host` (str, optional): The inference service host URL -- `decompress` (bool): Whether to attempt zstd decompression (default: True) -- `schema` (list, optional): Pre-fetched schema (list of FeatureInfo). If provided, skips schema fetch. +# Clear before testing +clear_schema_cache() -**Returns:** -- `pd.DataFrame`: DataFrame with `entity_id` as first column and features as remaining columns +# This will fetch fresh from API +schema = get_feature_schema("my-model", 1) +``` -### `decode_mplog_dataframe()` +--- -Decode MPLog features from a DataFrame with specific column structure. +## Data Types -**Parameters:** -- `df` (pd.DataFrame): Input DataFrame with MPLog data columns -- `inference_host` (str, optional): The inference service host URL -- `decompress` (bool): Whether to attempt zstd decompression (default: True) -- `features_column` (str): Name of the column containing encoded features (default: "features") -- `metadata_column` (str): Name of the column containing metadata byte (default: "metadata") -- `mp_config_id_column` (str): Name of the column containing model proxy config ID (default: "mp_config_id") +### Format Enum -**Returns:** -- `pd.DataFrame`: DataFrame with decoded features, one row per entity +```python +from inference_logging_client import Format -### `get_mplog_metadata()` +class Format(Enum): + PROTO = "proto" # Custom binary encoding + ARROW = "arrow" # Arrow IPC format + PARQUET = "parquet" # Parquet format +``` + +### FeatureInfo + +```python +from inference_logging_client import FeatureInfo -Extract metadata from MPLog bytes without full decoding. +@dataclass +class FeatureInfo: + name: str # Feature name (e.g., "user_embedding") + feature_type: str # Type string (e.g., "FP32VECTOR") + index: int # Position in encoded data +``` -**Parameters:** -- `log_data` (bytes): The MPLog bytes (possibly compressed) -- `decompress` (bool): Whether to attempt zstd decompression (default: True) +### DecodedMPLog -**Returns:** -- `DecodedMPLog`: Object with metadata fields populated +```python +from inference_logging_client import DecodedMPLog + +@dataclass +class DecodedMPLog: + user_id: str = "" + tracking_id: str = "" + model_proxy_config_id: str = "" + entities: list[str] = field(default_factory=list) + parent_entity: list[str] = field(default_factory=list) + metadata_byte: int = 0 + compression_enabled: bool = False + version: int = 0 + format_type: int = 0 # 0=proto, 1=arrow, 2=parquet +``` + +--- ## Supported Feature Types ### Scalar Types -- Integer: INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64 -- Float: FP16, FP32, FP64, FP8E5M2, FP8E4M3 -- Boolean: BOOL -- String: STRING + +| Type Aliases | Size | Description | +|-------------|------|-------------| +| `INT8`, `I8` | 1 byte | Signed 8-bit integer | +| `INT16`, `I16`, `SHORT` | 2 bytes | Signed 16-bit integer | +| `INT32`, `I32`, `INT` | 4 bytes | Signed 32-bit integer | +| `INT64`, `I64`, `LONG` | 8 bytes | Signed 64-bit integer | +| `UINT8`, `U8` | 1 byte | Unsigned 8-bit integer | +| `UINT16`, `U16` | 2 bytes | Unsigned 16-bit integer | +| `UINT32`, `U32` | 4 bytes | Unsigned 32-bit integer | +| `UINT64`, `U64` | 8 bytes | Unsigned 64-bit integer | +| `FP16`, `FLOAT16`, `F16` | 2 bytes | IEEE 754 half-precision float | +| `FP32`, `FLOAT32`, `F32`, `FLOAT` | 4 bytes | IEEE 754 single-precision float | +| `FP64`, `FLOAT64`, `F64`, `DOUBLE` | 8 bytes | IEEE 754 double-precision float | +| `FP8E5M2`, `FP8E4M3` | 1 byte | 8-bit floating point (raw byte) | +| `BOOL`, `BOOLEAN` | 1 byte | Boolean value | + +### String Types + +| Type | Description | +|------|-------------| +| `STRING`, `STR` | UTF-8 encoded string | +| `BYTES` | Binary bytes with 2-byte length prefix | ### Vector Types -- All scalar types can be vectors (e.g., FP32VECTOR, INT64VECTOR) -- Vectors can be binary-encoded or JSON-encoded -## Encoding Formats +All scalar types have vector variants: + +| Type Pattern | Description | +|--------------|-------------| +| `{TYPE}VECTOR` | e.g., `FP32VECTOR`, `INT64VECTOR` | +| `VECTOR_{TYPE}` | e.g., `VECTOR_FP32`, `VECTOR_INT64` | +| `DATATYPE{TYPE}VECTOR` | e.g., `DATATYPEFP32VECTOR` | + +Vectors can be encoded as: +- **Binary**: Packed element bytes (most common for feature stores) +- **JSON**: JSON array string (fallback) + +#### Example: Working with Vectors + +```python +from pyspark.sql import SparkSession +import inference_logging_client + +spark = SparkSession.builder.appName("decode").getOrCreate() + +df = inference_logging_client.decode_mplog(data, "model", 1, spark) + +# Vector columns contain arrays +df.select("entity_id", "user_embedding").show(truncate=False) + +# Access vector elements with Spark SQL functions +from pyspark.sql import functions as F +df.select("entity_id", F.element_at("user_embedding", 1).alias("first_elem")).show() +``` + +--- + +## Encoding Formats Explained ### Proto Format -- First byte: generated flag -- Scalars: fixed size bytes based on type -- Strings/Vectors: 2-byte little-endian size prefix + data bytes + +The default and most compact encoding format. + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Byte 0: Generated Flag (1 = no generated values) │ +├─────────────────────────────────────────────────────────────┤ +│ Feature 0: [fixed bytes OR 2-byte size + data] │ +├─────────────────────────────────────────────────────────────┤ +│ Feature 1: [fixed bytes OR 2-byte size + data] │ +├─────────────────────────────────────────────────────────────┤ +│ ... │ +└─────────────────────────────────────────────────────────────┘ +``` + +- **Scalars**: Fixed size based on type (e.g., 4 bytes for FP32) +- **Strings/Vectors**: 2-byte little-endian size prefix + data ### Arrow Format -- Arrow IPC format with binary columns -- Column names are feature indices ("0", "1", ...) -- Each column contains raw feature value bytes + +Uses Arrow IPC (Inter-Process Communication) format. + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Arrow IPC Stream │ +│ ├── Schema: columns "0", "1", "2", ... (binary type) │ +│ └── RecordBatch │ +│ ├── Column "0": [entity0_feature0_bytes, ...] │ +│ ├── Column "1": [entity0_feature1_bytes, ...] │ +│ └── ... │ +└─────────────────────────────────────────────────────────────┘ +``` + +- Column names are feature indices as strings ("0", "1", "2", ...) +- Each cell contains raw binary feature bytes +- All entities in a single IPC blob ### Parquet Format -- Parquet file with Features column (`map[int][]byte`) -- Each row represents an entity -## Metadata Byte Layout +Uses Parquet columnar format. + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Parquet File │ +│ └── Column "Features": map │ +│ ├── Row 0: {0: bytes, 1: bytes, ...} │ +│ ├── Row 1: {0: bytes, 1: bytes, ...} │ +│ └── ... │ +└─────────────────────────────────────────────────────────────┘ +``` + +- Features column is a map from feature index to binary bytes +- Each row represents one entity +- Alternative: columnar format with index-named columns (like Arrow) + +### Metadata Byte Layout + +``` +Bit Layout: +┌─────┬─────┬─────┬─────┬─────┬─────┬─────┬─────┐ +│ 7 │ 6 │ 5 │ 4 │ 3 │ 2 │ 1 │ 0 │ +├─────┴─────┼─────┴─────┴─────┴─────┼─────┼─────┤ +│ Format │ Version │ Rsv │Comp │ +│ (2 bits) │ (4 bits) │ │ │ +└───────────┴───────────────────────┴─────┴─────┘ + +Format Type: + 00 = Proto + 01 = Arrow + 10 = Parquet + 11 = Reserved + +Version: 0-15 (4 bits) +Compression: 0 = disabled, 1 = enabled (zstd) +``` + +--- + +## Exception Handling + +### Exception Hierarchy + +``` +InferenceLoggingError (base) +├── SchemaFetchError # API request failed +├── SchemaNotFoundError # No features in response +├── DecodeError # Feature decoding failed +├── FormatError # Invalid format or parse error +└── ProtobufError # Protobuf parsing failed +``` + +### Example: Comprehensive Error Handling + +```python +from pyspark.sql import SparkSession +from inference_logging_client import ( + decode_mplog, + InferenceLoggingError, + SchemaFetchError, + SchemaNotFoundError, + FormatError, +) + +spark = SparkSession.builder.appName("decode").getOrCreate() + +try: + df = decode_mplog(data, "my-model", 1, spark) +except SchemaFetchError as e: + print(f"Failed to fetch schema: {e}") + # Check network, inference service availability +except SchemaNotFoundError as e: + print(f"Schema not found: {e}") + # Check model_proxy_id and version +except FormatError as e: + print(f"Invalid data format: {e}") + # Check data integrity, correct format +except ValueError as e: + print(f"Invalid parameter: {e}") + # Check version range (0-15) +except ImportError as e: + print(f"Missing dependency: {e}") + # Install zstandard if data is compressed +except InferenceLoggingError as e: + print(f"Decode error: {e}") + # Generic fallback +finally: + spark.stop() +``` + +--- + +## Command Line Interface + +### Basic Usage + +```bash +# Decode with auto-detection +inference-logging-client --model-proxy-id my-model --version 1 input.bin + +# Short form +inference-logging-client -m my-model -v 1 input.bin +``` + +### CLI Arguments + +| Argument | Short | Required | Default | Description | +|----------|-------|----------|---------|-------------| +| `input` | - | Yes | - | Input file or `-` for stdin | +| `--model-proxy-id` | `-m` | Yes | - | Model proxy config ID | +| `--version` | `-v` | Yes | - | Schema version | +| `--format` | `-f` | No | `auto` | Format: `proto`, `arrow`, `parquet`, `auto` | +| `--inference-host` | - | No | env/localhost | Inference service URL | +| `--hex` | - | No | - | Input is hex-encoded | +| `--base64` | - | No | - | Input is base64-encoded | +| `--no-decompress` | - | No | - | Skip zstd decompression | +| `--output` | `-o` | No | stdout | Output directory (CSV/JSON) | +| `--json` | - | No | - | Output as JSON | +| `--spark-master` | - | No | `local[*]` | Spark master URL | + +### Examples + +```bash +# Output to CSV directory +inference-logging-client -m my-model -v 1 input.bin -o output_dir + +# Output as JSON +inference-logging-client -m my-model -v 1 input.bin --json -The metadata byte encodes: -- Bit 0: compression flag (0=disabled, 1=enabled) -- Bit 1: reserved -- Bits 2-5: version (4 bits, 0-15) -- Bits 6-7: format type (00=proto, 01=arrow, 10=parquet) +# Read from stdin (base64 encoded) +echo "BASE64_DATA" | inference-logging-client -m my-model -v 1 --base64 - + +# Read from stdin (hex encoded) +cat hex_data.txt | inference-logging-client -m my-model -v 1 --hex - + +# Explicit Arrow format +inference-logging-client -m my-model -v 1 --format arrow input.bin + +# Custom inference host +inference-logging-client -m my-model -v 1 \ + --inference-host https://inference.prod.example.com \ + input.bin + +# Custom Spark master +inference-logging-client -m my-model -v 1 \ + --spark-master spark://master:7077 \ + input.bin + +# Skip decompression (for pre-decompressed data) +inference-logging-client -m my-model -v 1 --no-decompress input.bin +``` + +### CLI Output Format + +``` ++----------+----------+----------+----------+ +| entity_id| feature_1| feature_2| feature_3| ++----------+----------+----------+----------+ +| entity_0 | 1.5 | 2.5 | 3.5 | +| entity_1 | 4.5 | 5.5 | 6.5 | ++----------+----------+----------+----------+ + +--- Summary --- +Format: proto (from metadata) +Version: 1 +Compression: disabled +Rows: 2 +Columns: 4 +Features: feature_1, feature_2, feature_3... +``` + +--- + +## Advanced Usage Examples + +### Batch Processing with Schema Reuse + +```python +import os +import glob +from pyspark.sql import SparkSession +import inference_logging_client + +spark = SparkSession.builder.appName("batch-decode").getOrCreate() + +# Pre-fetch schema once +schema = inference_logging_client.get_feature_schema( + "batch-model", 2, "https://inference.example.com" +) + +def process_file(filepath): + with open(filepath, "rb") as f: + data = f.read() + + return inference_logging_client.decode_mplog( + log_data=data, + model_proxy_id="batch-model", + version=2, + spark=spark, + schema=schema # Reuse cached schema + ) + +# Process files sequentially +log_files = glob.glob("/data/logs/*.bin") +all_dfs = [process_file(f) for f in log_files] + +# Union all DataFrames +from functools import reduce +all_data = reduce(lambda a, b: a.union(b), all_dfs) +print(f"Total entities: {all_data.count()}") + +spark.stop() +``` + +### Feature Analysis Pipeline + +```python +from pyspark.sql import SparkSession +from pyspark.sql import functions as F +import inference_logging_client + +spark = SparkSession.builder.appName("analysis").getOrCreate() + +# Decode logs +df = inference_logging_client.decode_mplog_dataframe( + spark.read.parquet("logs.parquet"), + spark +) + +# Analyze vector features +embedding_col = "user_embedding" + +# Get embedding statistics +df.select( + F.size(F.col(embedding_col)).alias("dimension"), + F.aggregate(F.col(embedding_col), F.lit(0.0), lambda acc, x: acc + x).alias("sum") +).show() + +# Find entities with unusual embeddings (using array functions) +df.withColumn( + "embedding_norm", + F.sqrt(F.aggregate( + F.col(embedding_col), + F.lit(0.0), + lambda acc, x: acc + x * x + )) +).filter(F.col("embedding_norm") > 10.0).show() + +spark.stop() +``` + +### Integration with Feature Store + +```python +from pyspark.sql import SparkSession +import inference_logging_client + +spark = SparkSession.builder.appName("feature-compare").getOrCreate() + +# Decode inference logs +df = inference_logging_client.decode_mplog(data, "ranking-model", 1, spark) + +# Compare with feature store values +from feature_store import FeatureStoreClient + +fs = FeatureStoreClient() + +# Collect for comparison (for small datasets) +for row in df.collect(): + entity_id = row['entity_id'] + + # Get fresh features from store + fresh_features = fs.get_features(entity_id, ["feature_a", "feature_b"]) + + # Compare logged vs fresh + for feature_name in ["feature_a", "feature_b"]: + logged = row[feature_name] + fresh = fresh_features[feature_name] + + if logged != fresh: + print(f"Drift detected for {entity_id}.{feature_name}:") + print(f" Logged: {logged}") + print(f" Fresh: {fresh}") + +spark.stop() +``` + +### Custom Schema Source + +```python +from pyspark.sql import SparkSession +from inference_logging_client import decode_mplog, FeatureInfo + +spark = SparkSession.builder.appName("custom-schema").getOrCreate() + +# Define schema manually (useful for testing or offline processing) +custom_schema = [ + FeatureInfo(name="user_age", feature_type="INT32", index=0), + FeatureInfo(name="user_score", feature_type="FP32", index=1), + FeatureInfo(name="user_embedding", feature_type="FP32VECTOR", index=2), + FeatureInfo(name="user_category", feature_type="STRING", index=3), +] + +df = decode_mplog( + log_data=data, + model_proxy_id="my-model", # Not used when schema provided + version=1, # Not used when schema provided + spark=spark, + schema=custom_schema +) + +spark.stop() +``` + +--- + +## Architecture & Internals + +### Module Structure + +``` +inference_logging_client/ +├── __init__.py # Public API exports, decode_mplog(), decode_mplog_dataframe() +├── __main__.py # Module execution entry point +├── cli.py # Command-line interface +├── decoder.py # Core byte decoding, type conversion +├── exceptions.py # Exception classes +├── formats.py # Proto/Arrow/Parquet format decoders +├── io.py # Schema fetching, protobuf parsing +├── types.py # Data type definitions (Format, FeatureInfo, DecodedMPLog) +└── utils.py # Utility functions (type normalization, formatting) +``` + +### Decoding Flow + +``` + ┌──────────────────┐ + │ MPLog Bytes │ + │ (compressed?) │ + └────────┬─────────┘ + │ + ┌────────▼─────────┐ + │ Zstd Decompress │ + │ (if enabled) │ + └────────┬─────────┘ + │ + ┌────────▼─────────┐ + │ Parse Protobuf │ + │ (outer wrapper) │ + └────────┬─────────┘ + │ + ┌──────────────┼──────────────┐ + │ │ │ + ┌───────▼───────┐ ┌────▼────┐ ┌───────▼───────┐ + │ Proto Format │ │ Arrow │ │ Parquet Format│ + │ Decoder │ │ Decoder │ │ Decoder │ + └───────┬───────┘ └────┬────┘ └───────┬───────┘ + │ │ │ + └──────────────┼──────────────┘ + │ + ┌────────▼─────────┐ + │ Feature Schema │◄──── API Fetch + │ (cached) │ + └────────┬─────────┘ + │ + ┌────────▼─────────┐ + │ Decode Features │ + │ (by type) │ + └────────┬─────────┘ + │ + ┌────────▼─────────┐ + │ Spark DataFrame │ + └──────────────────┘ +``` + +### Schema Cache + +```python +# Thread-safe LRU cache with max 100 entries +# Key: (model_config_id, version) +# Value: list[FeatureInfo] + +# Cache is NOT keyed by host/path because schemas are canonical +# Same model+version = same schema regardless of which host serves it +``` + +--- + +## Troubleshooting + +### Common Issues + +#### "No features found in schema" + +``` +SchemaNotFoundError: No features found in schema for model_config_id=xxx, version=1 +``` + +**Causes:** +- Incorrect `model_proxy_id` +- Wrong `version` number +- Schema not yet registered + +**Solutions:** +1. Verify model_proxy_id matches exactly +2. Check available versions in inference service +3. Use `get_mplog_metadata()` to see the version in the data + +#### "Version out of valid range" + +``` +ValueError: Version 16 is out of valid range (0-15) +``` + +**Cause:** Version is encoded in 4 bits (0-15 only) + +**Solution:** Check the version number passed to decode functions + +#### "Data appears to be zstd-compressed but zstandard not installed" + +``` +ImportError: Data appears to be zstd-compressed but the 'zstandard' package is not installed. +``` + +**Solution:** +```bash +pip install zstandard +``` + +#### "Failed to read Arrow IPC data" + +**Causes:** +- Corrupted data +- Wrong format specified +- Incomplete data + +**Solutions:** +1. Use `format_type=None` for auto-detection +2. Check data integrity +3. Try `get_mplog_metadata()` to inspect format + +#### Empty DataFrame Returned + +**Causes:** +- No entities in the log +- All features decoded as None +- Schema mismatch + +**Solutions:** +1. Check `get_mplog_metadata()` to verify entity count +2. Verify schema matches data version +3. Check for decode warnings + +### Debug Mode + +```python +import warnings +import logging +from pyspark.sql import SparkSession + +# Enable all warnings +warnings.simplefilter("always") + +# Enable debug logging for HTTP requests +logging.basicConfig(level=logging.DEBUG) + +# Create Spark session with verbose logging +spark = SparkSession.builder \ + .appName("debug") \ + .config("spark.driver.extraJavaOptions", "-Dlog4j.logger.org.apache.spark=DEBUG") \ + .getOrCreate() + +# Inspect before decoding +import inference_logging_client + +metadata = inference_logging_client.get_mplog_metadata(data) +print(f"Format: {metadata.format_type}") +print(f"Version: {metadata.version}") +print(f"Entities: {len(metadata.entities)}") +print(f"Model: {metadata.model_proxy_config_id}") +``` + +--- ## Development +### Setup + ```bash -# Install in development mode -pip install -e . +# Clone repository +git clone https://github.com/Meesho/BharatMLStack.git +cd BharatMLStack/py-sdk/inference_logging_client + +# Create virtual environment +python -m venv venv +source venv/bin/activate -# Install with dev dependencies +# Install in editable mode with dev dependencies pip install -e ".[dev]" +``` + +### Running Tests -# Run tests +```bash pytest -# Format code -black src/ +# With coverage +pytest --cov=inference_logging_client --cov-report=html +``` + +### Code Formatting -# Lint code -ruff check src/ +```bash +# Format with black +black inference_logging_client/ + +# Lint with ruff +ruff check inference_logging_client/ ``` +### Building Package + +```bash +python -m build +``` + +--- + ## License MIT License @@ -188,3 +1200,9 @@ MIT License ## Contributing Contributions are welcome! Please feel free to submit a Pull Request. + +1. Fork the repository +2. Create your feature branch (`git checkout -b feature/amazing-feature`) +3. Commit your changes (`git commit -m 'Add amazing feature'`) +4. Push to the branch (`git push origin feature/amazing-feature`) +5. Open a Pull Request From a2ca5304c8bee0f6195022a84d0ba4e891a3f988 Mon Sep 17 00:00:00 2001 From: dhruvgupta-meesho Date: Wed, 4 Feb 2026 11:52:48 +0530 Subject: [PATCH 2/4] changed format to base64 --- .../inference_logging_client/__init__.py | 283 ++++++++++-------- .../inference_logging_client/formats.py | 121 ++++++++ 2 files changed, 271 insertions(+), 133 deletions(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/__init__.py b/py-sdk/inference_logging_client/inference_logging_client/__init__.py index 447e6482..89a7319c 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/__init__.py +++ b/py-sdk/inference_logging_client/inference_logging_client/__init__.py @@ -36,7 +36,14 @@ SchemaFetchError, SchemaNotFoundError, ) -from .formats import decode_arrow_format, decode_parquet_format, decode_proto_format +from .formats import ( + decode_arrow_format, + decode_arrow_features, + decode_parquet_format, + decode_parquet_features, + decode_proto_format, + decode_proto_features, +) from .io import clear_schema_cache, get_feature_schema, get_mplog_metadata, parse_mplog_protobuf from .types import FORMAT_TYPE_MAP, DecodedMPLog, FeatureInfo, Format from .utils import format_dataframe_floats, get_format_name, unpack_metadata_byte @@ -277,23 +284,45 @@ def decode_mplog_dataframe( # Key: (mp_config_id, version) only - host/path intentionally excluded as schemas are canonical schema_cache: dict[tuple[str, int], list[FeatureInfo]] = {} + def _extract_metadata_byte(metadata_data) -> int: + """Extract metadata byte from JSON array with base64-encoded string. + + Expected format: JSON array with single base64-encoded string, e.g., '["BQ=="]' + """ + if metadata_data is None: + return 0 + + # Handle JSON string format + if isinstance(metadata_data, str): + try: + parsed = json.loads(metadata_data) + if isinstance(parsed, list) and len(parsed) > 0: + decoded = base64.b64decode(parsed[0]) + if len(decoded) > 0: + return decoded[0] + except (json.JSONDecodeError, ValueError, TypeError): + pass + return 0 + + # Handle already-parsed list format + if isinstance(metadata_data, list) and len(metadata_data) > 0: + first_item = metadata_data[0] + if isinstance(first_item, str): + try: + decoded = base64.b64decode(first_item) + if len(decoded) > 0: + return decoded[0] + except (ValueError, TypeError): + pass + return 0 + + return 0 + # First pass: collect unique (mp_config_id, version) pairs for row in rows: # Extract metadata byte to get version metadata_data = row[metadata_column] - metadata_byte = 0 - if metadata_data is not None: - if isinstance(metadata_data, (int, float)): - metadata_byte = int(metadata_data) - elif isinstance(metadata_data, bytes) and len(metadata_data) > 0: - metadata_byte = metadata_data[0] - elif isinstance(metadata_data, (bytearray, memoryview)) and len(metadata_data) > 0: - metadata_byte = metadata_data[0] - elif isinstance(metadata_data, str): - try: - metadata_byte = int(metadata_data) - except ValueError: - pass + metadata_byte = _extract_metadata_byte(metadata_data) _, version, _ = unpack_metadata_byte(metadata_byte) @@ -318,50 +347,30 @@ def decode_mplog_dataframe( all_decoded_rows = [] + # Metadata columns to preserve + row_metadata_columns = [ + "prism_ingested_at", + "prism_extracted_at", + "created_at", + "mp_config_id", + "parent_entity", + "tracking_id", + "user_id", + "year", + "month", + "day", + "hour", + ] + for idx, row in enumerate(rows): - # Extract features bytes + # Extract features data features_data = row[features_column] if features_data is None: continue - # Convert features to bytes (handle base64, hex, or raw bytes) - features_bytes = None - if isinstance(features_data, bytes): - features_bytes = features_data - elif isinstance(features_data, str): - # Try base64 first - try: - features_bytes = base64.b64decode(features_data) - except Exception: - # Try hex - try: - features_bytes = bytes.fromhex(features_data) - except Exception: - # Try UTF-8 encoding - features_bytes = features_data.encode("utf-8") - elif isinstance(features_data, (bytearray, memoryview)): - features_bytes = bytes(features_data) - else: - continue - - if features_bytes is None or len(features_bytes) == 0: - continue - # Extract metadata byte metadata_data = row[metadata_column] - metadata_byte = 0 - if metadata_data is not None: - if isinstance(metadata_data, (int, float)): - metadata_byte = int(metadata_data) - elif isinstance(metadata_data, bytes) and len(metadata_data) > 0: - metadata_byte = metadata_data[0] - elif isinstance(metadata_data, (bytearray, memoryview)) and len(metadata_data) > 0: - metadata_byte = metadata_data[0] - elif isinstance(metadata_data, str): - try: - metadata_byte = int(metadata_data) - except ValueError: - pass + metadata_byte = _extract_metadata_byte(metadata_data) # Extract version from metadata byte _, version, _ = unpack_metadata_byte(metadata_byte) @@ -385,105 +394,113 @@ def decode_mplog_dataframe( cache_key = (mp_config_id, version) cached_schema = schema_cache.get(cache_key) - # Decode this row's features try: - # Attempt decompression if enabled - working_data = features_bytes - if decompress: - working_data = _decompress_zstd(features_bytes) - - # Parse protobuf to get format from metadata - parsed = parse_mplog_protobuf(working_data) - if parsed.format_type in FORMAT_TYPE_MAP: - detected_format = FORMAT_TYPE_MAP[parsed.format_type] + # Parse features JSON (expected format: JSON array of dicts with encoded_features) + if isinstance(features_data, str): + features_list = json.loads(features_data) else: - detected_format = Format.PROTO + features_list = features_data + + if not isinstance(features_list, list): + warnings.warn(f"Row {idx}: features is not a list, skipping", UserWarning) + continue + + # Get entities from row + entities_val = None + if "entities" in df_columns: + entities_raw = row["entities"] + if entities_raw is not None: + if isinstance(entities_raw, str): + try: + entities_val = json.loads(entities_raw) + except (json.JSONDecodeError, ValueError): + entities_val = [entities_raw] + elif isinstance(entities_raw, list): + entities_val = entities_raw + else: + entities_val = [entities_raw] # Use cached schema or fetch feature_schema = cached_schema if feature_schema is None: feature_schema = get_feature_schema(mp_config_id, version, inference_host) - # Decode based on format - if detected_format == Format.PROTO: - entity_ids, decoded_feature_rows = decode_proto_format(working_data, feature_schema) - elif detected_format == Format.ARROW: - entity_ids, decoded_feature_rows = decode_arrow_format(working_data, feature_schema) - elif detected_format == Format.PARQUET: - entity_ids, decoded_feature_rows = decode_parquet_format(working_data, feature_schema) + # Determine format type from metadata byte + # unpack_metadata_byte returns (compression_enabled, version, format_type) + _, _, format_type_num = unpack_metadata_byte(metadata_byte) + if format_type_num in FORMAT_TYPE_MAP: + detected_format = FORMAT_TYPE_MAP[format_type_num] else: - raise FormatError(f"Unsupported format: {detected_format}") - - # Add original row metadata to each decoded entity row - if decoded_feature_rows: - # Metadata columns to preserve - metadata_columns = [ - "prism_ingested_at", - "prism_extracted_at", - "created_at", - "mp_config_id", - "parent_entity", - "tracking_id", - "user_id", - "year", - "month", - "day", - "hour", - ] - - # Get entities from row if available - entities_val = None - if "entities" in df_columns: - entities_val = row["entities"] - if entities_val is not None: - if isinstance(entities_val, str): - try: - entities_val = json.loads(entities_val) - except (json.JSONDecodeError, ValueError): - entities_val = [entities_val] - elif not isinstance(entities_val, list): - entities_val = [entities_val] - - # Process parent_entity - parent_entity_val = None - if "parent_entity" in df_columns and row["parent_entity"] is not None: - parent_val = row["parent_entity"] - if isinstance(parent_val, str): - try: - parent_val = json.loads(parent_val) - except (json.JSONDecodeError, ValueError): - parent_val = [parent_val] - if isinstance(parent_val, list): - if len(parent_val) == 1: - parent_entity_val = parent_val[0] - elif len(parent_val) > 1: - parent_entity_val = str(parent_val) - else: - parent_entity_val = None + detected_format = Format.PROTO # Default to proto + + # Process parent_entity + parent_entity_val = None + if "parent_entity" in df_columns and row["parent_entity"] is not None: + parent_val = row["parent_entity"] + if isinstance(parent_val, str): + try: + parent_val = json.loads(parent_val) + except (json.JSONDecodeError, ValueError): + parent_val = [parent_val] + if isinstance(parent_val, list): + if len(parent_val) == 1: + parent_entity_val = parent_val[0] + elif len(parent_val) > 1: + parent_entity_val = str(parent_val) else: - parent_entity_val = parent_val + parent_entity_val = None + else: + parent_entity_val = parent_val + + # Process each entity's features + for i, feature_item in enumerate(features_list): + # Get entity_id from entities array or generate synthetic + entity_id = f"entity_{i}" + if entities_val and i < len(entities_val): + entity_id = str(entities_val[i]) + + # Get and decode base64 encoded_features + encoded_features_b64 = feature_item.get("encoded_features", "") + if not encoded_features_b64: + continue + + try: + encoded_bytes = base64.b64decode(encoded_features_b64) + except (ValueError, TypeError): + continue + + if len(encoded_bytes) == 0: + continue + + # Attempt decompression if enabled + working_data = encoded_bytes + if decompress: + working_data = _decompress_zstd(encoded_bytes) - for i, (entity_id, feature_row) in enumerate(zip(entity_ids, decoded_feature_rows)): - result_row = {"entity_id": entity_id} - result_row.update(feature_row) + # Decode features based on format type + if detected_format == Format.ARROW: + decoded_features = decode_arrow_features(working_data, feature_schema) + elif detected_format == Format.PARQUET: + decoded_features = decode_parquet_features(working_data, feature_schema) + else: + # Default to proto format + decoded_features = decode_proto_features(working_data, feature_schema) - # Add metadata columns - for col in metadata_columns: - if col in df_columns: - result_row[col] = row[col] + result_row = {"entity_id": entity_id} + result_row.update(decoded_features) - # Override entity_id from entities column if available and matches count - if entities_val and len(entities_val) == len(entity_ids): - result_row["entity_id"] = entities_val[i] + # Add metadata columns + for col in row_metadata_columns: + if col in df_columns: + result_row[col] = row[col] - # Set parent_entity - if parent_entity_val is not None: - result_row["parent_entity"] = parent_entity_val + # Set parent_entity + if parent_entity_val is not None: + result_row["parent_entity"] = parent_entity_val - all_decoded_rows.append(result_row) + all_decoded_rows.append(result_row) except Exception as e: - # Track error but continue processing other rows decode_errors.append((idx, str(e))) warnings.warn(f"Failed to decode row {idx}: {e}", UserWarning) continue diff --git a/py-sdk/inference_logging_client/inference_logging_client/formats.py b/py-sdk/inference_logging_client/inference_logging_client/formats.py index bb662dcd..39cd86d5 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/formats.py +++ b/py-sdk/inference_logging_client/inference_logging_client/formats.py @@ -16,6 +16,127 @@ from .exceptions import FormatError +def decode_arrow_features(encoded_bytes: bytes, schema: list[FeatureInfo]) -> dict[str, Any]: + """ + Decode Arrow IPC-encoded features for a single entity. + + Arrow encoding for single entity: + - Arrow IPC stream with a single row + - Columns named by feature index ("0", "1", ...) + - Each column contains raw feature value bytes + """ + if len(encoded_bytes) == 0: + return {f.name: None for f in schema} + + try: + reader = pa.ipc.open_stream(io.BytesIO(encoded_bytes)) + table = reader.read_all() + except Exception as e: + raise FormatError(f"Failed to read Arrow IPC data: {e}") + + if table.num_rows == 0: + return {f.name: None for f in schema} + + result = {} + for feature in schema: + col_name = str(feature.index) + + if col_name not in table.column_names: + result[feature.name] = None + continue + + column = table.column(col_name) + + if column.is_null()[0].as_py(): + result[feature.name] = None + continue + + value_bytes = column[0].as_py() + if value_bytes is None or len(value_bytes) == 0: + result[feature.name] = None + else: + result[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + + return result + + +def decode_parquet_features(encoded_bytes: bytes, schema: list[FeatureInfo]) -> dict[str, Any]: + """ + Decode Parquet-encoded features for a single entity. + + Parquet encoding for single entity: + - Parquet file with a single row + - Features column (map[int][]byte) or columnar format + """ + if len(encoded_bytes) == 0: + return {f.name: None for f in schema} + + try: + table = pq.read_table(io.BytesIO(encoded_bytes)) + except Exception as e: + raise FormatError(f"Failed to read Parquet data: {e}") + + if table.num_rows == 0: + return {f.name: None for f in schema} + + result = {} + + # Check for Features column (map format) + if "Features" in table.column_names: + features_col = table.column("Features") + feature_data = features_col[0].as_py() + + if feature_data is None: + return {f.name: None for f in schema} + + if isinstance(feature_data, dict): + for feature in schema: + value_bytes = feature_data.get(feature.index) + if value_bytes is None or len(value_bytes) == 0: + result[feature.name] = None + else: + result[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + elif isinstance(feature_data, list): + if len(feature_data) > 0 and isinstance(feature_data[0], tuple): + feature_map = {k: v for k, v in feature_data} + for feature in schema: + value_bytes = feature_map.get(feature.index) + if value_bytes is None or len(value_bytes) == 0: + result[feature.name] = None + else: + result[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + else: + for feature in schema: + if feature.index < len(feature_data): + value_bytes = feature_data[feature.index] + if value_bytes is None or (isinstance(value_bytes, (bytes, bytearray)) and len(value_bytes) == 0): + result[feature.name] = None + else: + result[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + else: + result[feature.name] = None + else: + return {f.name: None for f in schema} + else: + # Columnar format + for feature in schema: + col_name = str(feature.index) + if col_name in table.column_names: + column = table.column(col_name) + if column.is_null()[0].as_py(): + result[feature.name] = None + else: + value_bytes = column[0].as_py() + if value_bytes is None or len(value_bytes) == 0: + result[feature.name] = None + else: + result[feature.name] = decode_feature_value(value_bytes, feature.feature_type) + else: + result[feature.name] = None + + return result + + def decode_proto_features(encoded_bytes: bytes, schema: list[FeatureInfo]) -> dict[str, Any]: """ Decode proto-encoded features for a single entity. From ebd6311a0be17c4ec28b6b205e2cd1608ef34a54 Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 4 Feb 2026 14:50:53 +0530 Subject: [PATCH 3/4] Update pyproject.toml --- py-sdk/inference_logging_client/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-sdk/inference_logging_client/pyproject.toml b/py-sdk/inference_logging_client/pyproject.toml index 396ad071..344fcc01 100644 --- a/py-sdk/inference_logging_client/pyproject.toml +++ b/py-sdk/inference_logging_client/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "inference-logging-client" -version = "0.1.0" +version = "0.2.0" description = "Decode MPLog feature logs from proto, arrow, or parquet format" readme = "readme.md" requires-python = ">=3.8" From 3b206fd436f243800b48a7f4ef34269558ccc30f Mon Sep 17 00:00:00 2001 From: Dhruv Gupta Date: Wed, 4 Feb 2026 15:12:46 +0530 Subject: [PATCH 4/4] Update io.py coderabbit --- py-sdk/inference_logging_client/inference_logging_client/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-sdk/inference_logging_client/inference_logging_client/io.py b/py-sdk/inference_logging_client/inference_logging_client/io.py index 0a88d2c0..dea3d5b2 100644 --- a/py-sdk/inference_logging_client/inference_logging_client/io.py +++ b/py-sdk/inference_logging_client/inference_logging_client/io.py @@ -49,7 +49,7 @@ def _fetch_schema_with_retry(url: str, max_retries: int = _MAX_RETRIES) -> dict: url, headers={ "Content-Type": "application/json", - "User-Agent": "inference-logging-client/0.1.0", + "User-Agent": "inference-logging-client/0.2.0", }, ) with urllib.request.urlopen(req, timeout=30) as response: