Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
514 changes: 304 additions & 210 deletions py-sdk/inference_logging_client/inference_logging_client/__init__.py

Large diffs are not rendered by default.

171 changes: 103 additions & 68 deletions py-sdk/inference_logging_client/inference_logging_client/cli.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Command-line interface for inference-logging-client."""

import argparse
import base64
import os
import sys
import base64
import argparse

import pandas as pd

from . import decode_mplog, get_mplog_metadata, get_format_name, format_dataframe_floats
from . import decode_mplog, format_dataframe_floats, get_format_name, get_mplog_metadata
from .types import Format


Expand All @@ -29,43 +27,50 @@ def main():

# Output to CSV
inference-logging-client --model-proxy-id my-model --version 1 input.bin -o output.csv
"""
""",
)

parser.add_argument("input", help="Input file containing MPLog bytes (or - for stdin)")
parser.add_argument("--model-proxy-id", "-m", required=True,
help="Model proxy config ID")
parser.add_argument("--version", "-v", type=int, required=True,
help="Schema version")
parser.add_argument("--format", "-f", choices=["proto", "arrow", "parquet", "auto"],
default="auto",
help="Encoding format (default: auto-detect from metadata)")
parser.add_argument("--inference-host", default=None,
help="Inference service host URL (default: reads from INFERENCE_HOST env var or http://localhost:8082)")
parser.add_argument("--hex", action="store_true",
help="Input is hex-encoded string")
parser.add_argument("--base64", action="store_true",
help="Input is base64-encoded string")
parser.add_argument("--no-decompress", action="store_true",
help="Skip automatic zstd decompression")
parser.add_argument("--output", "-o",
help="Output file (CSV format, default: print to stdout)")
parser.add_argument("--json", action="store_true",
help="Output as JSON instead of CSV")

parser.add_argument("--model-proxy-id", "-m", required=True, help="Model proxy config ID")
parser.add_argument("--version", "-v", type=int, required=True, help="Schema version")
parser.add_argument(
"--format",
"-f",
choices=["proto", "arrow", "parquet", "auto"],
default="auto",
help="Encoding format (default: auto-detect from metadata)",
)
parser.add_argument(
"--inference-host",
default=None,
help="Inference service host URL (default: reads from INFERENCE_HOST env var or http://localhost:8082)",
)
parser.add_argument("--hex", action="store_true", help="Input is hex-encoded string")
parser.add_argument("--base64", action="store_true", help="Input is base64-encoded string")
parser.add_argument(
"--no-decompress", action="store_true", help="Skip automatic zstd decompression"
)
parser.add_argument("--output", "-o", help="Output file (CSV format, default: print to stdout)")
parser.add_argument("--json", action="store_true", help="Output as JSON instead of CSV")
parser.add_argument(
"--spark-master",
default="local[*]",
help="Spark master URL (default: local[*])",
)

args = parser.parse_args()

# Read input
if args.input == "-":
data = sys.stdin.buffer.read()
else:
with open(args.input, "rb") as f:
data = f.read()

# Decode input format
if args.hex:
try:
data = bytes.fromhex(data.decode('utf-8').strip())
data = bytes.fromhex(data.decode("utf-8").strip())
except ValueError as e:
print(f"Error: Invalid hex input: {e}", file=sys.stderr)
sys.exit(1)
Expand All @@ -75,59 +80,89 @@ def main():
except Exception as e:
print(f"Error: Invalid base64 input: {e}", file=sys.stderr)
sys.exit(1)

# Parse format (None for auto-detection)
format_type = None if args.format == "auto" else Format(args.format)

# Get inference host from argument or environment variable
inference_host = args.inference_host or os.getenv("INFERENCE_HOST", "http://localhost:8082")

# Decode MPLog

# Create SparkSession for CLI
from pyspark.sql import SparkSession

spark = SparkSession.builder \
.appName("inference-logging-client") \
.master(args.spark_master) \
.config("spark.ui.showConsoleProgress", "false") \
.config("spark.driver.memory", "2g") \
.getOrCreate()

# Suppress Spark logging for cleaner CLI output
spark.sparkContext.setLogLevel("ERROR")

try:
# Decode MPLog
df = decode_mplog(
log_data=data,
model_proxy_id=args.model_proxy_id,
version=args.version,
spark=spark,
format_type=format_type,
inference_host=inference_host,
decompress=not args.no_decompress
decompress=not args.no_decompress,
)

# Format floats before output
df = format_dataframe_floats(df)

# Output
if args.output:
if args.json:
# Write as JSON
df.coalesce(1).write.mode("overwrite").json(args.output)
else:
# Write as CSV
df.coalesce(1).write.mode("overwrite").option("header", "true").csv(args.output)
print(f"Output written to {args.output}")
else:
if args.json:
# Collect and print as JSON
import json
rows = [row.asDict() for row in df.collect()]
print(json.dumps(rows, indent=2, default=str))
else:
# Show table
df.show(truncate=False)

# Get metadata for summary
metadata = get_mplog_metadata(data, decompress=not args.no_decompress)
detected_format_name = get_format_name(metadata.format_type)

# Print summary
print("\n--- Summary ---", file=sys.stderr)
print(
f"Format: {detected_format_name} (from metadata)"
if args.format == "auto"
else f"Format: {args.format}",
file=sys.stderr,
)
print(f"Version: {metadata.version}", file=sys.stderr)
print(
f"Compression: {'enabled' if metadata.compression_enabled else 'disabled'}", file=sys.stderr
)
print(f"Rows: {df.count()}", file=sys.stderr)
print(f"Columns: {len(df.columns)}", file=sys.stderr)
col_preview = df.columns[1:5] if len(df.columns) > 1 else []
print(
f"Features: {', '.join(col_preview)}{'...' if len(df.columns) > 5 else ''}",
file=sys.stderr,
)
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
sys.exit(1)

# Format floats before output
df = format_dataframe_floats(df)

# Output
if args.output:
if args.json:
df.to_json(args.output, orient="records", indent=2)
else:
df.to_csv(args.output, index=False)
print(f"Output written to {args.output}")
else:
if args.json:
print(df.to_json(orient="records", indent=2))
else:
# Pretty print for terminal
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', 50)
print(df.to_string(index=False))

# Get metadata for summary
metadata = get_mplog_metadata(data, decompress=not args.no_decompress)
detected_format_name = get_format_name(metadata.format_type)

# Print summary
print(f"\n--- Summary ---", file=sys.stderr)
print(f"Format: {detected_format_name} (from metadata)" if args.format == "auto" else f"Format: {args.format}", file=sys.stderr)
print(f"Version: {metadata.version}", file=sys.stderr)
print(f"Compression: {'enabled' if metadata.compression_enabled else 'disabled'}", file=sys.stderr)
print(f"Rows: {len(df)}", file=sys.stderr)
print(f"Columns: {len(df.columns)}", file=sys.stderr)
print(f"Features: {', '.join(df.columns[1:5])}{'...' if len(df.columns) > 5 else ''}", file=sys.stderr)
finally:
# Stop SparkSession
spark.stop()


if __name__ == "__main__":
Expand Down
Loading
Loading