Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
207 commits
Select commit Hold shift + click to select a range
5a95cab
apo initial
wensun Jun 18, 2025
d89cd64
.
wensun Jun 18, 2025
1d925d2
vllm fix
jdchang1 Jun 18, 2025
e9dbad9
critic free
jdchang1 Jun 18, 2025
934e864
cleanup
jdchang1 Jun 20, 2025
b7cb34e
local run
jdchang1 Jun 20, 2025
28014a9
added timeout in rlvr utils
wensun Jun 21, 2025
ed4e554
added timeout in rlvr utils
wensun Jun 21, 2025
de6400c
.
wensun Jun 21, 2025
6b4c385
fix some comments issues in local yaml
wensun Jun 23, 2025
b2ae469
.
wensun Jun 23, 2025
4b51437
Update compose_rl/algorithms/online/callback.py
jdchang1 Jun 23, 2025
d4ffc59
clean comments
jdchang1 Jun 23, 2025
1fd7e2d
comment cleanup
jdchang1 Jun 23, 2025
c769975
cleanup of callback
jdchang1 Jun 23, 2025
0d275ec
undo rlvr update
jdchang1 Jun 23, 2025
d6c24d0
Merge branch 'main' into wensun/apo
jdchang1 Jun 23, 2025
202b0a2
fix
jdchang1 Jun 23, 2025
fc2f835
Merge branch 'wensun/apo' of https://github.com/databricks/compose-rl…
jdchang1 Jun 23, 2025
d6d81ca
apo offline initial implementaiton
wensun Jun 29, 2025
dd04bff
apo offline initial implementaiton
wensun Jun 29, 2025
01d28c7
updates
wensun Jun 30, 2025
1c62fdb
add vstar entries
wensun Jun 30, 2025
a5fa391
.
wensun Jun 30, 2025
725e3a2
.
wensun Jun 30, 2025
a0b0209
.
wensun Jun 30, 2025
39377ae
.
wensun Jun 30, 2025
ce6d18a
added ope estimator in apo
wensun Jun 30, 2025
4f0456e
.
wensun Jun 30, 2025
06ceb47
.
wensun Jun 30, 2025
ea680a5
add multimodal support for models
jdchang1 Jul 1, 2025
a7e86cc
linting
jdchang1 Jul 1, 2025
08952f7
multimodal handling for gemma3
jdchang1 Jul 1, 2025
bcc273b
lint
jdchang1 Jul 1, 2025
949a8bd
bce loss
wensun Jul 2, 2025
0368483
bce
wensun Jul 2, 2025
634e151
fix multimodal preference loading
jdchang1 Jul 3, 2025
d35ed90
fix collator
jdchang1 Jul 3, 2025
8bea0e3
debug
jdchang1 Jul 3, 2025
b305fde
debug
jdchang1 Jul 3, 2025
b631ccc
debug
jdchang1 Jul 3, 2025
fabb4a8
change pixel values from being bytes to ndarray or pil
jdchang1 Jul 3, 2025
e2880b6
.
wensun Jul 4, 2025
b040daa
.
wensun Jul 4, 2025
62769ed
.
wensun Jul 4, 2025
649a113
.
wensun Jul 4, 2025
d49d2a2
.
wensun Jul 4, 2025
e7be405
.
wensun Jul 4, 2025
9a2cc92
.
wensun Jul 5, 2025
ece883a
.
wensun Jul 5, 2025
0701099
merge main
jdchang1 Jul 7, 2025
4e469e2
added single offline data track
jdchang1 Jul 7, 2025
3059bc4
cleanup
jdchang1 Jul 7, 2025
a3a2e71
bce fix
jdchang1 Jul 7, 2025
dd578ec
merge conflict
jdchang1 Jul 7, 2025
9ecb7d0
linkt
jdchang1 Jul 7, 2025
829b732
update
jdchang1 Jul 8, 2025
d687b02
update
jdchang1 Jul 8, 2025
22ae655
bug fix for attention mask
jdchang1 Jul 8, 2025
a112dc3
sequence len typo fix
jdchang1 Jul 8, 2025
d6bd975
update
jdchang1 Jul 8, 2025
b38736a
update
jdchang1 Jul 8, 2025
15aabfd
fix
jdchang1 Jul 8, 2025
6672e5d
shape fix
jdchang1 Jul 8, 2025
3d57b0a
fix
jdchang1 Jul 8, 2025
7b04254
fix
jdchang1 Jul 8, 2025
11df0ad
fix
jdchang1 Jul 8, 2025
36d0120
pre-commit
jdchang1 Jul 8, 2025
3a33026
precommit isort
jdchang1 Jul 8, 2025
b06871b
Merge branch 'main' into wensun/offline_apo
jdchang1 Jul 8, 2025
d44610e
support ndarray typing
jdchang1 Jul 8, 2025
aed3e3a
Merge branch 'wensun/offline_apo' of https://github.com/databricks/co…
jdchang1 Jul 8, 2025
e0e015a
support ndarray typing
jdchang1 Jul 8, 2025
2c1c0d4
PIL image support
jdchang1 Jul 8, 2025
15ca2c6
numpy support bug fix
jdchang1 Jul 8, 2025
375d257
pixel values into lists
jdchang1 Jul 8, 2025
ebf43fa
logging fix
jdchang1 Jul 8, 2025
daae699
fix
jdchang1 Jul 8, 2025
82e2f1d
change back to tensor
jdchang1 Jul 8, 2025
5a1891d
Update offline_data.py
jdchang1 Jul 8, 2025
6af3c99
Update offline_data.py
jdchang1 Jul 8, 2025
dbeec36
nd array for pixel_values
jdchang1 Jul 8, 2025
e521e72
fix
jdchang1 Jul 8, 2025
83d1e3b
fix
jdchang1 Jul 8, 2025
b85e4f4
Update compose_rl/algorithms/offline/model_methods.py
jdchang1 Jul 8, 2025
26c0ab8
Update compose_rl/data/offline_data.py
jdchang1 Jul 8, 2025
7650f55
Merge branch 'main' into jchang/multimodal
jdchang1 Jul 9, 2025
99165b0
remove vstar from preference dat
jdchang1 Jul 9, 2025
7cd408e
merge multimodal into
jdchang1 Jul 9, 2025
b68ca4b
add pixel values to forward pass
jdchang1 Jul 9, 2025
ba9e453
offline single stream multimodal support
jdchang1 Jul 9, 2025
06d6a2f
temperature scaling
jdchang1 Jul 9, 2025
c2da2f6
add processor to Dataset to ensure proper HF checkpointing
jdchang1 Jul 9, 2025
952a976
quick test for shape
wensun Jul 10, 2025
4f5fb8b
convert it back
wensun Jul 10, 2025
ddd0ad7
add another metric to track batch advantage
wensun Jul 11, 2025
7eb0951
.
wensun Jul 11, 2025
0a3b8cd
.
wensun Jul 11, 2025
5d0c53f
quick fix
jdchang1 Jul 11, 2025
9b4b9b9
Merge branch 'wensun/offline_apo' of https://github.com/databricks/co…
jdchang1 Jul 11, 2025
144d80e
added bce
wensun Jul 12, 2025
ba2d9e8
temporally just set bce to be true
wensun Jul 13, 2025
b94b594
.
wensun Jul 13, 2025
94359e6
.
wensun Jul 13, 2025
844470d
computation
jdchang1 Jul 13, 2025
8e25f81
model compatibility
jdchang1 Jul 13, 2025
171b0f1
multistep computation
jdchang1 Jul 13, 2025
26ec9a4
multimodal fix
jdchang1 Jul 14, 2025
331ae42
batch advantage computation fix
jdchang1 Jul 14, 2025
2997533
qrpo
jdchang1 Jul 14, 2025
aaef5f8
Update model_methods.py
jdchang1 Jul 14, 2025
204df50
fix
jdchang1 Jul 14, 2025
bc61d5b
Merge branch 'wensun/offline_apo' of https://github.com/databricks/co…
jdchang1 Jul 14, 2025
2e50aff
fix
jdchang1 Jul 14, 2025
fdf8ec2
remove the need for preprocessed image inputs
jdchang1 Jul 17, 2025
43b4324
add token type ids back in
jdchang1 Jul 17, 2025
da38dee
initial set up for dealing with multi turn dataformat
wensun Jul 20, 2025
51912d3
.
wensun Jul 20, 2025
a48305b
.
wensun Jul 20, 2025
1e9b855
.
wensun Jul 20, 2025
0ac2d29
.
wensun Jul 20, 2025
00304a0
.
wensun Jul 21, 2025
2078427
.
wensun Jul 21, 2025
1019e74
.
wensun Jul 21, 2025
254ca00
.
wensun Jul 21, 2025
a449286
.
wensun Jul 21, 2025
6637844
.
wensun Jul 21, 2025
dddc4b3
working version of optimizing tool call in traj level
wensun Jul 21, 2025
9e4395d
exclude gold
jdchang1 Jul 21, 2025
c6d9c64
fixed a seq len bug
wensun Jul 22, 2025
3f08ed1
.
wensun Jul 22, 2025
7292533
.
wensun Jul 22, 2025
49bcae8
testing new collator
wensun Jul 22, 2025
3942181
testing new collator
wensun Jul 22, 2025
0842b91
changed it back to the original collator, test for the new one looks …
wensun Jul 22, 2025
74a8842
.
wensun Jul 22, 2025
76df519
.
wensun Jul 22, 2025
076994f
.
wensun Jul 23, 2025
3e9c9b0
.
wensun Jul 23, 2025
9656354
.
wensun Jul 23, 2025
5cce120
.
wensun Jul 23, 2025
0812e79
tested adding bonus
wensun Jul 23, 2025
5b803e4
.
wensun Jul 23, 2025
9739fc2
.
wensun Jul 23, 2025
a099f47
.
wensun Jul 23, 2025
6e7befe
.
wensun Jul 23, 2025
096bb76
.
wensun Jul 23, 2025
32931f5
.
wensun Jul 23, 2025
f28bea7
.
wensun Jul 23, 2025
7561a25
.
wensun Jul 23, 2025
f593881
.
wensun Jul 23, 2025
19531c6
.
wensun Jul 23, 2025
091a3ce
.
wensun Jul 23, 2025
0dfccdd
.
wensun Jul 23, 2025
e8d04b4
.
wensun Jul 23, 2025
8f5063c
.
wensun Jul 23, 2025
917723c
.
wensun Jul 23, 2025
bf40318
.
wensun Jul 23, 2025
f40974b
rgb
jdchang1 Jul 24, 2025
64ca7a7
tracking sequence entropies in offline rl forward
abaheti95 Jul 27, 2025
5e9c33b
excluding gold
jdchang1 Jul 28, 2025
9268ede
merging branch
jdchang1 Jul 28, 2025
21695dd
.
wensun Aug 20, 2025
b67fbe6
add sequence id
jdchang1 Aug 26, 2025
de597f4
pad token id fix
wensun Aug 28, 2025
e9e3378
Merge branch 'wensun/offline_apo' of github.com:databricks/compose-rl…
wensun Aug 28, 2025
517d2b9
add reference model loading
wensun Sep 3, 2025
4093fbb
.
wensun Sep 3, 2025
f246ea9
.
wensun Sep 3, 2025
5d14693
.
wensun Sep 3, 2025
306dbd9
first version of a unified dataloader
wensun Sep 5, 2025
d0d7c95
dataloader
wensun Sep 5, 2025
d522ea6
tokenizer
wensun Sep 5, 2025
fca744d
.
wensun Sep 5, 2025
e9470b4
test tool loading and message loading
wensun Sep 6, 2025
0511355
.
wensun Sep 6, 2025
23576c4
.
wensun Sep 6, 2025
acda258
.
wensun Sep 6, 2025
ffce2f2
.
wensun Sep 6, 2025
2bfe9c4
.
wensun Sep 6, 2025
09c3ae4
.
wensun Sep 6, 2025
230f5f4
.
wensun Sep 6, 2025
dddbdd4
.
wensun Sep 6, 2025
969d182
.
wensun Sep 6, 2025
2f96310
new data formt
wensun Sep 6, 2025
d6f31a7
test tool
wensun Sep 6, 2025
6cdb619
test tool
wensun Sep 6, 2025
2382630
test tool
wensun Sep 6, 2025
42cf662
message working
wensun Sep 6, 2025
2642698
implemented flatten messages
wensun Sep 6, 2025
efe8965
test non flatten message
wensun Sep 6, 2025
46a74ce
test non flatten message
wensun Sep 6, 2025
0819de1
first version of value function integration using flatten_message
wensun Sep 9, 2025
8cbbc40
testing value function integration
wensun Sep 9, 2025
11efde5
.
wensun Sep 9, 2025
2bc485c
.
wensun Sep 9, 2025
003ec73
.
wensun Sep 9, 2025
91914db
.
wensun Sep 9, 2025
f5e99c3
testing apo critic with reward and vstar
wensun Sep 9, 2025
f176ec5
repupose the apo critic code for the standard apo code just for testing
wensun Sep 9, 2025
485abf8
.
wensun Sep 9, 2025
983fc88
.
wensun Sep 9, 2025
bf1e617
.
wensun Sep 9, 2025
8172e57
.
wensun Sep 10, 2025
6566e0e
use reward at the last step
wensun Sep 10, 2025
03af246
use vstar at the beginning
wensun Sep 11, 2025
0379786
Add value learning into offline RL (#151)
Owen-Oertell Sep 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion compose_rl/algorithms/offline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,22 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

from compose_rl.algorithms.offline.callback import ReferencePolicyCallback
from compose_rl.algorithms.offline.callback import (
PairwiseReferencePolicyCallback,
ReferencePolicyCallback,
)
from compose_rl.algorithms.offline.model import (
ComposerHFOfflinePolicyLM,
ComposerHFPairwiseOfflinePolicyLM,
ComposerMPTOfflinePolicyLM,
ComposerMPTPairwiseOfflinePolicyLM,
)

__all__ = [
'ComposerHFOfflinePolicyLM',
'ComposerMPTOfflinePolicyLM',
'ComposerMPTPairwiseOfflinePolicyLM',
'ComposerHFPairwiseOfflinePolicyLM',
'PairwiseReferencePolicyCallback',
'ReferencePolicyCallback',
]
66 changes: 65 additions & 1 deletion compose_rl/algorithms/offline/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,20 @@ def __init__(
):
self.train_config = copy.deepcopy(train_config)
self.reference_model = None
self.auxiliary_model = None # Add auxiliary model

def after_load(self, state: State, logger: Logger) -> None:
model_config = self.train_config['model']
#model_config = self.train_config['model']
model_config = self.train_config['variables']['reference_model']
init_context = process_init_device(
model_config,
self.train_config.get('fsdp_config'),
)
name = model_config.pop('name')
print("################################################")
print("reference model config:")
print(model_config)
print("################################################")
self.reference_model = build_composer_model(
name=name,
cfg=model_config,
Expand Down Expand Up @@ -75,6 +81,64 @@ def after_load(self, state: State, logger: Logger) -> None:
callbacks=load_checkpoint_callbacks,
)

# Load auxiliary model following the same pattern
if 'auxiliary_model' in self.train_config.get('variables', {}):
aux_model_config = self.train_config['variables']['auxiliary_model']
aux_init_context = process_init_device(
aux_model_config,
self.train_config.get('fsdp_config'),
)
aux_name = aux_model_config.pop('name')
print("################################################")
print("auxiliary model config:")
print(aux_model_config)
print("################################################")
self.auxiliary_model = build_composer_model(
name=aux_name,
cfg=aux_model_config,
tokenizer=state.model.tokenizer, # type: ignore
init_context=aux_init_context,
master_weights_dtype=aux_model_config.get('master_weights_dtype', None),
)

# Load auxiliary model with same checkpoint loading procedure
_ = Trainer(
model=self.auxiliary_model,
parallelism_config={'fsdp': state.fsdp_config},
precision=state.precision,
load_weights_only=True,
load_strict_model_weights=False,
load_path=original_load_path,
callbacks=load_checkpoint_callbacks,
)

def before_forward(self, state: State, logger: Logger) -> Optional[int]:
# Before every batch we need to do a forwards pass over the reference model
with get_precision_context(state.precision):
with torch.no_grad():
assert self.reference_model is not None
reference_outputs = self.reference_model(state.batch)
state.batch.update({
'ref_logp': reference_outputs['policy_logp'],
'ref_token_policy_logps': reference_outputs['token_policy_logps'],
})

# Add auxiliary model forward pass if available
if self.auxiliary_model is not None:
auxiliary_outputs = self.auxiliary_model(state.batch)
state.batch.update({
'aux_first_num_bins_logits': auxiliary_outputs['first_num_bins_logits'],
})


class PairwiseReferencePolicyCallback(ReferencePolicyCallback):
"""Callback to run reference policy in pairwise offline RL.

Args:
train_config (dict): Training config passed to callback via foundry train.py as
callback is registered under callbacks_with_config registry.
"""

def before_forward(self, state: State, logger: Logger) -> Optional[int]:
# Before every batch we need to do a forwards pass over the reference model
with get_precision_context(state.precision):
Expand Down
130 changes: 129 additions & 1 deletion compose_rl/algorithms/offline/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2024 MosaicML ComposeRL authors
# SPDX-License-Identifier: Apache-2.0

"""Pairwise Offline RL Composer Implementation."""
"""Offline RL Composer Implementation."""

from __future__ import annotations

Expand All @@ -12,9 +12,19 @@
from llmfoundry.models import ComposerHFCausalLM, ComposerMPTCausalLM
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers.modeling_outputs import CausalLMOutputWithPast
from compose_rl.metrics.offline_learning_metrics import (
TestEstimatedRewardLossMetric,
TestImplicitRewardsLossMetric,
TestKLDivergenceLossMetric,
TestSequenceEntropiesLossMetric,
TestTotalLossMetric
)

from compose_rl.algorithms.offline.model_methods import (
RegressionOfflineEnum,
PairwiseOfflineEnum,
offline_forward,
offline_loss,
pairwise_offline_forward,
pairwise_offline_loss,
)
Expand All @@ -24,6 +34,124 @@
log = logging.getLogger(__name__)


class ComposerMPTOfflinePolicyLM(ComposerMPTCausalLM):
"""MPT model wrapper for offline rl model."""

def __init__(
self,
loss_type: str = 'apo',
beta1: float = 0.5,
beta2: float = 0.1,
eta: float = 0.5,
multistep: bool = False,
average_log_prob: bool = False,
temperature: float = 1.0,
**kwargs: Any,
):
self.loss_type = RegressionOfflineEnum(loss_type)
self.beta1 = beta1
self.beta2 = beta2
self.eta = eta
self.multistep = multistep
self.average_log_prob = average_log_prob
self.temperature = temperature

super().__init__(**kwargs)
self.train_metrics = None # DPOLM does not support eval_forward

def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
assert self.tokenizer is not None
return offline_forward(
model=self.model,
batch=batch,
average_log_prob=self.average_log_prob,
policy_model_config=self.config,
)

def eval_forward(
self,
batch: MutableMapping,
outputs: CausalLMOutputWithPast,
) -> None:
raise ValueError('Eval forward is not implemented for ComposerDPOLM.')

def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> dict[str, torch.Tensor]:
return offline_loss(
outputs = outputs,
batch = batch,
loss_type = self.loss_type,
beta1 = self.beta1,
beta2 = self.beta2,
eta = self.eta,
multistep = self.multistep,
)


class ComposerHFOfflinePolicyLM(ComposerHFCausalLM):
"""HF class wrapper for offline rl model."""

def __init__(
self,
loss_type: str = 'apo',
beta1: float = 0.5,
beta2: float = 0.1,
eta: float = 0.5,
multistep: bool = False,
average_log_prob: bool = False,
temperature: float = 1.0,
num_bins: int = 1, # Add num_bins parameter
distributional_value_learning: bool = True,
**kwargs: Any,
):
self.loss_type = RegressionOfflineEnum(loss_type)
self.beta1 = beta1
self.beta2 = beta2
self.eta = eta
self.multistep = multistep
self.average_log_prob = average_log_prob
self.temperature = temperature
self.num_bins = num_bins # Store num_bins
self.distributional_value_learning = distributional_value_learning

super().__init__(**kwargs)
self.train_metrics = None # DPOLM does not support eval_forward
self.val_metrics = {metric.__class__.__name__ : metric for metric in [TestEstimatedRewardLossMetric(), TestImplicitRewardsLossMetric(), TestKLDivergenceLossMetric(), TestSequenceEntropiesLossMetric(), TestTotalLossMetric()]}

def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
assert self.tokenizer is not None
return offline_forward(
model=self.model,
batch=batch,
average_log_prob=self.average_log_prob,
temperature=self.temperature,
num_bins=self.num_bins, # Pass num_bins to offline_forward
)

def eval_forward(
self,
batch: MutableMapping,
outputs: CausalLMOutputWithPast | None = None,
) -> dict[str, torch.Tensor]:
with torch.no_grad():
fwd = self.forward(batch)
loss = self.loss(fwd, batch)

return loss

def loss(self, outputs: CausalLMOutputWithPast,
batch: Mapping) -> dict[str, torch.Tensor]:
return offline_loss(
outputs = outputs,
batch = batch,
loss_type = self.loss_type,
beta1 = self.beta1,
beta2 = self.beta2,
eta = self.eta,
multistep = self.multistep,
distributional_value_learning = self.distributional_value_learning,
)

class ComposerMPTPairwiseOfflinePolicyLM(ComposerMPTCausalLM):
"""MPT model wrapper for DPO model."""

Expand Down
Loading
Loading