Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/semgrep.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
name: 🚨 Semgrep Analysis
name: Semgrep Analysis
on:
merge_group:
pull_request:
Expand Down Expand Up @@ -28,7 +28,7 @@ permissions:

jobs:
semgrep:
name: 🚨 Semgrep Analysis
name: Semgrep Analysis
runs-on: ubuntu-latest
container:
image: returntocorp/semgrep
Expand Down
55 changes: 55 additions & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
python:
name: Python - Lint, Typecheck, Test

strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]

runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683

- name: Setup Python ${{ matrix.python-version }}
uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b
with:
python-version: ${{ matrix.python-version }}

- name: Install Poetry
uses: abatilo/actions-poetry@e78f54a89cb052fff327414dd9ff010b5d2b4dbd

- name: Configure Poetry
run: |
poetry config virtualenvs.create true --local
poetry config virtualenvs.in-project true --local

- name: Cache dependencies
uses: actions/cache@0c907a75c2c80ebcb7f088228285e798b750cf8f # v4.2.1
with:
path: ./.venv
key: venv-${{ runner.os }}-py${{ matrix.python-version }}-${{ hashFiles('poetry.lock') }}
restore-keys: |
venv-${{ runner.os }}-py${{ matrix.python-version }}-

- name: Install package
run: poetry install --all-extras

- name: Lint
run: poetry run ruff check --output-format=github .

- name: Typecheck
run: poetry run mypy .

- name: Test
run: poetry run pytest
58 changes: 0 additions & 58 deletions .github/workflows/tests.yaml

This file was deleted.

4 changes: 2 additions & 2 deletions dreadnode/artifact/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,9 @@ def _update_directory_hash(self, dir_node: DirectoryNode) -> str:

for child in dir_node["children"]:
if child["type"] == "file":
child_hashes.append(cast(FileNode, child)["hash"]) # noqa: TC006
child_hashes.append(cast("FileNode", child)["hash"])
else:
child_hash = self._update_directory_hash(cast(DirectoryNode, child)) # noqa: TC006
child_hash = self._update_directory_hash(cast("DirectoryNode", child))
child_hashes.append(child_hash)

child_hashes.sort() # Ensure consistent hash regardless of order
Expand Down
6 changes: 3 additions & 3 deletions dreadnode/artifact/tree_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _build_tree_structure(
}
dir_structure[root_dir_path] = root_node

