diff --git a/notebooks/fine_tuning/01_golf_forecasting.ipynb b/notebooks/fine_tuning/01_golf_forecasting.ipynb index 1016780..1040c33 100644 --- a/notebooks/fine_tuning/01_golf_forecasting.ipynb +++ b/notebooks/fine_tuning/01_golf_forecasting.ipynb @@ -1,491 +1,1140 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "header", - "metadata": {}, - "source": [ - "# Golf Forecasting\n", - "\n", - "Generate a forecasting dataset about professional golf (tournaments, majors, rankings) using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "aaddd17b", - "metadata": {}, - "outputs": [ + "cells": [ { - "data": { - "text/plain": [ - "True" + "cell_type": "markdown", + "id": "header", + "metadata": {}, + "source": [ + "# Golf Forecasting\n", + "\n", + "Generate a forecasting dataset about professional golf (tournaments, majors, rankings) using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%pip install lightningrod-ai python-dotenv pandas\n", - "\n", - "from IPython.display import clear_output\n", - "clear_output()\n", - "\n", - "from datetime import datetime\n", - "\n", - "import pandas as pd\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()" - ] - }, - { - "cell_type": "markdown", - "id": "part1", - "metadata": {}, - "source": [ - "## Set up the client\n", - "\n", - "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "sdk-setup", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import LightningRod\n", - "from lightningrod.utils import config\n", - "\n", - "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", - "lr = LightningRod(api_key=api_key)" - ] - }, - { - "cell_type": "markdown", - "id": "a04964c0", - "metadata": {}, - "source": [ - "## Build the pipeline\n", - "\n", - "Configure the pipeline with domain-specific instructions and examples for golf forecasting." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "config", - "metadata": {}, - "outputs": [], - "source": [ - "instructions = \"\"\"\n", - "Generate binary forecasting questions about professional golf across all major tours and events.\n", - "\n", - "Cover what golf fans bet on: tournament outcomes, cuts, matchups, majors, team events, season races, world rankings, and player milestones.\n", - "\n", - "Questions should be specific, verifiable, and span the full probability spectrum.\n", - "\"\"\"\n", - "\n", - "good_examples = [\n", - " \"Will Scottie Scheffler win the 2025 Masters?\",\n", - " \"Will the 2025 US Open winning score be under par?\",\n", - " \"Will Tiger Woods make the cut at the 2025 Masters?\",\n", - " \"Will Rory McIlroy finish top 5 at the 2025 US Open?\",\n", - " \"Will any LIV player win a major championship in 2025?\",\n", - " \"Will Europe win the 2025 Ryder Cup?\",\n", - " \"Will any player win 4+ PGA Tour events in 2025?\",\n", - " \"Will Scottie Scheffler remain world #1 through June 2025?\",\n", - " \"Will a first-time major winner emerge at the 2025 PGA Championship?\",\n", - " \"Will Nelly Korda win the 2025 US Women's Open?\",\n", - "]\n", - "\n", - "bad_examples = [\n", - " \"Will someone win the tournament? (obvious)\",\n", - " \"Will golf be exciting? (subjective)\",\n", - " \"Will there be birdies? (trivial)\",\n", - "]\n", - "\n", - "search_queries = [\n", - " \"PGA Tour\",\n", - " \"LIV Golf\",\n", - " \"LPGA\",\n", - " \"golf major championship\",\n", - " \"Ryder Cup Presidents Cup\",\n", - " \"golf world rankings\",\n", - " \"professional golf\",\n", - " \"women's golf\",\n", - " \"European Tour golf\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "pipeline", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import (\n", - " BinaryAnswerType,\n", - " NewsSeedGenerator,\n", - " ForwardLookingQuestionGenerator,\n", - " NewsContextGenerator,\n", - " WebSearchLabeler,\n", - " QuestionPipeline,\n", - ")\n", - "\n", - "answer_type = BinaryAnswerType()\n", - "\n", - "pipeline = QuestionPipeline(\n", - " seed_generator=NewsSeedGenerator(\n", - " start_date=datetime(2024, 6, 1),\n", - " end_date=datetime(2026, 1, 1),\n", - " interval_duration_days=14,\n", - " search_query=search_queries,\n", - " articles_per_search=10,\n", - " ),\n", - " question_generator=ForwardLookingQuestionGenerator(\n", - " instructions=instructions,\n", - " examples=good_examples,\n", - " bad_examples=bad_examples,\n", - " answer_type=answer_type,\n", - " questions_per_seed=5,\n", - " ),\n", - " context_generators=[\n", - " NewsContextGenerator(\n", - " articles_per_query=3,\n", - " num_search_queries=3,\n", - " num_articles=5,\n", - " )\n", - " ],\n", - " labeler=WebSearchLabeler(answer_type=answer_type),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "324a35cc", - "metadata": {}, - "source": [ - "## Run the pipeline\n", - "\n", - "This will collect news articles, generate questions, and find answers. Use `max_questions` to limit the run for testing." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7f62e9c4", - "metadata": {}, - "outputs": [ + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f25abaaeb92e42f1bca02f0ea69c7f15", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" + "cell_type": "code", + "execution_count": 1, + "id": "aaddd17b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%pip install lightningrod-ai python-dotenv pandas openai\n", + "\n", + "from IPython.display import clear_output\n", + "clear_output()\n", + "\n", + "from datetime import datetime\n", + "\n", + "import pandas as pd\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
+      "cell_type": "markdown",
+      "id": "part1",
+      "metadata": {},
+      "source": [
+        "## Set up the client\n",
+        "\n",
+        "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/?redirect=/api) to get your API key and **$50 of free credits**."
+      ]
     },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "100 samples (87.0% valid)\n"
