diff --git a/notebooks/fine_tuning/01_golf_forecasting.ipynb b/notebooks/fine_tuning/01_golf_forecasting.ipynb index 1016780..1040c33 100644 --- a/notebooks/fine_tuning/01_golf_forecasting.ipynb +++ b/notebooks/fine_tuning/01_golf_forecasting.ipynb @@ -1,491 +1,1140 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "header", - "metadata": {}, - "source": [ - "# Golf Forecasting\n", - "\n", - "Generate a forecasting dataset about professional golf (tournaments, majors, rankings) using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments." - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "aaddd17b", - "metadata": {}, - "outputs": [ + "cells": [ { - "data": { - "text/plain": [ - "True" + "cell_type": "markdown", + "id": "header", + "metadata": {}, + "source": [ + "# Golf Forecasting\n", + "\n", + "Generate a forecasting dataset about professional golf (tournaments, majors, rankings) using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%pip install lightningrod-ai python-dotenv pandas\n", - "\n", - "from IPython.display import clear_output\n", - "clear_output()\n", - "\n", - "from datetime import datetime\n", - "\n", - "import pandas as pd\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()" - ] - }, - { - "cell_type": "markdown", - "id": "part1", - "metadata": {}, - "source": [ - "## Set up the client\n", - "\n", - "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**." - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "sdk-setup", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import LightningRod\n", - "from lightningrod.utils import config\n", - "\n", - "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", - "lr = LightningRod(api_key=api_key)" - ] - }, - { - "cell_type": "markdown", - "id": "a04964c0", - "metadata": {}, - "source": [ - "## Build the pipeline\n", - "\n", - "Configure the pipeline with domain-specific instructions and examples for golf forecasting." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "config", - "metadata": {}, - "outputs": [], - "source": [ - "instructions = \"\"\"\n", - "Generate binary forecasting questions about professional golf across all major tours and events.\n", - "\n", - "Cover what golf fans bet on: tournament outcomes, cuts, matchups, majors, team events, season races, world rankings, and player milestones.\n", - "\n", - "Questions should be specific, verifiable, and span the full probability spectrum.\n", - "\"\"\"\n", - "\n", - "good_examples = [\n", - " \"Will Scottie Scheffler win the 2025 Masters?\",\n", - " \"Will the 2025 US Open winning score be under par?\",\n", - " \"Will Tiger Woods make the cut at the 2025 Masters?\",\n", - " \"Will Rory McIlroy finish top 5 at the 2025 US Open?\",\n", - " \"Will any LIV player win a major championship in 2025?\",\n", - " \"Will Europe win the 2025 Ryder Cup?\",\n", - " \"Will any player win 4+ PGA Tour events in 2025?\",\n", - " \"Will Scottie Scheffler remain world #1 through June 2025?\",\n", - " \"Will a first-time major winner emerge at the 2025 PGA Championship?\",\n", - " \"Will Nelly Korda win the 2025 US Women's Open?\",\n", - "]\n", - "\n", - "bad_examples = [\n", - " \"Will someone win the tournament? (obvious)\",\n", - " \"Will golf be exciting? (subjective)\",\n", - " \"Will there be birdies? (trivial)\",\n", - "]\n", - "\n", - "search_queries = [\n", - " \"PGA Tour\",\n", - " \"LIV Golf\",\n", - " \"LPGA\",\n", - " \"golf major championship\",\n", - " \"Ryder Cup Presidents Cup\",\n", - " \"golf world rankings\",\n", - " \"professional golf\",\n", - " \"women's golf\",\n", - " \"European Tour golf\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "pipeline", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import (\n", - " BinaryAnswerType,\n", - " NewsSeedGenerator,\n", - " ForwardLookingQuestionGenerator,\n", - " NewsContextGenerator,\n", - " WebSearchLabeler,\n", - " QuestionPipeline,\n", - ")\n", - "\n", - "answer_type = BinaryAnswerType()\n", - "\n", - "pipeline = QuestionPipeline(\n", - " seed_generator=NewsSeedGenerator(\n", - " start_date=datetime(2024, 6, 1),\n", - " end_date=datetime(2026, 1, 1),\n", - " interval_duration_days=14,\n", - " search_query=search_queries,\n", - " articles_per_search=10,\n", - " ),\n", - " question_generator=ForwardLookingQuestionGenerator(\n", - " instructions=instructions,\n", - " examples=good_examples,\n", - " bad_examples=bad_examples,\n", - " answer_type=answer_type,\n", - " questions_per_seed=5,\n", - " ),\n", - " context_generators=[\n", - " NewsContextGenerator(\n", - " articles_per_query=3,\n", - " num_search_queries=3,\n", - " num_articles=5,\n", - " )\n", - " ],\n", - " labeler=WebSearchLabeler(answer_type=answer_type),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "324a35cc", - "metadata": {}, - "source": [ - "## Run the pipeline\n", - "\n", - "This will collect news articles, generate questions, and find answers. Use `max_questions` to limit the run for testing." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "7f62e9c4", - "metadata": {}, - "outputs": [ + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f25abaaeb92e42f1bca02f0ea69c7f15", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" + "cell_type": "code", + "execution_count": 1, + "id": "aaddd17b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%pip install lightningrod-ai python-dotenv pandas openai\n", + "\n", + "from IPython.display import clear_output\n", + "clear_output()\n", + "\n", + "from datetime import datetime\n", + "\n", + "import pandas as pd\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "
\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" + "cell_type": "markdown", + "id": "part1", + "metadata": {}, + "source": [ + "## Set up the client\n", + "\n", + "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/?redirect=/api) to get your API key and **$50 of free credits**." + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "100 samples (87.0% valid)\n" - ] - } - ], - "source": [ - "dataset = lr.transforms.run(pipeline, max_questions=100, name=\"Golf forecasting\")\n", - "samples = dataset.download()\n", - "\n", - "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n", - "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")" - ] - }, - { - "cell_type": "markdown", - "id": "fdbi5exhd6c", - "metadata": {}, - "source": [ - "## Prepare the dataset\n", - "\n", - "Use SDK utils to filter valid samples, deduplicate, and split into train/test sets. We filter by `date_close <= today` to only include questions that have already resolved." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "upload", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 2, + "id": "sdk-setup", + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import LightningRod\n", + "from lightningrod.utils import config\n", + "\n", + "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", + "lr = LightningRod(api_key=api_key)" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: 37 rows, 32.4% yes\n", - "Test: 18 rows, 44.4% yes\n" - ] - } - ], - "source": [ - "from lightningrod import filter_and_split\n", - "\n", - "train_dataset, test_dataset = filter_and_split(\n", - " dataset,\n", - " test_size=0.2,\n", - " split_strategy=\"temporal\",\n", - " days_to_resolution_range=(1, None), # at least 1 day to resolution\n", - ")\n", - "\n", - "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n", - " data = ds.flattened()\n", - " yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n", - " print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")\n", - " display(pd.DataFrame(data).head())" - ] - }, - { - "cell_type": "markdown", - "id": "b2e9efba", - "metadata": {}, - "source": [ - "## Uploading the dataset to HuggingFace\n", - "\n", - "Once we have a training-ready dataset, we can push it to Hugging Face for sharing or downstream use." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "ae7c826b", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "id": "a04964c0", + "metadata": {}, + "source": [ + "## Build the pipeline\n", + "\n", + "Configure the pipeline with domain-specific instructions and examples for golf forecasting." + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", - "Note: you may need to restart the kernel to use updated packages.\n", - "Train: 37 rows, Test: 18 rows\n", - "Columns: ['question_text', 'date_close', 'event_date', 'resolution_criteria', 'prediction_date', 'label', 'answer_type', 'label_confidence'] ...\n" - ] + "cell_type": "code", + "execution_count": 3, + "id": "config", + "metadata": {}, + "outputs": [], + "source": [ + "instructions = \"\"\"\n", + "Generate binary forecasting questions about professional golf across all major tours and events.\n", + "\n", + "Cover what golf fans bet on: tournament outcomes, cuts, matchups, majors, team events, season races, world rankings, and player milestones.\n", + "\n", + "Questions should be specific, verifiable, and span the full probability spectrum.\n", + "\"\"\"\n", + "\n", + "good_examples = [\n", + " \"Will Scottie Scheffler win the 2025 Masters?\",\n", + " \"Will the 2025 US Open winning score be under par?\",\n", + " \"Will Tiger Woods make the cut at the 2025 Masters?\",\n", + " \"Will Rory McIlroy finish top 5 at the 2025 US Open?\",\n", + " \"Will any LIV player win a major championship in 2025?\",\n", + " \"Will Europe win the 2025 Ryder Cup?\",\n", + " \"Will any player win 4+ PGA Tour events in 2025?\",\n", + " \"Will Scottie Scheffler remain world #1 through June 2025?\",\n", + " \"Will a first-time major winner emerge at the 2025 PGA Championship?\",\n", + " \"Will Nelly Korda win the 2025 US Women's Open?\",\n", + "]\n", + "\n", + "bad_examples = [\n", + " \"Will someone win the tournament? (obvious)\",\n", + " \"Will golf be exciting? (subjective)\",\n", + " \"Will there be birdies? (trivial)\",\n", + "]\n", + "\n", + "search_queries = [\n", + " \"PGA Tour\",\n", + " \"LIV Golf\",\n", + " \"LPGA\",\n", + " \"golf major championship\",\n", + " \"Ryder Cup Presidents Cup\",\n", + " \"golf world rankings\",\n", + " \"professional golf\",\n", + " \"women's golf\",\n", + " \"European Tour golf\",\n", + "]" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "453c59f9daf345828e087cc2a47af33f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Uploading the dataset shards: 0%| | 0/1 [00:00, ? shards/s]" + "cell_type": "code", + "execution_count": 4, + "id": "pipeline", + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import (\n", + " BinaryAnswerType,\n", + " NewsSeedGenerator,\n", + " ForwardLookingQuestionGenerator,\n", + " NewsContextGenerator,\n", + " WebSearchLabeler,\n", + " QuestionPipeline,\n", + ")\n", + "\n", + "answer_type = BinaryAnswerType()\n", + "\n", + "pipeline = QuestionPipeline(\n", + " seed_generator=NewsSeedGenerator(\n", + " start_date=datetime(2024, 6, 1),\n", + " end_date=datetime(2026, 1, 1),\n", + " interval_duration_days=14,\n", + " search_query=search_queries,\n", + " articles_per_search=10,\n", + " ),\n", + " question_generator=ForwardLookingQuestionGenerator(\n", + " instructions=instructions,\n", + " examples=good_examples,\n", + " bad_examples=bad_examples,\n", + " answer_type=answer_type,\n", + " questions_per_seed=5,\n", + " ),\n", + " context_generators=[\n", + " NewsContextGenerator(\n", + " articles_per_query=3,\n", + " num_search_queries=3,\n", + " num_articles=5,\n", + " )\n", + " ],\n", + " labeler=WebSearchLabeler(answer_type=answer_type),\n", + ")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "317b75c5d8fb4a9589a7e06b866f0bbc", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Creating parquet from Arrow format: 0%| | 0/1 [00:00, ?ba/s]" + "cell_type": "markdown", + "id": "324a35cc", + "metadata": {}, + "source": [ + "## Run the pipeline\n", + "\n", + "This will collect news articles, generate questions, and find answers. Use `max_questions` to limit the run for testing." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "83958fad318c406b986919f051bf3c1a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing Files (0 / 0): | | 0.00B / 0.00B " + "cell_type": "code", + "execution_count": 5, + "id": "7f62e9c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Pipeline Completed │\n", + "│ │\n", + "│ Total cost: $47.67 │\n", + "│ │\n", + "│ ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ │\n", + "│ ┃ Step ┃ Progress ┃ In ┃ Out ┃ Rejected ┃ Errors ┃ Rejection Reasons ┃ Duration ┃ │\n", + "│ ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ │\n", + "│ │ NewsSeedGenerator… │ Complete │ 20 │ 200 │ 0 │ 0 │ - │ 19s │ │\n", + "│ │ ForwardLookingQue… │ Complete │ 200 │ 965 │ 33 │ 0 │ date_close not │ 15s │ │\n", + "│ │ │ │ │ │ │ │ after event_date │ │ │\n", + "│ │ │ │ │ │ │ │ (33) │ │ │\n", + "│ │ WebSearchLabelerT… │ Complete │ 965 │ 796 │ 169 │ 0 │ Undetermined label │ 1m 8s │ │\n", + "│ │ │ │ │ │ │ │ (164), Resolution │ │ │\n", + "│ │ │ │ │ │ │ │ date is before │ │ │\n", + "│ │ │ │ │ │ │ │ seed creation date │ │ │\n", + "│ │ │ │ │ │ │ │ (5) │ │ │\n", + "│ │ NewsContextGenera… │ Complete │ 796 │ 795 │ 1 │ 0 │ <failed_attempts> │ 11m 37s │ │\n", + "│ │ │ │ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ │ <generation │ │ │\n", + "│ │ │ │ │ │ │ │ number=\"1\"> │ │ │\n", + "│ │ │ │ │ │ │ │ <exception> │ │ │\n", + "│ │ │ │ │ │ │ │ Request timed │ │ │\n", + "│ │ │ │ │ │ │ │ out. │ │ │\n", + "│ │ │ │ │ │ │ │ </exception> │ │ │\n", + "│ │ │ │ │ │ │ │ <completion> │ │ │\n", + "│ │ │ │ │ │ │ │ None │ │ │\n", + "│ │ │ │ │ │ │ │ </completion> │ │ │\n", + "│ │ │ │ │ │ │ │ </generation> │ │ │\n", + "│ │ │ │ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ │ <generation │ │ │\n", + "│ │ │ │ │ │ │ │ number=\"2\"> │ │ │\n", + "│ │ │ │ │ │ │ │ <exception> │ │ │\n", + "│ │ │ │ │ │ │ │ Connection │ │ │\n", + "│ │ │ │ │ │ │ │ error. │ │ │\n", + "│ │ │ │ │ │ │ │ </exception> │ │ │\n", + "│ │ │ │ │ │ │ │ <completion> │ │ │\n", + "│ │ │ │ │ │ │ │ None │ │ │\n", + "│ │ │ │ │ │ │ │ </completion> │ │ │\n", + "│ │ │ │ │ │ │ │ </generation> │ │ │\n", + "│ │ │ │ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ │ </failed_attempts> │ │ │\n", + "│ │ │ │ │ │ │ │ │ │ │\n", + "│ │ │ │ │ │ │ │ <last_exception> │ │ │\n", + "│ │ │ │ │ │ │ │ Connection │ │ │\n", + "│ │ │ │ │ │ │ │ error. │ │ │\n", + "│ │ │ │ │ │ │ │ </last_exception> │ │ │\n", + "│ │ │ │ │ │ │ │ (1) │ │ │\n", + "│ └────────────────────┴──────────────────────┴─────┴─────┴──────────┴────────┴────────────────────┴──────────┘ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[92m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[1;92m>> Pipeline Completed\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[1mTotal cost:\u001b[0m \u001b[92m$47.67\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mStep \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mProgress \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m In\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mOut\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejected\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mErrors\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejection Reasons \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mDuration\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mNewsSeedGenerator…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 20 │ 200 │ \u001b[2m 0\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2m- \u001b[0m │ 19s │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mForwardLookingQue…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 200 │ 965 │ \u001b[91m 33\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2mdate_close not \u001b[0m │ 15s │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mafter event_date \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2m(33) \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mWebSearchLabelerT…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 965 │ 796 │ \u001b[91m 169\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2mUndetermined label\u001b[0m │ 1m 8s │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2m(164), Resolution \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mdate is before \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mseed creation date\u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2m(5) \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mNewsContextGenera…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 796 │ 795 │ \u001b[91m 1\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2m
| \n", + " | sample_id | \n", + "is_valid | \n", + "question_text | \n", + "date_close | \n", + "event_date | \n", + "resolution_criteria | \n", + "prediction_date | \n", + "label | \n", + "answer_type | \n", + "label_confidence | \n", + "... | \n", + "reasoning | \n", + "answer_sources | \n", + "seed_text | \n", + "seed_url | \n", + "seed_creation_date | \n", + "seed_search_query | \n", + "context | \n", + "meta_sample_id | \n", + "meta_parent_sample_id | \n", + "meta_processing_time_ms | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "23a607a2-e9db-45a9-8e33-67cd16a32b56 | \n", + "True | \n", + "Will the Eastern Michigan University women's g... | \n", + "2025-05-10T00:00:00 | \n", + "2024-07-15T00:00:00 | \n", + "The question resolves to 'Yes' if Eastern Mich... | \n", + "2024-07-15T00:00:00 | \n", + "0 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The Eastern Michigan University (EMU) women's ... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Eastern Michigan Athletics\\nCaterina Don Named... | \n", + "https://emueagles.com/news/2024/7/9/womens-gol... | \n", + "2024-07-15T00:00:00 | \n", + "women's golf | \n", + "[{'rendered_context': '', 'search_query': 'Eas... | \n", + "fa146afa-b53f-48c1-8d2d-a81ce2dec41b | \n", + "0107be94-88d9-4068-a355-ec38b8691376 | \n", + "844641.292 | \n", + "
| 1 | \n", + "2736b5ea-b6b0-4fde-a237-96f6a3d9ee86 | \n", + "True | \n", + "Will an Arizona Wildcats player be named the B... | \n", + "2025-05-01T00:00:00 | \n", + "2024-07-15T00:00:00 | \n", + "The question resolves to 'Yes' if the Big 12 C... | \n", + "2024-07-15T00:00:00 | \n", + "1 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The Arizona Wildcats officially joined the Big... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "TUCSON, Ariz. – Arizona Women's Golf Head Coac... | \n", + "https://arizonawildcats.com/news/2024/7/15/bra... | \n", + "2024-07-15T00:00:00 | \n", + "women's golf | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Arizon... | \n", + "ece95e4d-151b-4af8-936a-5c7e18276b97 | \n", + "f7429e77-a524-46dd-ae47-821794e79938 | \n", + "988488.435 | \n", + "
| 2 | \n", + "2eae4276-f449-45e9-8973-e760b6d36d61 | \n", + "True | \n", + "Will Caterina Don remain in her role as the As... | \n", + "2025-05-31T00:00:00 | \n", + "2024-07-15T00:00:00 | \n", + "The question resolves to 'Yes' if Caterina Don... | \n", + "2024-07-15T00:00:00 | \n", + "1 | \n", + "binary | \n", + "0.95 | \n", + "... | \n", + "Caterina Don was hired as the first full-time ... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Eastern Michigan Athletics\\nCaterina Don Named... | \n", + "https://emueagles.com/news/2024/7/9/womens-gol... | \n", + "2024-07-15T00:00:00 | \n", + "women's golf | \n", + "[{'rendered_context': '', 'search_query': 'Cat... | \n", + "4fbfad5d-b425-4cfe-b04d-1db1279c80a8 | \n", + "0107be94-88d9-4068-a355-ec38b8691376 | \n", + "485377.689 | \n", + "
| 3 | \n", + "35b2be70-0b7c-4b1b-b82f-2e3f0d3b61d8 | \n", + "True | \n", + "Will the University of North Carolina women's ... | \n", + "2025-04-15T00:00:00 | \n", + "2024-07-15T00:00:00 | \n", + "The question resolves to Yes if the UNC women'... | \n", + "2024-07-15T00:00:00 | \n", + "1 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The University of North Carolina women's golf ... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "University of North Carolina Athletics\\nNeff's... | \n", + "https://goheels.com/news/2024/7/15/neffs-contr... | \n", + "2024-07-15T00:00:00 | \n", + "women's golf | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] 2024-2... | \n", + "05a3aff2-dee2-4111-a0f4-c085ba138679 | \n", + "c9839639-2fd2-4f0a-ad52-4566bbde89be | \n", + "1052828.897 | \n", + "
| 4 | \n", + "382c5c2e-db8f-4b8c-bbe8-6fc4e55edf53 | \n", + "True | \n", + "Will the University of Arizona Women's Golf te... | \n", + "2025-04-15T00:00:00 | \n", + "2024-07-15T00:00:00 | \n", + "The question resolves to 'Yes' if the Universi... | \n", + "2024-07-15T00:00:00 | \n", + "1 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The University of Arizona Women's Golf team wo... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "TUCSON, Ariz. – Arizona Women's Golf Head Coac... | \n", + "https://arizonawildcats.com/news/2024/7/15/bra... | \n", + "2024-07-15T00:00:00 | \n", + "women's golf | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Arizon... | \n", + "1fd71efb-9506-4c76-b1b6-d8eed1a30ac7 | \n", + "f7429e77-a524-46dd-ae47-821794e79938 | \n", + "1089828.234 | \n", + "
5 rows × 21 columns
\n", + "| \n", + " | sample_id | \n", + "is_valid | \n", + "question_text | \n", + "date_close | \n", + "event_date | \n", + "resolution_criteria | \n", + "prediction_date | \n", + "label | \n", + "answer_type | \n", + "label_confidence | \n", + "... | \n", + "reasoning | \n", + "answer_sources | \n", + "seed_text | \n", + "seed_url | \n", + "seed_creation_date | \n", + "seed_search_query | \n", + "context | \n", + "meta_sample_id | \n", + "meta_parent_sample_id | \n", + "meta_processing_time_ms | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "69d062b4-681c-431f-9e01-e13befba3ea0 | \n", + "True | \n", + "Will Luke Clanton finish in the top 10 of the ... | \n", + "2025-07-07T00:00:00 | \n", + "2025-06-24T00:00:00 | \n", + "The question resolves to 'Yes' if Luke Clanton... | \n", + "2025-06-24T00:00:00 | \n", + "0 | \n", + "binary | \n", + "1.0 | \n", + "... | \n", + "Luke Clanton participated in the 2025 John Dee... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: No. 15 Ben Griffin, rising star Luke Cl... | \n", + "https://www.wqad.com/article/sports/john-deere... | \n", + "2025-06-24T00:00:00 | \n", + "golf world rankings | \n", + "[{'rendered_context': '', 'search_query': 'Luk... | \n", + "7cdbf935-029e-4ce5-bcb4-d6cbf016303a | \n", + "c96f580b-8dff-42a0-90dc-3e5c888680c6 | \n", + "489996.714 | \n", + "
| 1 | \n", + "c3ef9b69-79e5-42c4-a26c-b0d02bf82abb | \n", + "True | \n", + "Will Ben Griffin be ranked in the top 10 of th... | \n", + "2025-07-07T00:00:00 | \n", + "2025-06-24T00:00:00 | \n", + "The question resolves to 'Yes' if Ben Griffin'... | \n", + "2025-06-24T00:00:00 | \n", + "0 | \n", + "binary | \n", + "1.0 | \n", + "... | \n", + "Ben Griffin was ranked No. 17 in the Official ... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: No. 15 Ben Griffin, rising star Luke Cl... | \n", + "https://www.wqad.com/article/sports/john-deere... | \n", + "2025-06-24T00:00:00 | \n", + "golf world rankings | \n", + "[{'rendered_context': '', 'search_query': 'Ben... | \n", + "c3a577ab-069d-4857-8c52-9fce6f44e876 | \n", + "c96f580b-8dff-42a0-90dc-3e5c888680c6 | \n", + "493772.150 | \n", + "
| 2 | \n", + "e1d9f94a-8fa0-4caf-a85b-2f2602e7e9ae | \n", + "True | \n", + "Will Luke Clanton outscore Ben Griffin in the ... | \n", + "2025-07-04T00:00:00 | \n", + "2025-06-24T00:00:00 | \n", + "The question resolves to 'Yes' if Luke Clanton... | \n", + "2025-06-24T00:00:00 | \n", + "1 | \n", + "binary | \n", + "1.0 | \n", + "... | \n", + "The first round of the 2025 John Deere Classic... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: No. 15 Ben Griffin, rising star Luke Cl... | \n", + "https://www.wqad.com/article/sports/john-deere... | \n", + "2025-06-24T00:00:00 | \n", + "golf world rankings | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Player... | \n", + "41c1a085-7b54-4a59-b7c8-145a2c9f225f | \n", + "c96f580b-8dff-42a0-90dc-3e5c888680c6 | \n", + "693496.445 | \n", + "
| 3 | \n", + "093a04a1-03f4-429c-a8cd-c2f7c8ea098c | \n", + "True | \n", + "Will Jordan Smith win the 2025 Italian Open? | \n", + "2025-06-30T00:00:00 | \n", + "2025-06-25T00:00:00 | \n", + "This question resolves to Yes if Jordan Smith ... | \n", + "2025-06-25T00:00:00 | \n", + "0 | \n", + "binary | \n", + "1.0 | \n", + "... | \n", + "The 2025 Italian Open (golf) took place from J... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: 2025 Italian Open betting tips: Our exp... | \n", + "https://www.todays-golfer.com/news-and-events/... | \n", + "2025-06-25T00:00:00 | \n", + "European Tour golf | \n", + "[{'rendered_context': '', 'search_query': 'Jor... | \n", + "cebd9023-809d-441b-9ef5-251381880f5c | \n", + "20e82969-aad5-477a-8e2e-cd641f7d7eec | \n", + "493786.983 | \n", + "
| 4 | \n", + "24bfbd2f-7a6c-4383-82aa-94d73f429685 | \n", + "True | \n", + "Will Eddie Pepperell win at least one tourname... | \n", + "2025-11-30T00:00:00 | \n", + "2025-06-25T00:00:00 | \n", + "The question resolves to 'Yes' if Eddie Pepper... | \n", + "2025-06-25T00:00:00 | \n", + "0 | \n", + "binary | \n", + "0.9 | \n", + "... | \n", + "The close date is 2025-11-30, and the question... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: Eddie Pepperell feeling refreshed after... | \n", + "https://www.europeantour.com/dpworld-tour/news... | \n", + "2025-06-25T00:00:00 | \n", + "European Tour golf | \n", + "[{'rendered_context': '', 'search_query': 'Edd... | \n", + "a43d140a-a541-4537-8aab-6523ecbf79ce | \n", + "3fa59bd7-f4b8-4ea9-b5b2-d8cba7d5ce0d | \n", + "512717.986 | \n", + "
5 rows × 21 columns
\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Training COMPLETED │\n", + "│ │\n", + "│ Job: Golf forecasting │\n", + "│ │\n", + "│ Reward: latest -0.9948 avg -0.8261 (11 steps) (higher is better) │\n", + "│ │\n", + "│ Cost: $0.19 │\n", + "│ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m Golf forecasting \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -0.9948 avg -0.8261 (11 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.19 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job 6c82c197-0627-4ee8-954c-d5ddb93e66f2 completed with status: COMPLETED\n", + "Trained model ID: checkpoint:6c82c197-0627-4ee8-954c-d5ddb93e66f2\n" + ] + } + ], + "source": [ + "job = lr.training.run(config, dataset=train_dataset, name=\"Golf forecasting\")\n", + "print(f\"Job {job.id} completed with status: {job.status}\")\n", + "print(f\"Trained model ID: {job.model_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4c74226d", + "metadata": {}, + "source": [ + "## Inference with your trained model\n", + "\n", + "Use `lr.predict()` to run inference with your trained model." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/datasets/bart/golf-forecasting-demo/commit/75aa82cb9a4bc3ff669cf29a4d2e0f260f6aac58', commit_message='Upload dataset', commit_description='', oid='75aa82cb9a4bc3ff669cf29a4d2e0f260f6aac58', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/bart/golf-forecasting-demo', endpoint='https://huggingface.co', repo_type='dataset', repo_id='bart/golf-forecasting-demo'), pr_revision=None, pr_num=None)" + "cell_type": "code", + "execution_count": null, + "id": "ba1dcfc5", + "metadata": {}, + "outputs": [], + "source": [ + "print(lr.predict(job.model_id, \"Will Scottie Scheffler win the 2026 Masters?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "76781ccb", + "metadata": {}, + "source": [ + "## Run evals on trained model\n", + "\n", + "Run test evals on your trained model against the test dataset. The eval job runs the model on the dataset and reports metrics." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "0e81dc80", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Eval COMPLETED │\n", + "│ │\n", + "│ ID: cd970c00-d5b9-4db3-ac1f-f4815960abb0 │\n", + "│ Model: checkpoint:6c82c197-0627-4ee8-954c-d5ddb93e66f2 │\n", + "│ Dataset: 708f1623-6f06-4897-bb2e-dd58b7aebd45 │\n", + "│ │\n", + "│ ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓ │\n", + "│ ┃ Metric ┃ base ┃ trained ┃ │\n", + "│ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩ │\n", + "│ │ brier_score │ 0.2784 │ 0.2377 │ │\n", + "│ │ ece │ 0.2207 │ 0.1597 │ │\n", + "│ │ mean_reward │ -0.9210 │ -0.8026 │ │\n", + "│ │ mean_valid_reward │ -0.9210 │ -0.8026 │ │\n", + "│ │ n_samples │ 143 │ 143 │ │\n", + "│ │ n_valid │ 143 │ 143 │ │\n", + "│ │ parse_rate │ 1.0000 │ 1.0000 │ │\n", + "│ │ total_cost │ 0.0084 │ 0.0084 │ │\n", + "│ │ total_input_tokens │ 115968 │ 115968 │ │\n", + "│ │ total_output_tokens │ 1403 │ 1416 │ │\n", + "│ └─────────────────────┴─────────┴─────────┘ │\n", + "│ │\n", + "│ Cost: $0.02 │\n", + "│ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Eval COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mID:\u001b[0m cd970c00-d5b9-4db3-ac1f-f4815960abb0 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mModel:\u001b[0m checkpoint:6c82c197-0627-4ee8-954c-d5ddb93e66f2 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mDataset:\u001b[0m 708f1623-6f06-4897-bb2e-dd58b7aebd45 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mMetric \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m base\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mtrained\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mbrier_score \u001b[0m\u001b[2m \u001b[0m│ 0.2784 │ 0.2377 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mece \u001b[0m\u001b[2m \u001b[0m│ 0.2207 │ 0.1597 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_reward \u001b[0m\u001b[2m \u001b[0m│ -0.9210 │ -0.8026 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_valid_reward \u001b[0m\u001b[2m \u001b[0m│ -0.9210 │ -0.8026 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_samples \u001b[0m\u001b[2m \u001b[0m│ 143 │ 143 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_valid \u001b[0m\u001b[2m \u001b[0m│ 143 │ 143 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mparse_rate \u001b[0m\u001b[2m \u001b[0m│ 1.0000 │ 1.0000 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_cost \u001b[0m\u001b[2m \u001b[0m│ 0.0084 │ 0.0084 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_input_tokens \u001b[0m\u001b[2m \u001b[0m│ 115968 │ 115968 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_output_tokens\u001b[0m\u001b[2m \u001b[0m│ 1403 │ 1416 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m └─────────────────────┴─────────┴─────────┘ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.02 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "62718605", + "metadata": {}, + "source": [ + "> Note: the trained model checkpoint will only be available for 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" } - ], - "source": [ - "%pip install datasets -q\n", - "\n", - "from datasets import Dataset, DatasetDict\n", - "from lightningrod.utils import config\n", - "\n", - "dataset = DatasetDict({\n", - " \"train\": Dataset.from_list(train_dataset.flattened()),\n", - " \"test\": Dataset.from_list(test_dataset.flattened()),\n", - "})\n", - "print(f\"Train: {len(dataset['train'])} rows, Test: {len(dataset['test'])} rows\")\n", - "print(\"Columns:\", dataset[\"train\"].column_names[:8], \"...\")\n", - "\n", - "DATASET_PATH = f\"{config.get_config_value('HF_USERNAME')}/golf-forecasting-demo\"\n", - "dataset.push_to_hub(DATASET_PATH, token=config.get_config_value(\"HF_ACCESS_TOKEN\"))" - ] - }, - { - "cell_type": "markdown", - "id": "part2", - "metadata": {}, - "source": [ - "## Model Training\n", - "\n", - "We used the generated dataset above to fine-tune a forecasting model via RL on 3,178 forecasting questions, surpassing GPT-5 performance.\n", - "\n", - "**For more details on methods, results, and data:**\n", - "- **[Golf-Forecaster Model](https://huggingface.co/LightningRodLabs/Golf-Forecaster)**\n", - "- **[Golf-Forecaster Dataset](https://huggingface.co/datasets/LightningRodLabs/GolfForecasting)**\n", - "\n", - "\n", - "\n", - "**Coming Soon:** Seamlessly generate datasets, fine-tune, and evaluate your own forecasting models end-to-end on the Lightningrod platform.\n", - " \n", - "\ud83d\udc49 [Sign up to get early access and updates.](https://lightningrod.ai/)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python (lightningrod-sdk)", - "language": "python", - "name": "lightningrod-sdk" + ], + "metadata": { + "kernelspec": { + "display_name": "Python (lightningrod-sdk)", + "language": "python", + "name": "lightningrod-sdk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/fine_tuning/02_trump_forecasting.ipynb b/notebooks/fine_tuning/02_trump_forecasting.ipynb index 4cb0092..31e5712 100644 --- a/notebooks/fine_tuning/02_trump_forecasting.ipynb +++ b/notebooks/fine_tuning/02_trump_forecasting.ipynb @@ -1,479 +1,1168 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "7d4d26cb", - "metadata": {}, - "source": [ - "# WWTD-2025 (What Would Trump Do?)\n", - "\n", - "Generate a forecasting dataset about Trump's actions, decisions, and statements using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments\u2014including evaluation with and without context." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "6f2c4443", - "metadata": {}, - "outputs": [ + "cells": [ { - "data": { - "text/plain": [ - "True" + "cell_type": "markdown", + "id": "7d4d26cb", + "metadata": {}, + "source": [ + "# WWTD-2025 (What Would Trump Do?)\n", + "\n", + "Generate a forecasting dataset about Trump's actions, decisions, and statements using the LightningRod SDK. This example showcases dataset generation, preparation with SDK utils, and training results from our experiments—including evaluation with and without context." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "%pip install lightningrod-ai python-dotenv pandas\n", - "\n", - "from IPython.display import clear_output\n", - "clear_output()\n", - "\n", - "from datetime import datetime\n", - "\n", - "import pandas as pd\n", - "from dotenv import load_dotenv\n", - "\n", - "load_dotenv()" - ] - }, - { - "cell_type": "markdown", - "id": "f523d274", - "metadata": {}, - "source": [ - "## Set up the client\n", - "\n", - "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "7ca71c31", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import LightningRod\n", - "from lightningrod.utils import config\n", - "\n", - "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", - "lr = LightningRod(api_key=api_key)" - ] - }, - { - "cell_type": "markdown", - "id": "082d4f24", - "metadata": {}, - "source": [ - "## Build the pipeline\n", - "\n", - "Configure the pipeline with domain-specific instructions and examples for Trump-related forecasting." - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "faccbe0a", - "metadata": {}, - "outputs": [], - "source": [ - "instructions = \"\"\"\n", - "Generate binary forecasting questions about Trump's actions, decisions, positions, and statements.\n", - "Questions should be diverse, related to the content, and should evenly cover the full range from very likely to very unlikely.\n", - "Horizon: outcomes should be known within 2 months of the question date, and may be known much sooner.\n", - "Criteria: binary outcome, exact dates, self-contained, verifiable via web search, newsworthy.\n", - "\"\"\"\n", - "\n", - "good_examples = [\n", - " \"Will Trump impose 25% tariffs on all goods from Canada by February 1, 2025?\",\n", - " \"Will Trump issue pardons to January 6 defendants within his first week in office?\",\n", - " \"Will Pete Hegseth be confirmed as Secretary of Defense by February 15, 2025?\",\n", - " \"Will Trump sign an executive order to keep TikTok operational in the US by January 31, 2025?\",\n", - " \"Will Kash Patel be confirmed as FBI Director by March 1, 2025?\",\n", - "]\n", - "\n", - "bad_examples = [\n", - " \"Will Trump do something controversial? (too vague)\",\n", - " \"Will Trump be in the news? (obvious)\",\n", - " \"Will tariffs be imposed? (needs specifics)\",\n", - "]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "4ce2d710", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import (\n", - " BinaryAnswerType,\n", - " NewsSeedGenerator,\n", - " ForwardLookingQuestionGenerator,\n", - " NewsContextGenerator,\n", - " WebSearchLabeler,\n", - " QuestionPipeline,\n", - ")\n", - "\n", - "answer_type = BinaryAnswerType()\n", - "\n", - "pipeline = QuestionPipeline(\n", - " seed_generator=NewsSeedGenerator(\n", - " start_date=datetime(2025, 1, 1),\n", - " end_date=datetime(2026, 1, 1),\n", - " interval_duration_days=7,\n", - " search_query=[\n", - " \"Donald Trump domestic policy agenda\",\n", - " \"Donald Trump trade and tariff actions\",\n", - " \"Donald Trump foreign policy decisions\",\n", - " \"Donald Trump interviews and press appearances\",\n", - " \"Donald Trump lawsuits and court rulings\",\n", - " ],\n", - " articles_per_search=10,\n", - " ),\n", - " question_generator=ForwardLookingQuestionGenerator(\n", - " instructions=instructions,\n", - " examples=good_examples,\n", - " bad_examples=bad_examples,\n", - " answer_type=answer_type,\n", - " questions_per_seed=20,\n", - " ),\n", - " context_generators=[\n", - " NewsContextGenerator(\n", - " articles_per_query=3,\n", - " num_search_queries=1,\n", - " num_articles=5,\n", - " )\n", - " ],\n", - " labeler=WebSearchLabeler(answer_type=answer_type),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "1603b3de", - "metadata": {}, - "source": [ - "## Run the pipeline\n", - "\n", - "This will collect news articles, generate questions, and find answers. Use `max_questions` to limit the run for testing." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "4de4b87c", - "metadata": {}, - "outputs": [ + }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "80f2463c20044d2c9ab2ace831e2adff", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Output()" + "cell_type": "code", + "execution_count": 1, + "id": "6f2c4443", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "%pip install lightningrod-ai python-dotenv pandas openai\n", + "\n", + "from IPython.display import clear_output\n", + "clear_output()\n", + "\n", + "from datetime import datetime\n", + "\n", + "import pandas as pd\n", + "from dotenv import load_dotenv\n", + "\n", + "load_dotenv()" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/html": [ - "\n" - ], - "text/plain": [] - }, - "metadata": {}, - "output_type": "display_data" + "cell_type": "markdown", + "id": "f523d274", + "metadata": {}, + "source": [ + "## Set up the client\n", + "\n", + "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/?redirect=/api) to get your API key and **$50 of free credits**." + ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "178 samples (46.1% valid)\n" - ] - } - ], - "source": [ - "dataset = lr.transforms.run(pipeline, max_questions=500, name=\"WWTD-2025\")\n", - "\n", - "samples = dataset.download()\n", - "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n", - "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")" - ] - }, - { - "cell_type": "markdown", - "id": "91866bf7", - "metadata": {}, - "source": [ - "## Prepare the dataset\n", - "\n", - "Use SDK utils to filter valid samples, deduplicate, and split into train/test sets. We filter by `date_close <= today` to only include questions that have already resolved." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5e4d3f2a", - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 2, + "id": "7ca71c31", + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import LightningRod\n", + "from lightningrod.utils import config\n", + "\n", + "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", + "lr = LightningRod(api_key=api_key)" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: 39 rows, 12.8% yes\n", - "Test: 13 rows, 38.5% yes\n" - ] - } - ], - "source": [ - "from lightningrod import filter_and_split\n", - "\n", - "train_dataset, test_dataset = filter_and_split(\n", - " dataset,\n", - " test_size=0.2,\n", - " split_strategy=\"temporal\",\n", - " days_to_resolution_range=(1, 60), # horizon within 2 months\n", - ")\n", - "\n", - "for name, ds in [(\"Train\", train_dataset), (\"Test\", test_dataset)]:\n", - " data = ds.flattened()\n", - " yes_count = sum(1 for s in data if s.get(\"label\") in (1, \"1\", 1.0))\n", - " print(f\"{name}: {len(data)} rows, {yes_count/len(data)*100:.1f}% yes\")\n", - " display(pd.DataFrame(data).head())" - ] - }, - { - "cell_type": "markdown", - "id": "0e799cfe", - "metadata": {}, - "source": [ - "## Uploading the dataset to HuggingFace\n", - "\n", - "Once we have a training-ready dataset, we can push it to Hugging Face for sharing or downstream use." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "91093200", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "id": "082d4f24", + "metadata": {}, + "source": [ + "## Build the pipeline\n", + "\n", + "Configure the pipeline with domain-specific instructions and examples for Trump-related forecasting." + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip available: \u001b[0m\u001b[31;49m22.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m26.0.1\u001b[0m\n", - "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", - "Note: you may need to restart the kernel to use updated packages.\n", - "Train: 58 rows, Test: 17 rows\n", - "Columns: ['question_text', 'date_close', 'event_date', 'resolution_criteria', 'prediction_date', 'label', 'answer_type', 'label_confidence'] ...\n" - ] + "cell_type": "code", + "execution_count": 3, + "id": "faccbe0a", + "metadata": {}, + "outputs": [], + "source": [ + "instructions = \"\"\"\n", + "Generate binary forecasting questions about Trump's actions, decisions, positions, and statements.\n", + "Questions should be diverse, related to the content, and should evenly cover the full range from very likely to very unlikely.\n", + "Horizon: outcomes should be known within 2 months of the question date, and may be known much sooner.\n", + "Criteria: binary outcome, exact dates, self-contained, verifiable via web search, newsworthy.\n", + "\"\"\"\n", + "\n", + "good_examples = [\n", + " \"Will Trump impose 25% tariffs on all goods from Canada by February 1, 2025?\",\n", + " \"Will Trump issue pardons to January 6 defendants within his first week in office?\",\n", + " \"Will Pete Hegseth be confirmed as Secretary of Defense by February 15, 2025?\",\n", + " \"Will Trump sign an executive order to keep TikTok operational in the US by January 31, 2025?\",\n", + " \"Will Kash Patel be confirmed as FBI Director by March 1, 2025?\",\n", + "]\n", + "\n", + "bad_examples = [\n", + " \"Will Trump do something controversial? (too vague)\",\n", + " \"Will Trump be in the news? (obvious)\",\n", + " \"Will tariffs be imposed? (needs specifics)\",\n", + "]" + ] }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "0b949581ce354458a957d2f6047eb184", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Uploading the dataset shards: 0%| | 0/1 [00:00, ? shards/s]" + "cell_type": "code", + "execution_count": 4, + "id": "4ce2d710", + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import (\n", + " BinaryAnswerType,\n", + " NewsSeedGenerator,\n", + " ForwardLookingQuestionGenerator,\n", + " NewsContextGenerator,\n", + " WebSearchLabeler,\n", + " QuestionPipeline,\n", + ")\n", + "\n", + "answer_type = BinaryAnswerType()\n", + "\n", + "pipeline = QuestionPipeline(\n", + " seed_generator=NewsSeedGenerator(\n", + " start_date=datetime(2025, 1, 1),\n", + " end_date=datetime(2026, 1, 1),\n", + " interval_duration_days=7,\n", + " search_query=[\n", + " \"Donald Trump domestic policy agenda\",\n", + " \"Donald Trump trade and tariff actions\",\n", + " \"Donald Trump foreign policy decisions\",\n", + " \"Donald Trump interviews and press appearances\",\n", + " \"Donald Trump lawsuits and court rulings\",\n", + " ],\n", + " articles_per_search=10,\n", + " ),\n", + " question_generator=ForwardLookingQuestionGenerator(\n", + " instructions=instructions,\n", + " examples=good_examples,\n", + " bad_examples=bad_examples,\n", + " answer_type=answer_type,\n", + " questions_per_seed=5,\n", + " ),\n", + " context_generators=[\n", + " NewsContextGenerator(\n", + " articles_per_query=3,\n", + " num_search_queries=1,\n", + " num_articles=5,\n", + " )\n", + " ],\n", + " labeler=WebSearchLabeler(answer_type=answer_type),\n", + ")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "77f47298cc57447fbfa3df175aa04bd4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Creating parquet from Arrow format: 0%| | 0/1 [00:00, ?ba/s]" + "cell_type": "markdown", + "id": "1603b3de", + "metadata": {}, + "source": [ + "## Run the pipeline\n", + "\n", + "This will collect news articles, generate questions, and find answers. Use `max_questions` to limit the run for testing." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cab85d43c65442b5a21d56bd6bfb9895", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Processing Files (0 / 0): | | 0.00B / 0.00B " + "cell_type": "code", + "execution_count": 5, + "id": "4de4b87c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Pipeline Completed │\n", + "│ │\n", + "│ Total cost: $1.90 │\n", + "│ │\n", + "│ ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ │\n", + "│ ┃ Step ┃ Progress ┃ In ┃ Out ┃ Rejected ┃ Errors ┃ Rejection Reasons ┃ Duration ┃ │\n", + "│ ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ │\n", + "│ │ NewsSeedGenerator… │ Complete │ 10 │ 95 │ 0 │ 0 │ - │ 1s │ │\n", + "│ │ ForwardLookingQue… │ Complete │ 95 │ 442 │ 31 │ 0 │ date_close not │ 1s │ │\n", + "│ │ │ │ │ │ │ │ after event_date │ │ │\n", + "│ │ │ │ │ │ │ │ (31) │ │ │\n", + "│ │ WebSearchLabelerT… │ Complete │ 442 │ 380 │ 62 │ 0 │ Resolution date is │ 4s │ │\n", + "│ │ │ │ │ │ │ │ before seed │ │ │\n", + "│ │ │ │ │ │ │ │ creation date │ │ │\n", + "│ │ │ │ │ │ │ │ (36), Undetermined │ │ │\n", + "│ │ │ │ │ │ │ │ label (25), Low │ │ │\n", + "│ │ │ │ │ │ │ │ confidence: 0.80 < │ │ │\n", + "│ │ │ │ │ │ │ │ 0.9 (1) │ │ │\n", + "│ │ NewsContextGenera… │ Complete │ 380 │ 380 │ 0 │ 0 │ - │ 57s │ │\n", + "│ └────────────────────┴──────────────────────┴─────┴─────┴──────────┴────────┴────────────────────┴──────────┘ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[92m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[1;92m>> Pipeline Completed\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[1mTotal cost:\u001b[0m \u001b[92m$1.90\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━┳━━━━━┳━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mStep \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mProgress \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m In\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mOut\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejected\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mErrors\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mRejection Reasons \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mDuration\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━╇━━━━━╇━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mNewsSeedGenerator…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 10 │ 95 │ \u001b[2m 0\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2m- \u001b[0m │ 1s │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mForwardLookingQue…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 95 │ 442 │ \u001b[91m 31\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2mdate_close not \u001b[0m │ 1s │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mafter event_date \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2m(31) \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mWebSearchLabelerT…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 442 │ 380 │ \u001b[91m 62\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2mResolution date is\u001b[0m │ 4s │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mbefore seed \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mcreation date \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2m(36), Undetermined\u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mlabel (25), Low \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2mconfidence: 0.80 <\u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m│ │ │ │ │ │ \u001b[2m0.9 (1) \u001b[0m │ │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m │\u001b[1m \u001b[0m\u001b[1mNewsContextGenera…\u001b[0m\u001b[1m \u001b[0m│ \u001b[1;92mComplete \u001b[0m │ 380 │ 380 │ \u001b[2m 0\u001b[0m │ \u001b[2m 0\u001b[0m │ \u001b[2m- \u001b[0m │ 57s │ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m └────────────────────┴──────────────────────┴─────┴─────┴──────────┴────────┴────────────────────┴──────────┘ \u001b[92m│\u001b[0m\n", + "\u001b[92m│\u001b[0m \u001b[92m│\u001b[0m\n", + "\u001b[92m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "473 samples (80.3% valid)\n" + ] + } + ], + "source": [ + "dataset = lr.transforms.run(pipeline, max_questions=500, name=\"WWTD-2025\")\n", + "\n", + "samples = dataset.download()\n", + "pct = (sum(1 for s in samples if s.is_valid is True) / len(samples) * 100) if samples else 0\n", + "print(f\"{len(samples)} samples ({pct:.1f}% valid)\")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "54bf4eaee9e34b4db6e2b55054f4fc8f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "New Data Upload: | | 0.00B / 0.00B " + "cell_type": "markdown", + "id": "91866bf7", + "metadata": {}, + "source": [ + "## Prepare the dataset\n", + "\n", + "Use SDK utils to filter valid samples, deduplicate, and split into train/test sets. We filter by `date_close <= today` to only include questions that have already resolved." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "6e52a8a016c54295a661a53facb6eb69", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Uploading the dataset shards: 0%| | 0/1 [00:00, ? shards/s]" + "cell_type": "code", + "execution_count": 6, + "id": "5e4d3f2a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> prepare_for_training │\n", + "│ │\n", + "│ Starting with 473 samples │\n", + "│ │\n", + "│ Filter: Dropped 93 invalid, 103 horizon → 277 remain │\n", + "│ Dedup: 277 remain (0 duplicates) │\n", + "│ Split: Splits: 167 train | 56 test (0 dropped, no prediction_date) │\n", + "│ 54 train samples removed for leakage │\n", + "│ │\n", + "│ ⚠ Unhealthy dataset │\n", + "│ │\n", + "│ Only 167 train samples remain after preparation. This is below the recommended minimum of 200 for effective │\n", + "│ training. │\n", + "│ │\n", + "│ Tips: │\n", + "│ • Increase max_questions in lr.transforms.run() to generate more samples. │\n", + "│ • Increase questions_per_seed in your question generator (ForwardLookingQuestionGenerator or │\n", + "│ QuestionGenerator) to produce more questions from each seed article.Add more search queries to your seed │\n", + "│ generator to diversify seed sources. │\n", + "│ • Widen the seed generator date range (start_date to end_date) to capture more events. │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[33m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[1;33m>> prepare_for_training\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[2mStarting with 473 samples\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[1mFilter:\u001b[0m Dropped 93 invalid, 103 horizon → 277 remain \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[1mDedup:\u001b[0m 277 remain (0 duplicates) \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[1mSplit:\u001b[0m Splits: 167 train | 56 test (0 dropped, no prediction_date) \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m54 train samples removed for leakage\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[1;33m⚠ Unhealthy dataset\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[1mOnly 167 train samples remain after preparation. This is below the recommended minimum of 200 for effective \u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[1mtraining.\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[2mTips:\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m • Increase max_questions in lr.transforms.run() to generate more samples. \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m • Increase questions_per_seed in your question generator (ForwardLookingQuestionGenerator or \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m QuestionGenerator) to produce more questions from each seed article.Add more search queries to your seed \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m generator to diversify seed sources. \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m • Widen the seed generator date range (start_date to end_date) to capture more events. \u001b[33m│\u001b[0m\n", + "\u001b[33m│\u001b[0m \u001b[33m│\u001b[0m\n", + "\u001b[33m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "167\n", + "Train: 167 rows, 21.0% yes\n" + ] + }, + { + "data": { + "text/html": [ + "
| \n", + " | sample_id | \n", + "is_valid | \n", + "question_text | \n", + "date_close | \n", + "event_date | \n", + "resolution_criteria | \n", + "prediction_date | \n", + "label | \n", + "answer_type | \n", + "label_confidence | \n", + "... | \n", + "reasoning | \n", + "answer_sources | \n", + "seed_text | \n", + "seed_url | \n", + "seed_creation_date | \n", + "seed_search_query | \n", + "context | \n", + "meta_sample_id | \n", + "meta_parent_sample_id | \n", + "meta_processing_time_ms | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "07511299-78d6-4020-8efe-d7b5b865f826 | \n", + "True | \n", + "Will the 11th Circuit Court of Appeals issue a... | \n", + "2025-02-15T00:00:00 | \n", + "2025-01-08T00:00:00 | \n", + "This question resolves to 'Yes' if the U.S. Co... | \n", + "2025-01-08T00:00:00 | \n", + "1 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "On January 9, 2025, the U.S. Court of Appeals ... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: The Situation: Ending the Trump Cases t... | \n", + "https://www.lawfaremedia.org/article/the-situa... | \n", + "2025-01-08T00:00:00 | \n", + "Donald Trump lawsuits and court rulings | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Judge ... | \n", + "59824a03-d075-4de4-bd89-17ec5f0651f1 | \n", + "da9d2197-889d-4fe4-ae4e-960ab3f9726f | \n", + "16820.019 | \n", + "
| 1 | \n", + "3618250c-9aa7-4e3e-b1f4-d59351e25415 | \n", + "True | \n", + "Will the criminal charges against Carlos De Ol... | \n", + "2025-03-05T00:00:00 | \n", + "2025-01-08T00:00:00 | \n", + "This question resolves to 'Yes' if a federal c... | \n", + "2025-01-08T00:00:00 | \n", + "1 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The criminal charges against Carlos De Oliveir... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: The Situation: Ending the Trump Cases t... | \n", + "https://www.lawfaremedia.org/article/the-situa... | \n", + "2025-01-08T00:00:00 | \n", + "Donald Trump lawsuits and court rulings | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Trump ... | \n", + "89a53581-5832-4797-9a76-4a85aefb8993 | \n", + "da9d2197-889d-4fe4-ae4e-960ab3f9726f | \n", + "170592.880 | \n", + "
| 2 | \n", + "6f5808f8-5fbe-40e3-a902-2959e0159960 | \n", + "True | \n", + "Will Justice Juan Merchan sentence Donald Trum... | \n", + "2025-03-01T00:00:00 | \n", + "2025-01-08T00:00:00 | \n", + "This question resolves to 'Yes' if Justice Jua... | \n", + "2025-01-08T00:00:00 | \n", + "0 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The close date for this question is 2025-03-01... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: The Situation: Ending the Trump Cases t... | \n", + "https://www.lawfaremedia.org/article/the-situa... | \n", + "2025-01-08T00:00:00 | \n", + "Donald Trump lawsuits and court rulings | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Judge ... | \n", + "da9d2197-889d-4fe4-ae4e-960ab3f9726f | \n", + "b988692d-28ac-4e59-932f-089b30c1fdff | \n", + "19099.822 | \n", + "
| 3 | \n", + "811942a3-0ce0-4a23-8ccf-cb9b6b038a78 | \n", + "True | \n", + "Will the full, unredacted Special Counsel repo... | \n", + "2025-03-01T00:00:00 | \n", + "2025-01-08T00:00:00 | \n", + "This question resolves to 'Yes' if the Departm... | \n", + "2025-01-08T00:00:00 | \n", + "0 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "Special Counsel Jack Smith submitted a two-vol... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: The Situation: Ending the Trump Cases t... | \n", + "https://www.lawfaremedia.org/article/the-situa... | \n", + "2025-01-08T00:00:00 | \n", + "Donald Trump lawsuits and court rulings | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Trump ... | \n", + "d53c18aa-e58f-47cc-ad8d-07284f3837b5 | \n", + "da9d2197-889d-4fe4-ae4e-960ab3f9726f | \n", + "210318.406 | \n", + "
| 4 | \n", + "abbac7a6-09fd-4d34-b588-0a3df04f2f37 | \n", + "True | \n", + "Will Donald Trump grant a formal presidential ... | \n", + "2025-02-28T00:00:00 | \n", + "2025-01-08T00:00:00 | \n", + "This question resolves to 'Yes' if the White H... | \n", + "2025-01-08T00:00:00 | \n", + "0 | \n", + "binary | \n", + "0.95 | \n", + "... | \n", + "The close date for this question is 2025-02-28... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: The Situation: Ending the Trump Cases t... | \n", + "https://www.lawfaremedia.org/article/the-situa... | \n", + "2025-01-08T00:00:00 | \n", + "Donald Trump lawsuits and court rulings | \n", + "[{'rendered_context': '', 'search_query': 'Tru... | \n", + "c1f9b0fa-f364-4cda-a9ca-a41f1dcdcf3c | \n", + "da9d2197-889d-4fe4-ae4e-960ab3f9726f | \n", + "16623.478 | \n", + "
5 rows × 21 columns
\n", + "| \n", + " | sample_id | \n", + "is_valid | \n", + "question_text | \n", + "date_close | \n", + "event_date | \n", + "resolution_criteria | \n", + "prediction_date | \n", + "label | \n", + "answer_type | \n", + "label_confidence | \n", + "... | \n", + "reasoning | \n", + "answer_sources | \n", + "seed_text | \n", + "seed_url | \n", + "seed_creation_date | \n", + "seed_search_query | \n", + "context | \n", + "meta_sample_id | \n", + "meta_parent_sample_id | \n", + "meta_processing_time_ms | \n", + "
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", + "7960bd90-44c8-4cd0-bac0-3a17265101e5 | \n", + "True | \n", + "Will Donald Trump announce a complete exemptio... | \n", + "2025-12-01T00:00:00 | \n", + "2025-10-08T00:00:00 | \n", + "The question is answered 'Yes' if the US Presi... | \n", + "2025-10-08T00:00:00 | \n", + "0 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The close date for this question is 2025-12-01... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: 'We will get an even better deal,' Carn... | \n", + "https://www.cbc.ca/news/politics/carney-even-b... | \n", + "2025-10-08T00:00:00 | \n", + "Donald Trump trade and tariff actions | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] U.S.-C... | \n", + "ecd757d7-fa3c-47ec-bdc6-81cdf88c1b65 | \n", + "a3db9248-0a4f-4d51-81d1-7c21057856cf | \n", + "12056.629 | \n", + "
| 1 | \n", + "b562085f-b58d-4284-ae0b-fcfdeb90f90c | \n", + "True | \n", + "Will the United States and Canada sign a forma... | \n", + "2025-12-01T00:00:00 | \n", + "2025-10-08T00:00:00 | \n", + "A formal bilateral agreement or signed Memoran... | \n", + "2025-10-08T00:00:00 | \n", + "0 | \n", + "binary | \n", + "0.95 | \n", + "... | \n", + "Based on the provided reports regarding the 20... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: 'We will get an even better deal,' Carn... | \n", + "https://www.cbc.ca/news/politics/carney-even-b... | \n", + "2025-10-08T00:00:00 | \n", + "Donald Trump trade and tariff actions | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Carney... | \n", + "a3db9248-0a4f-4d51-81d1-7c21057856cf | \n", + "993816e2-887d-4db0-99c7-f45e93776a8a | \n", + "255122.468 | \n", + "
| 2 | \n", + "63dc110c-6073-4f98-abcb-57b8933b8fbf | \n", + "True | \n", + "Will the Trump administration hold a new offsh... | \n", + "2026-04-01T00:00:00 | \n", + "2025-11-26T00:00:00 | \n", + "The question resolves to 'Yes' if the Departme... | \n", + "2025-11-26T00:00:00 | \n", + "1 | \n", + "binary | \n", + "1.00 | \n", + "... | \n", + "The Trump administration held the first new of... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: nytimes.com\\n\\nURL Source: https://www.... | \n", + "https://www.nytimes.com/2025/11/26/climate/tru... | \n", + "2025-11-26T00:00:00 | \n", + "Donald Trump domestic policy agenda | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Interi... | \n", + "87e57caa-9696-4eb8-a7ed-cd4060b0649e | \n", + "b4bfbb84-63db-48d6-b63d-c809b20ed5f0 | \n", + "6996.226 | \n", + "
| 3 | \n", + "2581c035-5519-434f-aed1-c51326b11e73 | \n", + "True | \n", + "Will Donald Trump and Javier Milei hold a join... | \n", + "2026-01-01T00:00:00 | \n", + "2025-11-27T00:00:00 | \n", + "The question resolves to 'Yes' if Donald Trump... | \n", + "2025-11-27T00:00:00 | \n", + "0 | \n", + "binary | \n", + "0.95 | \n", + "... | \n", + "The close date for this question is 2026-01-01... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: The Paradox of Europe's Trumpian Right:... | \n", + "https://www.foreignaffairs.com/europe/paradox-... | \n", + "2025-11-27T00:00:00 | \n", + "Donald Trump domestic policy agenda | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Milei ... | \n", + "3a88f43d-6a2c-422f-b7dd-5b6039a29375 | \n", + "7977998c-8ccb-444f-a382-1fcd8e322c23 | \n", + "11967.560 | \n", + "
| 4 | \n", + "3bd13f7e-25a8-40c0-8b2a-d9f77ef9f827 | \n", + "True | \n", + "Will the United States official executive bran... | \n", + "2026-01-20T00:00:00 | \n", + "2025-11-27T00:00:00 | \n", + "The question resolves to 'Yes' if the U.S. gov... | \n", + "2025-11-27T00:00:00 | \n", + "1 | \n", + "binary | \n", + "0.95 | \n", + "... | \n", + "Between the question date (2025-11-27) and the... | \n", + "https://vertexaisearch.cloud.google.com/ground... | \n", + "Title: The Paradox of Europe's Trumpian Right:... | \n", + "https://www.foreignaffairs.com/europe/paradox-... | \n", + "2025-11-27T00:00:00 | \n", + "Donald Trump domestic policy agenda | \n", + "[{'rendered_context': '---\n", + "ARTICLES\n", + "[1] Presid... | \n", + "d15a72f1-d8ac-4fc9-9bde-732514e9008c | \n", + "7977998c-8ccb-444f-a382-1fcd8e322c23 | \n", + "66899.224 | \n", + "
5 rows × 21 columns
\n", + "╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Training COMPLETED │\n", + "│ │\n", + "│ Job: WWTD-2025 │\n", + "│ │\n", + "│ Reward: latest -1.3368 avg -0.8626 (6 steps) (higher is better) │\n", + "│ │\n", + "│ Cost: $0.11 │\n", + "│ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m WWTD-2025 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -1.3368 avg -0.8626 (6 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.11 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job 371e5ff4-ebf5-43af-8809-d26c14edb8e2 completed with status: COMPLETED\n", + "Trained model ID: checkpoint:371e5ff4-ebf5-43af-8809-d26c14edb8e2\n" + ] + } + ], + "source": [ + "job = lr.training.run(config, dataset=train_dataset, name=\"WWTD-2025\")\n", + "print(f\"Job {job.id} completed with status: {job.status}\")\n", + "print(f\"Trained model ID: {job.model_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "6487c1d9", + "metadata": {}, + "source": [ + "## Inference with your trained model\n", + "\n", + "Use `lr.predict()` to run inference with your trained model." ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "data": { - "text/plain": [ - "CommitInfo(commit_url='https://huggingface.co/datasets/bart/wwtd-forecasting-demo/commit/cd7bfd6d7addc58cf2c3ac8f6677219a1ce91a91', commit_message='Upload dataset', commit_description='', oid='cd7bfd6d7addc58cf2c3ac8f6677219a1ce91a91', pr_url=None, repo_url=RepoUrl('https://huggingface.co/datasets/bart/wwtd-forecasting-demo', endpoint='https://huggingface.co', repo_type='dataset', repo_id='bart/wwtd-forecasting-demo'), pr_revision=None, pr_num=None)" + "cell_type": "code", + "execution_count": null, + "id": "f013b514", + "metadata": {}, + "outputs": [], + "source": [ + "print(lr.predict(job.model_id, \"Will Trump impose 25% tariffs on all goods from Canada by February 1, 2027?\"))" + ] + }, + { + "cell_type": "markdown", + "id": "09e1c0d7", + "metadata": {}, + "source": [ + "## Run evals on trained model\n", + "\n", + "Run test evals on your trained model against the test dataset. The eval job runs the model on the dataset and reports metrics." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "853e7904", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Eval COMPLETED │\n", + "│ │\n", + "│ ID: 00602447-5872-4732-93f9-b0d99459da1a │\n", + "│ Model: checkpoint:13fa02ec-27f4-47a9-84c9-762d91a1904a │\n", + "│ Dataset: 82186c26-a309-43a6-9543-37bdda38d41d │\n", + "│ │\n", + "│ ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓ │\n", + "│ ┃ Metric ┃ base ┃ trained ┃ benchmark ┃ │\n", + "│ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩ │\n", + "│ │ brier_score │ 0.2333 │ 0.1877 │ 0.1555 │ │\n", + "│ │ ece │ 0.1451 │ 0.0963 │ 0.0590 │ │\n", + "│ │ mean_reward │ -0.7840 │ -0.6048 │ -0.4966 │ │\n", + "│ │ mean_valid_reward │ -0.7840 │ -0.6048 │ -0.4966 │ │\n", + "│ │ n_samples │ 113 │ 113 │ 113 │ │\n", + "│ │ n_valid │ 113 │ 113 │ 113 │ │\n", + "│ │ parse_rate │ 1.0000 │ 1.0000 │ 1.0000 │ │\n", + "│ │ total_cost │ 0.0068 │ 0.0068 │ — │ │\n", + "│ │ total_input_tokens │ 93344 │ 93344 │ 88060 │ │\n", + "│ │ total_output_tokens │ 1111 │ 1101 │ 28947 │ │\n", + "│ └─────────────────────┴─────────┴─────────┴───────────┘ │\n", + "│ │\n", + "│ Cost: $0.01 │\n", + "│ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Eval COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mID:\u001b[0m 00602447-5872-4732-93f9-b0d99459da1a \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mModel:\u001b[0m checkpoint:13fa02ec-27f4-47a9-84c9-762d91a1904a \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mDataset:\u001b[0m 82186c26-a309-43a6-9543-37bdda38d41d \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━┓ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mMetric \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m base\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mtrained\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mbenchmark\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━┩ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mbrier_score \u001b[0m\u001b[2m \u001b[0m│ 0.2333 │ 0.1877 │ 0.1555 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mece \u001b[0m\u001b[2m \u001b[0m│ 0.1451 │ 0.0963 │ 0.0590 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_reward \u001b[0m\u001b[2m \u001b[0m│ -0.7840 │ -0.6048 │ -0.4966 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_valid_reward \u001b[0m\u001b[2m \u001b[0m│ -0.7840 │ -0.6048 │ -0.4966 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_samples \u001b[0m\u001b[2m \u001b[0m│ 113 │ 113 │ 113 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_valid \u001b[0m\u001b[2m \u001b[0m│ 113 │ 113 │ 113 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mparse_rate \u001b[0m\u001b[2m \u001b[0m│ 1.0000 │ 1.0000 │ 1.0000 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_cost \u001b[0m\u001b[2m \u001b[0m│ 0.0068 │ 0.0068 │ — │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_input_tokens \u001b[0m\u001b[2m \u001b[0m│ 93344 │ 93344 │ 88060 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_output_tokens\u001b[0m\u001b[2m \u001b[0m│ 1111 │ 1101 │ 28947 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m └─────────────────────┴─────────┴─────────┴───────────┘ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.01 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset, benchmark_model_id=\"openai/gpt-5.2\")" + ] + }, + { + "cell_type": "markdown", + "id": "96d20f89", + "metadata": {}, + "source": [ + "> Note: the trained model checkpoint will only be available for 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." ] - }, - "execution_count": null, - "metadata": {}, - "output_type": "execute_result" } - ], - "source": [ - "%pip install datasets -q\n", - "\n", - "from datasets import Dataset, DatasetDict\n", - "from lightningrod.utils import config\n", - "\n", - "dataset = DatasetDict({\n", - " \"train\": Dataset.from_list(train_dataset.flattened()),\n", - " \"test\": Dataset.from_list(test_dataset.flattened()),\n", - "})\n", - "print(f\"Train: {len(dataset['train'])} rows, Test: {len(dataset['test'])} rows\")\n", - "print(\"Columns:\", dataset[\"train\"].column_names[:8], \"...\")\n", - "\n", - "DATASET_PATH = f\"{config.get_config_value('HF_USERNAME')}/wwtd-forecasting-demo\"\n", - "dataset.push_to_hub(DATASET_PATH, token=config.get_config_value(\"HF_ACCESS_TOKEN\"))" - ] - }, - { - "cell_type": "markdown", - "id": "49a3c7f8", - "metadata": {}, - "source": [ - "## Model Training\n", - "\n", - "We used the generated dataset above to fine-tune a forecasting model via RL on 2,790 questions, surpassing GPT-5 performance.\n", - "\n", - "**For more details on methods, results, and data:**\n", - "- **[Trump-Forecaster Model](https://huggingface.co/LightningRodLabs/Trump-Forecaster)**\n", - "- **[Trump-Forecaster Dataset](https://huggingface.co/datasets/LightningRodLabs/WWTD-2025)**\n", - "\n", - "\n", - "\n", - "**Coming Soon:** Seamlessly generate datasets, fine-tune, and evaluate your own forecasting models end-to-end on the Lightningrod platform.\n", - " \n", - "\ud83d\udc49 [Sign up to get early access and updates.](https://lightningrod.ai/)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python (lightningrod-sdk)", - "language": "python", - "name": "lightningrod-sdk" + ], + "metadata": { + "kernelspec": { + "display_name": "Python (lightningrod-sdk)", + "language": "python", + "name": "lightningrod-sdk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" + } }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/getting_started/05_fine_tuning.ipynb b/notebooks/getting_started/05_fine_tuning.ipynb index f2b1e86..21edf1c 100644 --- a/notebooks/getting_started/05_fine_tuning.ipynb +++ b/notebooks/getting_started/05_fine_tuning.ipynb @@ -1,290 +1,356 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "4dde071c", - "metadata": {}, - "source": [ - "# Training API\n", - "\n", - "Fine-tune forecasting models on your Lightning Rod datasets. This notebook walks through the full training workflow: generating a dataset, estimating cost, creating a training job, and monitoring progress.\n", - "\n", - "The training API supports LoRA fine-tuning with configurable base models, training steps, batch size, and rank." - ] - }, - { - "cell_type": "markdown", - "id": "fae3f735", - "metadata": {}, - "source": [ - "## Install the SDK" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c1a6a1e2", - "metadata": {}, - "outputs": [], - "source": [ - "%pip install lightningrod-ai python-dotenv openai\n", - "\n", - "from IPython.display import clear_output\n", - "clear_output()" - ] - }, - { - "cell_type": "markdown", - "id": "7490b222", - "metadata": {}, - "source": [ - "## Set up the client\n", - "\n", - "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/sign-up?redirect=/api) to get your API key and **$50 of free credits**.\n", - "\n", - "- **Google Colab**: Go to the Secrets section (key icon in left sidebar) and add a secret named `LIGHTNINGROD_API_KEY`\n", - "- **Local Jupyter**: Set the `LIGHTNINGROD_API_KEY` environment variable, or you'll be prompted to enter it" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "7a023c6c", - "metadata": {}, - "outputs": [], - "source": [ - "from dotenv import load_dotenv\n", - "from lightningrod import LightningRod\n", - "from lightningrod.utils import config\n", - "\n", - "load_dotenv()\n", - "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", - "\n", - "lr = LightningRod(api_key=api_key)" - ] - }, - { - "cell_type": "markdown", - "id": "9c49c320", - "metadata": {}, - "source": [ - "## Prepare the dataset\n", - "\n", - "Training requires a dataset ID from a pipeline run. Run one of the other notebooks first to generate a dataset - each one prints the **Dataset ID** after `transforms.run()` — copy it into the cell below." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "da381a43", - "metadata": {}, - "outputs": [], - "source": [ - "dataset_id = config.get_config_value(\"LIGHTNINGROD_DATASET_ID\")\n", - "\n", - "dataset = lr.datasets.get(dataset_id)\n", - "_ = dataset.download()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c9418ddf", - "metadata": {}, - "outputs": [], - "source": [ - "from lightningrod import filter_and_split\n", - "\n", - "train_dataset, test_dataset = filter_and_split(\n", - " dataset,\n", - " test_size=0.2,\n", - " days_to_resolution_range=(90, None),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "01e98826", - "metadata": {}, - "source": [ - "## Estimate training cost\n", - "\n", - "Before starting a job, use `estimate_cost` to see the expected cost and token usage." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "4283478f", - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Estimated cost: $0.02\n", - "Effective steps: 3\n", - "Train tokens: 63,349\n", - "Notes: Estimate uses per-answer-type output token estimates; actual may vary\n" - ] - } - ], - "source": [ - "from lightningrod import TrainingConfig\n", - "\n", - "config = TrainingConfig(\n", - " base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n", - " training_steps=50,\n", - ")\n", - "cost_estimate = lr.training.estimate_cost(config, dataset=train_dataset)\n", - "print(f\"Estimated cost: ${cost_estimate.total_cost_dollars:.2f}\")\n", - "print(f\"Effective steps: {cost_estimate.effective_steps}\")\n", - "print(f\"Train tokens: {cost_estimate.train_tokens:,}\")\n", - "print(f\"Notes: {cost_estimate.notes}\")" - ] - }, - { - "cell_type": "markdown", - "id": "ef160c9c", - "metadata": {}, - "source": [ - "## Start training\n", - "\n", - "`run` creates a job and polls until completion with a live progress display.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "b7800672", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "id": "4dde071c", + "metadata": {}, + "source": [ + "# Training API\n", + "\n", + "Fine-tune forecasting models on your Lightning Rod datasets. This notebook walks through the full training workflow: generating a dataset, estimating cost, creating a training job, and monitoring progress.\n", + "\n", + "The training API supports LoRA fine-tuning with configurable base models, training steps, batch size, and rank." + ] + }, + { + "cell_type": "markdown", + "id": "fae3f735", + "metadata": {}, + "source": [ + "## Install the SDK" + ] + }, { - "data": { - "text/html": [ - "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", - "│ │\n", - "│ >> Training COMPLETED │\n", - "│ │\n", - "│ Job: Forecasting fine-tune │\n", - "│ │\n", - "│ Reward: latest -0.4030 avg -0.8912 (3 steps) (higher is better) │\n", - "│ │\n", - "│ Cost: $0.01 │\n", - "│ │\n", - "│ │\n", - "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", - "\n" + "cell_type": "code", + "execution_count": 1, + "id": "c1a6a1e2", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install lightningrod-ai python-dotenv openai\n", + "\n", + "from IPython.display import clear_output\n", + "clear_output()" + ] + }, + { + "cell_type": "markdown", + "id": "7490b222", + "metadata": {}, + "source": [ + "## Set up the client\n", + "\n", + "Sign up at [dashboard.lightningrod.ai](https://dashboard.lightningrod.ai/?redirect=/api) to get your API key and **$50 of free credits**.\n", + "\n", + "- **Google Colab**: Go to the Secrets section (key icon in left sidebar) and add a secret named `LIGHTNINGROD_API_KEY`\n", + "- **Local Jupyter**: Set the `LIGHTNINGROD_API_KEY` environment variable, or you'll be prompted to enter it" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7a023c6c", + "metadata": {}, + "outputs": [], + "source": [ + "from dotenv import load_dotenv\n", + "from lightningrod import LightningRod\n", + "from lightningrod.utils import config\n", + "\n", + "load_dotenv()\n", + "api_key = config.get_config_value(\"LIGHTNINGROD_API_KEY\")\n", + "\n", + "lr = LightningRod(api_key=api_key)" + ] + }, + { + "cell_type": "markdown", + "id": "9c49c320", + "metadata": {}, + "source": [ + "## Prepare the dataset\n", + "\n", + "Training requires a dataset ID from a pipeline run. Run one of the other notebooks first to generate a dataset - each one prints the **Dataset ID** after `transforms.run()` — copy it into the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "da381a43", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_id = config.get_config_value(\"LIGHTNINGROD_DATASET_ID\")\n", + "\n", + "dataset = lr.datasets.get(dataset_id)\n", + "_ = dataset.download()\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9418ddf", + "metadata": {}, + "outputs": [], + "source": [ + "from lightningrod import prepare_for_training, FilterParams, SplitParams\n", + "\n", + "train_dataset, test_dataset = prepare_for_training(\n", + " dataset,\n", + " filter=FilterParams(days_to_resolution_range=(90, None)),\n", + " split=SplitParams(test_size=0.2),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "01e98826", + "metadata": {}, + "source": [ + "## Estimate training cost\n", + "\n", + "Before starting a job, use `estimate_cost` to see the expected cost and token usage." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4283478f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Estimated cost: $0.02\n", + "Effective steps: 3\n", + "Train tokens: 63,349\n", + "Notes: Estimate uses per-answer-type output token estimates; actual may vary\n" + ] + } ], - "text/plain": [ - "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m Forecasting fine-tune \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -0.4030 avg -0.8912 (3 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.01 \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", - "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + "source": [ + "from lightningrod import TrainingConfig\n", + "\n", + "config = TrainingConfig(\n", + " base_model=\"Qwen/Qwen3-4B-Instruct-2507\",\n", + " training_steps=50,\n", + ")\n", + "cost_estimate = lr.training.estimate_cost(config, dataset=train_dataset)\n", + "print(f\"Estimated cost: ${cost_estimate.total_cost_dollars:.2f}\")\n", + "print(f\"Effective steps: {cost_estimate.effective_steps}\")\n", + "print(f\"Train tokens: {cost_estimate.train_tokens:,}\")\n", + "print(f\"Notes: {cost_estimate.notes}\")" ] - }, - "metadata": {}, - "output_type": "display_data" }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Job bd114679-610a-4334-8802-13d047a7bc30 completed with status: COMPLETED\n", - "Trained model ID: checkpoint:bd114679-610a-4334-8802-13d047a7bc30\n" - ] - } - ], - "source": [ - "job = lr.training.run(config, dataset=train_dataset, name=\"Forecasting fine-tune\")\n", - "print(f\"Job {job.id} completed with status: {job.status}\")\n", - "print(f\"Trained model ID: {job.model_id}\")" - ] - }, - { - "cell_type": "markdown", - "id": "4fe2792c", - "metadata": {}, - "source": [ - "## Inference with your trained model\n", - "\n", - "Use `lr.predict()` to run inference with your trained model. You can also use the OpenAI-compatible API directly — see [08_foresight_model.ipynb](08_foresight_model.ipynb) for the pre-trained foresight model.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "50744b98", - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "id": "ef160c9c", + "metadata": {}, + "source": [ + "## Start training\n", + "\n", + "`run` creates a job and polls until completion with a live progress display.\n" + ] + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Training COMPLETED │\n", + "│ │\n", + "│ Job: Forecasting fine-tune │\n", + "│ │\n", + "│ Reward: latest -0.3752 avg -0.8808 (3 steps) (higher is better) │\n", + "│ │\n", + "│ Cost: $0.01 │\n", + "│ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Training COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mJob:\u001b[0m Forecasting fine-tune \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mReward:\u001b[0m latest -0.3752 avg -0.8808 (3 steps) \u001b[2m(higher is better)\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.01 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Job 73184285-cd39-4afe-af8e-ceab345d80dc completed with status: COMPLETED\n", + "Trained model ID: checkpoint:73184285-cd39-4afe-af8e-ceab345d80dc\n" + ] + } + ], + "source": [ + "job = lr.training.run(config, dataset=train_dataset, name=\"Forecasting fine-tune\")\n", + "print(f\"Job {job.id} completed with status: {job.status}\")\n", + "print(f\"Trained model ID: {job.model_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4fe2792c", + "metadata": {}, + "source": [ + "## Inference with your trained model\n", + "\n", + "Use `lr.predict()` to run inference with your trained model. You can also use the OpenAI-compatible API directly — see [08_foresight_model.ipynb](08_foresight_model.ipynb) for the pre-trained foresight model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "50744b98", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\n", + "│ │\n", + "│ >> Eval COMPLETED │\n", + "│ │\n", + "│ ID: 4a78c3a1-a1a6-4288-9cd1-974066ddcc66 │\n", + "│ Model: checkpoint:73184285-cd39-4afe-af8e-ceab345d80dc │\n", + "│ Dataset: e87e04c3-4c0d-49ab-97bf-b30d724395d3 │\n", + "│ │\n", + "│ ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓ │\n", + "│ ┃ Metric ┃ base ┃ trained ┃ │\n", + "│ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩ │\n", + "│ │ brier_score │ 0.1925 │ 0.1961 │ │\n", + "│ │ ece │ 0.0671 │ 0.0510 │ │\n", + "│ │ mean_reward │ -0.5715 │ -0.5797 │ │\n", + "│ │ mean_valid_reward │ -0.5715 │ -0.5797 │ │\n", + "│ │ n_samples │ 81 │ 81 │ │\n", + "│ │ n_valid │ 81 │ 81 │ │\n", + "│ │ parse_rate │ 1.0000 │ 1.0000 │ │\n", + "│ │ total_cost │ 0.0016 │ 0.0016 │ │\n", + "│ │ total_input_tokens │ 20154 │ 20154 │ │\n", + "│ │ total_output_tokens │ 787 │ 787 │ │\n", + "│ └─────────────────────┴─────────┴─────────┘ │\n", + "│ │\n", + "│ Cost: $0.00 │\n", + "│ │\n", + "│ │\n", + "╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\n", + "\n" + ], + "text/plain": [ + "\u001b[94m╭─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1;92m>> Eval COMPLETED\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mID:\u001b[0m 4a78c3a1-a1a6-4288-9cd1-974066ddcc66 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mModel:\u001b[0m checkpoint:73184285-cd39-4afe-af8e-ceab345d80dc \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mDataset:\u001b[0m e87e04c3-4c0d-49ab-97bf-b30d724395d3 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┃\u001b[1;36m \u001b[0m\u001b[1;36mMetric \u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36m base\u001b[0m\u001b[1;36m \u001b[0m┃\u001b[1;36m \u001b[0m\u001b[1;36mtrained\u001b[0m\u001b[1;36m \u001b[0m┃ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mbrier_score \u001b[0m\u001b[2m \u001b[0m│ 0.1925 │ 0.1961 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mece \u001b[0m\u001b[2m \u001b[0m│ 0.0671 │ 0.0510 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_reward \u001b[0m\u001b[2m \u001b[0m│ -0.5715 │ -0.5797 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mmean_valid_reward \u001b[0m\u001b[2m \u001b[0m│ -0.5715 │ -0.5797 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_samples \u001b[0m\u001b[2m \u001b[0m│ 81 │ 81 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mn_valid \u001b[0m\u001b[2m \u001b[0m│ 81 │ 81 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mparse_rate \u001b[0m\u001b[2m \u001b[0m│ 1.0000 │ 1.0000 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_cost \u001b[0m\u001b[2m \u001b[0m│ 0.0016 │ 0.0016 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_input_tokens \u001b[0m\u001b[2m \u001b[0m│ 20154 │ 20154 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m │\u001b[2m \u001b[0m\u001b[2mtotal_output_tokens\u001b[0m\u001b[2m \u001b[0m│ 787 │ 787 │ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m └─────────────────────┴─────────┴─────────┘ \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[1mCost:\u001b[0m $0.00 \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m│\u001b[0m \u001b[94m│\u001b[0m\n", + "\u001b[94m╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "72c2631e", + "metadata": {}, + "source": [ + "> Note: the trained model checkpoint will only be available for the period of 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python (lightningrod-sdk)", + "language": "python", + "name": "lightningrod-sdk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.2" } - ], - "source": [ - "print(lr.predict(job.model_id, \"Will the Fed cut rates by 25bp in March 2026?\"))\n" - ] - }, - { - "cell_type": "markdown", - "id": "c8d360d3", - "metadata": {}, - "source": [ - "## Run evals on trained model\n", - "\n", - "Run test evals on your trained model against a test dataset. The eval job runs the model on the dataset and reports metrics. Use the same dataset for a quick check, or a separate test split for production." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9dd52fd4", - "metadata": {}, - "outputs": [], - "source": [ - "eval_job = lr.evals.run(model_id=job.model_id, dataset=test_dataset)" - ] - }, - { - "cell_type": "markdown", - "id": "72c2631e", - "metadata": {}, - "source": [ - "> Note: the trained model checkpoint will only be available for the period of 7 days. If you wish to host this model long-term, reach out to us at support@lightningrod.ai." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python (lightningrod-sdk)", - "language": "python", - "name": "lightningrod-sdk" }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.2" - } - }, - "nbformat": 4, - "nbformat_minor": 5 + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/src/lightningrod/__init__.py b/src/lightningrod/__init__.py index 0804f0e..9f1c75a 100644 --- a/src/lightningrod/__init__.py +++ b/src/lightningrod/__init__.py @@ -9,7 +9,7 @@ from lightningrod import preprocessing, training, utils from lightningrod.utils.sample import create_sample from lightningrod.utils.models import open_router_model -from lightningrod.training import filter_and_split +from lightningrod.training import prepare_for_training, FilterParams, DedupParams, SplitParams from lightningrod.training.client import TrainingConfig from lightningrod._generated.models import ( TransformJob, @@ -97,7 +97,10 @@ "Rollout", "RolloutScorer", "RolloutGenerator", - "filter_and_split", + "prepare_for_training", + "FilterParams", + "DedupParams", + "SplitParams", "TrainingConfig", "Sample", "SampleMeta", diff --git a/src/lightningrod/_display.py b/src/lightningrod/_display.py index ccd418c..790bc4e 100644 --- a/src/lightningrod/_display.py +++ b/src/lightningrod/_display.py @@ -466,13 +466,13 @@ def _build_invalid_samples_error_message(original_message: str) -> Group: renderables.append(_safe_markup("[bold]Next steps:[/bold]")) renderables.append(_safe_markup(" • Check the dataset samples to see specific failure reasons in the 'meta.filter_reason' field")) - renderables.append(_safe_markup(" • Adjust and retry the transform pipeline (e.g., lower confidence thresholds, relax filter criteria)")) + renderables.append(_safe_markup(" • Adjust and retry the transform pipeline (e.g., try a wider date range)")) renderables.append(_safe_markup(" • If the problem persists, contact support or open a GitHub issue: [link=https://github.com/lightning-rod-labs/lightningrod-python-sdk/issues]https://github.com/lightning-rod-labs/lightningrod-python-sdk/issues[/link]")) return Group(*renderables) -def display_error(message: str, title: str = "Error", job: Any = None) -> None: +def display_error(message: str, title: str = "Error", job: Any = None, response_body: str | None = None) -> None: console = Console() renderables: list[RenderableType] = [] @@ -484,6 +484,11 @@ def display_error(message: str, title: str = "Error", job: Any = None) -> None: else: renderables.append(_safe_markup(f"[bold]{message}[/bold]")) + if response_body is not None and response_body.strip(): + renderables.append(Text("")) + renderables.append(_safe_markup("[bold]Response body:[/bold]")) + renderables.append(Text(response_body.strip()[:2000], style="dim")) + if job is not None: cost_lines = _build_transform_cost_lines(job) if isinstance(job, TransformJob) else _build_cost_lines(job) if cost_lines: @@ -493,6 +498,79 @@ def display_error(message: str, title: str = "Error", job: Any = None) -> None: console.print(Panel(Group(*renderables), border_style="bright_red", padding=(1, 2))) +def display_prepare_report(report: Any, verbose: bool = True) -> None: + """Render a PrepareReport as a Rich panel. Used inside Jupyter notebooks.""" + from lightningrod.training.samples import PrepareReport + assert isinstance(report, PrepareReport) + stats = report.stats + console = Console() + renderables: list[RenderableType] = [] + + border = "bright_green" if report.is_healthy else "yellow" + header_style = "bold bright_green" if report.is_healthy else "bold yellow" + renderables.append(_safe_markup(f"[{header_style}]>> prepare_for_training[/{header_style}]")) + renderables.append(Text("")) + + if verbose or not report.is_healthy: + renderables.append(_safe_markup(f" [dim]Starting with {stats.total} samples[/dim]")) + renderables.append(Text("")) + + parts = [] + if stats.filter_invalid: + parts.append(f"{stats.filter_invalid} invalid") + if stats.filter_horizon: + part = f"{stats.filter_horizon} horizon" + if stats.filter_missing_resolution_date or stats.filter_missing_prediction_date: + sub = [] + if stats.filter_missing_resolution_date: + sub.append(f"{stats.filter_missing_resolution_date} missing resolution date") + if stats.filter_missing_prediction_date: + sub.append(f"{stats.filter_missing_prediction_date} missing prediction date") + part += f" ({', '.join(sub)})" + parts.append(part) + if stats.filter_context: + parts.append(f"{stats.filter_context} missing context") + filter_line = ( + f" [bold]Filter:[/bold] Dropped {', '.join(parts)} → {stats.filter_kept} remain" + if parts else + f" [bold]Filter:[/bold] {stats.filter_kept} remain (0 dropped)" + ) + renderables.append(_safe_markup(filter_line)) + + if stats.dedup_removed > 0: + renderables.append(_safe_markup( + f" [bold]Dedup:[/bold] Removed {stats.dedup_removed} duplicates " + f"({stats.dedup_kept + stats.dedup_removed} → {stats.dedup_kept})" + )) + for k, c in stats.dedup_top_collisions: + q = repr(k[0])[:60] + ("..." if len(repr(k[0])) > 60 else "") + renderables.append(Text(f" ({q}, {k[1]}): {c} samples → 1", style="dim")) + else: + renderables.append(_safe_markup(f" [bold]Dedup:[/bold] {stats.dedup_kept} remain (0 duplicates)")) + + split_detail = f"Splits: {stats.split_train_after} train | {stats.split_test_after} test ({stats.split_no_sort_key} dropped, no prediction_date)" + renderables.append(_safe_markup(f" [bold]Split:[/bold] {split_detail}")) + n_leaked = stats.split_train_before - stats.split_train_after + if n_leaked: + renderables.append(_safe_markup( + f" [yellow]{n_leaked} train samples removed for leakage[/yellow]" + )) + + if not report.is_healthy: + renderables.append(Text("")) + renderables.append(_safe_markup("[bold yellow]⚠ Unhealthy dataset[/bold yellow]")) + for issue in report.issues: + renderables.append(Text("")) + renderables.append(Text(issue.message, style="bold")) + if issue.tips: + renderables.append(Text("")) + renderables.append(_safe_markup(" [dim]Tips:[/dim]")) + for tip in issue.tips: + renderables.append(Text(f" • {tip}")) + + console.print(Panel(Group(*renderables), border_style=border, padding=(1, 2))) + + def display_warning(message: str, title: str = "Warning", job: Any = None) -> None: console = Console() renderables: list[RenderableType] = [] diff --git a/src/lightningrod/training/__init__.py b/src/lightningrod/training/__init__.py index 2b7d7e4..40a7073 100644 --- a/src/lightningrod/training/__init__.py +++ b/src/lightningrod/training/__init__.py @@ -4,10 +4,13 @@ from lightningrod.training.samples import ( deduplicate_samples, filter_samples, - filter_and_split, + prepare_for_training, train_test_split, to_messages, to_record, + FilterParams, + DedupParams, + SplitParams, ) __all__ = [ @@ -15,10 +18,13 @@ "print_eval", "TrainingClient", "TrainingConfig", - "filter_and_split", + "prepare_for_training", "train_test_split", "deduplicate_samples", "filter_samples", "to_record", "to_messages", + "FilterParams", + "DedupParams", + "SplitParams", ] diff --git a/src/lightningrod/training/samples.py b/src/lightningrod/training/samples.py index 4977dc2..1d90afd 100644 --- a/src/lightningrod/training/samples.py +++ b/src/lightningrod/training/samples.py @@ -32,9 +32,19 @@ @dataclass class PrepareStats: - """Tracks metrics collected during prepare_for_training.""" + """Tracks metrics collected during prepare_for_training. + + Fields are grouped by pipeline stage. Invariants: + filter_kept = total - filter_invalid - filter_horizon - filter_context + dedup_kept = filter_kept - dedup_removed + split_train_before + split_test_after + split_no_sort_key = dedup_kept (temporal) + split_train_excluded = split_train_before - split_train_after (inferred; train leakage only) + """ + + # ── Input ──────────────────────────────────────────────────────────────── total: int = 0 + # ── Stage 1: filter_samples ─────────────────────────────────────────────── filter_invalid: int = 0 filter_horizon: int = 0 filter_context: int = 0 @@ -42,16 +52,61 @@ class PrepareStats: filter_missing_prediction_date: int = 0 filter_kept: int = 0 + # ── Stage 2: deduplicate_samples ───────────────────────────────────────── dedup_removed: int = 0 dedup_kept: int = 0 dedup_top_collisions: list[tuple[tuple[Any, ...], int]] = field(default_factory=list) - split_strategy: str = "" - split_test_size: float | None = None + # ── Stage 3: train_test_split ───────────────────────────────────────────── split_no_sort_key: int = 0 - split_leaky: int = 0 - split_train: int = 0 - split_test: int = 0 + + split_train_before: int = 0 + split_train_after: int = 0 + split_test_after: int = 0 + + +@dataclass +class FilterParams: + """Parameters for :func:`filter_samples`.""" + days_to_resolution_range: DaysToResolutionRange = None + drop_missing_context: bool = False + + +@dataclass +class DedupParams: + """Parameters for :func:`deduplicate_samples`.""" + key_fn: Callable[[Sample], tuple[Any, ...]] | None = None + + +@dataclass +class SplitParams: + """Parameters for :func:`train_test_split`.""" + strategy: str = "temporal" + test_size: float | None = 0.2 + test_start: str | None = None + random_state: int = 196 + sort_key: Callable[[Sample], str | None] | None = None + leakage_keys: list[Callable[[Sample], str | None]] | None = None + filter_leaky_train: bool = True + + +@dataclass +class PrepareIssue: + """A single detected problem in the prepared dataset, with actionable tips.""" + message: str + tips: list[str] = field(default_factory=list) + + +@dataclass +class PrepareReport: + """Full report produced by :func:`prepare_for_training`, covering all pipeline stages.""" + stats: PrepareStats + issues: list[PrepareIssue] = field(default_factory=list) + + @property + def is_healthy(self) -> bool: + return not self.issues + def _validate_days_to_resolution_range(value: Any) -> None: if value is None: @@ -85,14 +140,14 @@ def _parse_date(value: Any) -> Optional[date]: def filter_samples( samples: list[Sample], - days_to_resolution_range: DaysToResolutionRange = None, - drop_missing_context: bool = True, + params: FilterParams | None = None, stats: PrepareStats | None = None, ) -> list[Sample]: """Filter samples by validity, horizon, and optional context presence.""" - _validate_days_to_resolution_range(days_to_resolution_range) - min_horizon = days_to_resolution_range[0] if days_to_resolution_range else None - max_horizon = days_to_resolution_range[1] if days_to_resolution_range else None + params = params or FilterParams() + _validate_days_to_resolution_range(params.days_to_resolution_range) + min_horizon = params.days_to_resolution_range[0] if params.days_to_resolution_range else None + max_horizon = params.days_to_resolution_range[1] if params.days_to_resolution_range else None n_invalid = n_horizon = n_context = n_missing_resolution_date = n_missing_prediction_date = 0 filtered: list[Sample] = [] @@ -133,7 +188,7 @@ def filter_samples( if max_horizon is not None and horizon_days > max_horizon: n_horizon += 1 continue - if drop_missing_context: + if params.drop_missing_context: if not sample.context: n_context += 1 continue @@ -236,28 +291,23 @@ def get_resolution_date(sample: Sample) -> str | None: def train_test_split( samples: list[Sample], - *, - split_strategy: str = "temporal", - test_start: str | None = None, - test_size: float | None = None, - random_state: int = 196, - sort_key: Callable[[Sample], str | None] | None = None, - leakage_keys: list[Callable[[Sample], str | None]] | None = None, - filter_leaky_train: bool = True, + params: SplitParams | None = None, stats: PrepareStats | None = None, ) -> tuple[list[str], list[str]]: """Split samples into train/test by temporal order or random shuffle, with optional leakage filtering. Returns (train_ids, test_ids) for memory efficiency.""" - temporal_split = split_strategy == "temporal" + params = params or SplitParams() + temporal_split = params.strategy == "temporal" if temporal_split: - if (test_start is None) == (test_size is None): - raise ValueError("Provide exactly one of test_start or test_size when split_strategy='temporal'") + if (params.test_start is None) == (params.test_size is None): + raise ValueError("Provide exactly one of test_start or test_size when strategy='temporal'") else: - if test_size is None: - raise ValueError("test_size is required when split_strategy='random'") - if test_start is not None: - raise ValueError("test_start is only valid when split_strategy='temporal'") + if params.test_size is None: + raise ValueError("test_size is required when strategy='random'") + if params.test_start is not None: + raise ValueError("test_start is only valid when strategy='temporal'") + sort_key = params.sort_key if sort_key is None: def default_sort_key(sample: Sample) -> str | None: if not sample.question: @@ -271,24 +321,23 @@ def default_sort_key(sample: Sample) -> str | None: return None sort_key = default_sort_key - if leakage_keys is None: - leakage_keys = _default_leakage_keys() + leakage_keys = params.leakage_keys or _default_leakage_keys() if temporal_split: valid_samples = [r for r in samples if sort_key(r) is not None] n_no_sort_key = len(samples) - len(valid_samples) sorted_samples = sorted(valid_samples, key=sort_key) - if test_size is not None: - split_idx = int(len(sorted_samples) * (1 - test_size)) + if params.test_size is not None: + split_idx = int(len(sorted_samples) * (1 - params.test_size)) train, test = sorted_samples[:split_idx], sorted_samples[split_idx:] else: - assert test_start is not None - train = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) < test_start] - test = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) >= test_start] + assert params.test_start is not None + train = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) < params.test_start] + test = [r for r in sorted_samples if sort_key(r) is not None and sort_key(r) >= params.test_start] n_leaky = 0 - if filter_leaky_train and test: + if params.filter_leaky_train and test: test_cutoff = sort_key(test[0]) if test_cutoff is not None: def is_safe(row: Sample) -> bool: @@ -303,28 +352,26 @@ def is_safe(row: Sample) -> bool: n_leaky = train_before - len(train) if stats is not None: - stats.split_strategy = "temporal" stats.split_no_sort_key = n_no_sort_key - stats.split_leaky = n_leaky - stats.split_train = len(train) - stats.split_test = len(test) + stats.split_train_before = len(train) + n_leaky + stats.split_train_after = len(train) + stats.split_test_after = len(test) return [s.id for s in train], [s.id for s in test] shuffled = list(samples) - rng = random.Random(random_state) if random_state is not None else random + rng = random.Random(params.random_state) if params.random_state is not None else random rng.shuffle(shuffled) - assert test_size is not None - split_idx = int(len(shuffled) * (1 - test_size)) + assert params.test_size is not None + split_idx = int(len(shuffled) * (1 - params.test_size)) train = shuffled[:split_idx] test = shuffled[split_idx:] if stats is not None: - stats.split_strategy = "random" - stats.split_test_size = test_size - stats.split_train = len(train) - stats.split_test = len(test) + stats.split_train_before = len(train) + stats.split_train_after = len(train) + stats.split_test_after = len(test) return [s.id for s in train], [s.id for s in test] @@ -343,11 +390,12 @@ def _default_dedup_key(sample: Sample) -> tuple[Any, ...]: def deduplicate_samples( samples: list[Sample], - key_fn: Callable[[Sample], tuple[Any, ...]] | None = None, + params: DedupParams | None = None, stats: PrepareStats | None = None, ) -> list[Sample]: """Remove duplicate samples by (question_text, resolution_date) or custom key.""" - key_fn_local: Callable[[Sample], tuple[Any, ...]] = key_fn or _default_dedup_key + params = params or DedupParams() + key_fn_local: Callable[[Sample], tuple[Any, ...]] = params.key_fn or _default_dedup_key seen: set[tuple[Any, ...]] = set() key_counts: dict[tuple[Any, ...], int] = {} result: list[Sample] = [] @@ -582,59 +630,195 @@ def _render_context(context: list[Union[NewsContext, RAGContext]]) -> str: return "\n\n".join(rendered_sections) -def _print_stats(stats: PrepareStats) -> None: - print(f"[prepare_for_training] Starting with {stats.total} samples") - - parts = [] - if stats.filter_invalid: - parts.append(f"{stats.filter_invalid} invalid") - if stats.filter_horizon: - part = f"{stats.filter_horizon} horizon" - if stats.filter_missing_resolution_date or stats.filter_missing_prediction_date: - sub = [] - if stats.filter_missing_resolution_date: - sub.append(f"{stats.filter_missing_resolution_date} missing resolution date") - if stats.filter_missing_prediction_date: - sub.append(f"{stats.filter_missing_prediction_date} missing prediction date") - part += f" ({', '.join(sub)})" - parts.append(part) - if stats.filter_context: - parts.append(f"{stats.filter_context} missing context") - if parts: - print(f"[filter] Dropped {', '.join(parts)} → {stats.filter_kept} remain") - else: - print(f"[filter] {stats.filter_kept} remain (0 dropped)") +def _build_report(stats: PrepareStats, split: SplitParams, filter: FilterParams) -> PrepareReport: + """Build a structured PrepareReport from pipeline stats and params. Pure — no side effects.""" + issues: list[PrepareIssue] = [] + split_train_excluded = stats.split_train_before - stats.split_train_after - if stats.dedup_removed > 0: - print(f"[dedup] Removed {stats.dedup_removed} duplicates ({stats.dedup_kept + stats.dedup_removed} → {stats.dedup_kept}). Top colliding keys:") - for k, c in stats.dedup_top_collisions: - q = repr(k[0])[:60] + ("..." if len(repr(k[0])) > 60 else "") - print(f" ({q}, {k[1]}): {c} samples → 1") - else: - print(f"[dedup] {stats.dedup_kept} remain (0 duplicates)") + # Issue: majority of train samples leaked into test period + majority_train_leaked = ( + split.filter_leaky_train + and split.strategy == "temporal" + and stats.split_train_before > 0 + and split_train_excluded > stats.split_train_before // 2 + ) + if majority_train_leaked: + pct = int(100 * split_train_excluded / stats.split_train_before) + tips: list[str] = [] + max_horizon = filter.days_to_resolution_range[1] if filter.days_to_resolution_range else None + if max_horizon is not None: + tips.append( + f"Extend the seed generator date range to start earlier — the range should span at least " + f"{max_horizon * 2} days so questions generated near the start resolve well before the test window." + ) + else: + tips.append( + "Extend the seed generator (date) filter range to start earlier — questions generated near the start " + "will resolve well before the test window. Aim for at least 2× your max resolution horizon." + ) + tips.append( + "Generate more samples by increasing max_questions in lr.transforms.run() or removing the limit, " + "or increase questions_per_seed in your question generator config. " + "A larger, temporally well-spread dataset naturally pushes the split cutoff far enough back." + ) + tips.append( + "If very few seeds were returned by the pipeline (check the run summary table), the search queries " + "may not surface results across the full date range. Try more diverse search queries, increase " + "articles_per_search, or shorten interval_duration_days." + ) + issues.append(PrepareIssue( + message=( + f"{split_train_excluded}/{stats.split_train_before} train samples ({pct}%) were removed " + "for temporal leakage — the date_close or resolution_date of train questions extends into the test period." + ), + tips=tips, + )) + + # Issue: too few train samples for effective training + MIN_TRAIN_SAMPLES = 200 + if stats.split_train_after < MIN_TRAIN_SAMPLES and stats.split_train_after > 0: + issues.append(PrepareIssue( + message=( + f"Only {stats.split_train_after} train samples remain after preparation. " + f"This is below the recommended minimum of {MIN_TRAIN_SAMPLES} for effective training." + ), + tips=[ + "Increase max_questions in lr.transforms.run() to generate more samples.", + "Increase questions_per_seed in your question generator (ForwardLookingQuestionGenerator or QuestionGenerator) to produce more questions from each seed article." + "Add more search queries to your seed generator to diversify seed sources.", + "Widen the seed generator date range (start_date to end_date) to capture more events.", + ], + )) + + # Issue: too few test samples for reliable evaluation + MIN_TEST_SAMPLES = 50 + if stats.split_test_after < MIN_TEST_SAMPLES and stats.split_test_after > 0: + issues.append(PrepareIssue( + message=( + f"Only {stats.split_test_after} test samples remain after preparation. " + f"This is below the recommended minimum of ~{MIN_TEST_SAMPLES} for reliable evaluation." + ), + tips=[ + "Generate more samples overall — test samples come from the most recent portion of your date range.", + "Ensure your seed generator date range extends close to the present so recent events appear in the test set.", + ], + )) + + # Issue: high invalid rate (>30% of samples were invalid) + HIGH_INVALID_THRESHOLD = 0.30 + if stats.total > 0 and stats.filter_invalid / stats.total > HIGH_INVALID_THRESHOLD: + pct = int(100 * stats.filter_invalid / stats.total) + issues.append(PrepareIssue( + message=( + f"{stats.filter_invalid}/{stats.total} samples ({pct}%) were marked invalid. " + "This suggests issues with dataset generation configuration." + ), + tips=[ + "Add more examples and bad_examples to guide the question generator toward more sensible questions.", + "Check the labeler configuration — if WebSearchLabeler can't find resolution info, samples are marked invalid.", + "Inspect a few invalid samples with dataset.flattened() to identify patterns.", + ], + )) + + # Issue: high dedup rate (>40% removed as duplicates) + HIGH_DEDUP_THRESHOLD = 0.40 + if stats.filter_kept > 0 and stats.dedup_removed / stats.filter_kept > HIGH_DEDUP_THRESHOLD: + pct = int(100 * stats.dedup_removed / stats.filter_kept) + issues.append(PrepareIssue( + message=( + f"{stats.dedup_removed}/{stats.filter_kept} samples ({pct}%) were duplicates. " + "The pipeline is generating repetitive or similar questions." + ), + tips=[ + "Add more diverse search queries to your seed generator to surface different source articles.", + "Increase interval_duration_days to spread seeds across more time periods.", + "Add bad_examples to your question generator showing the repetitive patterns to avoid.", + "Use more specific instructions in your question generator to encourage variety.", + ], + )) + + # Issue: high horizon filter rate (>50% filtered out by horizon) + HIGH_HORIZON_THRESHOLD = 0.50 + if stats.total > 0 and stats.filter_horizon / stats.total > HIGH_HORIZON_THRESHOLD: + pct = int(100 * stats.filter_horizon / stats.total) + horizon_desc = "" + if filter.days_to_resolution_range: + min_h, max_h = filter.days_to_resolution_range + if min_h is not None and max_h is not None: + horizon_desc = f" (required: {min_h}-{max_h} days)" + elif min_h is not None: + horizon_desc = f" (required: ≥{min_h} days)" + elif max_h is not None: + horizon_desc = f" (required: ≤{max_h} days)" + issues.append(PrepareIssue( + message=( + f"{stats.filter_horizon}/{stats.total} samples ({pct}%) fell outside the resolution horizon{horizon_desc}. " + ), + tips=[ + "Widen your days_to_resolution_range in FilterParams if your use case allows longer/shorter horizons.", + "Adjust your seed generator date range — questions from very recent seeds may not have resolved yet, " + "while questions from old seeds may exceed your max horizon.", + "Check that your question generator is producing questions with appropriate resolution timelines for your target horizon." + ], + )) + + return PrepareReport(stats=stats, issues=issues) + + +def _print_report(report: PrepareReport, verbose: bool) -> None: + """Print the report to stdout and raise if unhealthy (non-notebook path).""" + stats = report.stats + if verbose: + print(f"[prepare_for_training] Starting with {stats.total} samples") + + parts = [] + if stats.filter_invalid: + parts.append(f"{stats.filter_invalid} invalid") + if stats.filter_horizon: + part = f"{stats.filter_horizon} horizon" + if stats.filter_missing_resolution_date or stats.filter_missing_prediction_date: + sub = [] + if stats.filter_missing_resolution_date: + sub.append(f"{stats.filter_missing_resolution_date} missing resolution date") + if stats.filter_missing_prediction_date: + sub.append(f"{stats.filter_missing_prediction_date} missing prediction date") + part += f" ({', '.join(sub)})" + parts.append(part) + if stats.filter_context: + parts.append(f"{stats.filter_context} missing context") + print(f"[filter] Dropped {', '.join(parts)} → {stats.filter_kept} remain" if parts else f"[filter] {stats.filter_kept} remain (0 dropped)") + + if stats.dedup_removed > 0: + print(f"[dedup] Removed {stats.dedup_removed} duplicates ({stats.dedup_kept + stats.dedup_removed} → {stats.dedup_kept}). Top colliding keys:") + for k, c in stats.dedup_top_collisions: + q = repr(k[0])[:60] + ("..." if len(repr(k[0])) > 60 else "") + print(f" ({q}, {k[1]}): {c} samples → 1") + else: + print(f"[dedup] {stats.dedup_kept} remain (0 duplicates)") - if stats.split_strategy == "temporal": if stats.split_no_sort_key: print(f"[split] {stats.split_no_sort_key} samples had no prediction_date (dropped)") - if stats.split_leaky: - print(f"[split] {stats.split_leaky} train samples removed for leakage") - print(f"[split] Temporal split: {stats.split_train} train, {stats.split_test} test") - else: - print(f"[split] Random split (test_size={stats.split_test_size}): {stats.split_train} train, {stats.split_test} test") + n_leaked = stats.split_train_before - stats.split_train_after + if n_leaked: + print(f"[split] {n_leaked} train samples removed for leakage") + print(f"[split] Temporal split: {stats.split_train_after} train, {stats.split_test_after} test") + if not report.is_healthy: + lines = ["[prepare_for_training] Unhealthy split detected."] + for issue in report.issues: + lines.append(issue.message) + if issue.tips: + lines.append("Tips:\n" + "\n".join(f" - {t}" for t in issue.tips)) + raise ValueError("\n\n".join(lines)) -def filter_and_split( + +def prepare_for_training( dataset: "SampleDataset", *, - test_size: float = 0.2, - split_strategy: str = "temporal", - test_start: str | None = None, - drop_missing_context: bool = False, - days_to_resolution_range: DaysToResolutionRange = None, - random_state: int = 196, - filter_leaky_train: bool = True, - deduplicate_key_fn: Callable[[Sample], tuple[Any, ...]] | None = None, - verbose: bool = False, + filter: FilterParams | None = None, + dedup: DedupParams | None = None, + split: SplitParams | None = None, + verbose: bool = True, ) -> tuple["SampleDataset", "SampleDataset"]: """Prepare a dataset for model training: filter, deduplicate, split into train/test. @@ -644,42 +828,30 @@ def filter_and_split( Args: dataset: SampleDataset to prepare (samples are fetched via dataset.samples()). - test_size: Fraction of samples for the test set (0.0–1.0). Default 0.2. - split_strategy: 'temporal' (default) or 'random'. - test_start: ISO date string for temporal splits. Provide exactly one of - test_start or test_size for temporal splits. - drop_missing_context: If True, exclude samples with no context. - days_to_resolution_range: Optional (min_days, max_days) tuple. - random_state: Seed for reproducible random splits. - filter_leaky_train: When True and temporal, remove temporal leakage. - deduplicate_key_fn: Optional function to customize deduplication key. + filter: Controls validity filtering and horizon range. See :class:`FilterParams`. + dedup: Controls deduplication key. See :class:`DedupParams`. + split: Controls train/test split strategy, size, and leakage filtering. See :class:`SplitParams`. verbose: When True, print step-by-step stats. Returns: (train_dataset, test_dataset): SampleDatasets ready for training/eval. """ + filter = filter or FilterParams() + dedup = dedup or DedupParams() + split = split or SplitParams() + samples = dataset.samples() stats = PrepareStats(total=len(samples)) - filtered = filter_samples( - samples, - days_to_resolution_range=days_to_resolution_range, - drop_missing_context=drop_missing_context, - stats=stats, - ) - deduped = deduplicate_samples(filtered, key_fn=deduplicate_key_fn, stats=stats) - - train_ids, test_ids = train_test_split( - deduped, - split_strategy=split_strategy, - test_start=test_start, - filter_leaky_train=filter_leaky_train, - test_size=test_size, - random_state=random_state, - stats=stats, - ) + filtered = filter_samples(samples, filter, stats=stats) + deduped = deduplicate_samples(filtered, dedup, stats=stats) + train_ids, test_ids = train_test_split(deduped, split, stats=stats) - if verbose: - _print_stats(stats) + report = _build_report(stats, split=split, filter=filter) + from lightningrod._display import _is_notebook, display_prepare_report + if _is_notebook(): + display_prepare_report(report, verbose=verbose) + else: + _print_report(report, verbose=verbose) return dataset.subset(train_ids), dataset.subset(test_ids) \ No newline at end of file