From 490ca2e9a5502f148176f78eeb77fbc2dfcc37e6 Mon Sep 17 00:00:00 2001 From: menonpg Date: Wed, 17 Jun 2026 20:18:12 -0400 Subject: [PATCH] feat(skills): add databricks-memory-prompts skill Add a skill for building AI applications with persistent memory using RAG + RLM (Recursive Language Modeling) architecture from soul.py. Components: - SKILL.md: Main skill with quick start and architecture overview - 1-memory-schema.md: Unity Catalog DDL for decisions, patterns, feedback - 2-vector-search-setup.md: Vector Search index configuration - 3-learning-pipeline.md: Lakeflow pipeline for pattern extraction - 4-mlflow-integration.md: Prompt Registry + Tracing integration The core pattern: 1. Store what you learn (patterns, decisions) in Unity Catalog 2. Retrieve relevant context when building prompts 3. Inject context into prompts before calling the LLM Based on: github.com/menonpg/soul.py Contributed by: ThinkCreate.AI (thinkcreate.ai) --- .../1-memory-schema.md | 205 +++++++++ .../2-vector-search-setup.md | 233 +++++++++++ .../3-learning-pipeline.md | 308 ++++++++++++++ .../4-mlflow-integration.md | 390 ++++++++++++++++++ .../databricks-memory-prompts/SKILL.md | 289 +++++++++++++ 5 files changed, 1425 insertions(+) create mode 100644 databricks-skills/databricks-memory-prompts/1-memory-schema.md create mode 100644 databricks-skills/databricks-memory-prompts/2-vector-search-setup.md create mode 100644 databricks-skills/databricks-memory-prompts/3-learning-pipeline.md create mode 100644 databricks-skills/databricks-memory-prompts/4-mlflow-integration.md create mode 100644 databricks-skills/databricks-memory-prompts/SKILL.md diff --git a/databricks-skills/databricks-memory-prompts/1-memory-schema.md b/databricks-skills/databricks-memory-prompts/1-memory-schema.md new file mode 100644 index 00000000..3a862774 --- /dev/null +++ b/databricks-skills/databricks-memory-prompts/1-memory-schema.md @@ -0,0 +1,205 @@ +# Memory Schema Reference + +Complete DDL for memory tables in Unity Catalog. + +## Core Tables + +### decisions + +Explicit choices made during development or operation. + +```sql +CREATE TABLE IF NOT EXISTS my_catalog.memory.decisions ( + id STRING DEFAULT uuid() NOT NULL, + decision TEXT NOT NULL COMMENT 'The decision that was made', + context TEXT COMMENT 'What situation led to this decision', + rationale TEXT COMMENT 'Why this decision was made', + alternatives TEXT COMMENT 'Other options considered', + created_at TIMESTAMP DEFAULT current_timestamp(), + created_by STRING DEFAULT current_user(), + confidence DOUBLE DEFAULT 1.0 COMMENT '0.0-1.0 confidence score', + tags ARRAY COMMENT 'Categorization tags', + task_scope STRING COMMENT 'Which task/domain this applies to', + expires_at TIMESTAMP COMMENT 'Optional expiration for time-bound decisions', + CONSTRAINT decisions_pk PRIMARY KEY (id) +) +USING DELTA +COMMENT 'Explicit decisions that inform prompt construction' +TBLPROPERTIES ( + 'delta.autoOptimize.optimizeWrite' = 'true', + 'delta.autoOptimize.autoCompact' = 'true' +); + +-- Index for task-scoped queries +CREATE INDEX IF NOT EXISTS idx_decisions_task +ON my_catalog.memory.decisions (task_scope, created_at DESC); +``` + +### patterns + +Learned behaviors extracted from production data or feedback. + +```sql +CREATE TABLE IF NOT EXISTS my_catalog.memory.patterns ( + id STRING DEFAULT uuid() NOT NULL, + pattern TEXT NOT NULL COMMENT 'The learned pattern', + evidence TEXT COMMENT 'Examples or proof of this pattern', + source STRING COMMENT 'Where this pattern was learned from', + frequency INT DEFAULT 1 COMMENT 'How often this pattern has been observed', + first_seen TIMESTAMP DEFAULT current_timestamp(), + last_seen TIMESTAMP DEFAULT current_timestamp(), + confidence DOUBLE COMMENT 'Confidence score, often based on frequency', + tags ARRAY, + task_scope STRING, + superseded_by STRING COMMENT 'ID of pattern that replaces this one', + CONSTRAINT patterns_pk PRIMARY KEY (id) +) +USING DELTA +COMMENT 'Patterns learned from production data and feedback' +TBLPROPERTIES ( + 'delta.autoOptimize.optimizeWrite' = 'true', + 'delta.autoOptimize.autoCompact' = 'true' +); + +-- Index for high-frequency patterns +CREATE INDEX IF NOT EXISTS idx_patterns_freq +ON my_catalog.memory.patterns (task_scope, frequency DESC); +``` + +### feedback + +User corrections, preferences, and complaints. + +```sql +CREATE TABLE IF NOT EXISTS my_catalog.memory.feedback ( + id STRING DEFAULT uuid() NOT NULL, + run_id STRING COMMENT 'MLflow run ID if available', + trace_id STRING COMMENT 'MLflow trace ID if available', + input_text TEXT COMMENT 'The input that produced the output', + input_hash STRING COMMENT 'Hash of input for deduplication', + output_text TEXT COMMENT 'What the model produced', + correction TEXT COMMENT 'What the user said it should be', + feedback_type STRING NOT NULL COMMENT 'correction, preference, complaint, praise', + severity STRING DEFAULT 'medium' COMMENT 'low, medium, high, critical', + created_at TIMESTAMP DEFAULT current_timestamp(), + created_by STRING, + resolved BOOLEAN DEFAULT false, + resolved_by_pattern_id STRING COMMENT 'Pattern created to address this', + tags ARRAY, + task_scope STRING, + CONSTRAINT feedback_pk PRIMARY KEY (id) +) +USING DELTA +COMMENT 'User feedback on model outputs' +TBLPROPERTIES ( + 'delta.autoOptimize.optimizeWrite' = 'true', + 'delta.autoOptimize.autoCompact' = 'true' +); + +-- Index for unresolved feedback +CREATE INDEX IF NOT EXISTS idx_feedback_unresolved +ON my_catalog.memory.feedback (resolved, task_scope, created_at DESC); +``` + +## Supporting Tables + +### memory_embeddings + +Unified embedding table for Vector Search. + +```sql +CREATE TABLE IF NOT EXISTS my_catalog.memory.memory_embeddings ( + id STRING NOT NULL, + memory_type STRING NOT NULL COMMENT 'decision, pattern, or feedback', + source_id STRING NOT NULL COMMENT 'ID in source table', + content TEXT NOT NULL COMMENT 'Text to embed', + embedding ARRAY COMMENT 'Vector embedding', + task_scope STRING, + confidence DOUBLE, + created_at TIMESTAMP DEFAULT current_timestamp(), + CONSTRAINT embeddings_pk PRIMARY KEY (id) +) +USING DELTA +COMMENT 'Unified embeddings for semantic memory search'; + +-- Create Vector Search index +-- Note: Run this separately after table creation +/* +CREATE VECTOR SEARCH INDEX memory_index +ON my_catalog.memory.memory_embeddings (embedding) +USING 'databricks-gte-large-en' +OPTIONS ( + metric_type = 'cosine', + num_dimensions = 1024 +); +*/ +``` + +### prompt_memory_links + +Track which memories influenced which prompt versions. + +```sql +CREATE TABLE IF NOT EXISTS my_catalog.memory.prompt_memory_links ( + id STRING DEFAULT uuid() NOT NULL, + prompt_name STRING NOT NULL COMMENT 'MLflow prompt registry name', + prompt_version INT NOT NULL, + memory_type STRING NOT NULL, + memory_id STRING NOT NULL, + influence_score DOUBLE COMMENT 'How much this memory influenced the prompt', + created_at TIMESTAMP DEFAULT current_timestamp(), + CONSTRAINT links_pk PRIMARY KEY (id) +) +USING DELTA +COMMENT 'Links between prompts and the memories that shaped them'; +``` + +## Retention Policies + +```sql +-- Clean up old, low-confidence patterns (run weekly) +DELETE FROM my_catalog.memory.patterns +WHERE confidence < 0.3 + AND last_seen < current_timestamp() - INTERVAL 90 DAYS + AND superseded_by IS NOT NULL; + +-- Archive resolved feedback older than 1 year +-- (Move to archive table, not shown here) + +-- Decay confidence for stale patterns +UPDATE my_catalog.memory.patterns +SET confidence = confidence * 0.95 +WHERE last_seen < current_timestamp() - INTERVAL 30 DAYS + AND confidence > 0.1; +``` + +## Migration Script + +```python +def create_memory_schema(catalog: str, schema: str): + """Create all memory tables in the specified location.""" + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + + # Create schema if not exists + spark.sql(f"CREATE SCHEMA IF NOT EXISTS {catalog}.{schema}") + + # Read and execute DDL + ddl_statements = [ + # decisions table + f"""CREATE TABLE IF NOT EXISTS {catalog}.{schema}.decisions ...""", + # patterns table + f"""CREATE TABLE IF NOT EXISTS {catalog}.{schema}.patterns ...""", + # feedback table + f"""CREATE TABLE IF NOT EXISTS {catalog}.{schema}.feedback ...""", + # embeddings table + f"""CREATE TABLE IF NOT EXISTS {catalog}.{schema}.memory_embeddings ...""", + # links table + f"""CREATE TABLE IF NOT EXISTS {catalog}.{schema}.prompt_memory_links ...""", + ] + + for ddl in ddl_statements: + spark.sql(ddl) + + print(f"Memory schema created at {catalog}.{schema}") +``` diff --git a/databricks-skills/databricks-memory-prompts/2-vector-search-setup.md b/databricks-skills/databricks-memory-prompts/2-vector-search-setup.md new file mode 100644 index 00000000..966bd6b2 --- /dev/null +++ b/databricks-skills/databricks-memory-prompts/2-vector-search-setup.md @@ -0,0 +1,233 @@ +# Vector Search Setup for Memory + +Configure Databricks Vector Search to enable semantic retrieval over memory tables. + +## Overview + +Memory retrieval works best with semantic search — finding memories by meaning, not just keywords. This guide shows how to set up Vector Search over the unified `memory_embeddings` table. + +## Prerequisites + +- Unity Catalog enabled workspace +- Vector Search endpoint (or create one) +- Memory tables created (see [1-memory-schema.md](1-memory-schema.md)) + +## Step 1: Create Vector Search Endpoint + +```python +from databricks.vector_search.client import VectorSearchClient + +vsc = VectorSearchClient() + +# Create endpoint if it doesn't exist +try: + vsc.create_endpoint( + name="memory_search_endpoint", + endpoint_type="STANDARD" # or "PERFORMANCE" for production + ) +except Exception as e: + if "already exists" not in str(e): + raise +``` + +## Step 2: Create Sync Pipeline + +The embeddings table needs to sync from source tables. Use a Lakeflow pipeline: + +```python +# In a Databricks notebook with DLT enabled + +import dlt +from pyspark.sql.functions import col, lit, concat_ws, coalesce + +@dlt.table( + name="memory_embeddings_staging", + comment="Staging table for memory embeddings" +) +def memory_embeddings_staging(): + # Combine all memory sources + decisions = ( + spark.table("my_catalog.memory.decisions") + .select( + col("id").alias("source_id"), + lit("decision").alias("memory_type"), + concat_ws(" | ", + col("decision"), + coalesce(col("context"), lit("")), + coalesce(col("rationale"), lit("")) + ).alias("content"), + col("task_scope"), + col("confidence"), + col("created_at") + ) + ) + + patterns = ( + spark.table("my_catalog.memory.patterns") + .select( + col("id").alias("source_id"), + lit("pattern").alias("memory_type"), + concat_ws(" | ", + col("pattern"), + coalesce(col("evidence"), lit("")) + ).alias("content"), + col("task_scope"), + col("confidence"), + col("last_seen").alias("created_at") + ) + ) + + feedback = ( + spark.table("my_catalog.memory.feedback") + .filter(col("resolved") == False) # Only unresolved feedback + .select( + col("id").alias("source_id"), + lit("feedback").alias("memory_type"), + concat_ws(" | ", + col("feedback_type"), + coalesce(col("correction"), lit("")), + coalesce(col("output_text"), lit("")) + ).alias("content"), + col("task_scope"), + lit(0.5).alias("confidence"), # Default confidence for feedback + col("created_at") + ) + ) + + return decisions.union(patterns).union(feedback) +``` + +## Step 3: Create Vector Search Index + +```python +from databricks.vector_search.client import VectorSearchClient + +vsc = VectorSearchClient() + +# Create the index with managed embeddings +index = vsc.create_delta_sync_index( + endpoint_name="memory_search_endpoint", + source_table_name="my_catalog.memory.memory_embeddings_staging", + index_name="my_catalog.memory.memory_index", + primary_key="source_id", + pipeline_type="TRIGGERED", # or "CONTINUOUS" for real-time + embedding_source_column="content", + embedding_model_endpoint_name="databricks-gte-large-en" +) + +print(f"Index created: {index.name}") +``` + +## Step 4: Query the Index + +```python +def search_memory( + query: str, + task_scope: str = None, + memory_types: list = None, + min_confidence: float = 0.0, + k: int = 10 +) -> list: + """ + Search memory for relevant context. + + Args: + query: The search query + task_scope: Filter to specific task (optional) + memory_types: Filter to specific types (optional) + min_confidence: Minimum confidence threshold + k: Number of results to return + + Returns: + List of memory items with scores + """ + vsc = VectorSearchClient() + index = vsc.get_index("memory_search_endpoint", "my_catalog.memory.memory_index") + + # Build filters + filters = {} + if task_scope: + filters["task_scope"] = task_scope + if memory_types: + filters["memory_type"] = {"$in": memory_types} + if min_confidence > 0: + filters["confidence"] = {"$gte": min_confidence} + + # Execute search + results = index.similarity_search( + query_text=query, + columns=["source_id", "memory_type", "content", "confidence", "task_scope"], + filters=filters if filters else None, + num_results=k + ) + + # Parse results + memories = [] + for row in results.get("result", {}).get("data_array", []): + memories.append({ + "id": row[0], + "type": row[1], + "content": row[2], + "confidence": row[3], + "task_scope": row[4], + "score": row[5] if len(row) > 5 else None + }) + + return memories +``` + +## Step 5: Sync Schedule + +For production, set up a scheduled sync: + +```python +# Trigger sync manually +index.sync() + +# Or set up a scheduled job +# The TRIGGERED pipeline will sync whenever source tables change +``` + +## Index Management + +### Check Sync Status + +```python +status = index.describe() +print(f"Status: {status['status']['ready']}") +print(f"Indexed rows: {status['status'].get('indexed_row_count', 'N/A')}") +``` + +### Force Resync + +```python +# If embeddings get out of sync +index.sync() +``` + +### Delete and Recreate + +```python +# For schema changes +vsc.delete_index("memory_search_endpoint", "my_catalog.memory.memory_index") +# Then recreate with Step 3 +``` + +## Performance Tuning + +| Setting | Development | Production | +|---------|-------------|------------| +| Endpoint type | STANDARD | PERFORMANCE | +| Pipeline type | TRIGGERED | CONTINUOUS | +| num_results | 5-10 | 3-5 (latency-sensitive) | +| Sync frequency | On-demand | Real-time | + +## Troubleshooting + +| Issue | Solution | +|-------|----------| +| Index not ready | Wait for sync to complete; check `index.describe()` | +| Empty results | Verify source tables have data; check filter syntax | +| Slow queries | Reduce `num_results`; use PERFORMANCE endpoint | +| Stale results | Trigger manual sync with `index.sync()` | +| Embedding errors | Ensure `content` column has valid text (no nulls) | diff --git a/databricks-skills/databricks-memory-prompts/3-learning-pipeline.md b/databricks-skills/databricks-memory-prompts/3-learning-pipeline.md new file mode 100644 index 00000000..14d01ee0 --- /dev/null +++ b/databricks-skills/databricks-memory-prompts/3-learning-pipeline.md @@ -0,0 +1,308 @@ +# Continuous Learning Pipeline + +A Lakeflow Declarative Pipeline that extracts patterns from feedback and maintains memory health. + +## Overview + +This pipeline runs on a schedule (daily or weekly) to: +1. Extract patterns from accumulated feedback +2. Update confidence scores based on usage +3. Decay stale memories +4. Generate memory health reports + +## Pipeline Definition + +```yaml +# learning_pipeline.yml +name: memory_learning_pipeline +catalog: ${catalog} +schema: ${schema} + +clusters: + - label: default + num_workers: 2 + spark_version: "15.4.x-scala2.12" + +libraries: + - notebook: + path: /Workspace/memory/learning_pipeline + +schedule: + quartz_cron_expression: "0 0 2 * * ?" # Daily at 2 AM + timezone_id: "UTC" +``` + +## Pipeline Notebook + +```python +# learning_pipeline notebook + +import dlt +from pyspark.sql.functions import ( + col, expr, count, avg, max as max_, + current_timestamp, lit, when, concat_ws +) + +# Configuration +PATTERN_MIN_FREQUENCY = 3 # Minimum occurrences to extract pattern +CONFIDENCE_DECAY_RATE = 0.95 # Weekly decay multiplier +STALE_THRESHOLD_DAYS = 30 + +@dlt.table( + name="feedback_aggregates", + comment="Aggregated feedback ready for pattern extraction" +) +def feedback_aggregates(): + """Group similar feedback for pattern extraction.""" + return ( + spark.table("my_catalog.memory.feedback") + .filter(col("resolved") == False) + .filter(col("created_at") > expr("current_timestamp() - INTERVAL 30 DAYS")) + .groupBy("task_scope", "feedback_type") + .agg( + count("*").alias("frequency"), + concat_ws(" || ", expr("collect_list(correction)")).alias("corrections"), + concat_ws(" || ", expr("collect_list(output_text)")).alias("outputs"), + max_("created_at").alias("latest") + ) + .filter(col("frequency") >= PATTERN_MIN_FREQUENCY) + ) + +@dlt.table( + name="extracted_patterns", + comment="Patterns extracted from feedback using LLM" +) +def extracted_patterns(): + """Use LLM to extract patterns from aggregated feedback.""" + return ( + dlt.read("feedback_aggregates") + .withColumn("pattern", expr(f""" + ai_query( + 'databricks-meta-llama-3-3-70b-instruct', + concat( + 'You are analyzing user feedback to extract a reusable pattern. ', + 'Task scope: ', task_scope, '. ', + 'Feedback type: ', feedback_type, '. ', + 'Number of occurrences: ', frequency, '. ', + 'Sample corrections: ', substr(corrections, 1, 500), '. ', + 'Extract ONE concise pattern that would prevent this feedback. ', + 'Format: A single sentence describing what to do or avoid. ', + 'Do not include explanations or preamble.' + ) + ) + """)) + .select( + expr("uuid()").alias("id"), + col("pattern"), + col("corrections").alias("evidence"), + lit("feedback_extraction").alias("source"), + col("frequency"), + col("task_scope"), + expr("0.5 + (0.5 * least(frequency / 10.0, 1.0))").alias("confidence") + ) + ) + +@dlt.table( + name="pattern_updates", + comment="Updates to apply to patterns table" +) +def pattern_updates(): + """ + Merge extracted patterns with existing ones. + - New patterns: insert + - Existing similar patterns: increment frequency + """ + extracted = dlt.read("extracted_patterns") + existing = spark.table("my_catalog.memory.patterns") + + # Find similar patterns using semantic similarity + return ( + extracted.alias("new") + .join( + existing.alias("old"), + expr("ai_similarity(new.pattern, old.pattern) > 0.85"), + "left" + ) + .select( + when(col("old.id").isNotNull(), col("old.id")) + .otherwise(col("new.id")).alias("id"), + when(col("old.id").isNotNull(), col("old.pattern")) + .otherwise(col("new.pattern")).alias("pattern"), + col("new.evidence"), + col("new.source"), + when(col("old.id").isNotNull(), col("old.frequency") + col("new.frequency")) + .otherwise(col("new.frequency")).alias("frequency"), + when(col("old.id").isNotNull(), col("old.first_seen")) + .otherwise(current_timestamp()).alias("first_seen"), + current_timestamp().alias("last_seen"), + when(col("old.id").isNotNull(), + expr("greatest(old.confidence, new.confidence)")) + .otherwise(col("new.confidence")).alias("confidence"), + col("new.task_scope"), + lit(None).alias("superseded_by"), + col("old.id").isNotNull().alias("is_update") + ) + ) + +@dlt.table( + name="confidence_decay", + comment="Patterns with decayed confidence" +) +def confidence_decay(): + """Apply confidence decay to stale patterns.""" + return ( + spark.table("my_catalog.memory.patterns") + .filter(col("last_seen") < expr(f"current_timestamp() - INTERVAL {STALE_THRESHOLD_DAYS} DAYS")) + .filter(col("confidence") > 0.1) + .select( + col("id"), + (col("confidence") * CONFIDENCE_DECAY_RATE).alias("new_confidence") + ) + ) + +@dlt.table( + name="memory_health_report", + comment="Daily memory health metrics" +) +def memory_health_report(): + """Generate health metrics for monitoring.""" + decisions = spark.table("my_catalog.memory.decisions") + patterns = spark.table("my_catalog.memory.patterns") + feedback = spark.table("my_catalog.memory.feedback") + + return spark.sql(f""" + SELECT + current_timestamp() AS report_time, + + -- Decisions metrics + (SELECT COUNT(*) FROM my_catalog.memory.decisions) AS total_decisions, + (SELECT COUNT(*) FROM my_catalog.memory.decisions + WHERE created_at > current_timestamp() - INTERVAL 7 DAYS) AS new_decisions_7d, + (SELECT AVG(confidence) FROM my_catalog.memory.decisions) AS avg_decision_confidence, + + -- Patterns metrics + (SELECT COUNT(*) FROM my_catalog.memory.patterns) AS total_patterns, + (SELECT COUNT(*) FROM my_catalog.memory.patterns + WHERE last_seen > current_timestamp() - INTERVAL 7 DAYS) AS active_patterns_7d, + (SELECT AVG(confidence) FROM my_catalog.memory.patterns) AS avg_pattern_confidence, + (SELECT AVG(frequency) FROM my_catalog.memory.patterns) AS avg_pattern_frequency, + + -- Feedback metrics + (SELECT COUNT(*) FROM my_catalog.memory.feedback) AS total_feedback, + (SELECT COUNT(*) FROM my_catalog.memory.feedback WHERE resolved = false) AS unresolved_feedback, + (SELECT COUNT(*) FROM my_catalog.memory.feedback + WHERE created_at > current_timestamp() - INTERVAL 7 DAYS) AS new_feedback_7d, + + -- Health indicators + (SELECT COUNT(*) FROM my_catalog.memory.patterns + WHERE confidence < 0.3) AS low_confidence_patterns, + (SELECT COUNT(*) FROM my_catalog.memory.patterns + WHERE last_seen < current_timestamp() - INTERVAL 90 DAYS) AS stale_patterns + """) +``` + +## Apply Updates + +After the pipeline runs, apply the changes: + +```python +# Run as a separate job after pipeline completes + +from pyspark.sql import SparkSession +spark = SparkSession.builder.getOrCreate() + +# Insert new patterns, update existing ones +spark.sql(""" + MERGE INTO my_catalog.memory.patterns AS target + USING my_catalog.memory.pattern_updates AS source + ON target.id = source.id + WHEN MATCHED AND source.is_update = true THEN UPDATE SET + frequency = source.frequency, + last_seen = source.last_seen, + confidence = source.confidence, + evidence = concat(target.evidence, ' || ', source.evidence) + WHEN NOT MATCHED THEN INSERT * +""") + +# Apply confidence decay +spark.sql(""" + MERGE INTO my_catalog.memory.patterns AS target + USING my_catalog.memory.confidence_decay AS source + ON target.id = source.id + WHEN MATCHED THEN UPDATE SET + confidence = source.new_confidence +""") + +# Mark feedback as resolved where patterns were created +spark.sql(""" + UPDATE my_catalog.memory.feedback + SET resolved = true, + resolved_by_pattern_id = ( + SELECT p.id FROM my_catalog.memory.patterns p + WHERE ai_similarity(feedback.correction, p.pattern) > 0.8 + ORDER BY p.last_seen DESC + LIMIT 1 + ) + WHERE resolved = false + AND EXISTS ( + SELECT 1 FROM my_catalog.memory.patterns p + WHERE ai_similarity(feedback.correction, p.pattern) > 0.8 + ) +""") + +print("Memory updates applied successfully") +``` + +## Monitoring + +### Alerts + +Set up alerts on the health report: + +```python +# In a monitoring notebook +health = spark.table("my_catalog.memory.memory_health_report").orderBy(col("report_time").desc()).first() + +alerts = [] +if health.unresolved_feedback > 50: + alerts.append(f"High unresolved feedback: {health.unresolved_feedback}") +if health.low_confidence_patterns > 20: + alerts.append(f"Many low-confidence patterns: {health.low_confidence_patterns}") +if health.stale_patterns > 30: + alerts.append(f"Many stale patterns: {health.stale_patterns}") + +if alerts: + # Send to Slack/email/etc + notify(alerts) +``` + +### Dashboard + +Create a SQL dashboard with: + +```sql +-- Memory growth over time +SELECT DATE(report_time) AS date, + total_decisions, + total_patterns, + total_feedback +FROM my_catalog.memory.memory_health_report +ORDER BY date; + +-- Pattern confidence distribution +SELECT + CASE + WHEN confidence >= 0.8 THEN 'High (0.8+)' + WHEN confidence >= 0.5 THEN 'Medium (0.5-0.8)' + ELSE 'Low (<0.5)' + END AS confidence_band, + COUNT(*) AS pattern_count +FROM my_catalog.memory.patterns +GROUP BY 1; + +-- Top patterns by usage +SELECT pattern, frequency, confidence, last_seen +FROM my_catalog.memory.patterns +ORDER BY frequency DESC +LIMIT 20; +``` diff --git a/databricks-skills/databricks-memory-prompts/4-mlflow-integration.md b/databricks-skills/databricks-memory-prompts/4-mlflow-integration.md new file mode 100644 index 00000000..3916acc0 --- /dev/null +++ b/databricks-skills/databricks-memory-prompts/4-mlflow-integration.md @@ -0,0 +1,390 @@ +# MLflow Integration + +Deep integration between memory-aware prompts and MLflow Prompt Registry + Tracing. + +## Overview + +This guide shows how to: +1. Link prompts to the memories that shaped them +2. Log memory context in MLflow traces +3. Use experiments to compare memory variants +4. Track prompt evolution over time + +## Prompt Registry Integration + +### Register with Memory Context + +```python +import mlflow +from typing import Optional +import json + +def register_memory_prompt( + name: str, + template: str, + task_description: str, + enhancer: "MemoryPromptEnhancer", + commit_message: str, + alias: Optional[str] = None +) -> "mlflow.genai.Prompt": + """ + Register a prompt with full memory context tracking. + + This creates: + 1. A new prompt version in the registry + 2. An MLflow run logging the memory context + 3. Links in the prompt_memory_links table + """ + # Retrieve relevant memories + context = enhancer.retrieve_context(task_description, k=10) + + # Build memory summary for tags + memory_summary = { + "decisions_count": len(context.get("decisions", [])), + "patterns_count": len(context.get("patterns", [])), + "feedback_count": len(context.get("feedback", [])), + "top_decision": context["decisions"][0]["content"][:100] if context.get("decisions") else "", + "top_pattern": context["patterns"][0]["content"][:100] if context.get("patterns") else "", + } + + # Register the prompt + prompt = mlflow.genai.register_prompt( + name=name, + template=template, + commit_message=commit_message, + tags={ + "memory_enhanced": "true", + "task": task_description, + **{f"memory.{k}": str(v) for k, v in memory_summary.items()} + } + ) + + # Log detailed context in an MLflow run + with mlflow.start_run(run_name=f"prompt_reg_{name.split('.')[-1]}_v{prompt.version}"): + mlflow.log_param("prompt_name", name) + mlflow.log_param("prompt_version", prompt.version) + mlflow.log_param("task_description", task_description) + + # Log full memory context as artifact + mlflow.log_dict(context, "memory_context.json") + + # Log individual memory IDs for lineage + all_memory_ids = [] + for mem_type, memories in context.items(): + for i, mem in enumerate(memories): + mlflow.log_param(f"{mem_type}_{i}_id", mem.get("id", "unknown")) + all_memory_ids.append((mem_type, mem.get("id"))) + + mlflow.log_metric("total_memories_used", len(all_memory_ids)) + + # Record links in database + _record_memory_links(name, prompt.version, all_memory_ids) + + # Set alias if provided + if alias: + mlflow.genai.set_prompt_alias(name=name, alias=alias, version=prompt.version) + + return prompt + + +def _record_memory_links(prompt_name: str, version: int, memory_ids: list): + """Record which memories influenced this prompt version.""" + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + + for mem_type, mem_id in memory_ids: + spark.sql(f""" + INSERT INTO my_catalog.memory.prompt_memory_links + (prompt_name, prompt_version, memory_type, memory_id, influence_score) + VALUES ('{prompt_name}', {version}, '{mem_type}', '{mem_id}', 1.0) + """) +``` + +### Load with Memory Refresh + +```python +def load_memory_prompt( + name_or_uri: str, + enhancer: "MemoryPromptEnhancer", + task_description: str, + refresh_memory: bool = False +) -> tuple: + """ + Load a prompt, optionally refreshing with latest memory. + + Returns: + (prompt_template, memory_context) + """ + # Load from registry + prompt = mlflow.genai.load_prompt(name_or_uri) + + if refresh_memory: + # Get fresh memory context + context = enhancer.retrieve_context(task_description) + # Enhance the template + enhanced_template = enhancer.enhance_prompt(prompt.template, task_description) + return enhanced_template, context + else: + return prompt.template, None +``` + +## Tracing Integration + +### Memory-Aware Spans + +```python +import mlflow +from mlflow import trace + +@trace +def memory_enhanced_generation( + input_text: str, + task: str, + enhancer: "MemoryPromptEnhancer", + model: str = "databricks-meta-llama-3-3-70b-instruct" +): + """Generate with memory context logged to trace.""" + + # Retrieve memory + with mlflow.start_span(name="memory_retrieval") as span: + context = enhancer.retrieve_context(task) + span.set_attributes({ + "memory.decisions_count": len(context.get("decisions", [])), + "memory.patterns_count": len(context.get("patterns", [])), + "memory.feedback_count": len(context.get("feedback", [])), + }) + + # Log individual memory IDs + for mem_type, memories in context.items(): + for i, mem in enumerate(memories[:3]): # Top 3 per type + span.set_attribute(f"memory.{mem_type}_{i}", mem.get("id", "")) + + # Enhance prompt + with mlflow.start_span(name="prompt_enhancement") as span: + base_prompt = mlflow.genai.load_prompt(f"prompts:/catalog.schema.{task}@production") + enhanced = enhancer.enhance_prompt(base_prompt.template, task) + span.set_inputs({"base_prompt_length": len(base_prompt.template)}) + span.set_outputs({"enhanced_prompt_length": len(enhanced)}) + + # Generate + with mlflow.start_span(name="llm_generation") as span: + from databricks.sdk import WorkspaceClient + llm = WorkspaceClient().serving_endpoints.get_open_ai_client() + + response = llm.chat.completions.create( + model=model, + messages=[{"role": "user", "content": enhanced.format(input=input_text)}] + ) + + result = response.choices[0].message.content + span.set_inputs({"input_text": input_text[:200]}) + span.set_outputs({"output": result[:200]}) + + return result, context +``` + +### Trace Attributes for Memory + +Standard attributes to include in traces: + +```python +MEMORY_TRACE_ATTRIBUTES = { + # Counts + "memory.total_retrieved": "Total memories retrieved", + "memory.decisions_count": "Number of decisions used", + "memory.patterns_count": "Number of patterns used", + "memory.feedback_count": "Number of feedback items used", + + # Confidence + "memory.avg_confidence": "Average confidence of used memories", + "memory.min_confidence": "Minimum confidence threshold applied", + + # IDs (for lineage) + "memory.decision_ids": "Comma-separated decision IDs", + "memory.pattern_ids": "Comma-separated pattern IDs", + + # Task scope + "memory.task_scope": "Task scope filter applied", +} +``` + +## Experiment Tracking + +### Compare Memory Variants + +```python +def run_memory_experiment( + task: str, + test_inputs: list, + enhancer: "MemoryPromptEnhancer", + variants: list = ["no_memory", "decisions_only", "full_memory"] +): + """ + Run an experiment comparing different memory configurations. + """ + experiment_name = f"/memory_experiments/{task}" + mlflow.set_experiment(experiment_name) + + base_prompt = mlflow.genai.load_prompt(f"prompts:/catalog.schema.{task}@production") + + results = {} + + for variant in variants: + with mlflow.start_run(run_name=variant): + mlflow.log_param("variant", variant) + mlflow.log_param("task", task) + mlflow.log_param("test_inputs_count", len(test_inputs)) + + # Configure memory inclusion + if variant == "no_memory": + template = base_prompt.template + memory_types = [] + elif variant == "decisions_only": + template = enhancer.enhance_prompt( + base_prompt.template, task, include=["decisions"] + ) + memory_types = ["decisions"] + else: # full_memory + template = enhancer.enhance_prompt(base_prompt.template, task) + memory_types = ["decisions", "patterns", "feedback"] + + mlflow.log_param("memory_types", ",".join(memory_types)) + mlflow.log_text(template, "enhanced_prompt.txt") + + # Run inference on test inputs + outputs = [] + latencies = [] + + for input_text in test_inputs: + import time + start = time.time() + + result = generate(template, input_text) # Your generation function + + latencies.append(time.time() - start) + outputs.append(result) + + # Log metrics + mlflow.log_metric("avg_latency_ms", sum(latencies) / len(latencies) * 1000) + mlflow.log_metric("p95_latency_ms", sorted(latencies)[int(len(latencies) * 0.95)] * 1000) + + # Log outputs for evaluation + mlflow.log_dict( + {"inputs": test_inputs, "outputs": outputs}, + "inference_results.json" + ) + + results[variant] = outputs + + return results +``` + +### Evaluate with MLflow Scorers + +```python +from mlflow.genai.scorers import Correctness, Relevance + +def evaluate_memory_impact(experiment_name: str): + """Evaluate memory variants using MLflow scorers.""" + + # Load runs from experiment + runs = mlflow.search_runs(experiment_names=[experiment_name]) + + for _, run in runs.iterrows(): + run_id = run["run_id"] + variant = run["params.variant"] + + # Load inference results + artifact_path = mlflow.artifacts.download_artifacts( + run_id=run_id, artifact_path="inference_results.json" + ) + with open(artifact_path) as f: + results = json.load(f) + + # Build eval dataset + eval_data = [ + {"inputs": inp, "outputs": out} + for inp, out in zip(results["inputs"], results["outputs"]) + ] + + # Run evaluation + eval_results = mlflow.genai.evaluate( + data=eval_data, + model=None, # Outputs already generated + scorers=[ + Correctness(), + Relevance(), + ] + ) + + # Log eval metrics back to run + with mlflow.start_run(run_id=run_id): + for metric, value in eval_results.metrics.items(): + mlflow.log_metric(f"eval_{metric}", value) +``` + +## Prompt Evolution Tracking + +### View Prompt History with Memory + +```python +def get_prompt_evolution(prompt_name: str) -> list: + """ + Get the evolution of a prompt with memory context at each version. + """ + from pyspark.sql import SparkSession + spark = SparkSession.builder.getOrCreate() + + # Get all versions + versions = mlflow.genai.search_prompts(name=prompt_name) + + evolution = [] + for v in versions: + # Get memory links for this version + links = spark.sql(f""" + SELECT memory_type, memory_id, influence_score + FROM my_catalog.memory.prompt_memory_links + WHERE prompt_name = '{prompt_name}' AND prompt_version = {v.version} + """).collect() + + # Get the actual memories + memories = [] + for link in links: + mem = spark.sql(f""" + SELECT * FROM my_catalog.memory.{link.memory_type}s + WHERE id = '{link.memory_id}' + """).first() + if mem: + memories.append({ + "type": link.memory_type, + "content": mem.get("decision") or mem.get("pattern") or mem.get("correction"), + "influence": link.influence_score + }) + + evolution.append({ + "version": v.version, + "commit_message": v.commit_message, + "created_at": v.creation_time, + "memories_used": memories + }) + + return evolution +``` + +### Visualize in Dashboard + +```sql +-- Prompt versions over time with memory counts +SELECT + p.prompt_name, + p.prompt_version, + p.created_at, + COUNT(DISTINCT l.memory_id) AS memories_used, + SUM(CASE WHEN l.memory_type = 'decision' THEN 1 ELSE 0 END) AS decisions, + SUM(CASE WHEN l.memory_type = 'pattern' THEN 1 ELSE 0 END) AS patterns +FROM prompt_versions p +LEFT JOIN my_catalog.memory.prompt_memory_links l + ON p.prompt_name = l.prompt_name AND p.prompt_version = l.prompt_version +GROUP BY 1, 2, 3 +ORDER BY p.prompt_name, p.prompt_version; +``` diff --git a/databricks-skills/databricks-memory-prompts/SKILL.md b/databricks-skills/databricks-memory-prompts/SKILL.md new file mode 100644 index 00000000..e88d6661 --- /dev/null +++ b/databricks-skills/databricks-memory-prompts/SKILL.md @@ -0,0 +1,289 @@ +--- +name: databricks-memory-prompts +description: > + Persistent memory for AI applications using RAG + RLM (Recursive Language Modeling). + Store decisions, patterns, and feedback in Unity Catalog. Retrieve relevant context + before prompt construction. Learn from production feedback automatically. + + Based on the architecture from soul.py (github.com/menonpg/soul.py). +--- + +# Memory-Aware Prompts + +Extend your AI applications with persistent memory that learns from experience. + +> Contributed by [ThinkCreate.AI](https://thinkcreateai.com) +> Based on [soul.py](https://github.com/menonpg/soul.py) — RAG + RLM memory architecture + +--- + +## The Problem + +You deploy a PII redaction prompt. Users report bugs: +- "It missed phone extensions like 555-1234 x789" +- "It redacted 'Boston General Hospital' but that's not patient PII" +- "Family member names in clinical notes weren't caught" + +You fix the prompt. A month later, a colleague deploys a similar prompt — same bugs. Your learnings were in your head, not in the system. + +**This skill solves that problem.** + +--- + +## How It Works + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ THE MEMORY LOOP │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ 1. USER REPORTS ISSUE │ +│ "555-1234 x789 wasn't redacted" │ +│ │ │ +│ ▼ │ +│ 2. YOU LOG THE PATTERN │ +│ INSERT INTO patterns (pattern, evidence, task_scope) │ +│ VALUES ('Phone extensions need special handling', │ +│ 'Ticket #4521', 'pii_redaction') │ +│ │ │ +│ ▼ │ +│ 3. NEXT PROMPT BUILD │ +│ SELECT pattern FROM patterns WHERE task_scope = 'pii' │ +│ → Returns: "Phone extensions need special handling" │ +│ │ │ +│ ▼ │ +│ 4. ENHANCED PROMPT │ +│ Base prompt + "Watch for: Phone extensions..." │ +│ → LLM now handles the edge case │ +│ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Prerequisites + +1. **Unity Catalog** — You need a catalog where you can create tables +2. **CREATE TABLE permission** — In the schema where you'll store memory +3. **Serverless SQL Warehouse** or cluster with DBR 15.1+ — For `ai_query()` + +**Important:** You must manually create the memory tables (Step 1 below) before using this pattern. They are not auto-created. + +--- + +## Quick Start + +### Step 1: Create Memory Tables + +Run this SQL in a Databricks notebook. Replace `my_catalog` with your catalog name. + +```sql +-- Run in any Databricks notebook +CREATE SCHEMA IF NOT EXISTS my_catalog.memory; + +-- Patterns: Things you learned from production +CREATE TABLE IF NOT EXISTS my_catalog.memory.patterns ( + id STRING DEFAULT uuid(), + pattern TEXT NOT NULL, + evidence TEXT, + task_scope STRING, + frequency INT DEFAULT 1, + confidence DOUBLE DEFAULT 0.8, + created_at TIMESTAMP DEFAULT current_timestamp() +); + +-- Decisions: Explicit choices you made +CREATE TABLE IF NOT EXISTS my_catalog.memory.decisions ( + id STRING DEFAULT uuid(), + decision TEXT NOT NULL, + rationale TEXT, + task_scope STRING, + created_at TIMESTAMP DEFAULT current_timestamp() +); + +-- Feedback: Raw user corrections (input for pattern extraction) +CREATE TABLE IF NOT EXISTS my_catalog.memory.feedback ( + id STRING DEFAULT uuid(), + input_text TEXT, + output_text TEXT, + correction TEXT, + task_scope STRING, + created_at TIMESTAMP DEFAULT current_timestamp() +); +``` + +### Step 2: Log What You Learn + +When you discover something from production: + +```sql +-- A pattern you learned from a bug +INSERT INTO my_catalog.memory.patterns (pattern, evidence, task_scope) +VALUES ( + 'Phone numbers with extensions (x1234) require explicit handling', + 'Ticket #4521: User reported 555-1234 x789 was not redacted', + 'pii_redaction' +); + +-- A decision you made deliberately +INSERT INTO my_catalog.memory.decisions (decision, rationale, task_scope) +VALUES ( + 'Use typed tags [NAME], [SSN], [PHONE] instead of generic [REDACTED]', + 'Typed tags enable downstream compliance audits and selective reveal', + 'pii_redaction' +); +``` + +### Step 3: Build Enhanced Prompts + +```python +def build_prompt_with_memory(base_prompt: str, task_scope: str) -> str: + """Enhance a prompt with relevant patterns and decisions from memory.""" + + # Retrieve patterns + patterns_df = spark.sql(f""" + SELECT pattern FROM my_catalog.memory.patterns + WHERE task_scope = '{task_scope}' + ORDER BY confidence DESC, created_at DESC + LIMIT 5 + """) + patterns = [row.pattern for row in patterns_df.collect()] + + # Retrieve decisions + decisions_df = spark.sql(f""" + SELECT decision FROM my_catalog.memory.decisions + WHERE task_scope = '{task_scope}' + ORDER BY created_at DESC + LIMIT 3 + """) + decisions = [row.decision for row in decisions_df.collect()] + + # Build enhanced prompt + enhanced = base_prompt + + if decisions: + enhanced += "\n\n## Design Decisions (follow these)\n" + for d in decisions: + enhanced += f"- {d}\n" + + if patterns: + enhanced += "\n## Learned Patterns (watch for these)\n" + for p in patterns: + enhanced += f"- {p}\n" + + return enhanced +``` + +### Step 4: Use It + +```python +# Your base prompt +base_prompt = """You are a PII redaction system for healthcare data. +Replace personally identifiable information with typed tags like [NAME], [SSN], [PHONE].""" + +# Enhance with memory +enhanced_prompt = build_prompt_with_memory(base_prompt, task_scope="pii_redaction") + +# Call the LLM +result = spark.sql(f""" + SELECT ai_query( + 'databricks-meta-llama-3-3-70b-instruct', + concat('{enhanced_prompt}', ' + +Text to redact: +', clinical_notes) + ) AS redacted + FROM patient_records + LIMIT 10 +""") +``` + +**What the LLM actually sees:** + +``` +You are a PII redaction system for healthcare data. +Replace personally identifiable information with typed tags like [NAME], [SSN], [PHONE]. + +## Design Decisions (follow these) +- Use typed tags [NAME], [SSN], [PHONE] instead of generic [REDACTED] + +## Learned Patterns (watch for these) +- Phone numbers with extensions (x1234) require explicit handling + +Text to redact: +Patient John Smith called from 555-1234 x789 regarding his prescription... +``` + +--- + +## The RAG + RLM Architecture + +This skill uses the same memory architecture as [soul.py](https://github.com/menonpg/soul.py): + +**RAG (Retrieval-Augmented Generation):** +- Store memories in Unity Catalog tables +- Retrieve relevant context via SQL or Vector Search +- Inject into prompts before generation + +**RLM (Recursive Language Modeling):** +- Accumulate raw feedback over time +- Periodically distill feedback into high-confidence patterns +- Patterns with high frequency/confidence surface first + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ RAG + RLM MEMORY STACK │ +├─────────────────────────────────────────────────────────────────┤ +│ │ +│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ +│ │ Feedback │ → │ Patterns │ → │ Decisions │ │ +│ │ (raw) │ │ (distilled)│ │ (explicit) │ │ +│ └─────────────┘ └─────────────┘ └─────────────┘ │ +│ ↑ ↑ ↑ │ +│ User reports Auto-extracted You document │ +│ bugs/corrections from feedback deliberately │ +│ │ +│ Low confidence ←────────────────────────→ High confidence │ +│ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +--- + +## Going Further + +| Feature | Description | Reference | +|---------|-------------|-----------| +| **Full Schema** | Production DDL with indexes, retention, confidence decay | [1-memory-schema.md](1-memory-schema.md) | +| **Vector Search** | Semantic retrieval (find by meaning, not just task_scope) | [2-vector-search-setup.md](2-vector-search-setup.md) | +| **Learning Pipeline** | Auto-extract patterns from feedback using Lakeflow | [3-learning-pipeline.md](3-learning-pipeline.md) | +| **MLflow Integration** | Track which memories shaped each prompt version | [4-mlflow-integration.md](4-mlflow-integration.md) | + +--- + +## Common Issues + +| Problem | Solution | +|---------|----------| +| Prompt gets too long | Limit to 3-5 patterns. Use `LIMIT 5` and `ORDER BY confidence DESC` | +| Old patterns conflict with new | Add `superseded_by` column; filter out superseded patterns | +| Need semantic search | Set up Vector Search index (see reference file) | +| Want automatic pattern extraction | Deploy the learning pipeline (see reference file) | + +--- + +## Why This Matters + +> "MLflow tracks what you deployed. Memory tracks what you learned." + +Without memory: +- Session 1: You fix a bug +- Session 10: Someone else hits the same bug +- Session 50: The fix is folklore, not code + +With memory: +- Session 1: You fix a bug, log the pattern +- Session 10: Pattern surfaces automatically +- Session 50: The system has learned 50 sessions worth of patterns + +**The learning compounds.**