-     ]
-    }
-   ],
-   "source": [
-    "dataset = lr.transforms.run(pipeline, max_questions=100, name=\"Golf forecasting\")\n",
-    "samples = dataset.download()\n",
-    "\n",
-    "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n",
-    "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "fdbi5exhd6c",
-   "metadata": {},
-   "source": [
-    "## Prepare the dataset\n",
-    "\n",
-    "Use SDK utils to filter valid samples, deduplicate, and split into train/test sets. We filter by `date_close <= today` to only include questions that have already resolved."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "upload",
-   "metadata": {},
-   "outputs": [
+      "cell_type": "code",
+      "execution_count": 2,
+      "id": "sdk-setup",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "from lightningrod import LightningRod\n",
+        "from lightningrod.utils import config\n",
+        "\n",
+        "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n",
+        "lr = LightningRod(api_key=api_key)"
+      ]
+    },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Train: 37 rows, 32.4% yes\n",
-      "Test: 18 rows, 44.4% yes\n"
-     ]
-    }
-   ],
-   "source": [
-    "from lightningrod import filter_and_split\n",
-    "\n",
-    "train_dataset, test_dataset = filter_and_split(\n",
-    "    dataset,\n",
-    "    test_size=0.2,\n",
-    "    split_strategy=\"temporal\",\n",
-    "    days_to_resolution_range=(1, None),  # at least 1 day to resolution\n",
-    ")\n",
-    "\n",
-    "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n",
-    "    data = ds.flattened()\n",
-    "    yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n",
-    "    print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")\n",
-    "    display(pd.DataFrame(data).head())"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "b2e9efba",
-   "metadata": {},
-   "source": [
-    "## Uploading the dataset to HuggingFace\n",
-    "\n",
-    "Once we have a training-ready dataset, we can push it to Hugging Face for sharing or downstream use."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 8,
-   "id": "ae7c826b",
-   "metadata": {},
-   "outputs": [
+      "cell_type": "markdown",
+      "id": "a04964c0",
+      "metadata": {},
+      "source": [
+        "## Build the pipeline\n",
+        "\n",
+        "Configure the pipeline with domain-specific instructions and examples for golf forecasting."
+      ]
+    },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "\n",
-      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
-      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
-      "Note: you may need to restart the kernel to use updated packages.\n",
-      "Train: 37 rows, Test: 18 rows\n",
-      "Columns: ['question_text', 'date_close', 'event_date', 'resolution_criteria', 'prediction_date', 'label', 'answer_type', 'label_confidence'] ...\n"
-     ]
+      "cell_type": "code",
+      "execution_count": 3,
+      "id": "config",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "instructions = \"\"\"\n",
+        "Generate binary forecasting questions about professional golf across all major tours and events.\n",
+        "\n",
+        "Cover what golf fans bet on: tournament outcomes, cuts, matchups, majors, team events, season races, world rankings, and player milestones.\n",
+        "\n",
+        "Questions should be specific, verifiable, and span the full probability spectrum.\n",
+        "\"\"\"\n",
+        "\n",
+        "good_examples = [\n",
+        "    \"Will Scottie Scheffler win the 2025 Masters?\",\n",
+        "    \"Will the 2025 US Open winning score be under par?\",\n",
+        "    \"Will Tiger Woods make the cut at the 2025 Masters?\",\n",
+        "    \"Will Rory McIlroy finish top 5 at the 2025 US Open?\",\n",
+        "    \"Will any LIV player win a major championship in 2025?\",\n",
+        "    \"Will Europe win the 2025 Ryder Cup?\",\n",
+        "    \"Will any player win 4+ PGA Tour events in 2025?\",\n",
+        "    \"Will Scottie Scheffler remain world #1 through June 2025?\",\n",
+        "    \"Will a first-time major winner emerge at the 2025 PGA Championship?\",\n",
+        "    \"Will Nelly Korda win the 2025 US Women's Open?\",\n",
+        "]\n",
+        "\n",
+        "bad_examples = [\n",
+        "    \"Will someone win the tournament? (obvious)\",\n",
+        "    \"Will golf be exciting? (subjective)\",\n",
+        "    \"Will there be birdies? (trivial)\",\n",
+        "]\n",
+        "\n",
+        "search_queries = [\n",
+        "    \"PGA Tour\",\n",
+        "    \"LIV Golf\",\n",
+        "    \"LPGA\",\n",
+        "    \"golf major championship\",\n",
+        "    \"Ryder Cup Presidents Cup\",\n",
+        "    \"golf world rankings\",\n",
+        "    \"professional golf\",\n",
+        "    \"women's golf\",\n",
+        "    \"European Tour golf\",\n",
+        "]"
+      ]
     },
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "453c59f9daf345828e087cc2a47af33f",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Uploading the dataset shards:   0%|          | 0/1 [00:00╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
+              "                                                                                                                 \n",
+              "  >> Pipeline Completed                                                                                          \n",
+              "                                                                                                                 \n",
+              "    Total cost: $47.67                                                                                           \n",
+              "                                                                                                                 \n",
+              "  ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓  \n",
+              " Step                Progress               In  Out  Rejected  Errors  Rejection Reasons   Duration \n",
+              "  ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩  \n",
+              " NewsSeedGenerator… Complete             │  20 │ 200 │        0     0-                  │      19s │  \n",
+              " ForwardLookingQue… Complete             │ 200 │ 965 │       33     0date_close not     │      15s │  \n",
+              "                    │                      │     │     │          │        │ after event_date   │          │  \n",
+              "                    │                      │     │     │          │        │ (33)               │          │  \n",
+              " WebSearchLabelerT… Complete             │ 965 │ 796 │      169     0Undetermined label │    1m 8s │  \n",
+              "                    │                      │     │     │          │        │ (164), Resolution  │          │  \n",
+              "                    │                      │     │     │          │        │ date is before     │          │  \n",
+              "                    │                      │     │     │          │        │ seed creation date │          │  \n",
+              "                    │                      │     │     │          │        │ (5)                │          │  \n",
+              " NewsContextGenera… Complete             │ 796 │ 795 │        1     0<failed_attempts>  │  11m 37s │  \n",
+              "                    │                      │     │     │          │        │                    │          │  \n",
+              "                    │                      │     │     │          │        │ <generation        │          │  \n",
+              "                    │                      │     │     │          │        │ number=\"1\">        │          │  \n",
+              "                    │                      │     │     │          │        │ <exception>        │          │  \n",
+              "                    │                      │     │     │          │        │     Request timed  │          │  \n",
+              "                    │                      │     │     │          │        │ out.               │          │  \n",
+              "                    │                      │     │     │          │        │ </exception>       │          │  \n",
+              "                    │                      │     │     │          │        │ <completion>       │          │  \n",
+              "                    │                      │     │     │          │        │     None           │          │  \n",
+              "                    │                      │     │     │          │        │ </completion>      │          │  \n",
+              "                    │                      │     │     │          │        │ </generation>      │          │  \n",
+              "                    │                      │     │     │          │        │                    │          │  \n",
+              "                    │                      │     │     │          │        │ <generation        │          │  \n",
+              "                    │                      │     │     │          │        │ number=\"2\">        │          │  \n",
+              "                    │                      │     │     │          │        │ <exception>        │          │  \n",
+              "                    │                      │     │     │          │        │     Connection     │          │  \n",
+              "                    │                      │     │     │          │        │ error.             │          │  \n",
+              "                    │                      │     │     │          │        │ </exception>       │          │  \n",
+              "                    │                      │     │     │          │        │ <completion>       │          │  \n",
+              "                    │                      │     │     │          │        │     None           │          │  \n",
+              "                    │                      │     │     │          │        │ </completion>      │          │  \n",
+              "                    │                      │     │     │          │        │ </generation>      │          │  \n",
+              "                    │                      │     │     │          │        │                    │          │  \n",
+              "                    │                      │     │     │          │        │ </failed_attempts> │          │  \n",
+              "                    │                      │     │     │          │        │                    │          │  \n",
+              "                    │                      │     │     │          │        │ <last_exception>   │          │  \n",
+              "                    │                      │     │     │          │        │     Connection     │          │  \n",
+              "                    │                      │     │     │          │        │ error.             │          │  \n",
+              "                    │                      │     │     │          │        │ </last_exception>  │          │  \n",
+              "                    │                      │     │     │          │        │ (1)                │          │  \n",
+              "  └────────────────────┴──────────────────────┴─────┴─────┴──────────┴────────┴────────────────────┴──────────┘  \n",
+              "                                                                                                                 \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "\n"
+            ],
+            "text/plain": [
+              "\u001b[92m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  \u001b[1;92m>> Pipeline Completed\u001b[0m                                                                                          \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m    \u001b[1mTotal cost:\u001b[0m \u001b[92m$47.67\u001b[0m                                                                                           \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  ┃\u001b[1;36m \u001b[0m\u001b[1;36mStep              \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mProgress            \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m In\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mOut\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejected\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mErrors\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejection Reasons \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mDuration\u001b[0m\u001b[1;36m \u001b[0m┃  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mNewsSeedGenerator…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │  20 │ 200 │ \u001b[2m       0\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2m-                 \u001b[0m │      19s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mForwardLookingQue…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │ 200 │ 965 │ \u001b[91m      33\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2mdate_close not    \u001b[0m │      15s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mafter event_date  \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m(33)              \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mWebSearchLabelerT…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │ 965 │ 796 │ \u001b[91m     169\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2mUndetermined label\u001b[0m │    1m 8s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m(164), Resolution \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mdate is before    \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mseed creation date\u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m(5)               \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mNewsContextGenera…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │ 796 │ 795 │ \u001b[91m       1\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2m \u001b[0m │  11m 37s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m                  \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m       \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m       \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m    Request timed \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mout.              \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m      \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m      \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m    None          \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m     \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m     \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m                  \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m       \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m       \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m    Connection    \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2merror.            \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m      \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m      \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m    None          \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m     \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m     \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m                  \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m\u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m                  \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m  \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m    Connection    \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2merror.            \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m(1)               \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  └────────────────────┴──────────────────────┴─────┴─────┴──────────┴────────┴────────────────────┴──────────┘  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "948 samples (78.6% valid)\n"
+          ]
+        }
+      ],
+      "source": [
+        "dataset = lr.transforms.run(pipeline, max_questions=1000, name=\"Golf forecasting\")\n",
+        "samples = dataset.download()\n",
+        "\n",
+        "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n",
+        "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")"
       ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
     },
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "586e0e7d869d495b9daacd16e3389bcf",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "New Data Upload: |          |  0.00B /  0.00B            "
+      "cell_type": "markdown",
+      "id": "fdbi5exhd6c",
+      "metadata": {},
+      "source": [
+        "## Prepare the dataset\n",
+        "\n",
+        "Use SDK utils to filter valid samples, deduplicate, and split into train/test sets. We filter by `date_close <= today` to only include questions that have already resolved."
       ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
     },
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "f7672f6275a44c16983dc4a5c2c1df27",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Uploading the dataset shards:   0%|          | 0/1 [00:00\n",
+              "\n",
+              "\n",
+              "  \n",
+              "    \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "    \n",
+              "  \n",
+              "  \n",
+              "    \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "    \n",
+              "    \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "    \n",
+              "    \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "    \n",
+              "    \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "    \n",
+              "    \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "      \n",
+              "    \n",
+              "  \n",
+              "
sample_idis_validquestion_textdate_closeevent_dateresolution_criteriaprediction_datelabelanswer_typelabel_confidence...reasoninganswer_sourcesseed_textseed_urlseed_creation_dateseed_search_querycontextmeta_sample_idmeta_parent_sample_idmeta_processing_time_ms
023a607a2-e9db-45a9-8e33-67cd16a32b56TrueWill the Eastern Michigan University women's g...2025-05-10T00:00:002024-07-15T00:00:00The question resolves to 'Yes' if Eastern Mich...2024-07-15T00:00:000binary1.00...The Eastern Michigan University (EMU) women's ...https://vertexaisearch.cloud.google.com/ground...Eastern Michigan Athletics\\nCaterina Don Named...https://emueagles.com/news/2024/7/9/womens-gol...2024-07-15T00:00:00women's golf[{'rendered_context': '', 'search_query': 'Eas...fa146afa-b53f-48c1-8d2d-a81ce2dec41b0107be94-88d9-4068-a355-ec38b8691376844641.292
12736b5ea-b6b0-4fde-a237-96f6a3d9ee86TrueWill an Arizona Wildcats player be named the B...2025-05-01T00:00:002024-07-15T00:00:00The question resolves to 'Yes' if the Big 12 C...2024-07-15T00:00:001binary1.00...The Arizona Wildcats officially joined the Big...https://vertexaisearch.cloud.google.com/ground...TUCSON, Ariz. – Arizona Women's Golf Head Coac...https://arizonawildcats.com/news/2024/7/15/bra...2024-07-15T00:00:00women's golf[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Arizon...ece95e4d-151b-4af8-936a-5c7e18276b97f7429e77-a524-46dd-ae47-821794e79938988488.435
22eae4276-f449-45e9-8973-e760b6d36d61TrueWill Caterina Don remain in her role as the As...2025-05-31T00:00:002024-07-15T00:00:00The question resolves to 'Yes' if Caterina Don...2024-07-15T00:00:001binary0.95...Caterina Don was hired as the first full-time ...https://vertexaisearch.cloud.google.com/ground...Eastern Michigan Athletics\\nCaterina Don Named...https://emueagles.com/news/2024/7/9/womens-gol...2024-07-15T00:00:00women's golf[{'rendered_context': '', 'search_query': 'Cat...4fbfad5d-b425-4cfe-b04d-1db1279c80a80107be94-88d9-4068-a355-ec38b8691376485377.689
335b2be70-0b7c-4b1b-b82f-2e3f0d3b61d8TrueWill the University of North Carolina women's ...2025-04-15T00:00:002024-07-15T00:00:00The question resolves to Yes if the UNC women'...2024-07-15T00:00:001binary1.00...The University of North Carolina women's golf ...https://vertexaisearch.cloud.google.com/ground...University of North Carolina Athletics\\nNeff's...https://goheels.com/news/2024/7/15/neffs-contr...2024-07-15T00:00:00women's golf[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] 2024-2...05a3aff2-dee2-4111-a0f4-c085ba138679c9839639-2fd2-4f0a-ad52-4566bbde89be1052828.897
4382c5c2e-db8f-4b8c-bbe8-6fc4e55edf53TrueWill the University of Arizona Women's Golf te...2025-04-15T00:00:002024-07-15T00:00:00The question resolves to 'Yes' if the Universi...2024-07-15T00:00:001binary1.00...The University of Arizona Women's Golf team wo...https://vertexaisearch.cloud.google.com/ground...TUCSON, Ariz. – Arizona Women's Golf Head Coac...https://arizonawildcats.com/news/2024/7/15/bra...2024-07-15T00:00:00women's golf[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Arizon...1fd71efb-9506-4c76-b1b6-d8eed1a30ac7f7429e77-a524-46dd-ae47-821794e799381089828.234
\n", + "

5 rows × 21 columns

\n", + "" + ], + "text/plain": [ + " sample_id is_valid \\\n", + "0 23a607a2-e9db-45a9-8e33-67cd16a32b56 True \n", + "1 2736b5ea-b6b0-4fde-a237-96f6a3d9ee86 True \n", + "2 2eae4276-f449-45e9-8973-e760b6d36d61 True \n", + "3 35b2be70-0b7c-4b1b-b82f-2e3f0d3b61d8 True \n", + "4 382c5c2e-db8f-4b8c-bbe8-6fc4e55edf53 True \n", + "\n", + " question_text date_close \\\n", + "0 Will the Eastern Michigan University women's g... 2025-05-10T00:00:00 \n", + "1 Will an Arizona Wildcats player be named the B... 2025-05-01T00:00:00 \n", + "2 Will Caterina Don remain in her role as the As... 2025-05-31T00:00:00 \n", + "3 Will the University of North Carolina women's ... 2025-04-15T00:00:00 \n", + "4 Will the University of Arizona Women's Golf te... 2025-04-15T00:00:00 \n", + "\n", + " event_date resolution_criteria \\\n", + "0 2024-07-15T00:00:00 The question resolves to 'Yes' if Eastern Mich... \n", + "1 2024-07-15T00:00:00 The question resolves to 'Yes' if the Big 12 C... \n", + "2 2024-07-15T00:00:00 The question resolves to 'Yes' if Caterina Don... \n", + "3 2024-07-15T00:00:00 The question resolves to Yes if the UNC women'... \n", + "4 2024-07-15T00:00:00 The question resolves to 'Yes' if the Universi... \n", + "\n", + " prediction_date label answer_type label_confidence ... \\\n", + "0 2024-07-15T00:00:00 0 binary 1.00 ... \n", + "1 2024-07-15T00:00:00 1 binary 1.00 ... \n", + "2 2024-07-15T00:00:00 1 binary 0.95 ... \n", + "3 2024-07-15T00:00:00 1 binary 1.00 ... \n", + "4 2024-07-15T00:00:00 1 binary 1.00 ... \n", + "\n", + " reasoning \\\n", + "0 The Eastern Michigan University (EMU) women's ... \n", + "1 The Arizona Wildcats officially joined the Big... \n", + "2 Caterina Don was hired as the first full-time ... \n", + "3 The University of North Carolina women's golf ... \n", + "4 The University of Arizona Women's Golf team wo... \n", + "\n", + " answer_sources \\\n", + "0 https://vertexaisearch.cloud.google.com/ground... \n", + "1 https://vertexaisearch.cloud.google.com/ground... \n", + "2 https://vertexaisearch.cloud.google.com/ground... \n", + "3 https://vertexaisearch.cloud.google.com/ground... \n", + "4 https://vertexaisearch.cloud.google.com/ground... \n", + "\n", + " seed_text \\\n", + "0 Eastern Michigan Athletics\\nCaterina Don Named... \n", + "1 TUCSON, Ariz. – Arizona Women's Golf Head Coac... \n", + "2 Eastern Michigan Athletics\\nCaterina Don Named... \n", + "3 University of North Carolina Athletics\\nNeff's... \n", + "4 TUCSON, Ariz. – Arizona Women's Golf Head Coac... \n", + "\n", + " seed_url seed_creation_date \\\n", + "0 https://emueagles.com/news/2024/7/9/womens-gol... 2024-07-15T00:00:00 \n", + "1 https://arizonawildcats.com/news/2024/7/15/bra... 2024-07-15T00:00:00 \n", + "2 https://emueagles.com/news/2024/7/9/womens-gol... 2024-07-15T00:00:00 \n", + "3 https://goheels.com/news/2024/7/15/neffs-contr... 2024-07-15T00:00:00 \n", + "4 https://arizonawildcats.com/news/2024/7/15/bra... 2024-07-15T00:00:00 \n", + "\n", + " seed_search_query context \\\n", + "0 women's golf [{'rendered_context': '', 'search_query': 'Eas... \n", + "1 women's golf [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Arizon... \n", + "2 women's golf [{'rendered_context': '', 'search_query': 'Cat... \n", + "3 women's golf [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] 2024-2... \n", + "4 women's golf [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Arizon... \n", + "\n", + " meta_sample_id meta_parent_sample_id \\\n", + "0 fa146afa-b53f-48c1-8d2d-a81ce2dec41b 0107be94-88d9-4068-a355-ec38b8691376 \n", + "1 ece95e4d-151b-4af8-936a-5c7e18276b97 f7429e77-a524-46dd-ae47-821794e79938 \n", + "2 4fbfad5d-b425-4cfe-b04d-1db1279c80a8 0107be94-88d9-4068-a355-ec38b8691376 \n", + "3 05a3aff2-dee2-4111-a0f4-c085ba138679 c9839639-2fd2-4f0a-ad52-4566bbde89be \n", + "4 1fd71efb-9506-4c76-b1b6-d8eed1a30ac7 f7429e77-a524-46dd-ae47-821794e79938 \n", + "\n", + " meta_processing_time_ms \n", + "0 844641.292 \n", + "1 988488.435 \n", + "2 485377.689 \n", + "3 1052828.897 \n", + "4 1089828.234 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test: 143 rows, 32.2% yes\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sample_idis_validquestion_textdate_closeevent_dateresolution_criteriaprediction_datelabelanswer_typelabel_confidence...reasoninganswer_sourcesseed_textseed_urlseed_creation_dateseed_search_querycontextmeta_sample_idmeta_parent_sample_idmeta_processing_time_ms
069d062b4-681c-431f-9e01-e13befba3ea0TrueWill Luke Clanton finish in the top 10 of the ...2025-07-07T00:00:002025-06-24T00:00:00The question resolves to 'Yes' if Luke Clanton...2025-06-24T00:00:000binary1.0...Luke Clanton participated in the 2025 John Dee...https://vertexaisearch.cloud.google.com/ground...Title: No. 15 Ben Griffin, rising star Luke Cl...https://www.wqad.com/article/sports/john-deere...2025-06-24T00:00:00golf world rankings[{'rendered_context': '', 'search_query': 'Luk...7cdbf935-029e-4ce5-bcb4-d6cbf016303ac96f580b-8dff-42a0-90dc-3e5c888680c6489996.714
1c3ef9b69-79e5-42c4-a26c-b0d02bf82abbTrueWill Ben Griffin be ranked in the top 10 of th...2025-07-07T00:00:002025-06-24T00:00:00The question resolves to 'Yes' if Ben Griffin'...2025-06-24T00:00:000binary1.0...Ben Griffin was ranked No. 17 in the Official ...https://vertexaisearch.cloud.google.com/ground...Title: No. 15 Ben Griffin, rising star Luke Cl...https://www.wqad.com/article/sports/john-deere...2025-06-24T00:00:00golf world rankings[{'rendered_context': '', 'search_query': 'Ben...c3a577ab-069d-4857-8c52-9fce6f44e876c96f580b-8dff-42a0-90dc-3e5c888680c6493772.150
2e1d9f94a-8fa0-4caf-a85b-2f2602e7e9aeTrueWill Luke Clanton outscore Ben Griffin in the ...2025-07-04T00:00:002025-06-24T00:00:00The question resolves to 'Yes' if Luke Clanton...2025-06-24T00:00:001binary1.0...The first round of the 2025 John Deere Classic...https://vertexaisearch.cloud.google.com/ground...Title: No. 15 Ben Griffin, rising star Luke Cl...https://www.wqad.com/article/sports/john-deere...2025-06-24T00:00:00golf world rankings[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Player...41c1a085-7b54-4a59-b7c8-145a2c9f225fc96f580b-8dff-42a0-90dc-3e5c888680c6693496.445
3093a04a1-03f4-429c-a8cd-c2f7c8ea098cTrueWill Jordan Smith win the 2025 Italian Open?2025-06-30T00:00:002025-06-25T00:00:00This question resolves to Yes if Jordan Smith ...2025-06-25T00:00:000binary1.0...The 2025 Italian Open (golf) took place from J...https://vertexaisearch.cloud.google.com/ground...Title: 2025 Italian Open betting tips: Our exp...https://www.todays-golfer.com/news-and-events/...2025-06-25T00:00:00European Tour golf[{'rendered_context': '', 'search_query': 'Jor...cebd9023-809d-441b-9ef5-251381880f5c20e82969-aad5-477a-8e2e-cd641f7d7eec493786.983
424bfbd2f-7a6c-4383-82aa-94d73f429685TrueWill Eddie Pepperell win at least one tourname...2025-11-30T00:00:002025-06-25T00:00:00The question resolves to 'Yes' if Eddie Pepper...2025-06-25T00:00:000binary0.9...The close date is 2025-11-30, and the question...https://vertexaisearch.cloud.google.com/ground...Title: Eddie Pepperell feeling refreshed after...https://www.europeantour.com/dpworld-tour/news...2025-06-25T00:00:00European Tour golf[{'rendered_context': '', 'search_query': 'Edd...a43d140a-a541-4537-8aab-6523ecbf79ce3fa59bd7-f4b8-4ea9-b5b2-d8cba7d5ce0d512717.986
\n", + "

5 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " sample_id is_valid \\\n", + "0 69d062b4-681c-431f-9e01-e13befba3ea0 True \n", + "1 c3ef9b69-79e5-42c4-a26c-b0d02bf82abb True \n", + "2 e1d9f94a-8fa0-4caf-a85b-2f2602e7e9ae True \n", + "3 093a04a1-03f4-429c-a8cd-c2f7c8ea098c True \n", + "4 24bfbd2f-7a6c-4383-82aa-94d73f429685 True \n", + "\n", + " question_text date_close \\\n", + "0 Will Luke Clanton finish in the top 10 of the ... 2025-07-07T00:00:00 \n", + "1 Will Ben Griffin be ranked in the top 10 of th... 2025-07-07T00:00:00 \n", + "2 Will Luke Clanton outscore Ben Griffin in the ... 2025-07-04T00:00:00 \n", + "3 Will Jordan Smith win the 2025 Italian Open? 2025-06-30T00:00:00 \n", + "4 Will Eddie Pepperell win at least one tourname... 2025-11-30T00:00:00 \n", + "\n", + " event_date resolution_criteria \\\n", + "0 2025-06-24T00:00:00 The question resolves to 'Yes' if Luke Clanton... \n", + "1 2025-06-24T00:00:00 The question resolves to 'Yes' if Ben Griffin'... \n", + "2 2025-06-24T00:00:00 The question resolves to 'Yes' if Luke Clanton... \n", + "3 2025-06-25T00:00:00 This question resolves to Yes if Jordan Smith ... \n", + "4 2025-06-25T00:00:00 The question resolves to 'Yes' if Eddie Pepper... \n", + "\n", + " prediction_date label answer_type label_confidence ... \\\n", + "0 2025-06-24T00:00:00 0 binary 1.0 ... \n", + "1 2025-06-24T00:00:00 0 binary 1.0 ... \n", + "2 2025-06-24T00:00:00 1 binary 1.0 ... \n", + "3 2025-06-25T00:00:00 0 binary 1.0 ... \n", + "4 2025-06-25T00:00:00 0 binary 0.9 ... \n", + "\n", + " reasoning \\\n", + "0 Luke Clanton participated in the 2025 John Dee... \n", + "1 Ben Griffin was ranked No. 17 in the Official ... \n", + "2 The first round of the 2025 John Deere Classic... \n", + "3 The 2025 Italian Open (golf) took place from J... \n", + "4 The close date is 2025-11-30, and the question... \n", + "\n", + " answer_sources \\\n", + "0 https://vertexaisearch.cloud.google.com/ground... \n", + "1 https://vertexaisearch.cloud.google.com/ground... \n", + "2 https://vertexaisearch.cloud.google.com/ground... \n", + "3 https://vertexaisearch.cloud.google.com/ground... \n", + "4 https://vertexaisearch.cloud.google.com/ground... \n", + "\n", + " seed_text \\\n", + "0 Title: No. 15 Ben Griffin, rising star Luke Cl... \n", + "1 Title: No. 15 Ben Griffin, rising star Luke Cl... \n", + "2 Title: No. 15 Ben Griffin, rising star Luke Cl... \n", + "3 Title: 2025 Italian Open betting tips: Our exp... \n", + "4 Title: Eddie Pepperell feeling refreshed after... \n", + "\n", + " seed_url seed_creation_date \\\n", + "0 https://www.wqad.com/article/sports/john-deere... 2025-06-24T00:00:00 \n", + "1 https://www.wqad.com/article/sports/john-deere... 2025-06-24T00:00:00 \n", + "2 https://www.wqad.com/article/sports/john-deere... 2025-06-24T00:00:00 \n", + "3 https://www.todays-golfer.com/news-and-events/... 2025-06-25T00:00:00 \n", + "4 https://www.europeantour.com/dpworld-tour/news... 2025-06-25T00:00:00 \n", + "\n", + " seed_search_query context \\\n", + "0 golf world rankings [{'rendered_context': '', 'search_query': 'Luk... \n", + "1 golf world rankings [{'rendered_context': '', 'search_query': 'Ben... \n", + "2 golf world rankings [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Player... \n", + "3 European Tour golf [{'rendered_context': '', 'search_query': 'Jor... \n", + "4 European Tour golf [{'rendered_context': '', 'search_query': 'Edd... \n", + "\n", + " meta_sample_id meta_parent_sample_id \\\n", + "0 7cdbf935-029e-4ce5-bcb4-d6cbf016303a c96f580b-8dff-42a0-90dc-3e5c888680c6 \n", + "1 c3a577ab-069d-4857-8c52-9fce6f44e876 c96f580b-8dff-42a0-90dc-3e5c888680c6 \n", + "2 41c1a085-7b54-4a59-b7c8-145a2c9f225f c96f580b-8dff-42a0-90dc-3e5c888680c6 \n", + "3 cebd9023-809d-441b-9ef5-251381880f5c 20e82969-aad5-477a-8e2e-cd641f7d7eec \n", + "4 a43d140a-a541-4537-8aab-6523ecbf79ce 3fa59bd7-f4b8-4ea9-b5b2-d8cba7d5ce0d \n", + "\n", + " meta_processing_time_ms \n", + "0 489996.714 \n", + "1 493772.150 \n", + "2 693496.445 \n", + "3 493786.983 \n", + "4 512717.986 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from lightningrod import prepare_for_training, FilterParams, SplitParams\n", + "\n", + "train_dataset, test_dataset = prepare_for_training(\n", + " dataset,\n", + " filter=FilterParams(days_to_resolution_range=(1, None)),\n", + " split=SplitParams(test_size=0.2),\n", + ")\n", + "\n", + "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n", + " data = ds.flattened()\n", + " yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n", + " print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")\n", + " display(pd.DataFrame(data).head())" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "41a845cd3c3a44f3bd221b60dc5316ac", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Creating parquet from Arrow format: 0%| | 0/1 [00:00╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + " \n", + " >> Training COMPLETED \n", + " \n", + " Job: Golf forecasting \n", + " \n", + " Reward: latest -0.9948 avg -0.8261 (11 steps) (higher is better) \n", + " \n", + " Cost: $0.19 \n", + " \n", + " \n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m Golf forecasting \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -0.9948 avg -0.8261 (11 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.19 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job 6c82c197-0627-4ee8-954c-d5ddb93e66f2 completed with status: COMPLETED\n", + "Trained model ID: checkpoint:6c82c197-0627-4ee8-954c-d5ddb93e66f2\n" + ] + } + ], + "source": [ + "job = lr.training.run(config, dataset=train_dataset, name=\"Golf forecasting\")\n", + "print(f\"Job {job.id} completed with status: {job.status}\")\n", + "print(f\"Trained model ID: {job.model_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4c74226d", + "metadata": {}, + "source": [ + "## Inference with your trained model\n", + "\n", + "Use `lr.predict()` to run inference with your trained model." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/datasets/bart/golf-forecasting-demo/commit/75aa82cb9a4bc3ff669cf29a4d2e0f260f6aac58', commit_message='Upload dataset', commit_description='', oid='75aa82cb9a4bc3ff669cf29a4d2e0f260f6aac58', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/bart/golf-forecasting-demo', endpoint='https://huggingface.co', repo_type='dataset', repo_id='bart/golf-forecasting-demo'), pr_revision=None, pr_num=None)" + "cell_type": "code", + "execution_count": null, + "id": "ba1dcfc5", + "metadata": {}, + "outputs": [], + "source": [ + "print(lr.predict(job.model_id, \"Will Scottie Scheffler win the 2026 Masters?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "76781ccb", + "metadata": {}, + "source": [ + "## Run evals on trained model\n", + "\n", + "Run test evals on your trained model against the test dataset. The eval job runs the model on the dataset and reports metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "0e81dc80", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
+              "                                                                                                                 \n",
+              "  >> Eval COMPLETED                                                                                              \n",
+              "                                                                                                                 \n",
+              "    ID: cd970c00-d5b9-4db3-ac1f-f4815960abb0                                                                     \n",
+              "    Model: checkpoint:6c82c197-0627-4ee8-954c-d5ddb93e66f2                                                       \n",
+              "    Dataset: 708f1623-6f06-4897-bb2e-dd58b7aebd45                                                                \n",
+              "                                                                                                                 \n",
+              "  ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓                                                                    \n",
+              " Metric                  base  trained \n",
+              "  ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩                                                                    \n",
+              " brier_score         │  0.2784 │  0.2377 │                                                                    \n",
+              " ece                 │  0.2207 │  0.1597 │                                                                    \n",
+              " mean_reward         │ -0.9210 │ -0.8026 │                                                                    \n",
+              " mean_valid_reward   │ -0.9210 │ -0.8026 │                                                                    \n",
+              " n_samples           │     143 │     143 │                                                                    \n",
+              " n_valid             │     143 │     143 │                                                                    \n",
+              " parse_rate          │  1.0000 │  1.0000 │                                                                    \n",
+              " total_cost          │  0.0084 │  0.0084 │                                                                    \n",
+              " total_input_tokens  │  115968 │  115968 │                                                                    \n",
+              " total_output_tokens │    1403 │    1416 │                                                                    \n",
+              "  └─────────────────────┴─────────┴─────────┘                                                                    \n",
+              "                                                                                                                 \n",
+              "    Cost:  $0.02                                                                                                 \n",
+              "                                                                                                                 \n",
+              "                                                                                                                 \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Eval COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mID:\u001b[0m cd970c00-d5b9-4db3-ac1f-f4815960abb0 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mModel:\u001b[0m checkpoint:6c82c197-0627-4ee8-954c-d5ddb93e66f2 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mDataset:\u001b[0m 708f1623-6f06-4897-bb2e-dd58b7aebd45 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mMetric \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m base\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mtrained\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mbrier_score \u001b[0m\u001b[2m \u001b[0m│ 0.2784 │ 0.2377 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mece \u001b[0m\u001b[2m \u001b[0m│ 0.2207 │ 0.1597 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_reward \u001b[0m\u001b[2m \u001b[0m│ -0.9210 │ -0.8026 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_valid_reward \u001b[0m\u001b[2m \u001b[0m│ -0.9210 │ -0.8026 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_samples \u001b[0m\u001b[2m \u001b[0m│ 143 │ 143 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_valid \u001b[0m\u001b[2m \u001b[0m│ 143 │ 143 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mparse_rate \u001b[0m\u001b[2m \u001b[0m│ 1.0000 │ 1.0000 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_cost \u001b[0m\u001b[2m \u001b[0m│ 0.0084 │ 0.0084 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_input_tokens \u001b[0m\u001b[2m \u001b[0m│ 115968 │ 115968 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_output_tokens\u001b[0m\u001b[2m \u001b[0m│ 1403 │ 1416 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m └─────────────────────┴─────────┴─────────┘ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.02 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "62718605", + "metadata": {}, + "source": [ + "> Note: the trained model checkpoint will only be available for 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" } - ], - "source": [ - "%pip install datasets -q\n", - "\n", - "from datasets import Dataset, DatasetDict\n", - "from lightningrod.utils import config\n", - "\n", - "dataset = DatasetDict({\n", - " \"train\": Dataset.from_list(train_dataset.flattened()),\n", - " \"test\": Dataset.from_list(test_dataset.flattened()),\n", - "})\n", - "print(f\"Train: {len(dataset['train'])} rows, Test: {len(dataset['test'])} rows\")\n", - "print(\"Columns:\", dataset[\"train\"].column_names[:8], \"...\")\n", - "\n", - "DATASET_PATH = f\"{config.get_config_value('HF_USERNAME')}/golf-forecasting-demo\"\n", - "dataset.push_to_hub(DATASET_PATH, token=config.get_config_value(\"HF_ACCESS_TOKEN\"))" - ] - }, - { - "cell_type": "markdown", - "id": "part2", - "metadata": {}, - "source": [ - "## Model Training\n", - "\n", - "We used the generated dataset above to fine-tune a forecasting model via RL on 3,178 forecasting questions, surpassing GPT-5 performance.\n", - "\n", - "**For more details on methods, results, and data:**\n", - "- **[Golf-Forecaster Model](https://huggingface.co/LightningRodLabs/Golf-Forecaster)**\n", - "- **[Golf-Forecaster Dataset](https://huggingface.co/datasets/LightningRodLabs/GolfForecasting)**\n", - "\n", - "![Brier Skill Score](https://huggingface.co/datasets/LightningRodLabs/GolfForecasting/resolve/main/brier_skill_score.png)\n", - "\n", - "**Coming Soon:** Seamlessly generate datasets, fine-tune, and evaluate your own forecasting models end-to-end on the Lightningrod platform.\n", - " \n", - "\ud83d\udc49 [Sign up to get early access and updates.](https://lightningrod.ai/)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python (lightningrod-sdk)", - "language": "python", - "name": "lightningrod-sdk" + ], + "metadata": { + "kernelspec": { + "display_name": "Python (lightningrod-sdk)", + "language": "python", + "name": "lightningrod-sdk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/fine_tuning/02_trump_forecasting.ipynb b/notebooks/fine_tuning/02_trump_forecasting.ipynb index 4cb0092..31e5712 100644 --- a/notebooks/fine_tuning/02_trump_forecasting.ipynb +++ b/notebooks/fine_tuning/02_trump_forecasting.ipynb @@ -1,479 +1,1168 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "7d4d26cb", - "metadata": {}, - "source": [ - "# WWTD-2025 (What Would Trump Do?)\n", - "\n", - "Generate a forecasting dataset about Trump's actions, decisions, and statements using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments\u2014including evaluation with and without context." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6f2c4443", - "metadata": {}, - "outputs": [ + "cells": [ { - "data": { - "text/plain": [ - "True" + "cell_type": "markdown", + "id": "7d4d26cb", + "metadata": {}, + "source": [ + "# WWTD-2025 (What Would Trump Do?)\n", + "\n", + "Generate a forecasting dataset about Trump's actions, decisions, and statements using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments—including evaluation with and without context." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%pip install lightningrod-ai python-dotenv pandas\n", - "\n", - "from IPython.display import clear_output\n", - "clear_output()\n", - "\n", - "from datetime import datetime\n", - "\n", - "import pandas as pd\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()" - ] - }, - { - "cell_type": "markdown", - "id": "f523d274", - "metadata": {}, - "source": [ - "## Set up the client\n", - "\n", - "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "7ca71c31", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import LightningRod\n", - "from lightningrod.utils import config\n", - "\n", - "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", - "lr = LightningRod(api_key=api_key)" - ] - }, - { - "cell_type": "markdown", - "id": "082d4f24", - "metadata": {}, - "source": [ - "## Build the pipeline\n", - "\n", - "Configure the pipeline with domain-specific instructions and examples for Trump-related forecasting." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "faccbe0a", - "metadata": {}, - "outputs": [], - "source": [ - "instructions = \"\"\"\n", - "Generate binary forecasting questions about Trump's actions, decisions, positions, and statements.\n", - "Questions should be diverse, related to the content, and should evenly cover the full range from very likely to very unlikely.\n", - "Horizon: outcomes should be known within 2 months of the question date, and may be known much sooner.\n", - "Criteria: binary outcome, exact dates, self-contained, verifiable via web search, newsworthy.\n", - "\"\"\"\n", - "\n", - "good_examples = [\n", - " \"Will Trump impose 25% tariffs on all goods from Canada by February 1, 2025?\",\n", - " \"Will Trump issue pardons to January 6 defendants within his first week in office?\",\n", - " \"Will Pete Hegseth be confirmed as Secretary of Defense by February 15, 2025?\",\n", - " \"Will Trump sign an executive order to keep TikTok operational in the US by January 31, 2025?\",\n", - " \"Will Kash Patel be confirmed as FBI Director by March 1, 2025?\",\n", - "]\n", - "\n", - "bad_examples = [\n", - " \"Will Trump do something controversial? (too vague)\",\n", - " \"Will Trump be in the news? (obvious)\",\n", - " \"Will tariffs be imposed? (needs specifics)\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "4ce2d710", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import (\n", - " BinaryAnswerType,\n", - " NewsSeedGenerator,\n", - " ForwardLookingQuestionGenerator,\n", - " NewsContextGenerator,\n", - " WebSearchLabeler,\n", - " QuestionPipeline,\n", - ")\n", - "\n", - "answer_type = BinaryAnswerType()\n", - "\n", - "pipeline = QuestionPipeline(\n", - " seed_generator=NewsSeedGenerator(\n", - " start_date=datetime(2025, 1, 1),\n", - " end_date=datetime(2026, 1, 1),\n", - " interval_duration_days=7,\n", - " search_query=[\n", - " \"Donald Trump domestic policy agenda\",\n", - " \"Donald Trump trade and tariff actions\",\n", - " \"Donald Trump foreign policy decisions\",\n", - " \"Donald Trump interviews and press appearances\",\n", - " \"Donald Trump lawsuits and court rulings\",\n", - " ],\n", - " articles_per_search=10,\n", - " ),\n", - " question_generator=ForwardLookingQuestionGenerator(\n", - " instructions=instructions,\n", - " examples=good_examples,\n", - " bad_examples=bad_examples,\n", - " answer_type=answer_type,\n", - " questions_per_seed=20,\n", - " ),\n", - " context_generators=[\n", - " NewsContextGenerator(\n", - " articles_per_query=3,\n", - " num_search_queries=1,\n", - " num_articles=5,\n", - " )\n", - " ],\n", - " labeler=WebSearchLabeler(answer_type=answer_type),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "1603b3de", - "metadata": {}, - "source": [ - "## Run the pipeline\n", - "\n", - "This will collect news articles, generate questions, and find answers. Use `max_questions` to limit the run for testing." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4de4b87c", - "metadata": {}, - "outputs": [ + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "80f2463c20044d2c9ab2ace831e2adff", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" + "cell_type": "code", + "execution_count": 1, + "id": "6f2c4443", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%pip install lightningrod-ai python-dotenv pandas openai\n", + "\n", + "from IPython.display import clear_output\n", + "clear_output()\n", + "\n", + "from datetime import datetime\n", + "\n", + "import pandas as pd\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "
\n"
-      ],
-      "text/plain": []
-     },
-     "metadata": {},
-     "output_type": "display_data"
+      "cell_type": "markdown",
+      "id": "f523d274",
+      "metadata": {},
+      "source": [
+        "## Set up the client\n",
+        "\n",
+        "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/?redirect=/api) to get your API key and **$50 of free credits**."
+      ]
     },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "178 samples (46.1% valid)\n"
-     ]
-    }
-   ],
-   "source": [
-    "dataset = lr.transforms.run(pipeline, max_questions=500, name=\"WWTD-2025\")\n",
-    "\n",
-    "samples = dataset.download()\n",
-    "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n",
-    "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "91866bf7",
-   "metadata": {},
-   "source": [
-    "## Prepare the dataset\n",
-    "\n",
-    "Use SDK utils to filter valid samples, deduplicate, and split into train/test sets. We filter by `date_close <= today` to only include questions that have already resolved."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "id": "5e4d3f2a",
-   "metadata": {},
-   "outputs": [
+      "cell_type": "code",
+      "execution_count": 2,
+      "id": "7ca71c31",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "from lightningrod import LightningRod\n",
+        "from lightningrod.utils import config\n",
+        "\n",
+        "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n",
+        "lr = LightningRod(api_key=api_key)"
+      ]
+    },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Train: 39 rows, 12.8% yes\n",
-      "Test: 13 rows, 38.5% yes\n"
-     ]
-    }
-   ],
-   "source": [
-    "from lightningrod import filter_and_split\n",
-    "\n",
-    "train_dataset, test_dataset = filter_and_split(\n",
-    "    dataset,\n",
-    "    test_size=0.2,\n",
-    "    split_strategy=\"temporal\",\n",
-    "    days_to_resolution_range=(1, 60),  # horizon within 2 months\n",
-    ")\n",
-    "\n",
-    "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n",
-    "    data = ds.flattened()\n",
-    "    yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n",
-    "    print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")\n",
-    "    display(pd.DataFrame(data).head())"
-   ]
-  },
-  {
-   "cell_type": "markdown",
-   "id": "0e799cfe",
-   "metadata": {},
-   "source": [
-    "## Uploading the dataset to HuggingFace\n",
-    "\n",
-    "Once we have a training-ready dataset, we can push it to Hugging Face for sharing or downstream use."
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 13,
-   "id": "91093200",
-   "metadata": {},
-   "outputs": [
+      "cell_type": "markdown",
+      "id": "082d4f24",
+      "metadata": {},
+      "source": [
+        "## Build the pipeline\n",
+        "\n",
+        "Configure the pipeline with domain-specific instructions and examples for Trump-related forecasting."
+      ]
+    },
     {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "\n",
-      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n",
-      "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n",
-      "Note: you may need to restart the kernel to use updated packages.\n",
-      "Train: 58 rows, Test: 17 rows\n",
-      "Columns: ['question_text', 'date_close', 'event_date', 'resolution_criteria', 'prediction_date', 'label', 'answer_type', 'label_confidence'] ...\n"
-     ]
+      "cell_type": "code",
+      "execution_count": 3,
+      "id": "faccbe0a",
+      "metadata": {},
+      "outputs": [],
+      "source": [
+        "instructions = \"\"\"\n",
+        "Generate binary forecasting questions about Trump's actions, decisions, positions, and statements.\n",
+        "Questions should be diverse, related to the content, and should evenly cover the full range from very likely to very unlikely.\n",
+        "Horizon: outcomes should be known within 2 months of the question date, and may be known much sooner.\n",
+        "Criteria: binary outcome, exact dates, self-contained, verifiable via web search, newsworthy.\n",
+        "\"\"\"\n",
+        "\n",
+        "good_examples = [\n",
+        "    \"Will Trump impose 25% tariffs on all goods from Canada by February 1, 2025?\",\n",
+        "    \"Will Trump issue pardons to January 6 defendants within his first week in office?\",\n",
+        "    \"Will Pete Hegseth be confirmed as Secretary of Defense by February 15, 2025?\",\n",
+        "    \"Will Trump sign an executive order to keep TikTok operational in the US by January 31, 2025?\",\n",
+        "    \"Will Kash Patel be confirmed as FBI Director by March 1, 2025?\",\n",
+        "]\n",
+        "\n",
+        "bad_examples = [\n",
+        "    \"Will Trump do something controversial? (too vague)\",\n",
+        "    \"Will Trump be in the news? (obvious)\",\n",
+        "    \"Will tariffs be imposed? (needs specifics)\",\n",
+        "]"
+      ]
     },
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "0b949581ce354458a957d2f6047eb184",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Uploading the dataset shards:   0%|          | 0/1 [00:00╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
+              "                                                                                                                 \n",
+              "  >> Pipeline Completed                                                                                          \n",
+              "                                                                                                                 \n",
+              "    Total cost: $1.90                                                                                            \n",
+              "                                                                                                                 \n",
+              "  ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓  \n",
+              " Step                Progress               In  Out  Rejected  Errors  Rejection Reasons   Duration \n",
+              "  ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩  \n",
+              " NewsSeedGenerator… Complete             │  10 │  95 │        0     0-                  │       1s │  \n",
+              " ForwardLookingQue… Complete             │  95 │ 442 │       31     0date_close not     │       1s │  \n",
+              "                    │                      │     │     │          │        │ after event_date   │          │  \n",
+              "                    │                      │     │     │          │        │ (31)               │          │  \n",
+              " WebSearchLabelerT… Complete             │ 442 │ 380 │       62     0Resolution date is │       4s │  \n",
+              "                    │                      │     │     │          │        │ before seed        │          │  \n",
+              "                    │                      │     │     │          │        │ creation date      │          │  \n",
+              "                    │                      │     │     │          │        │ (36), Undetermined │          │  \n",
+              "                    │                      │     │     │          │        │ label (25), Low    │          │  \n",
+              "                    │                      │     │     │          │        │ confidence: 0.80 < │          │  \n",
+              "                    │                      │     │     │          │        │ 0.9 (1)            │          │  \n",
+              " NewsContextGenera… Complete             │ 380 │ 380 │        0     0-                  │      57s │  \n",
+              "  └────────────────────┴──────────────────────┴─────┴─────┴──────────┴────────┴────────────────────┴──────────┘  \n",
+              "                                                                                                                 \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "\n"
+            ],
+            "text/plain": [
+              "\u001b[92m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  \u001b[1;92m>> Pipeline Completed\u001b[0m                                                                                          \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m    \u001b[1mTotal cost:\u001b[0m \u001b[92m$1.90\u001b[0m                                                                                            \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  ┃\u001b[1;36m \u001b[0m\u001b[1;36mStep              \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mProgress            \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m In\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mOut\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejected\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mErrors\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejection Reasons \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mDuration\u001b[0m\u001b[1;36m \u001b[0m┃  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mNewsSeedGenerator…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │  10 │  95 │ \u001b[2m       0\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2m-                 \u001b[0m │       1s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mForwardLookingQue…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │  95 │ 442 │ \u001b[91m      31\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2mdate_close not    \u001b[0m │       1s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mafter event_date  \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m(31)              \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mWebSearchLabelerT…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │ 442 │ 380 │ \u001b[91m      62\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2mResolution date is\u001b[0m │       4s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mbefore seed       \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mcreation date     \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m(36), Undetermined\u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mlabel (25), Low   \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2mconfidence: 0.80 <\u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m                    \u001b[0m│                      │     │     │          │        │ \u001b[2m0.9 (1)           \u001b[0m │          │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  │\u001b[1m \u001b[0m\u001b[1mNewsContextGenera…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete            \u001b[0m │ 380 │ 380 │ \u001b[2m       0\u001b[0m │ \u001b[2m     0\u001b[0m │ \u001b[2m-                 \u001b[0m │      57s │  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m  └────────────────────┴──────────────────────┴─────┴─────┴──────────┴────────┴────────────────────┴──────────┘  \u001b[92m│\u001b[0m\n",
+              "\u001b[92m│\u001b[0m                                                                                                                 \u001b[92m│\u001b[0m\n",
+              "\u001b[92m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "473 samples (80.3% valid)\n"
+          ]
+        }
+      ],
+      "source": [
+        "dataset = lr.transforms.run(pipeline, max_questions=500, name=\"WWTD-2025\")\n",
+        "\n",
+        "samples = dataset.download()\n",
+        "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n",
+        "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")"
       ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
     },
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "54bf4eaee9e34b4db6e2b55054f4fc8f",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "New Data Upload: |          |  0.00B /  0.00B            "
+      "cell_type": "markdown",
+      "id": "91866bf7",
+      "metadata": {},
+      "source": [
+        "## Prepare the dataset\n",
+        "\n",
+        "Use SDK utils to filter valid samples, deduplicate, and split into train/test sets. We filter by `date_close <= today` to only include questions that have already resolved."
       ]
