Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/build_and_release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,17 @@ jobs:
with:
attestations: false
packages-dir: .github/.internal_dspyai/dist/
# Publish to dspy-runtime (minimal-dependency build from the same source tree)
- name: Update version in pyproject-runtime.toml
run: sed -i '/#replace_package_version_marker/{n;s/version *= *"[^"]*"/version="${{ needs.extract-tag.outputs.version }}"/;}' pyproject-runtime.toml
- name: Build dspy-runtime distribution
run: |
rm -rf dist
bash scripts/build_dspy_runtime.sh
- name: Publish distribution 📦 to PyPI (dspy-runtime)
uses: pypa/gh-action-pypi-publish@cef221092ed1bacb1cc03d23a2d87d1d172e277b # release/v1
with:
attestations: false
- uses: stefanzweifel/git-auto-commit-action@04702edda442b2e678b25b537cec683a1493fcb9 # v5 # auto commit changes to release branch
with:
commit_message: Update versions
Expand Down
3 changes: 2 additions & 1 deletion dspy/adapters/json_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import json_repair
import pydantic
import regex
from pydantic.fields import FieldInfo

from dspy.adapters.chat_adapter import ChatAdapter, FieldInfoWithName
Expand Down Expand Up @@ -146,6 +145,8 @@ def format_assistant_message_content(
return self.format_field_with_value(fields_with_values, role="assistant")

def parse(self, signature: type[Signature], completion: str) -> dict[str, Any]:
import regex

fields = json_repair.loads(completion)

if not isinstance(fields, dict):
Expand Down
6 changes: 4 additions & 2 deletions dspy/clients/_litellm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import os

from dspy.utils.lazy_import import require

_litellm = None


Expand All @@ -20,8 +22,8 @@ def get_litellm():
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

import litellm
from litellm._logging import verbose_logger
litellm = require("litellm", extra="litellm", feature="dspy.LM")
verbose_logger = require("litellm._logging", extra="litellm", feature="dspy.LM").verbose_logger

litellm.telemetry = False
litellm.cache = None
Expand Down
27 changes: 15 additions & 12 deletions dspy/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from datetime import datetime
from typing import Any

import openai

from dspy.clients.provider import Provider, TrainingJob
from dspy.clients.utils_finetune import TrainDataFormat, TrainingStatus, save_data
from dspy.utils.lazy_import import require


def _openai():
return require("openai", extra="full", feature="the OpenAI finetuning provider")


class TrainingJobOpenAI(TrainingJob):
Expand All @@ -22,13 +25,13 @@ def cancel(self):
err_msg = "Jobs that are complete cannot be canceled."
err_msg += f" Job with ID {self.provider_job_id} is done."
raise Exception(err_msg)
openai.fine_tuning.jobs.cancel(self.provider_job_id)
_openai().fine_tuning.jobs.cancel(self.provider_job_id)
self.provider_job_id = None

# Delete the provider file
if self.provider_file_id is not None:
if OpenAIProvider.does_file_exist(self.provider_file_id):
openai.files.delete(self.provider_file_id)
_openai().files.delete(self.provider_file_id)
self.provider_file_id = None

# Call the super's cancel method after the custom cancellation logic
Expand Down Expand Up @@ -104,7 +107,7 @@ def does_job_exist(job_id: str) -> bool:
try:
# TODO(nit): This call may fail for other reasons. We should check
# the error message to ensure that the job does not exist.
openai.fine_tuning.jobs.retrieve(job_id)
_openai().fine_tuning.jobs.retrieve(job_id)
return True
except Exception:
return False
Expand All @@ -114,7 +117,7 @@ def does_file_exist(file_id: str) -> bool:
try:
# TODO(nit): This call may fail for other reasons. We should check
# the error message to ensure that the file does not exist.
openai.files.retrieve(file_id)
_openai().files.retrieve(file_id)
return True
except Exception:
return False
Expand Down Expand Up @@ -147,7 +150,7 @@ def get_training_status(job_id: str) -> TrainingStatus:
assert OpenAIProvider.does_job_exist(job_id), err_msg

# Retrieve the provider's job and report the status
provider_job = openai.fine_tuning.jobs.retrieve(job_id)
provider_job = _openai().fine_tuning.jobs.retrieve(job_id)
provider_status = provider_job.status
status = provider_status_to_training_status[provider_status]

Expand All @@ -166,7 +169,7 @@ def validate_data_format(data_format: TrainDataFormat):
@staticmethod
def upload_data(data_path: str) -> str:
# Upload the data to the provider
provider_file = openai.files.create(
provider_file = _openai().files.create(
file=open(data_path, "rb"),
purpose="fine-tune",
)
Expand All @@ -175,7 +178,7 @@ def upload_data(data_path: str) -> str:
@staticmethod
def _start_remote_training(train_file_id: str, model: str, train_kwargs: dict[str, Any] | None = None) -> str:
train_kwargs = train_kwargs or {}
provider_job = openai.fine_tuning.jobs.create(
provider_job = _openai().fine_tuning.jobs.create(
model=model,
training_file=train_file_id,
hyperparameters=train_kwargs,
Expand All @@ -194,7 +197,7 @@ def wait_for_job(
while not done:
# Report estimated time if not already reported
if not reported_estimated_time:
remote_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id)
remote_job = _openai().fine_tuning.jobs.retrieve(job.provider_job_id)
timestamp = remote_job.estimated_finish
if timestamp:
estimated_finish_dt = datetime.fromtimestamp(timestamp)
Expand All @@ -203,7 +206,7 @@ def wait_for_job(
reported_estimated_time = True

# Get new events
page = openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job.provider_job_id, limit=1)
page = _openai().fine_tuning.jobs.list_events(fine_tuning_job_id=job.provider_job_id, limit=1)
new_event = page.data[0] if page.data else None
if new_event and new_event.id != cur_event_id:
dt = datetime.fromtimestamp(new_event.created_at)
Expand All @@ -222,6 +225,6 @@ def get_trained_model(job):
err_msg += f" Must be {TrainingStatus.succeeded} to retrieve model."
raise Exception(err_msg)

provider_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id)
provider_job = _openai().fine_tuning.jobs.retrieve(job.provider_job_id)
finetuned_model = provider_job.fine_tuned_model
return finetuned_model
15 changes: 11 additions & 4 deletions dspy/dsp/utils/dpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import logging
import unicodedata

import regex

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -157,6 +155,8 @@ def __init__(self, **kwargs):
Args:
annotators: None or empty set (only tokenizes).
"""
import regex

self._regexp = regex.compile(
"(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS),
flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE,
Expand Down Expand Up @@ -225,11 +225,18 @@ def locate_answers(tokenized_answers, text):
return occurrences


STokenizer = SimpleTokenizer()
_STokenizer = None


def _get_stokenizer():
global _STokenizer
if _STokenizer is None:
_STokenizer = SimpleTokenizer()
return _STokenizer


def DPR_tokenize(text): # noqa: N802
return STokenizer.tokenize(unicodedata.normalize("NFD", text))
return _get_stokenizer().tokenize(unicodedata.normalize("NFD", text))


def DPR_normalize(text): # noqa: N802
Expand Down
6 changes: 4 additions & 2 deletions dspy/streaming/streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from queue import Queue
from typing import TYPE_CHECKING, Any

import jiter

from dspy.adapters.chat_adapter import ChatAdapter
from dspy.adapters.json_adapter import JSONAdapter
from dspy.adapters.types import Type
Expand Down Expand Up @@ -245,6 +243,8 @@ def _json_adapter_handle_stream_chunk(self, token: str, chunk_message: str) -> S
# If the parse doesn't raise an error, that means the accumulated tokens is a valid json object. Because
# we add an extra "{" to the beginning of the field_accumulated_messages, so we know the streaming is
# finished.
import jiter

jiter.from_json(self.json_adapter_state["field_accumulated_messages"].encode("utf-8"))
self.stream_end = True
last_token = self.flush()
Expand All @@ -259,6 +259,8 @@ def _json_adapter_handle_stream_chunk(self, token: str, chunk_message: str) -> S
pass

try:
import jiter
Comment thread
greptile-apps[bot] marked this conversation as resolved.

parsed = jiter.from_json(
self.json_adapter_state["field_accumulated_messages"].encode("utf-8"),
partial_mode="trailing-strings",
Expand Down
17 changes: 11 additions & 6 deletions dspy/teleprompt/gepa/gepa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import logging
import random
from dataclasses import dataclass
from typing import Any, Literal, Optional, Protocol, Union
from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, Union

from gepa import GEPAResult
from gepa.core.adapter import ProposalFn
from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector
if TYPE_CHECKING:
from gepa import GEPAResult
from gepa.core.adapter import ProposalFn
from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector

from dspy.clients.lm import LM
from dspy.primitives import Example, Module, Prediction
Expand Down Expand Up @@ -491,7 +492,11 @@ def compile(
- trainset: The training set to use for reflective updates.
- valset: The validation set to use for tracking Pareto scores. If not provided, GEPA will use the trainset for both.
"""
from gepa import GEPAResult, optimize
from dspy.utils.lazy_import import require

gepa = require("gepa", extra="gepa", feature="dspy.GEPA")
GEPAResult = gepa.GEPAResult
optimize = gepa.optimize

from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, LoggerAdapter

Expand Down Expand Up @@ -575,7 +580,7 @@ def feedback_fn(
# Build the seed candidate: map each predictor name to its current instruction
seed_candidate = {name: pred.signature.instructions for name, pred in student.named_predictors()}

gepa_result: GEPAResult = optimize(
gepa_result: "GEPAResult" = optimize(
seed_candidate=seed_candidate,
trainset=trainset,
valset=valset,
Expand Down
33 changes: 27 additions & 6 deletions dspy/teleprompt/gepa/gepa_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import logging
import random
from typing import Any, Callable, Protocol, TypedDict

from gepa import EvaluationBatch, GEPAAdapter
from gepa.core.adapter import ProposalFn
from gepa.strategies.instruction_proposal import InstructionProposalSignature
from typing import TYPE_CHECKING, Any, Callable, Protocol, TypedDict

import dspy
from dspy.adapters.chat_adapter import ChatAdapter
Expand All @@ -13,10 +9,31 @@
from dspy.evaluate import Evaluate
from dspy.primitives import Example, Prediction
from dspy.teleprompt.bootstrap_trace import FailedPrediction, TraceData
from dspy.utils.lazy_import import optional, require

if TYPE_CHECKING:
from gepa import EvaluationBatch, GEPAAdapter
from gepa.core.adapter import ProposalFn

logger = logging.getLogger(__name__)


def _require_gepa():
require("gepa", extra="gepa", feature="dspy.GEPA")


def _get_gepa_adapter_base():
"""Return the GEPAAdapter base class, or `object` if gepa is not installed.

Returning `object` lets ``DspyAdapter`` be defined at import time without gepa;
actual use is gated by ``_require_gepa()`` inside methods that touch gepa internals.
"""
GEPAAdapter = optional("gepa", "GEPAAdapter")
if GEPAAdapter is None:
return object
return GEPAAdapter[Example, "TraceData", Prediction]


class LoggerAdapter:
def __init__(self, logger: logging.Logger):
self.logger = logger
Expand Down Expand Up @@ -74,7 +91,7 @@ def __call__(
...


class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]):
class DspyAdapter(_get_gepa_adapter_base()):
def __init__(
self,
student_module,
Expand Down Expand Up @@ -117,6 +134,8 @@ def propose_new_texts(
components_to_update=components_to_update,
)

from gepa.strategies.instruction_proposal import InstructionProposalSignature

results: dict[str, str] = {}

with dspy.context(lm=reflection_lm):
Expand All @@ -143,6 +162,8 @@ def build_program(self, candidate: dict[str, str]):
return new_prog

def evaluate(self, batch, candidate, capture_traces=False):
from gepa import EvaluationBatch

program = self.build_program(candidate)
callback_metadata = (
{"metric_key": "eval_full"}
Expand Down
15 changes: 11 additions & 4 deletions dspy/teleprompt/gepa/instruction_proposal.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import logging
from typing import Any

from gepa.core.adapter import ProposalFn
from typing import TYPE_CHECKING, Any

import dspy
from dspy.adapters.types.base_type import Type
from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample
from dspy.utils.lazy_import import optional

if TYPE_CHECKING:
from gepa.core.adapter import ProposalFn

logger = logging.getLogger(__name__)


def _get_proposal_fn_base():
"""Return ProposalFn base class, or `object` if gepa is not installed."""
return optional("gepa.core.adapter", "ProposalFn", default=object)


class GenerateEnhancedMultimodalInstructionFromFeedback(dspy.Signature):
"""I provided an assistant with instructions to perform a task involving visual content, but the assistant's performance needs improvement based on the examples and feedback below.

Expand Down Expand Up @@ -269,7 +276,7 @@ def _create_multimodal_examples(self, formatted_text: str, image_map: dict[int,
return multimodal_content


class MultiModalInstructionProposer(ProposalFn):
class MultiModalInstructionProposer(_get_proposal_fn_base()):
"""GEPA-compatible multimodal instruction proposer.

This class handles multimodal inputs (like dspy.Image) during GEPA optimization by using
Expand Down
Loading
Loading