diff --git a/cmonge/models/nn.py b/cmonge/models/nn.py index aba771c..60460a3 100644 --- a/cmonge/models/nn.py +++ b/cmonge/models/nn.py @@ -314,6 +314,9 @@ class ConditionalPerturbationNetwork(BasePotential): embed_cond_equal: bool = ( False # Whether all context variables should be treated as set or not ) + attention_pooling: bool = False + num_heads: int = 4 + dropout_rate: float = 0.1 context_entity_bonds: Iterable[Tuple[int, int]] = ( (0, 10), (0, 11), @@ -321,7 +324,11 @@ class ConditionalPerturbationNetwork(BasePotential): @nn.compact def __call__( - self, x: jnp.ndarray, c: jnp.ndarray, num_contexts: int = 2 + self, + x: jnp.ndarray, + c: jnp.ndarray, + num_contexts: int = 2, + deterministic: bool = True, ) -> jnp.ndarray: # noqa: D102 """ Args: @@ -379,8 +386,45 @@ def __call__( ) layer = nn.Dense(dim_cond_map[0], use_bias=True) embeddings = [self.act_fn(layer(context)) for context in contexts] - # Average along stacked dimension (alternatives like summing are possible) - cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0) + + if self.attention_pooling: + stacked_embeddings = jnp.stack(embeddings, axis=1) # (Batch, N, Dim) + + # Input Dropout + stacked_embeddings = nn.Dropout( + rate=self.dropout_rate, deterministic=deterministic + )(stacked_embeddings) + + # Multi-Head Attention Scores + att_layer = nn.Dense( + self.num_heads, use_bias=True, name="AttentionScores" + ) + scores = att_layer(stacked_embeddings) # (Batch, N, Heads) + weights = jax.nn.softmax(scores, axis=1) + + # Attention Weights Dropout + weights = nn.Dropout( + rate=self.dropout_rate, deterministic=deterministic + )(weights) + + # Weighted Pooling: (B, N, D), (B, N, H) -> (B, H, D) + weighted_sum = jnp.einsum("bnd,bnh->bhd", stacked_embeddings, weights) + + # Flatten and Project + cond_embedding = weighted_sum.reshape( + weighted_sum.shape[0], -1 + ) # (B, H*D) + cond_embedding = nn.Dense( + dim_cond_map[0], use_bias=True, name="AttentionOutput" + )(cond_embedding) + + # Output Dropout + cond_embedding = nn.Dropout( + rate=self.dropout_rate, deterministic=deterministic + )(cond_embedding) + else: + # Average along stacked dimension (alternatives like summing are possible) + cond_embedding = jnp.mean(jnp.stack(embeddings), axis=0) z = jnp.concatenate((x, cond_embedding), axis=1) if self.layer_norm: @@ -403,7 +447,12 @@ def create_train_state( """Create initial `TrainState`.""" c = jnp.ones((1, self.dim_cond)) # (n_batch, embed_dim) x = jnp.ones((1, self.dim_data)) # (n_batch, data_dim) - params = self.init(rng, x=x, c=c)["params"] + + # Split rng for dropout keys during init + rng, rng_dropout = jax.random.split(rng) + init_rngs = {"params": rng, "dropout": rng_dropout} + + params = self.init(init_rngs, x=x, c=c)["params"] return PotentialTrainState.create( apply_fn=self.apply, params=params, diff --git a/cmonge/tests/models/__init__.py b/cmonge/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cmonge/tests/models/test_attention_pooling.py b/cmonge/tests/models/test_attention_pooling.py new file mode 100644 index 0000000..b5723b4 --- /dev/null +++ b/cmonge/tests/models/test_attention_pooling.py @@ -0,0 +1,148 @@ +import jax +import jax.numpy as jnp +import pytest + +from cmonge.models.nn import ConditionalPerturbationNetwork + +# (context_bonds, dim_cond, num_contexts) +CONTEXT_BOND_CONFIGS = [ + pytest.param( + ((0, 10), (10, 20)), + 20, + 2, + id="non_overlapping_2_modalities", + ), + pytest.param( + ((0, 10), (0, 10)), + 10, + 2, + id="overlapping_2_modalities", + ), + pytest.param( + ((0, 10), (10, 20), (20, 30)), + 30, + 3, + id="non_overlapping_3_modalities", + ), +] + +DIM_DATA = 16 +DIM_HIDDEN = [32, 32] +DIM_COND_MAP = (8,) +BATCH_SIZE = 4 + + +def _make_model(context_bonds, attention_pooling, dropout_rate=0.1): + return ConditionalPerturbationNetwork( + dim_hidden=DIM_HIDDEN, + dim_data=DIM_DATA, + dim_cond=max(stop for _, stop in context_bonds), + dim_cond_map=DIM_COND_MAP, + embed_cond_equal=True, + attention_pooling=attention_pooling, + num_heads=4, + dropout_rate=dropout_rate, + context_entity_bonds=context_bonds, + ) + + +def _make_inputs(rng, dim_cond): + rng_x, rng_c = jax.random.split(rng) + x = jax.random.normal(rng_x, (BATCH_SIZE, DIM_DATA)) + c = jax.random.normal(rng_c, (BATCH_SIZE, dim_cond)) + return x, c + + +class TestAttentionPooling: + """Tests for attention pooling in ConditionalPerturbationNetwork.""" + + @pytest.mark.parametrize( + "context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS + ) + def test_attention_pooling_forward_pass( + self, context_bonds, dim_cond, num_contexts + ): + """Test that attention pooling produces correct output shape.""" + model = _make_model(context_bonds, attention_pooling=True) + rng = jax.random.PRNGKey(0) + x, c = _make_inputs(rng, dim_cond) + + rng_params, rng_dropout = jax.random.split(rng) + params = model.init({"params": rng_params, "dropout": rng_dropout}, x=x, c=c)[ + "params" + ] + + out = model.apply({"params": params}, x, c, num_contexts) + assert out.shape == (BATCH_SIZE, DIM_DATA) + assert not jnp.allclose(out, 0.0) + + @pytest.mark.parametrize( + "context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS + ) + def test_both_pooling_modes_same_output_shape( + self, context_bonds, dim_cond, num_contexts + ): + """Test that mean and attention pooling produce the same output shape.""" + rng = jax.random.PRNGKey(42) + x, c = _make_inputs(rng, dim_cond) + + model_mean = _make_model(context_bonds, attention_pooling=False) + rng_p1, rng_d1, rng_p2, rng_d2 = jax.random.split(rng, 4) + params_mean = model_mean.init({"params": rng_p1, "dropout": rng_d1}, x=x, c=c)[ + "params" + ] + out_mean = model_mean.apply({"params": params_mean}, x, c, num_contexts) + + model_attn = _make_model(context_bonds, attention_pooling=True) + params_attn = model_attn.init({"params": rng_p2, "dropout": rng_d2}, x=x, c=c)[ + "params" + ] + out_attn = model_attn.apply({"params": params_attn}, x, c, num_contexts) + + assert out_mean.shape == out_attn.shape == (BATCH_SIZE, DIM_DATA) + + @pytest.mark.parametrize( + "context_bonds,dim_cond,num_contexts", CONTEXT_BOND_CONFIGS + ) + def test_dropout_deterministic_vs_stochastic( + self, context_bonds, dim_cond, num_contexts + ): + """Test that deterministic=False produces different outputs across runs + while deterministic=True is consistent.""" + model = _make_model(context_bonds, attention_pooling=True, dropout_rate=0.5) + rng = jax.random.PRNGKey(7) + x, c = _make_inputs(rng, dim_cond) + + rng_params, rng_dropout = jax.random.split(rng) + params = model.init({"params": rng_params, "dropout": rng_dropout}, x=x, c=c)[ + "params" + ] + + # Deterministic mode: two calls should be identical + out_eval_1 = model.apply( + {"params": params}, x, c, num_contexts, deterministic=True + ) + out_eval_2 = model.apply( + {"params": params}, x, c, num_contexts, deterministic=True + ) + assert jnp.allclose(out_eval_1, out_eval_2) + + # Stochastic mode: two calls with different dropout keys should differ + key1, key2 = jax.random.split(jax.random.PRNGKey(99)) + out_train_1 = model.apply( + {"params": params}, + x, + c, + num_contexts, + deterministic=False, + rngs={"dropout": key1}, + ) + out_train_2 = model.apply( + {"params": params}, + x, + c, + num_contexts, + deterministic=False, + rngs={"dropout": key2}, + ) + assert not jnp.allclose(out_train_1, out_train_2) diff --git a/cmonge/trainers/conditional_monge_trainer.py b/cmonge/trainers/conditional_monge_trainer.py index 272158a..c4a3e07 100644 --- a/cmonge/trainers/conditional_monge_trainer.py +++ b/cmonge/trainers/conditional_monge_trainer.py @@ -149,10 +149,13 @@ def train(self, datamodule: ConditionalDataModule): else self.generate_batch(datamodule, "valid") ) + self.key, step_key = jax.random.split(self.key) + self.state_neural_net, grads, current_logs = self.step_fn( self.state_neural_net, grads=grads, train_batch=train_batch, + dropout_key=step_key, valid_batch=valid_batch, is_logging_step=is_logging_step, is_gradient_acc_step=is_gradient_acc_step, @@ -176,14 +179,20 @@ def loss_fn( apply_fn: Callable, batch: Dict[str, jnp.ndarray], n_contexts: int, + dropout_key: Optional[jnp.ndarray] = None, ) -> Tuple[float, Dict[str, float]]: """Loss function.""" # map samples with the fitted map + kwargs = {} + if dropout_key is not None: + kwargs = {"deterministic": False, "rngs": {"dropout": dropout_key}} + mapped_samples = apply_fn( {"params": params}, batch["source"], batch["condition"], n_contexts, + **kwargs, ) # compute the loss @@ -200,11 +209,12 @@ def loss_fn( return val_tot_loss, loss_logs - @functools.partial(jax.jit, static_argnums=[4, 5, 6, 7]) + @functools.partial(jax.jit, static_argnums=[5, 6, 7, 8]) def step_fn( state_neural_net: train_state.TrainState, grads: frozen_dict.FrozenDict, train_batch: Dict[str, jnp.ndarray], + dropout_key: jnp.ndarray, valid_batch: Optional[Dict[str, jnp.ndarray]] = None, is_logging_step: bool = False, is_gradient_acc_step: bool = False, @@ -219,6 +229,7 @@ def step_fn( state_neural_net.apply_fn, train_batch, n_train_contexts, + dropout_key, ) # Accumulate gradients grads = tree_map(lambda g, step_g: g + step_g, grads, step_grads) diff --git a/pyproject.toml b/pyproject.toml index 7855115..0ed439c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cmonge" -version = "0.1.2" +version = "0.1.3" description = "Extension of the Monge Gap to learn conditional optimal transport maps" authors = ["Alice Driessen ", "Benedek Harsanyi ", "Jannis Born "] readme = "README.md"