-     },
-     "metadata": {},
-     "output_type": "display_data"
     },
     {
-     "data": {
-      "application/vnd.jupyter.widget-view+json": {
-       "model_id": "6e52a8a016c54295a661a53facb6eb69",
-       "version_major": 2,
-       "version_minor": 0
-      },
-      "text/plain": [
-       "Uploading the dataset shards:   0%|          | 0/1 [00:00╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
+              "                                                                                                                 \n",
+              "  >> prepare_for_training                                                                                        \n",
+              "                                                                                                                 \n",
+              "    Starting with 473 samples                                                                                    \n",
+              "                                                                                                                 \n",
+              "    Filter:  Dropped 93 invalid, 103 horizon → 277 remain                                                        \n",
+              "    Dedup:   277 remain (0 duplicates)                                                                           \n",
+              "    Split:   Splits: 167 train | 56 test (0 dropped, no prediction_date)                                         \n",
+              "             54 train samples removed for leakage                                                                \n",
+              "                                                                                                                 \n",
+              "  ⚠ Unhealthy dataset                                                                                            \n",
+              "                                                                                                                 \n",
+              "  Only 167 train samples remain after preparation. This is below the recommended minimum of 200 for effective    \n",
+              "  training.                                                                                                      \n",
+              "                                                                                                                 \n",
+              "    Tips:                                                                                                        \n",
+              "      • Increase max_questions in lr.transforms.run() to generate more samples.                                  \n",
+              "      • Increase questions_per_seed in your question generator (ForwardLookingQuestionGenerator or               \n",
+              "  QuestionGenerator) to produce more questions from each seed article.Add more search queries to your seed       \n",
+              "  generator to diversify seed sources.                                                                           \n",
+              "      • Widen the seed generator date range (start_date to end_date) to capture more events.                     \n",
+              "                                                                                                                 \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "\n"
+            ],
+            "text/plain": [
+              "\u001b[33m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n",
+              "\u001b[33m│\u001b[0m                                                                                                                 \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m  \u001b[1;33m>> prepare_for_training\u001b[0m                                                                                        \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m                                                                                                                 \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m    \u001b[2mStarting with 473 samples\u001b[0m                                                                                    \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m                                                                                                                 \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m    \u001b[1mFilter:\u001b[0m  Dropped 93 invalid, 103 horizon → 277 remain                                                        \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m    \u001b[1mDedup:\u001b[0m   277 remain (0 duplicates)                                                                           \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m    \u001b[1mSplit:\u001b[0m   Splits: 167 train | 56 test (0 dropped, no prediction_date)                                         \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m             \u001b[33m54 train samples removed for leakage\u001b[0m                                                                \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m                                                                                                                 \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m  \u001b[1;33m⚠ Unhealthy dataset\u001b[0m                                                                                            \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m                                                                                                                 \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m  \u001b[1mOnly 167 train samples remain after preparation. This is below the recommended minimum of 200 for effective \u001b[0m   \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m  \u001b[1mtraining.\u001b[0m                                                                                                      \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m                                                                                                                 \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m    \u001b[2mTips:\u001b[0m                                                                                                        \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m      • Increase max_questions in lr.transforms.run() to generate more samples.                                  \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m      • Increase questions_per_seed in your question generator (ForwardLookingQuestionGenerator or               \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m  QuestionGenerator) to produce more questions from each seed article.Add more search queries to your seed       \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m  generator to diversify seed sources.                                                                           \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m      • Widen the seed generator date range (start_date to end_date) to capture more events.                     \u001b[33m│\u001b[0m\n",
+              "\u001b[33m│\u001b[0m                                                                                                                 \u001b[33m│\u001b[0m\n",
+              "\u001b[33m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n"
+            ]
+          },
+          "metadata": {},
+          "output_type": "display_data"
+        },
+        {
+          "name": "stdout",
+          "output_type": "stream",
+          "text": [
+            "167\n",
+            "Train: 167 rows, 21.0% yes\n"
+          ]
+        },
+        {
+          "data": {
+            "text/html": [
+              "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sample_idis_validquestion_textdate_closeevent_dateresolution_criteriaprediction_datelabelanswer_typelabel_confidence...reasoninganswer_sourcesseed_textseed_urlseed_creation_dateseed_search_querycontextmeta_sample_idmeta_parent_sample_idmeta_processing_time_ms
007511299-78d6-4020-8efe-d7b5b865f826TrueWill the 11th Circuit Court of Appeals issue a...2025-02-15T00:00:002025-01-08T00:00:00This question resolves to 'Yes' if the U.S. Co...2025-01-08T00:00:001binary1.00...On January 9, 2025, the U.S. Court of Appeals ...https://vertexaisearch.cloud.google.com/ground...Title: The Situation: Ending the Trump Cases t...https://www.lawfaremedia.org/article/the-situa...2025-01-08T00:00:00Donald Trump lawsuits and court rulings[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Judge ...59824a03-d075-4de4-bd89-17ec5f0651f1da9d2197-889d-4fe4-ae4e-960ab3f9726f16820.019
13618250c-9aa7-4e3e-b1f4-d59351e25415TrueWill the criminal charges against Carlos De Ol...2025-03-05T00:00:002025-01-08T00:00:00This question resolves to 'Yes' if a federal c...2025-01-08T00:00:001binary1.00...The criminal charges against Carlos De Oliveir...https://vertexaisearch.cloud.google.com/ground...Title: The Situation: Ending the Trump Cases t...https://www.lawfaremedia.org/article/the-situa...2025-01-08T00:00:00Donald Trump lawsuits and court rulings[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Trump ...89a53581-5832-4797-9a76-4a85aefb8993da9d2197-889d-4fe4-ae4e-960ab3f9726f170592.880
26f5808f8-5fbe-40e3-a902-2959e0159960TrueWill Justice Juan Merchan sentence Donald Trum...2025-03-01T00:00:002025-01-08T00:00:00This question resolves to 'Yes' if Justice Jua...2025-01-08T00:00:000binary1.00...The close date for this question is 2025-03-01...https://vertexaisearch.cloud.google.com/ground...Title: The Situation: Ending the Trump Cases t...https://www.lawfaremedia.org/article/the-situa...2025-01-08T00:00:00Donald Trump lawsuits and court rulings[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Judge ...da9d2197-889d-4fe4-ae4e-960ab3f9726fb988692d-28ac-4e59-932f-089b30c1fdff19099.822
3811942a3-0ce0-4a23-8ccf-cb9b6b038a78TrueWill the full, unredacted Special Counsel repo...2025-03-01T00:00:002025-01-08T00:00:00This question resolves to 'Yes' if the Departm...2025-01-08T00:00:000binary1.00...Special Counsel Jack Smith submitted a two-vol...https://vertexaisearch.cloud.google.com/ground...Title: The Situation: Ending the Trump Cases t...https://www.lawfaremedia.org/article/the-situa...2025-01-08T00:00:00Donald Trump lawsuits and court rulings[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Trump ...d53c18aa-e58f-47cc-ad8d-07284f3837b5da9d2197-889d-4fe4-ae4e-960ab3f9726f210318.406
4abbac7a6-09fd-4d34-b588-0a3df04f2f37TrueWill Donald Trump grant a formal presidential ...2025-02-28T00:00:002025-01-08T00:00:00This question resolves to 'Yes' if the White H...2025-01-08T00:00:000binary0.95...The close date for this question is 2025-02-28...https://vertexaisearch.cloud.google.com/ground...Title: The Situation: Ending the Trump Cases t...https://www.lawfaremedia.org/article/the-situa...2025-01-08T00:00:00Donald Trump lawsuits and court rulings[{'rendered_context': '', 'search_query': 'Tru...c1f9b0fa-f364-4cda-a9ca-a41f1dcdcf3cda9d2197-889d-4fe4-ae4e-960ab3f9726f16623.478
\n", + "

5 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " sample_id is_valid \\\n", + "0 07511299-78d6-4020-8efe-d7b5b865f826 True \n", + "1 3618250c-9aa7-4e3e-b1f4-d59351e25415 True \n", + "2 6f5808f8-5fbe-40e3-a902-2959e0159960 True \n", + "3 811942a3-0ce0-4a23-8ccf-cb9b6b038a78 True \n", + "4 abbac7a6-09fd-4d34-b588-0a3df04f2f37 True \n", + "\n", + " question_text date_close \\\n", + "0 Will the 11th Circuit Court of Appeals issue a... 2025-02-15T00:00:00 \n", + "1 Will the criminal charges against Carlos De Ol... 2025-03-05T00:00:00 \n", + "2 Will Justice Juan Merchan sentence Donald Trum... 2025-03-01T00:00:00 \n", + "3 Will the full, unredacted Special Counsel repo... 2025-03-01T00:00:00 \n", + "4 Will Donald Trump grant a formal presidential ... 2025-02-28T00:00:00 \n", + "\n", + " event_date resolution_criteria \\\n", + "0 2025-01-08T00:00:00 This question resolves to 'Yes' if the U.S. Co... \n", + "1 2025-01-08T00:00:00 This question resolves to 'Yes' if a federal c... \n", + "2 2025-01-08T00:00:00 This question resolves to 'Yes' if Justice Jua... \n", + "3 2025-01-08T00:00:00 This question resolves to 'Yes' if the Departm... \n", + "4 2025-01-08T00:00:00 This question resolves to 'Yes' if the White H... \n", + "\n", + " prediction_date label answer_type label_confidence ... \\\n", + "0 2025-01-08T00:00:00 1 binary 1.00 ... \n", + "1 2025-01-08T00:00:00 1 binary 1.00 ... \n", + "2 2025-01-08T00:00:00 0 binary 1.00 ... \n", + "3 2025-01-08T00:00:00 0 binary 1.00 ... \n", + "4 2025-01-08T00:00:00 0 binary 0.95 ... \n", + "\n", + " reasoning \\\n", + "0 On January 9, 2025, the U.S. Court of Appeals ... \n", + "1 The criminal charges against Carlos De Oliveir... \n", + "2 The close date for this question is 2025-03-01... \n", + "3 Special Counsel Jack Smith submitted a two-vol... \n", + "4 The close date for this question is 2025-02-28... \n", + "\n", + " answer_sources \\\n", + "0 https://vertexaisearch.cloud.google.com/ground... \n", + "1 https://vertexaisearch.cloud.google.com/ground... \n", + "2 https://vertexaisearch.cloud.google.com/ground... \n", + "3 https://vertexaisearch.cloud.google.com/ground... \n", + "4 https://vertexaisearch.cloud.google.com/ground... \n", + "\n", + " seed_text \\\n", + "0 Title: The Situation: Ending the Trump Cases t... \n", + "1 Title: The Situation: Ending the Trump Cases t... \n", + "2 Title: The Situation: Ending the Trump Cases t... \n", + "3 Title: The Situation: Ending the Trump Cases t... \n", + "4 Title: The Situation: Ending the Trump Cases t... \n", + "\n", + " seed_url seed_creation_date \\\n", + "0 https://www.lawfaremedia.org/article/the-situa... 2025-01-08T00:00:00 \n", + "1 https://www.lawfaremedia.org/article/the-situa... 2025-01-08T00:00:00 \n", + "2 https://www.lawfaremedia.org/article/the-situa... 2025-01-08T00:00:00 \n", + "3 https://www.lawfaremedia.org/article/the-situa... 2025-01-08T00:00:00 \n", + "4 https://www.lawfaremedia.org/article/the-situa... 2025-01-08T00:00:00 \n", + "\n", + " seed_search_query \\\n", + "0 Donald Trump lawsuits and court rulings \n", + "1 Donald Trump lawsuits and court rulings \n", + "2 Donald Trump lawsuits and court rulings \n", + "3 Donald Trump lawsuits and court rulings \n", + "4 Donald Trump lawsuits and court rulings \n", + "\n", + " context \\\n", + "0 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Judge ... \n", + "1 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Trump ... \n", + "2 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Judge ... \n", + "3 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Trump ... \n", + "4 [{'rendered_context': '', 'search_query': 'Tru... \n", + "\n", + " meta_sample_id meta_parent_sample_id \\\n", + "0 59824a03-d075-4de4-bd89-17ec5f0651f1 da9d2197-889d-4fe4-ae4e-960ab3f9726f \n", + "1 89a53581-5832-4797-9a76-4a85aefb8993 da9d2197-889d-4fe4-ae4e-960ab3f9726f \n", + "2 da9d2197-889d-4fe4-ae4e-960ab3f9726f b988692d-28ac-4e59-932f-089b30c1fdff \n", + "3 d53c18aa-e58f-47cc-ad8d-07284f3837b5 da9d2197-889d-4fe4-ae4e-960ab3f9726f \n", + "4 c1f9b0fa-f364-4cda-a9ca-a41f1dcdcf3c da9d2197-889d-4fe4-ae4e-960ab3f9726f \n", + "\n", + " meta_processing_time_ms \n", + "0 16820.019 \n", + "1 170592.880 \n", + "2 19099.822 \n", + "3 210318.406 \n", + "4 16623.478 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "56\n", + "Test: 56 rows, 37.5% yes\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sample_idis_validquestion_textdate_closeevent_dateresolution_criteriaprediction_datelabelanswer_typelabel_confidence...reasoninganswer_sourcesseed_textseed_urlseed_creation_dateseed_search_querycontextmeta_sample_idmeta_parent_sample_idmeta_processing_time_ms
07960bd90-44c8-4cd0-bac0-3a17265101e5TrueWill Donald Trump announce a complete exemptio...2025-12-01T00:00:002025-10-08T00:00:00The question is answered 'Yes' if the US Presi...2025-10-08T00:00:000binary1.00...The close date for this question is 2025-12-01...https://vertexaisearch.cloud.google.com/ground...Title: 'We will get an even better deal,' Carn...https://www.cbc.ca/news/politics/carney-even-b...2025-10-08T00:00:00Donald Trump trade and tariff actions[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] U.S.-C...ecd757d7-fa3c-47ec-bdc6-81cdf88c1b65a3db9248-0a4f-4d51-81d1-7c21057856cf12056.629
1b562085f-b58d-4284-ae0b-fcfdeb90f90cTrueWill the United States and Canada sign a forma...2025-12-01T00:00:002025-10-08T00:00:00A formal bilateral agreement or signed Memoran...2025-10-08T00:00:000binary0.95...Based on the provided reports regarding the 20...https://vertexaisearch.cloud.google.com/ground...Title: 'We will get an even better deal,' Carn...https://www.cbc.ca/news/politics/carney-even-b...2025-10-08T00:00:00Donald Trump trade and tariff actions[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Carney...a3db9248-0a4f-4d51-81d1-7c21057856cf993816e2-887d-4db0-99c7-f45e93776a8a255122.468
263dc110c-6073-4f98-abcb-57b8933b8fbfTrueWill the Trump administration hold a new offsh...2026-04-01T00:00:002025-11-26T00:00:00The question resolves to 'Yes' if the Departme...2025-11-26T00:00:001binary1.00...The Trump administration held the first new of...https://vertexaisearch.cloud.google.com/ground...Title: nytimes.com\\n\\nURL Source: https://www....https://www.nytimes.com/2025/11/26/climate/tru...2025-11-26T00:00:00Donald Trump domestic policy agenda[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Interi...87e57caa-9696-4eb8-a7ed-cd4060b0649eb4bfbb84-63db-48d6-b63d-c809b20ed5f06996.226
32581c035-5519-434f-aed1-c51326b11e73TrueWill Donald Trump and Javier Milei hold a join...2026-01-01T00:00:002025-11-27T00:00:00The question resolves to 'Yes' if Donald Trump...2025-11-27T00:00:000binary0.95...The close date for this question is 2026-01-01...https://vertexaisearch.cloud.google.com/ground...Title: The Paradox of Europe's Trumpian Right:...https://www.foreignaffairs.com/europe/paradox-...2025-11-27T00:00:00Donald Trump domestic policy agenda[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Milei ...3a88f43d-6a2c-422f-b7dd-5b6039a293757977998c-8ccb-444f-a382-1fcd8e322c2311967.560
43bd13f7e-25a8-40c0-8b2a-d9f77ef9f827TrueWill the United States official executive bran...2026-01-20T00:00:002025-11-27T00:00:00The question resolves to 'Yes' if the U.S. gov...2025-11-27T00:00:001binary0.95...Between the question date (2025-11-27) and the...https://vertexaisearch.cloud.google.com/ground...Title: The Paradox of Europe's Trumpian Right:...https://www.foreignaffairs.com/europe/paradox-...2025-11-27T00:00:00Donald Trump domestic policy agenda[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Presid...d15a72f1-d8ac-4fc9-9bde-732514e9008c7977998c-8ccb-444f-a382-1fcd8e322c2366899.224
\n", + "

5 rows × 21 columns

\n", + "
" + ], + "text/plain": [ + " sample_id is_valid \\\n", + "0 7960bd90-44c8-4cd0-bac0-3a17265101e5 True \n", + "1 b562085f-b58d-4284-ae0b-fcfdeb90f90c True \n", + "2 63dc110c-6073-4f98-abcb-57b8933b8fbf True \n", + "3 2581c035-5519-434f-aed1-c51326b11e73 True \n", + "4 3bd13f7e-25a8-40c0-8b2a-d9f77ef9f827 True \n", + "\n", + " question_text date_close \\\n", + "0 Will Donald Trump announce a complete exemptio... 2025-12-01T00:00:00 \n", + "1 Will the United States and Canada sign a forma... 2025-12-01T00:00:00 \n", + "2 Will the Trump administration hold a new offsh... 2026-04-01T00:00:00 \n", + "3 Will Donald Trump and Javier Milei hold a join... 2026-01-01T00:00:00 \n", + "4 Will the United States official executive bran... 2026-01-20T00:00:00 \n", + "\n", + " event_date resolution_criteria \\\n", + "0 2025-10-08T00:00:00 The question is answered 'Yes' if the US Presi... \n", + "1 2025-10-08T00:00:00 A formal bilateral agreement or signed Memoran... \n", + "2 2025-11-26T00:00:00 The question resolves to 'Yes' if the Departme... \n", + "3 2025-11-27T00:00:00 The question resolves to 'Yes' if Donald Trump... \n", + "4 2025-11-27T00:00:00 The question resolves to 'Yes' if the U.S. gov... \n", + "\n", + " prediction_date label answer_type label_confidence ... \\\n", + "0 2025-10-08T00:00:00 0 binary 1.00 ... \n", + "1 2025-10-08T00:00:00 0 binary 0.95 ... \n", + "2 2025-11-26T00:00:00 1 binary 1.00 ... \n", + "3 2025-11-27T00:00:00 0 binary 0.95 ... \n", + "4 2025-11-27T00:00:00 1 binary 0.95 ... \n", + "\n", + " reasoning \\\n", + "0 The close date for this question is 2025-12-01... \n", + "1 Based on the provided reports regarding the 20... \n", + "2 The Trump administration held the first new of... \n", + "3 The close date for this question is 2026-01-01... \n", + "4 Between the question date (2025-11-27) and the... \n", + "\n", + " answer_sources \\\n", + "0 https://vertexaisearch.cloud.google.com/ground... \n", + "1 https://vertexaisearch.cloud.google.com/ground... \n", + "2 https://vertexaisearch.cloud.google.com/ground... \n", + "3 https://vertexaisearch.cloud.google.com/ground... \n", + "4 https://vertexaisearch.cloud.google.com/ground... \n", + "\n", + " seed_text \\\n", + "0 Title: 'We will get an even better deal,' Carn... \n", + "1 Title: 'We will get an even better deal,' Carn... \n", + "2 Title: nytimes.com\\n\\nURL Source: https://www.... \n", + "3 Title: The Paradox of Europe's Trumpian Right:... \n", + "4 Title: The Paradox of Europe's Trumpian Right:... \n", + "\n", + " seed_url seed_creation_date \\\n", + "0 https://www.cbc.ca/news/politics/carney-even-b... 2025-10-08T00:00:00 \n", + "1 https://www.cbc.ca/news/politics/carney-even-b... 2025-10-08T00:00:00 \n", + "2 https://www.nytimes.com/2025/11/26/climate/tru... 2025-11-26T00:00:00 \n", + "3 https://www.foreignaffairs.com/europe/paradox-... 2025-11-27T00:00:00 \n", + "4 https://www.foreignaffairs.com/europe/paradox-... 2025-11-27T00:00:00 \n", + "\n", + " seed_search_query \\\n", + "0 Donald Trump trade and tariff actions \n", + "1 Donald Trump trade and tariff actions \n", + "2 Donald Trump domestic policy agenda \n", + "3 Donald Trump domestic policy agenda \n", + "4 Donald Trump domestic policy agenda \n", + "\n", + " context \\\n", + "0 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] U.S.-C... \n", + "1 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Carney... \n", + "2 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Interi... \n", + "3 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Milei ... \n", + "4 [{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Presid... \n", + "\n", + " meta_sample_id meta_parent_sample_id \\\n", + "0 ecd757d7-fa3c-47ec-bdc6-81cdf88c1b65 a3db9248-0a4f-4d51-81d1-7c21057856cf \n", + "1 a3db9248-0a4f-4d51-81d1-7c21057856cf 993816e2-887d-4db0-99c7-f45e93776a8a \n", + "2 87e57caa-9696-4eb8-a7ed-cd4060b0649e b4bfbb84-63db-48d6-b63d-c809b20ed5f0 \n", + "3 3a88f43d-6a2c-422f-b7dd-5b6039a29375 7977998c-8ccb-444f-a382-1fcd8e322c23 \n", + "4 d15a72f1-d8ac-4fc9-9bde-732514e9008c 7977998c-8ccb-444f-a382-1fcd8e322c23 \n", + "\n", + " meta_processing_time_ms \n", + "0 12056.629 \n", + "1 255122.468 \n", + "2 6996.226 \n", + "3 11967.560 \n", + "4 66899.224 \n", + "\n", + "[5 rows x 21 columns]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from lightningrod import prepare_for_training, FilterParams, SplitParams\n", + "\n", + "train_dataset, test_dataset = prepare_for_training(\n", + " dataset,\n", + " filter=FilterParams(days_to_resolution_range=(1, 60)),\n", + " split=SplitParams(test_size=0.2),\n", + ")\n", + "\n", + "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n", + " data = ds.flattened()\n", + " print(len(data))\n", + " yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n", + " print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")\n", + " display(pd.DataFrame(data).head())" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "1f06b9ae97284e6a805550a9da18a4e1", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Creating parquet from Arrow format: 0%| | 0/1 [00:00╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + " \n", + " >> Training COMPLETED \n", + " \n", + " Job: WWTD-2025 \n", + " \n", + " Reward: latest -1.3368 avg -0.8626 (6 steps) (higher is better) \n", + " \n", + " Cost: $0.11 \n", + " \n", + " \n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m WWTD-2025 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -1.3368 avg -0.8626 (6 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.11 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job 371e5ff4-ebf5-43af-8809-d26c14edb8e2 completed with status: COMPLETED\n", + "Trained model ID: checkpoint:371e5ff4-ebf5-43af-8809-d26c14edb8e2\n" + ] + } + ], + "source": [ + "job = lr.training.run(config, dataset=train_dataset, name=\"WWTD-2025\")\n", + "print(f\"Job {job.id} completed with status: {job.status}\")\n", + "print(f\"Trained model ID: {job.model_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "6487c1d9", + "metadata": {}, + "source": [ + "## Inference with your trained model\n", + "\n", + "Use `lr.predict()` to run inference with your trained model." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/datasets/bart/wwtd-forecasting-demo/commit/cd7bfd6d7addc58cf2c3ac8f6677219a1ce91a91', commit_message='Upload dataset', commit_description='', oid='cd7bfd6d7addc58cf2c3ac8f6677219a1ce91a91', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/bart/wwtd-forecasting-demo', endpoint='https://huggingface.co', repo_type='dataset', repo_id='bart/wwtd-forecasting-demo'), pr_revision=None, pr_num=None)" + "cell_type": "code", + "execution_count": null, + "id": "f013b514", + "metadata": {}, + "outputs": [], + "source": [ + "print(lr.predict(job.model_id, \"Will Trump impose 25% tariffs on all goods from Canada by February 1, 2027?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "09e1c0d7", + "metadata": {}, + "source": [ + "## Run evals on trained model\n", + "\n", + "Run test evals on your trained model against the test dataset. The eval job runs the model on the dataset and reports metrics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "853e7904", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
+              "                                                                                                                 \n",
+              "  >> Eval COMPLETED                                                                                              \n",
+              "                                                                                                                 \n",
+              "    ID: 00602447-5872-4732-93f9-b0d99459da1a                                                                     \n",
+              "    Model: checkpoint:13fa02ec-27f4-47a9-84c9-762d91a1904a                                                       \n",
+              "    Dataset: 82186c26-a309-43a6-9543-37bdda38d41d                                                                \n",
+              "                                                                                                                 \n",
+              "  ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓                                                        \n",
+              " Metric                  base  trained  benchmark \n",
+              "  ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩                                                        \n",
+              " brier_score         │  0.2333 │  0.1877 │    0.1555 │                                                        \n",
+              " ece                 │  0.1451 │  0.0963 │    0.0590 │                                                        \n",
+              " mean_reward         │ -0.7840 │ -0.6048 │   -0.4966 │                                                        \n",
+              " mean_valid_reward   │ -0.7840 │ -0.6048 │   -0.4966 │                                                        \n",
+              " n_samples           │     113 │     113 │       113 │                                                        \n",
+              " n_valid             │     113 │     113 │       113 │                                                        \n",
+              " parse_rate          │  1.0000 │  1.0000 │    1.0000 │                                                        \n",
+              " total_cost          │  0.0068 │  0.0068 │         — │                                                        \n",
+              " total_input_tokens  │   93344 │   93344 │     88060 │                                                        \n",
+              " total_output_tokens │    1111 │    1101 │     28947 │                                                        \n",
+              "  └─────────────────────┴─────────┴─────────┴───────────┘                                                        \n",
+              "                                                                                                                 \n",
+              "    Cost:  $0.01                                                                                                 \n",
+              "                                                                                                                 \n",
+              "                                                                                                                 \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Eval COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mID:\u001b[0m 00602447-5872-4732-93f9-b0d99459da1a \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mModel:\u001b[0m checkpoint:13fa02ec-27f4-47a9-84c9-762d91a1904a \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mDataset:\u001b[0m 82186c26-a309-43a6-9543-37bdda38d41d \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mMetric \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m base\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mtrained\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mbenchmark\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mbrier_score \u001b[0m\u001b[2m \u001b[0m│ 0.2333 │ 0.1877 │ 0.1555 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mece \u001b[0m\u001b[2m \u001b[0m│ 0.1451 │ 0.0963 │ 0.0590 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_reward \u001b[0m\u001b[2m \u001b[0m│ -0.7840 │ -0.6048 │ -0.4966 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_valid_reward \u001b[0m\u001b[2m \u001b[0m│ -0.7840 │ -0.6048 │ -0.4966 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_samples \u001b[0m\u001b[2m \u001b[0m│ 113 │ 113 │ 113 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_valid \u001b[0m\u001b[2m \u001b[0m│ 113 │ 113 │ 113 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mparse_rate \u001b[0m\u001b[2m \u001b[0m│ 1.0000 │ 1.0000 │ 1.0000 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_cost \u001b[0m\u001b[2m \u001b[0m│ 0.0068 │ 0.0068 │ — │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_input_tokens \u001b[0m\u001b[2m \u001b[0m│ 93344 │ 93344 │ 88060 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_output_tokens\u001b[0m\u001b[2m \u001b[0m│ 1111 │ 1101 │ 28947 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m └─────────────────────┴─────────┴─────────┴───────────┘ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.01 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset, benchmark_model_id=\"openai/gpt-5.2\")" + ] + }, + { + "cell_type": "markdown", + "id": "96d20f89", + "metadata": {}, + "source": [ + "> Note: the trained model checkpoint will only be available for 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" } - ], - "source": [ - "%pip install datasets -q\n", - "\n", - "from datasets import Dataset, DatasetDict\n", - "from lightningrod.utils import config\n", - "\n", - "dataset = DatasetDict({\n", - " \"train\": Dataset.from_list(train_dataset.flattened()),\n", - " \"test\": Dataset.from_list(test_dataset.flattened()),\n", - "})\n", - "print(f\"Train: {len(dataset['train'])} rows, Test: {len(dataset['test'])} rows\")\n", - "print(\"Columns:\", dataset[\"train\"].column_names[:8], \"...\")\n", - "\n", - "DATASET_PATH = f\"{config.get_config_value('HF_USERNAME')}/wwtd-forecasting-demo\"\n", - "dataset.push_to_hub(DATASET_PATH, token=config.get_config_value(\"HF_ACCESS_TOKEN\"))" - ] - }, - { - "cell_type": "markdown", - "id": "49a3c7f8", - "metadata": {}, - "source": [ - "## Model Training\n", - "\n", - "We used the generated dataset above to fine-tune a forecasting model via RL on 2,790 questions, surpassing GPT-5 performance.\n", - "\n", - "**For more details on methods, results, and data:**\n", - "- **[Trump-Forecaster Model](https://huggingface.co/LightningRodLabs/Trump-Forecaster)**\n", - "- **[Trump-Forecaster Dataset](https://huggingface.co/datasets/LightningRodLabs/WWTD-2025)**\n", - "\n", - "![Brier Skill Score](https://huggingface.co/datasets/LightningRodLabs/WWTD-2025/resolve/main/brier_skill_score.png)\n", - "\n", - "**Coming Soon:** Seamlessly generate datasets, fine-tune, and evaluate your own forecasting models end-to-end on the Lightningrod platform.\n", - " \n", - "\ud83d\udc49 [Sign up to get early access and updates.](https://lightningrod.ai/)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python (lightningrod-sdk)", - "language": "python", - "name": "lightningrod-sdk" + ], + "metadata": { + "kernelspec": { + "display_name": "Python (lightningrod-sdk)", + "language": "python", + "name": "lightningrod-sdk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/getting_started/05_fine_tuning.ipynb b/notebooks/getting_started/05_fine_tuning.ipynb index f2b1e86..21edf1c 100644 --- a/notebooks/getting_started/05_fine_tuning.ipynb +++ b/notebooks/getting_started/05_fine_tuning.ipynb @@ -1,290 +1,356 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "4dde071c", - "metadata": {}, - "source": [ - "# Training API\n", - "\n", - "Fine-tune forecasting models on your Lightning Rod datasets. This notebook walks through the full training workflow: generating a dataset, estimating cost, creating a training job, and monitoring progress.\n", - "\n", - "The training API supports LoRA fine-tuning with configurable base models, training steps, batch size, and rank." - ] - }, - { - "cell_type": "markdown", - "id": "fae3f735", - "metadata": {}, - "source": [ - "## Install the SDK" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1a6a1e2", - "metadata": {}, - "outputs": [], - "source": [ - "%pip install lightningrod-ai python-dotenv openai\n", - "\n", - "from IPython.display import clear_output\n", - "clear_output()" - ] - }, - { - "cell_type": "markdown", - "id": "7490b222", - "metadata": {}, - "source": [ - "## Set up the client\n", - "\n", - "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**.\n", - "\n", - "- **Google Colab**: Go to the Secrets section (key icon in left sidebar) and add a secret named `LIGHTNINGROD_API_KEY`\n", - "- **Local Jupyter**: Set the `LIGHTNINGROD_API_KEY` environment variable, or you'll be prompted to enter it" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "7a023c6c", - "metadata": {}, - "outputs": [], - "source": [ - "from dotenv import load_dotenv\n", - "from lightningrod import LightningRod\n", - "from lightningrod.utils import config\n", - "\n", - "load_dotenv()\n", - "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", - "\n", - "lr = LightningRod(api_key=api_key)" - ] - }, - { - "cell_type": "markdown", - "id": "9c49c320", - "metadata": {}, - "source": [ - "## Prepare the dataset\n", - "\n", - "Training requires a dataset ID from a pipeline run. Run one of the other notebooks first to generate a dataset - each one prints the **Dataset ID** after `transforms.run()` — copy it into the cell below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "da381a43", - "metadata": {}, - "outputs": [], - "source": [ - "dataset_id = config.get_config_value(\"LIGHTNINGROD_DATASET_ID\")\n", - "\n", - "dataset = lr.datasets.get(dataset_id)\n", - "_ = dataset.download()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9418ddf", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import filter_and_split\n", - "\n", - "train_dataset, test_dataset = filter_and_split(\n", - " dataset,\n", - " test_size=0.2,\n", - " days_to_resolution_range=(90, None),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "01e98826", - "metadata": {}, - "source": [ - "## Estimate training cost\n", - "\n", - "Before starting a job, use `estimate_cost` to see the expected cost and token usage." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "4283478f", - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Estimated cost: $0.02\n", - "Effective steps: 3\n", - "Train tokens: 63,349\n", - "Notes: Estimate uses per-answer-type output token estimates; actual may vary\n" - ] - } - ], - "source": [ - "from lightningrod import TrainingConfig\n", - "\n", - "config = TrainingConfig(\n", - " base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n", - " training_steps=50,\n", - ")\n", - "cost_estimate = lr.training.estimate_cost(config, dataset=train_dataset)\n", - "print(f\"Estimated cost: ${cost_estimate.total_cost_dollars:.2f}\")\n", - "print(f\"Effective steps: {cost_estimate.effective_steps}\")\n", - "print(f\"Train tokens: {cost_estimate.train_tokens:,}\")\n", - "print(f\"Notes: {cost_estimate.notes}\")" - ] - }, - { - "cell_type": "markdown", - "id": "ef160c9c", - "metadata": {}, - "source": [ - "## Start training\n", - "\n", - "`run` creates a job and polls until completion with a live progress display.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "b7800672", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "id": "4dde071c", + "metadata": {}, + "source": [ + "# Training API\n", + "\n", + "Fine-tune forecasting models on your Lightning Rod datasets. This notebook walks through the full training workflow: generating a dataset, estimating cost, creating a training job, and monitoring progress.\n", + "\n", + "The training API supports LoRA fine-tuning with configurable base models, training steps, batch size, and rank." + ] + }, + { + "cell_type": "markdown", + "id": "fae3f735", + "metadata": {}, + "source": [ + "## Install the SDK" + ] + }, { - "data": { - "text/html": [ - "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
-       "                                                                                                                 \n",
-       "  >> Training COMPLETED                                                                                          \n",
-       "                                                                                                                 \n",
-       "    Job: Forecasting fine-tune                                                                                   \n",
-       "                                                                                                                 \n",
-       "    Reward: latest -0.4030  avg -0.8912  (3 steps)  (higher is better)                                           \n",
-       "                                                                                                                 \n",
-       "    Cost:  $0.01                                                                                                 \n",
-       "                                                                                                                 \n",
-       "                                                                                                                 \n",
-       "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
-       "
\n" + "cell_type": "code", + "execution_count": 1, + "id": "c1a6a1e2", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install lightningrod-ai python-dotenv openai\n", + "\n", + "from IPython.display import clear_output\n", + "clear_output()" + ] + }, + { + "cell_type": "markdown", + "id": "7490b222", + "metadata": {}, + "source": [ + "## Set up the client\n", + "\n", + "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/?redirect=/api) to get your API key and **$50 of free credits**.\n", + "\n", + "- **Google Colab**: Go to the Secrets section (key icon in left sidebar) and add a secret named `LIGHTNINGROD_API_KEY`\n", + "- **Local Jupyter**: Set the `LIGHTNINGROD_API_KEY` environment variable, or you'll be prompted to enter it" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7a023c6c", + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "from lightningrod import LightningRod\n", + "from lightningrod.utils import config\n", + "\n", + "load_dotenv()\n", + "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", + "\n", + "lr = LightningRod(api_key=api_key)" + ] + }, + { + "cell_type": "markdown", + "id": "9c49c320", + "metadata": {}, + "source": [ + "## Prepare the dataset\n", + "\n", + "Training requires a dataset ID from a pipeline run. Run one of the other notebooks first to generate a dataset - each one prints the **Dataset ID** after `transforms.run()` — copy it into the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "da381a43", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_id = config.get_config_value(\"LIGHTNINGROD_DATASET_ID\")\n", + "\n", + "dataset = lr.datasets.get(dataset_id)\n", + "_ = dataset.download()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9418ddf", + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import prepare_for_training, FilterParams, SplitParams\n", + "\n", + "train_dataset, test_dataset = prepare_for_training(\n", + " dataset,\n", + " filter=FilterParams(days_to_resolution_range=(90, None)),\n", + " split=SplitParams(test_size=0.2),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "01e98826", + "metadata": {}, + "source": [ + "## Estimate training cost\n", + "\n", + "Before starting a job, use `estimate_cost` to see the expected cost and token usage." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4283478f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Estimated cost: $0.02\n", + "Effective steps: 3\n", + "Train tokens: 63,349\n", + "Notes: Estimate uses per-answer-type output token estimates; actual may vary\n" + ] + } ], - "text/plain": [ - "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m Forecasting fine-tune \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -0.4030 avg -0.8912 (3 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.01 \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + "source": [ + "from lightningrod import TrainingConfig\n", + "\n", + "config = TrainingConfig(\n", + " base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n", + " training_steps=50,\n", + ")\n", + "cost_estimate = lr.training.estimate_cost(config, dataset=train_dataset)\n", + "print(f\"Estimated cost: ${cost_estimate.total_cost_dollars:.2f}\")\n", + "print(f\"Effective steps: {cost_estimate.effective_steps}\")\n", + "print(f\"Train tokens: {cost_estimate.train_tokens:,}\")\n", + "print(f\"Notes: {cost_estimate.notes}\")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Job bd114679-610a-4334-8802-13d047a7bc30 completed with status: COMPLETED\n", - "Trained model ID: checkpoint:bd114679-610a-4334-8802-13d047a7bc30\n" - ] - } - ], - "source": [ - "job = lr.training.run(config, dataset=train_dataset, name=\"Forecasting fine-tune\")\n", - "print(f\"Job {job.id} completed with status: {job.status}\")\n", - "print(f\"Trained model ID: {job.model_id}\")" - ] - }, - { - "cell_type": "markdown", - "id": "4fe2792c", - "metadata": {}, - "source": [ - "## Inference with your trained model\n", - "\n", - "Use `lr.predict()` to run inference with your trained model. You can also use the OpenAI-compatible API directly — see [08_foresight_model.ipynb](08_foresight_model.ipynb) for the pre-trained foresight model.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "50744b98", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "id": "ef160c9c", + "metadata": {}, + "source": [ + "## Start training\n", + "\n", + "`run` creates a job and polls until completion with a live progress display.\n" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.3\n" - ] + "cell_type": "code", + "execution_count": 6, + "id": "b7800672", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
+              "                                                                                                                 \n",
+              "  >> Training COMPLETED                                                                                          \n",
+              "                                                                                                                 \n",
+              "    Job: Forecasting fine-tune                                                                                   \n",
+              "                                                                                                                 \n",
+              "    Reward: latest -0.3752  avg -0.8808  (3 steps)  (higher is better)                                           \n",
+              "                                                                                                                 \n",
+              "    Cost:  $0.01                                                                                                 \n",
+              "                                                                                                                 \n",
+              "                                                                                                                 \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m Forecasting fine-tune \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -0.3752 avg -0.8808 (3 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.01 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job 73184285-cd39-4afe-af8e-ceab345d80dc completed with status: COMPLETED\n", + "Trained model ID: checkpoint:73184285-cd39-4afe-af8e-ceab345d80dc\n" + ] + } + ], + "source": [ + "job = lr.training.run(config, dataset=train_dataset, name=\"Forecasting fine-tune\")\n", + "print(f\"Job {job.id} completed with status: {job.status}\")\n", + "print(f\"Trained model ID: {job.model_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4fe2792c", + "metadata": {}, + "source": [ + "## Inference with your trained model\n", + "\n", + "Use `lr.predict()` to run inference with your trained model. You can also use the OpenAI-compatible API directly — see [08_foresight_model.ipynb](08_foresight_model.ipynb) for the pre-trained foresight model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "50744b98", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.35\n" + ] + } + ], + "source": [ + "print(lr.predict(job.model_id, \"Will the Fed cut rates by 25bp in March 2026?\"))\n" + ] + }, + { + "cell_type": "markdown", + "id": "c8d360d3", + "metadata": {}, + "source": [ + "## Run evals on trained model\n", + "\n", + "Run test evals on your trained model against a test dataset. The eval job runs the model on the dataset and reports metrics. Use the same dataset for a quick check, or a separate test split for production." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9dd52fd4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n",
+              "                                                                                                                 \n",
+              "  >> Eval COMPLETED                                                                                              \n",
+              "                                                                                                                 \n",
+              "    ID: 4a78c3a1-a1a6-4288-9cd1-974066ddcc66                                                                     \n",
+              "    Model: checkpoint:73184285-cd39-4afe-af8e-ceab345d80dc                                                       \n",
+              "    Dataset: e87e04c3-4c0d-49ab-97bf-b30d724395d3                                                                \n",
+              "                                                                                                                 \n",
+              "  ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓                                                                    \n",
+              " Metric                  base  trained \n",
+              "  ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩                                                                    \n",
+              " brier_score         │  0.1925 │  0.1961 │                                                                    \n",
+              " ece                 │  0.0671 │  0.0510 │                                                                    \n",
+              " mean_reward         │ -0.5715 │ -0.5797 │                                                                    \n",
+              " mean_valid_reward   │ -0.5715 │ -0.5797 │                                                                    \n",
+              " n_samples           │      81 │      81 │                                                                    \n",
+              " n_valid             │      81 │      81 │                                                                    \n",
+              " parse_rate          │  1.0000 │  1.0000 │                                                                    \n",
+              " total_cost          │  0.0016 │  0.0016 │                                                                    \n",
+              " total_input_tokens  │   20154 │   20154 │                                                                    \n",
+              " total_output_tokens │     787 │     787 │                                                                    \n",
+              "  └─────────────────────┴─────────┴─────────┘                                                                    \n",
+              "                                                                                                                 \n",
+              "    Cost:  $0.00                                                                                                 \n",
+              "                                                                                                                 \n",
+              "                                                                                                                 \n",
+              "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n",
+              "
\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Eval COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mID:\u001b[0m 4a78c3a1-a1a6-4288-9cd1-974066ddcc66 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mModel:\u001b[0m checkpoint:73184285-cd39-4afe-af8e-ceab345d80dc \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mDataset:\u001b[0m e87e04c3-4c0d-49ab-97bf-b30d724395d3 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mMetric \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m base\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mtrained\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mbrier_score \u001b[0m\u001b[2m \u001b[0m│ 0.1925 │ 0.1961 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mece \u001b[0m\u001b[2m \u001b[0m│ 0.0671 │ 0.0510 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_reward \u001b[0m\u001b[2m \u001b[0m│ -0.5715 │ -0.5797 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_valid_reward \u001b[0m\u001b[2m \u001b[0m│ -0.5715 │ -0.5797 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_samples \u001b[0m\u001b[2m \u001b[0m│ 81 │ 81 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_valid \u001b[0m\u001b[2m \u001b[0m│ 81 │ 81 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mparse_rate \u001b[0m\u001b[2m \u001b[0m│ 1.0000 │ 1.0000 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_cost \u001b[0m\u001b[2m \u001b[0m│ 0.0016 │ 0.0016 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_input_tokens \u001b[0m\u001b[2m \u001b[0m│ 20154 │ 20154 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_output_tokens\u001b[0m\u001b[2m \u001b[0m│ 787 │ 787 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m └─────────────────────┴─────────┴─────────┘ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.00 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "72c2631e", + "metadata": {}, + "source": [ + "> Note: the trained model checkpoint will only be available for the period of 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (lightningrod-sdk)", + "language": "python", + "name": "lightningrod-sdk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" } - ], - "source": [ - "print(lr.predict(job.model_id, \"Will the Fed cut rates by 25bp in March 2026?\"))\n" - ] - }, - { - "cell_type": "markdown", - "id": "c8d360d3", - "metadata": {}, - "source": [ - "## Run evals on trained model\n", - "\n", - "Run test evals on your trained model against a test dataset. The eval job runs the model on the dataset and reports metrics. Use the same dataset for a quick check, or a separate test split for production." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9dd52fd4", - "metadata": {}, - "outputs": [], - "source": [ - "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset)" - ] - }, - { - "cell_type": "markdown", - "id": "72c2631e", - "metadata": {}, - "source": [ - "> Note: the trained model checkpoint will only be available for the period of 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python (lightningrod-sdk)", - "language": "python", - "name": "lightningrod-sdk" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/src/lightningrod/__init__.py b/src/lightningrod/__init__.py index 0804f0e..9f1c75a 100644 --- a/src/lightningrod/__init__.py +++ b/src/lightningrod/__init__.py @@ -9,7 +9,7 @@ from lightningrod import preprocessing, training, utils from lightningrod.utils.sample import create_sample from lightningrod.utils.models import open_router_model -from lightningrod.training import filter_and_split +from lightningrod.training import prepare_for_training, FilterParams, DedupParams, SplitParams from lightningrod.training.client import TrainingConfig from lightningrod._generated.models import ( TransformJob, @@ -97,7 +97,10 @@ "Rollout", "RolloutScorer", "RolloutGenerator", - "filter_and_split", + "prepare_for_training", + "FilterParams", + "DedupParams", + "SplitParams", "TrainingConfig", "Sample", "SampleMeta", diff --git a/src/lightningrod/_display.py b/src/lightningrod/_display.py index ccd418c..790bc4e 100644 --- a/src/lightningrod/_display.py +++ b/src/lightningrod/_display.py @@ -466,13 +466,13 @@ def _build_invalid_samples_error_message(original_message: str) -> Group: renderables.append(_safe_markup("[bold]Next steps:[/bold]")) renderables.append(_safe_markup(" • Check the dataset samples to see specific failure reasons in the 'meta.filter_reason' field")) - renderables.append(_safe_markup(" • Adjust and retry the transform pipeline (e.g., lower confidence thresholds, relax filter criteria)")) + renderables.append(_safe_markup(" • Adjust and retry the transform pipeline (e.g., try a wider date range)")) renderables.append(_safe_markup(" • If the problem persists, contact support or open a GitHub issue: [link=https://github.com/lightning-rod-labs/lightningrod-python-sdk/issues]https://github.com/lightning-rod-labs/lightningrod-python-sdk/issues[/link]")) return Group(*renderables) -def display_error(message: str, title: str = "Error", job: Any = None) -> None: +def display_error(message: str, title: str = "Error", job: Any = None, response_body: str | None = None) -> None: console = Console() renderables: list[RenderableType] = [] @@ -484,6 +484,11 @@ def display_error(message: str, title: str = "Error", job: Any = None) -> None: else: renderables.append(_safe_markup(f"[bold]{message}[/bold]")) + if response_body is not None and response_body.strip(): + renderables.append(Text("")) + renderables.append(_safe_markup("[bold]Response body:[/bold]")) + renderables.append(Text(response_body.strip()[:2000], style="dim")) + if job is not None: cost_lines = _build_transform_cost_lines(job) if isinstance(job, TransformJob) else _build_cost_lines(job) if cost_lines: @@ -493,6 +498,79 @@ def display_error(message: str, title: str = "Error", job: Any = None) -> None: console.print(Panel(Group(*renderables), border_style="bright_red", padding=(1, 2))) +def display_prepare_report(report: Any, verbose: bool = True) -> None: + """Render a PrepareReport as a Rich panel. Used inside Jupyter notebooks.""" + from lightningrod.training.samples import PrepareReport + assert isinstance(report, PrepareReport) + stats = report.stats + console = Console() + renderables: list[RenderableType] = [] + + border = "bright_green" if report.is_healthy else "yellow" + header_style = "bold bright_green" if report.is_healthy else "bold yellow" + renderables.append(_safe_markup(f"[{header_style}]>> prepare_for_training[/{header_style}]")) + renderables.append(Text("")) + + if verbose or not report.is_healthy: + renderables.append(_safe_markup(f" [dim]Starting with {stats.total} samples[/dim]")) + renderables.append(Text("")) + + parts = [] + if stats.filter_invalid: + parts.append(f"{stats.filter_invalid} invalid") + if stats.filter_horizon: + part = f"{stats.filter_horizon} horizon" + if stats.filter_missing_resolution_date or stats.filter_missing_prediction_date: + sub = [] + if stats.filter_missing_resolution_date: + sub.append(f"{stats.filter_missing_resolution_date} missing resolution date") + if stats.filter_missing_prediction_date: + sub.append(f"{stats.filter_missing_prediction_date} missing prediction date") + part += f" ({', '.join(sub)})" + parts.append(part) + if stats.filter_context: + parts.append(f"{stats.filter_context} missing context") + filter_line = ( + f" [bold]Filter:[/bold] Dropped {', '.join(parts)} → {stats.filter_kept} remain" + if parts else + f" [bold]Filter:[/bold] {stats.filter_kept} remain (0 dropped)" + ) + renderables.append(_safe_markup(filter_line)) + + if stats.dedup_removed > 0: + renderables.append(_safe_markup( + f" [bold]Dedup:[/bold] Removed {stats.dedup_removed} duplicates " + f"({stats.dedup_kept + stats.dedup_removed} → {stats.dedup_kept})" + )) + for k, c in stats.dedup_top_collisions: + q = repr(k[0])[:60] + ("..." if len(repr(k[0])) > 60 else "") + renderables.append(Text(f" ({q}, {k[1]}): {c} samples → 1", style="dim")) + else: + renderables.append(_safe_markup(f" [bold]Dedup:[/bold] {stats.dedup_kept} remain (0 duplicates)")) + + split_detail = f"Splits: {stats.split_train_after} train | {stats.split_test_after} test ({stats.split_no_sort_key} dropped, no prediction_date)" + renderables.append(_safe_markup(f" [bold]Split:[/bold] {split_detail}")) + n_leaked = stats.split_train_before - stats.split_train_after + if n_leaked: + renderables.append(_safe_markup( + f" [yellow]{n_leaked} train samples removed for leakage[/yellow]" + )) + + if not report.is_healthy: + renderables.append(Text("")) + renderables.append(_safe_markup("[bold yellow]⚠ Unhealthy dataset[/bold yellow]")) + for issue in report.issues: + renderables.append(Text("")) + renderables.append(Text(issue.message, style="bold")) + if issue.tips: + renderables.append(Text("")) + renderables.append(_safe_markup(" [dim]Tips:[/dim]")) + for tip in issue.tips: + renderables.append(Text(f" • {tip}")) + + console.print(Panel(Group(*renderables), border_style=border, padding=(1, 2))) + + def display_warning(message: str, title: str = "Warning", job: Any = None) -> None: console = Console() renderables: list[RenderableType] = [] diff --git a/src/lightningrod/training/__init__.py b/src/lightningrod/training/__init__.py index 2b7d7e4..40a7073 100644 --- a/src/lightningrod/training/__init__.py +++ b/src/lightningrod/training/__init__.py @@ -4,10 +4,13 @@ from lightningrod.training.samples import ( deduplicate_samples, filter_samples, - filter_and_split, + prepare_for_training, train_test_split, to_messages, to_record, + FilterParams, + DedupParams, + SplitParams, ) __all__ = [ @@ -15,10 +18,13 @@ "print_eval", "TrainingClient", "TrainingConfig", - "filter_and_split", + "prepare_for_training", "train_test_split", "deduplicate_samples", "filter_samples", "to_record", "to_messages", + "FilterParams", + "DedupParams", + "SplitParams", ] diff --git a/src/lightningrod/training/samples.py b/src/lightningrod/training/samples.py index 4977dc2..1d90afd 100644 --- a/src/lightningrod/training/samples.py +++ b/src/lightningrod/training/samples.py @@ -32,9 +32,19 @@ @dataclass class PrepareStats: - """Tracks metrics collected during prepare_for_training.""" + """Tracks metrics collected during prepare_for_training. + + Fields are grouped by pipeline stage. Invariants: + filter_kept = total - filter_invalid - filter_horizon - filter_context + dedup_kept = filter_kept - dedup_removed + split_train_before + split_test_after + split_no_sort_key = dedup_kept (temporal) + split_train_excluded = split_train_before - split_train_after (inferred; train leakage only) + """ + + # ── Input ──────────────────────────────────────────────────────────────── total: int = 0 + # ── Stage 1: filter_samples ─────────────────────────────────────────────── filter_invalid: int = 0 filter_horizon: int = 0 filter_context: int = 0 @@ -42,16 +52,61 @@ class PrepareStats: filter_missing_prediction_date: int = 0 filter_kept: int = 0 + # ── Stage 2: deduplicate_samples ───────────────────────────────────────── dedup_removed: int = 0 dedup_kept: int = 0 dedup_top_collisions: list[tuple[tuple[Any, ...], int]] = field(default_factory=list) - split_strategy: str = "" - split_test_size: float | None = None + # ── Stage 3: train_test_split ───────────────────────────────────────────── split_no_sort_key: int = 0 - split_leaky: int = 0 - split_train: int = 0 - split_test: int = 0 + + split_train_before: int = 0 + split_train_after: int = 0 + split_test_after: int = 0 + + +@dataclass +class FilterParams: + """Parameters for :func:`filter_samples`.""" + days_to_resolution_range: DaysToResolutionRange = None + drop_missing_context: bool = False + + +@dataclass +class DedupParams: + """Parameters for :func:`deduplicate_samples`.""" + key_fn: Callable[[Sample], tuple[Any, ...]] | None = None + + +@dataclass +class SplitParams: + """Parameters for :func:`train_test_split`.""" + strategy: str = "temporal" + test_size: float | None = 0.2 + test_start: str | None = None + random_state: int = 196 + sort_key: Callable[[Sample], str | None] | None = None + leakage_keys: list[Callable[[Sample], str | None]] | None = None + filter_leaky_train: bool = True + + +@dataclass +class PrepareIssue: + """A single detected problem in the prepared dataset, with actionable tips.""" + message: str + tips: list[str] = field(default_factory=list) + + +@dataclass +class PrepareReport: + """Full report produced by :func:`prepare_for_training`, covering all pipeline stages.""" + stats: PrepareStats + issues: list[PrepareIssue] = field(default_factory=list) + + @property + def is_healthy(self) -> bool: + return not self.issues + def _validate_days_to_resolution_range(value: Any) -> None: if value is None: @@ -85,14 +140,14 @@ def _parse_date(value: Any) -> Optional[date]: def filter_samples( samples: list[Sample], - days_to_resolution_range: DaysToResolutionRange = None, - drop_missing_context: bool = True, + params: FilterParams | None = None, stats: PrepareStats | None = None, ) -> list[Sample]: """Filter samples by validity, horizon, and optional context presence.""" - _validate_days_to_resolution_range(days_to_resolution_range) - min_horizon = days_to_resolution_range[0] if days_to_resolution_range else None - max_horizon = days_to_resolution_range[1] if days_to_resolution_range else None + params = params or FilterParams() + _validate_days_to_resolution_range(params.days_to_resolution_range) + min_horizon = params.days_to_resolution_range[0] if params.days_to_resolution_range else None + max_horizon = params.days_to_resolution_range[1] if params.days_to_resolution_range else None n_invalid = n_horizon = n_context = n_missing_resolution_date = n_missing_prediction_date = 0 filtered: list[Sample] = [] @@ -133,7 +188,7 @@ def filter_samples( if max_horizon is not None and horizon_days > max_horizon: n_horizon += 1 continue - if drop_missing_context: + if params.drop_missing_context: if not sample.context: n_context += 1 continue @@ -236,28 +291,23 @@ def get_resolution_date(sample: Sample) -> str | None: def train_test_split( samples: list[Sample], - *, - split_strategy: str = "temporal", - test_start: str | None = None, - test_size: float | None = None, - random_state: int = 196, - sort_key: Callable[[Sample], str | None] | None = None, - leakage_keys: list[Callable[[Sample], str | None]] | None = None, - filter_leaky_train: bool = True, + params: SplitParams | None = None, stats: PrepareStats | None = None, ) -> tuple[list[str], list[str]]: """Split samples into train/test by temporal order or random shuffle, with optional leakage filtering. Returns (train_ids, test_ids) for memory efficiency.""" - temporal_split = split_strategy == "temporal" + params = params or SplitParams() + temporal_split = params.strategy == "temporal" if temporal_split: - if (test_start is None) == (test_size is None): - raise ValueError("Provide exactly one of test_start or test_size when split_strategy='temporal'") + if (params.test_start is None) == (params.test_size is None): + raise ValueError("Provide exactly one of test_start or test_size when strategy='temporal'") else: - if test_size is None: - raise ValueError("test_size is required when split_strategy='random'") - if test_start is not None: - raise ValueError("test_start is only valid when split_strategy='temporal'") + if params.test_size is None: + raise ValueError("test_size is required when strategy='random'") + if params.test_start is not None: + raise ValueError("test_start is only valid when strategy='temporal'") + sort_key = params.sort_key if sort_key is None: def default_sort_key(sample: Sample) -> str | None: if not sample.question: @@ -271,24 +321,23 @@ def default_sort_key(sample: Sample) -> str | None: return None sort_key = default_sort_key - if leakage_keys is None: - leakage_keys = _default_leakage_keys() + leakage_keys = params.leakage_keys or _default_leakage_keys() if temporal_split: valid_samples = [r for r in samples if sort_key(r) is not None] n_no_sort_key = len(samples) - len(valid_samples) sorted_samples = sorted(valid_samples, key=sort_key) - if test_size is not None: - split_idx = int(len(sorted_samples) * (1 - test_size)) + if params.test_size is not None: + split_idx = int(len(sorted_samples) * (1 - params.test_size)) train, test = sorted_samples[:split_idx], sorted_samples[split_idx:] else: - assert test_start is not None - train = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) < test_start] - test = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) >= test_start] + assert params.test_start is not None + train = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) < params.test_start] + test = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) >= params.test_start] n_leaky = 0 - if filter_leaky_train and test: + if params.filter_leaky_train and test: test_cutoff = sort_key(test[0]) if test_cutoff is not None: def is_safe(row: Sample) -> bool: @@ -303,28 +352,26 @@ def is_safe(row: Sample) -> bool: n_leaky = train_before - len(train) if stats is not None: - stats.split_strategy = "temporal" stats.split_no_sort_key = n_no_sort_key - stats.split_leaky = n_leaky - stats.split_train = len(train) - stats.split_test = len(test) + stats.split_train_before = len(train) + n_leaky + stats.split_train_after = len(train) + stats.split_test_after = len(test) return [s.id for s in train], [s.id for s in test] shuffled = list(samples) - rng = random.Random(random_state) if random_state is not None else random + rng = random.Random(params.random_state) if params.random_state is not None else random rng.shuffle(shuffled) - assert test_size is not None - split_idx = int(len(shuffled) * (1 - test_size)) + assert params.test_size is not None + split_idx = int(len(shuffled) * (1 - params.test_size)) train = shuffled[:split_idx] test = shuffled[split_idx:] if stats is not None: - stats.split_strategy = "random" - stats.split_test_size = test_size - stats.split_train = len(train) - stats.split_test = len(test) + stats.split_train_before = len(train) + stats.split_train_after = len(train) + stats.split_test_after = len(test) return [s.id for s in train], [s.id for s in test] @@ -343,11 +390,12 @@ def _default_dedup_key(sample: Sample) -> tuple[Any, ...]: def deduplicate_samples( samples: list[Sample], - key_fn: Callable[[Sample], tuple[Any, ...]] | None = None, + params: DedupParams | None = None, stats: PrepareStats | None = None, ) -> list[Sample]: """Remove duplicate samples by (question_text, resolution_date) or custom key.""" - key_fn_local: Callable[[Sample], tuple[Any, ...]] = key_fn or _default_dedup_key + params = params or DedupParams() + key_fn_local: Callable[[Sample], tuple[Any, ...]] = params.key_fn or _default_dedup_key seen: set[tuple[Any, ...]] = set() key_counts: dict[tuple[Any, ...], int] = {} result: list[Sample] = [] @@ -582,59 +630,195 @@ def _render_context(context: list[Union[NewsContext, RAGContext]]) -> str: return "\n\n".join(rendered_sections) -def _print_stats(stats: PrepareStats) -> None: - print(f"[prepare_for_training] Starting with {stats.total} samples") - - parts = [] - if stats.filter_invalid: - parts.append(f"{stats.filter_invalid} invalid") - if stats.filter_horizon: - part = f"{stats.filter_horizon} horizon" - if stats.filter_missing_resolution_date or stats.filter_missing_prediction_date: - sub = [] - if stats.filter_missing_resolution_date: - sub.append(f"{stats.filter_missing_resolution_date} missing resolution date") - if stats.filter_missing_prediction_date: - sub.append(f"{stats.filter_missing_prediction_date} missing prediction date") - part += f" ({', '.join(sub)})" - parts.append(part) - if stats.filter_context: - parts.append(f"{stats.filter_context} missing context") - if parts: - print(f"[filter] Dropped {', '.join(parts)} → {stats.filter_kept} remain") - else: - print(f"[filter] {stats.filter_kept} remain (0 dropped)") +def _build_report(stats: PrepareStats, split: SplitParams, filter: FilterParams) -> PrepareReport: + """Build a structured PrepareReport from pipeline stats and params. Pure — no side effects.""" + issues: list[PrepareIssue] = [] + split_train_excluded = stats.split_train_before - stats.split_train_after - if stats.dedup_removed > 0: - print(f"[dedup] Removed {stats.dedup_removed} duplicates ({stats.dedup_kept + stats.dedup_removed} → {stats.dedup_kept}). Top colliding keys:") - for k, c in stats.dedup_top_collisions: - q = repr(k[0])[:60] + ("..." if len(repr(k[0])) > 60 else "") - print(f" ({q}, {k[1]}): {c} samples → 1") - else: - print(f"[dedup] {stats.dedup_kept} remain (0 duplicates)") + # Issue: majority of train samples leaked into test period + majority_train_leaked = ( + split.filter_leaky_train + and split.strategy == "temporal" + and stats.split_train_before > 0 + and split_train_excluded > stats.split_train_before // 2 + ) + if majority_train_leaked: + pct = int(100 * split_train_excluded / stats.split_train_before) + tips: list[str] = [] + max_horizon = filter.days_to_resolution_range[1] if filter.days_to_resolution_range else None + if max_horizon is not None: + tips.append( + f"Extend the seed generator date range to start earlier — the range should span at least " + f"{max_horizon * 2} days so questions generated near the start resolve well before the test window." + ) + else: + tips.append( + "Extend the seed generator (date) filter range to start earlier — questions generated near the start " + "will resolve well before the test window. Aim for at least 2× your max resolution horizon." + ) + tips.append( + "Generate more samples by increasing max_questions in lr.transforms.run() or removing the limit, " + "or increase questions_per_seed in your question generator config. " + "A larger, temporally well-spread dataset naturally pushes the split cutoff far enough back." + ) + tips.append( + "If very few seeds were returned by the pipeline (check the run summary table), the search queries " + "may not surface results across the full date range. Try more diverse search queries, increase " + "articles_per_search, or shorten interval_duration_days." + ) + issues.append(PrepareIssue( + message=( + f"{split_train_excluded}/{stats.split_train_before} train samples ({pct}%) were removed " + "for temporal leakage — the date_close or resolution_date of train questions extends into the test period." + ), + tips=tips, + )) + + # Issue: too few train samples for effective training + MIN_TRAIN_SAMPLES = 200 + if stats.split_train_after < MIN_TRAIN_SAMPLES and stats.split_train_after > 0: + issues.append(PrepareIssue( + message=( + f"Only {stats.split_train_after} train samples remain after preparation. " + f"This is below the recommended minimum of {MIN_TRAIN_SAMPLES} for effective training." + ), + tips=[ + "Increase max_questions in lr.transforms.run() to generate more samples.", + "Increase questions_per_seed in your question generator (ForwardLookingQuestionGenerator or QuestionGenerator) to produce more questions from each seed article." + "Add more search queries to your seed generator to diversify seed sources.", + "Widen the seed generator date range (start_date to end_date) to capture more events.", + ], + )) + + # Issue: too few test samples for reliable evaluation + MIN_TEST_SAMPLES = 50 + if stats.split_test_after < MIN_TEST_SAMPLES and stats.split_test_after > 0: + issues.append(PrepareIssue( + message=( + f"Only {stats.split_test_after} test samples remain after preparation. " + f"This is below the recommended minimum of ~{MIN_TEST_SAMPLES} for reliable evaluation." + ), + tips=[ + "Generate more samples overall — test samples come from the most recent portion of your date range.", + "Ensure your seed generator date range extends close to the present so recent events appear in the test set.", + ], + )) + + # Issue: high invalid rate (>30% of samples were invalid) + HIGH_INVALID_THRESHOLD = 0.30 + if stats.total > 0 and stats.filter_invalid / stats.total > HIGH_INVALID_THRESHOLD: + pct = int(100 * stats.filter_invalid / stats.total) + issues.append(PrepareIssue( + message=( + f"{stats.filter_invalid}/{stats.total} samples ({pct}%) were marked invalid. " + "This suggests issues with dataset generation configuration." + ), + tips=[ + "Add more examples and bad_examples to guide the question generator toward more sensible questions.", + "Check the labeler configuration — if WebSearchLabeler can't find resolution info, samples are marked invalid.", + "Inspect a few invalid samples with dataset.flattened() to identify patterns.", + ], + )) + + # Issue: high dedup rate (>40% removed as duplicates) + HIGH_DEDUP_THRESHOLD = 0.40 + if stats.filter_kept > 0 and stats.dedup_removed / stats.filter_kept > HIGH_DEDUP_THRESHOLD: + pct = int(100 * stats.dedup_removed / stats.filter_kept) + issues.append(PrepareIssue( + message=( + f"{stats.dedup_removed}/{stats.filter_kept} samples ({pct}%) were duplicates. " + "The pipeline is generating repetitive or similar questions." + ), + tips=[ + "Add more diverse search queries to your seed generator to surface different source articles.", + "Increase interval_duration_days to spread seeds across more time periods.", + "Add bad_examples to your question generator showing the repetitive patterns to avoid.", + "Use more specific instructions in your question generator to encourage variety.", + ], + )) + + # Issue: high horizon filter rate (>50% filtered out by horizon) + HIGH_HORIZON_THRESHOLD = 0.50 + if stats.total > 0 and stats.filter_horizon / stats.total > HIGH_HORIZON_THRESHOLD: + pct = int(100 * stats.filter_horizon / stats.total) + horizon_desc = "" + if filter.days_to_resolution_range: + min_h, max_h = filter.days_to_resolution_range + if min_h is not None and max_h is not None: + horizon_desc = f" (required: {min_h}-{max_h} days)" + elif min_h is not None: + horizon_desc = f" (required: ≥{min_h} days)" + elif max_h is not None: + horizon_desc = f" (required: ≤{max_h} days)" + issues.append(PrepareIssue( + message=( + f"{stats.filter_horizon}/{stats.total} samples ({pct}%) fell outside the resolution horizon{horizon_desc}. " + ), + tips=[ + "Widen your days_to_resolution_range in FilterParams if your use case allows longer/shorter horizons.", + "Adjust your seed generator date range — questions from very recent seeds may not have resolved yet, " + "while questions from old seeds may exceed your max horizon.", + "Check that your question generator is producing questions with appropriate resolution timelines for your target horizon." + ], + )) + + return PrepareReport(stats=stats, issues=issues) + + +def _print_report(report: PrepareReport, verbose: bool) -> None: + """Print the report to stdout and raise if unhealthy (non-notebook path).""" + stats = report.stats + if verbose: + print(f"[prepare_for_training] Starting with {stats.total} samples") + + parts = [] + if stats.filter_invalid: + parts.append(f"{stats.filter_invalid} invalid") + if stats.filter_horizon: + part = f"{stats.filter_horizon} horizon" + if stats.filter_missing_resolution_date or stats.filter_missing_prediction_date: + sub = [] + if stats.filter_missing_resolution_date: + sub.append(f"{stats.filter_missing_resolution_date} missing resolution date") + if stats.filter_missing_prediction_date: + sub.append(f"{stats.filter_missing_prediction_date} missing prediction date") + part += f" ({', '.join(sub)})" + parts.append(part) + if stats.filter_context: + parts.append(f"{stats.filter_context} missing context") + print(f"[filter] Dropped {', '.join(parts)} → {stats.filter_kept} remain" if parts else f"[filter] {stats.filter_kept} remain (0 dropped)") + + if stats.dedup_removed > 0: + print(f"[dedup] Removed {stats.dedup_removed} duplicates ({stats.dedup_kept + stats.dedup_removed} → {stats.dedup_kept}). Top colliding keys:") + for k, c in stats.dedup_top_collisions: + q = repr(k[0])[:60] + ("..." if len(repr(k[0])) > 60 else "") + print(f" ({q}, {k[1]}): {c} samples → 1") + else: + print(f"[dedup] {stats.dedup_kept} remain (0 duplicates)") - if stats.split_strategy == "temporal": if stats.split_no_sort_key: print(f"[split] {stats.split_no_sort_key} samples had no prediction_date (dropped)") - if stats.split_leaky: - print(f"[split] {stats.split_leaky} train samples removed for leakage") - print(f"[split] Temporal split: {stats.split_train} train, {stats.split_test} test") - else: - print(f"[split] Random split (test_size={stats.split_test_size}): {stats.split_train} train, {stats.split_test} test") + n_leaked = stats.split_train_before - stats.split_train_after + if n_leaked: + print(f"[split] {n_leaked} train samples removed for leakage") + print(f"[split] Temporal split: {stats.split_train_after} train, {stats.split_test_after} test") + if not report.is_healthy: + lines = ["[prepare_for_training] Unhealthy split detected."] + for issue in report.issues: + lines.append(issue.message) + if issue.tips: + lines.append("Tips:\n" + "\n".join(f" - {t}" for t in issue.tips)) + raise ValueError("\n\n".join(lines)) -def filter_and_split( + +def prepare_for_training( dataset: "SampleDataset", *, - test_size: float = 0.2, - split_strategy: str = "temporal", - test_start: str | None = None, - drop_missing_context: bool = False, - days_to_resolution_range: DaysToResolutionRange = None, - random_state: int = 196, - filter_leaky_train: bool = True, - deduplicate_key_fn: Callable[[Sample], tuple[Any, ...]] | None = None, - verbose: bool = False, + filter: FilterParams | None = None, + dedup: DedupParams | None = None, + split: SplitParams | None = None, + verbose: bool = True, ) -> tuple["SampleDataset", "SampleDataset"]: """Prepare a dataset for model training: filter, deduplicate, split into train/test. @@ -644,42 +828,30 @@ def filter_and_split( Args: dataset: SampleDataset to prepare (samples are fetched via dataset.samples()). - test_size: Fraction of samples for the test set (0.0–1.0). Default 0.2. - split_strategy: 'temporal' (default) or 'random'. - test_start: ISO date string for temporal splits. Provide exactly one of - test_start or test_size for temporal splits. - drop_missing_context: If True, exclude samples with no context. - days_to_resolution_range: Optional (min_days, max_days) tuple. - random_state: Seed for reproducible random splits. - filter_leaky_train: When True and temporal, remove temporal leakage. - deduplicate_key_fn: Optional function to customize deduplication key. + filter: Controls validity filtering and horizon range. See :class:`FilterParams`. + dedup: Controls deduplication key. See :class:`DedupParams`. + split: Controls train/test split strategy, size, and leakage filtering. See :class:`SplitParams`. verbose: When True, print step-by-step stats. Returns: (train_dataset, test_dataset): SampleDatasets ready for training/eval. """ + filter = filter or FilterParams() + dedup = dedup or DedupParams() + split = split or SplitParams() + samples = dataset.samples() stats = PrepareStats(total=len(samples)) - filtered = filter_samples( - samples, - days_to_resolution_range=days_to_resolution_range, - drop_missing_context=drop_missing_context, - stats=stats, - ) - deduped = deduplicate_samples(filtered, key_fn=deduplicate_key_fn, stats=stats) - - train_ids, test_ids = train_test_split( - deduped, - split_strategy=split_strategy, - test_start=test_start, - filter_leaky_train=filter_leaky_train, - test_size=test_size, - random_state=random_state, - stats=stats, - ) + filtered = filter_samples(samples, filter, stats=stats) + deduped = deduplicate_samples(filtered, dedup, stats=stats) + train_ids, test_ids = train_test_split(deduped, split, stats=stats) - if verbose: - _print_stats(stats) + report = _build_report(stats, split=split, filter=filter) + from lightningrod._display import _is_notebook, display_prepare_report + if _is_notebook(): + display_prepare_report(report, verbose=verbose) + else: + _print_report(report, verbose=verbose) return dataset.subset(train_ids), dataset.subset(test_ids) \ No newline at end of file