for file_path in file_nodes_by_path: # noqa: PLC0206
for file_path, file_node in file_nodes_by_path.items():
try:
rel_path = file_path.relative_to(base_dir)
parts = rel_path.parts
Expand All @@ -272,7 +272,7 @@ def _build_tree_structure(

# File in the root directory
if len(parts) == 1:
root_node["children"].append(file_nodes_by_path[file_path])
root_node["children"].append(file_node)
continue

# Create parent directories
Expand All @@ -295,7 +295,7 @@ def _build_tree_structure(
# Now add the file to its parent directory
parent_dir_str = file_path.parent.resolve().as_posix()
if parent_dir_str in dir_structure:
dir_structure[parent_dir_str]["children"].append(file_nodes_by_path[file_path])
dir_structure[parent_dir_str]["children"].append(file_node)
self._compute_directory_hashes(dir_structure)

return root_node
Expand Down
12 changes: 4 additions & 8 deletions dreadnode/integrations/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,8 @@

import typing as t

from transformers.trainer_callback import ( # type: ignore [import-untyped]
TrainerCallback,
TrainerControl,
TrainerState,
TrainingArguments,
)
from transformers.trainer_callback import TrainerCallback, TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

import dreadnode as dn

Expand All @@ -28,7 +24,7 @@ def _clean_keys(data: dict[str, t.Any]) -> dict[str, t.Any]:
return cleaned


class DreadnodeCallback(TrainerCallback): # type: ignore [misc]
class DreadnodeCallback(TrainerCallback):
"""
An implementation of the `TrainerCallback` interface for Dreadnode.

Expand Down Expand Up @@ -124,7 +120,7 @@ def on_epoch_begin(
control: TrainerControl,
**kwargs: t.Any,
) -> None:
if self._run is None:
if self._run is None or state.epoch is None:
return

dn.log_metric("epoch", state.epoch)
Expand Down
24 changes: 21 additions & 3 deletions dreadnode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ENV_SERVER,
ENV_SERVER_URL,
)
from dreadnode.metric import Metric, Scorer, ScorerCallable, T
from dreadnode.metric import Metric, MetricMode, Scorer, ScorerCallable, T
from dreadnode.task import P, R, Task
from dreadnode.tracing.exporters import (
FileExportConfig,
Expand Down Expand Up @@ -757,6 +757,7 @@ def log_metric(
step: int = 0,
origin: t.Any | None = None,
timestamp: datetime | None = None,
mode: MetricMode = "direct",
to: ToObject = "task-or-run",
) -> None:
"""
Expand All @@ -778,6 +779,14 @@ def log_metric(
origin: The origin of the metric - can be provided any object which was logged
as an input or output anywhere in the run.
timestamp: The timestamp of the metric - defaults to the current time.
mode: The aggregation mode to use for the metric. Helpful when you want to let
the library take care of translating your raw values into better representations.
- direct: do not modify the value at all (default)
- min: the lowest observed value reported for this metric
- max: the highest observed value reported for this metric
- avg: the average of all reported values for this metric
- sum: the cumulative sum of all reported values for this metric
- count: increment every time this metric is logged - disregard value
to: The target object to log the metric to. Can be "task-or-run" or "run".
Defaults to "task-or-run". If "task-or-run", the metric will be logged
to the current task or run, whichever is the nearest ancestor.
Expand All @@ -790,6 +799,7 @@ def log_metric(
value: Metric,
*,
origin: t.Any | None = None,
mode: MetricMode = "direct",
to: ToObject = "task-or-run",
) -> None:
"""
Expand All @@ -809,11 +819,18 @@ def log_metric(
value: The metric object.
origin: The origin of the metric - can be provided any object which was logged
as an input or output anywhere in the run.
mode: The aggregation mode to use for the metric. Helpful when you want to let
the library take care of translating your raw values into better representations.
- direct: do not modify the value at all (default)
- min: always report the lowest ovbserved value for this metric
- max: always report the highest observed value for this metric
- avg: report the average of all values for this metric
- sum: report a rolling sum of all values for this metric
- count: report the number of times this metric has been logged
to: The target object to log the metric to. Can be "task-or-run" or "run".
Defaults to "task-or-run". If "task-or-run", the metric will be logged
to the current task or run, whichever is the nearest ancestor.
"""
... # noqa: PIE790

@handle_internal_errors()
def log_metric(
Expand All @@ -824,6 +841,7 @@ def log_metric(
step: int = 0,
origin: t.Any | None = None,
timestamp: datetime | None = None,
mode: MetricMode = "direct",
to: ToObject = "task-or-run",
) -> None:
task = current_task_span.get()
Expand All @@ -838,7 +856,7 @@ def log_metric(
if isinstance(value, Metric)
else Metric(float(value), step, timestamp or datetime.now(timezone.utc))
)
target.log_metric(key, metric, origin=origin)
target.log_metric(key, metric, origin=origin, mode=mode)

@handle_internal_errors()
def log_artifact(
Expand Down
44 changes: 43 additions & 1 deletion dreadnode/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

T = t.TypeVar("T")

MetricMode = t.Literal["direct", "avg", "sum", "min", "max", "count"]


@dataclass
class Metric:
Expand Down Expand Up @@ -55,6 +57,46 @@ def from_many(
score_attributes = {name: value for name, value, _ in values}
return cls(value=total / weight, step=step, attributes={**attributes, **score_attributes})

def apply_mode(self, mode: MetricMode, others: "list[Metric]") -> "Metric":
"""
Apply an aggregation mode to the metric.
This will modify the metric in place.

Args:
mode: The mode to apply. One of "sum", "min", "max", or "inc".
others: A list of other metrics to apply the mode to.

Returns:
self
"""
previous_mode = next((m.attributes.get("mode") for m in others), mode) or "direct"
if mode != previous_mode:
raise ValueError(
f"Cannot mix metric modes {mode} != {previous_mode}",
)

if mode == "direct":
return self

self.attributes["original"] = self.value
self.attributes["mode"] = mode

prior_values = [m.value for m in sorted(others, key=lambda m: m.timestamp)]

if mode == "sum":
self.value += max(prior_values)
elif mode == "min":
self.value = min([self.value, *prior_values])
elif mode == "max":
self.value = max([self.value, *prior_values])
elif mode == "count":
self.value = len(others) + 1
elif mode == "avg" and prior_values:
current_avg = prior_values[-1]
self.value = current_avg + (self.value - current_avg) / (len(prior_values) + 1)

return self


MetricDict = dict[str, list[Metric]]

Expand Down Expand Up @@ -83,7 +125,7 @@ class Scorer(t.Generic[T]):
def from_callable(
cls,
tracer: Tracer,
func: ScorerCallable[T] | "Scorer[T]", # noqa: TC010
func: "ScorerCallable[T] | Scorer[T]",
*,
name: str | None = None,
tags: t.Sequence[str] | None = None,
Expand Down
15 changes: 11 additions & 4 deletions dreadnode/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def top_n(
"""
sorted_ = self.sorted(reverse=reverse)[:n]
return (
t.cast(list[R], [span.output for span in sorted_]) # noqa: TC006
t.cast("list[R]", [span.output for span in sorted_])
if as_outputs
else TaskSpanList(sorted_)
)
Expand Down Expand Up @@ -246,6 +246,8 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
run_id=run.run_id,
tracer=self.tracer,
) as span:
span.run.log_metric(f"{self.label}.exec.count", 1, mode="count")

for name, value in params_to_log.items():
span.log_param(name, value)

Expand All @@ -254,10 +256,15 @@ async def run(self, *args: P.args, **kwargs: P.kwargs) -> TaskSpan[R]:
for name, value in inputs_to_log.items()
]

output = t.cast(R | t.Awaitable[R], self.func(*args, **kwargs)) # noqa: TC006
if inspect.isawaitable(output):
output = await output
try:
output = t.cast("R | t.Awaitable[R]", self.func(*args, **kwargs))
if inspect.isawaitable(output):
output = await output
except Exception:
span.run.log_metric(f"{self.label}.exec.success_rate", 0, mode="avg")
raise

span.run.log_metric(f"{self.label}.exec.success_rate", 1, mode="avg")
span.output = output

if self.log_output:
Expand Down
Loading