From 76156918b74ad61ba1fe49bd94945d4753fa1a38 Mon Sep 17 00:00:00 2001 From: Florence Bockting <48919471+florence-bockting@users.noreply.github.com> Date: Thu, 24 Apr 2025 11:11:20 +0200 Subject: [PATCH 1/6] test absolute value and derivative at harmonisation and convergence time --- .../convergence.py | 18 ++--- src/gradient_aware_harmonisation/utils.py | 18 ++--- ...t_harmonise_splines_cosine_weight_decay.py | 74 +++++++++++++++++++ tests/unit/test_get_cosine_decay_spline.py | 4 +- 4 files changed, 94 insertions(+), 20 deletions(-) diff --git a/src/gradient_aware_harmonisation/convergence.py b/src/gradient_aware_harmonisation/convergence.py index 611a50b..3eff7f6 100644 --- a/src/gradient_aware_harmonisation/convergence.py +++ b/src/gradient_aware_harmonisation/convergence.py @@ -272,8 +272,8 @@ def antiderivative(self) -> CosineDecaySplineHelperDerivative: def get_cosine_decay_harmonised_spline( harmonisation_time: Union[int, float], convergence_time: Union[int, float], - harmonised_spline_no_convergence: Spline, - convergence_spline: Spline, + diverge_from: Spline, + harmonisee: Spline, ) -> SumOfSplines: """ Generate the harmonised spline based on a cosine-decay @@ -284,18 +284,18 @@ def get_cosine_decay_harmonised_spline( Harmonisation time This is the time at and before which - the solution should be equal to `harmonised_spline_no_convergence`. + the solution should be equal to `diverge_from`. convergence_time Convergence time This is the time at and after which - the solution should be equal to `convergence_spline`. + the solution should be equal to `harmonisee`. - harmonised_spline_no_convergence + diverge_from Harmonised spline that does not consider convergence - convergence_spline + harmonisee The spline to which the result should converge Returns @@ -308,7 +308,7 @@ def get_cosine_decay_harmonised_spline( # first order derivative). Then we use a decay function to let # the harmonised spline converge to the convergence-spline. # This decay function has the form of a weighted sum: - # weight * harmonised_spline + (1-weight) * convergence_spline + # weight * diverge_from + (1-weight) * harmonisee # With weights decaying from 1 to 0 whereby the decay trajectory # is determined by the cosine decay. return SumOfSplines( @@ -318,7 +318,7 @@ def get_cosine_decay_harmonised_spline( final_time=convergence_time, apply_to_convergence=False, ), - harmonised_spline_no_convergence, + diverge_from, ), ProductOfSplines( CosineDecaySplineHelper( @@ -326,6 +326,6 @@ def get_cosine_decay_harmonised_spline( final_time=convergence_time, apply_to_convergence=True, ), - convergence_spline, + harmonisee, ), ) diff --git a/src/gradient_aware_harmonisation/utils.py b/src/gradient_aware_harmonisation/utils.py index c33eca9..bce6805 100644 --- a/src/gradient_aware_harmonisation/utils.py +++ b/src/gradient_aware_harmonisation/utils.py @@ -33,8 +33,8 @@ def __call__( self, harmonisation_time: Union[int, float], convergence_time: Union[int, float], - harmonised_spline_no_convergence: Spline, - convergence_spline: Spline, + diverge_from: Spline, + harmonisee: Spline, ) -> Spline: """ Generate the harmonised spline @@ -45,18 +45,18 @@ def __call__( Harmonisation time This is the time at and before which - the solution should be equal to `harmonised_spline_no_convergence`. + the solution should be equal to `diverge_from`. convergence_time Convergence time This is the time at and after which - the solution should be equal to `convergence_spline`. + the solution should be equal to `harmonisee`. - harmonised_spline_no_convergence + diverge_from Harmonised spline that does not consider convergence - convergence_spline + harmonisee The spline to which the result should converge Returns @@ -133,15 +133,15 @@ def harmonise_splines( # noqa: PLR0913 harmonised_spline_first_derivative_only(harmonisation_time), ) - harmonised_spline_no_convergence = add_constant_to_spline( + diverge_from = add_constant_to_spline( in_spline=harmonised_spline_first_derivative_only, constant=diff_spline ) harmonised_spline = get_harmonised_spline( harmonisation_time=harmonisation_time, convergence_time=convergence_time, - harmonised_spline_no_convergence=harmonised_spline_no_convergence, - convergence_spline=converge_to, + diverge_from=diverge_from, + harmonisee=converge_to, ) return harmonised_spline diff --git a/tests/integration/test_harmonise_splines_cosine_weight_decay.py b/tests/integration/test_harmonise_splines_cosine_weight_decay.py index e69de29..b61d816 100644 --- a/tests/integration/test_harmonise_splines_cosine_weight_decay.py +++ b/tests/integration/test_harmonise_splines_cosine_weight_decay.py @@ -0,0 +1,74 @@ +import numpy as np +import pytest + +from gradient_aware_harmonisation.convergence import get_cosine_decay_harmonised_spline +from gradient_aware_harmonisation.spline import Spline, SplineScipy + + +def check_expected_continuity( + solution: Spline, + diverge_from: Spline, + harmonisee: Spline, + harmonisation_time: float, + convergence_time: float, +) -> None: + np.testing.assert_allclose( + solution(harmonisation_time), + diverge_from(harmonisation_time), + err_msg=( + "Difference in absolute value of solution and diverge_from " + "at harmonisation_time" + ), + ) + + +@pytest.mark.parametrize( + "harmonisation_time, convergence_time", + ( + pytest.param(0.0, 1.0), + pytest.param(0.0, 1.7), + pytest.param(3.0, 8.0), + pytest.param(-3.0, 0.0), + pytest.param(-3.0, 8.0), + pytest.param(-3.0, -1.0), + pytest.param(3.0, 1.0, id="backwards_harmonisation_positive_times"), + pytest.param( + 3.0, -1.0, id="backwards_harmonisation_positive_and_negative_time" + ), + pytest.param(-30.0, -10.0, id="backwards_harmonisation_negative_times"), + ), +) +def test_harmonisation_convergence_times(harmonisation_time, convergence_time): + """ + Test over a variety of harmonisation and convergence times + """ + scipy = pytest.importorskip("scipy") + + diverge_from = SplineScipy( + scipy.interpolate.PPoly( + c=[[2.75], [1.2]], + x=[-100, 100], + ) + ) + + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[[2.3], [0.5]], + x=[-100, 100], + ) + ) + + res = get_cosine_decay_harmonised_spline( + diverge_from=diverge_from, + convergence_spline=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + check_expected_continuity( + solution=res, + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) diff --git a/tests/unit/test_get_cosine_decay_spline.py b/tests/unit/test_get_cosine_decay_spline.py index 3a70570..089fb20 100644 --- a/tests/unit/test_get_cosine_decay_spline.py +++ b/tests/unit/test_get_cosine_decay_spline.py @@ -54,14 +54,14 @@ def test_get_cosine_decay(harmonisation_time, convergence_time): harmonised_spline_first_derivative_only(harmonisation_time), ) - harmonised_spline_no_convergence = add_constant_to_spline( + diverge_from = add_constant_to_spline( in_spline=harmonised_spline_first_derivative_only, constant=diff_spline ) harmonised_spline_convergence = get_cosine_decay_harmonised_spline( harmonisation_time=harmonisation_time, convergence_time=convergence_time, - harmonised_spline_no_convergence=harmonised_spline_no_convergence, + diverge_from=diverge_from, convergence_spline=harmonisee_spline, ) From e6d8771725aac7d154011eed885516d7a3addfc5 Mon Sep 17 00:00:00 2001 From: Florence Bockting <48919471+florence-bockting@users.noreply.github.com> Date: Thu, 24 Apr 2025 11:17:31 +0200 Subject: [PATCH 2/6] rename arguments of cosine_weight_decay function to diverge_from and harmonisee --- .../test_convergence_integration.py | 29 +++++++------------ ...t_harmonise_splines_cosine_weight_decay.py | 2 +- tests/unit/test_get_cosine_decay_spline.py | 2 +- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/tests/integration/test_convergence_integration.py b/tests/integration/test_convergence_integration.py index 9aabb09..a41024b 100644 --- a/tests/integration/test_convergence_integration.py +++ b/tests/integration/test_convergence_integration.py @@ -184,13 +184,13 @@ def test_get_cosine_decay_harmonised_spline(): x_up_to_harmonisation_time = np.linspace(x_min, harmonisation_time, 50) x_after_convergence_time = np.linspace(convergence_time, x_max, 50) - harmonised_spline_no_convergence = SplineScipy( + diverge_from = SplineScipy( scipy.interpolate.PPoly( x=[x_min, x_max], c=[[1], [0], [0], [0]], # y=x^3 ) ) - convergence_spline = SplineScipy( + harmonisee = SplineScipy( scipy.interpolate.PPoly( x=[x_min, x_max], c=[[-1], [1], [2]], # y=-x^2 + x + 2 @@ -200,17 +200,17 @@ def test_get_cosine_decay_harmonised_spline(): res = get_cosine_decay_harmonised_spline( harmonisation_time=harmonisation_time, convergence_time=convergence_time, - harmonised_spline_no_convergence=harmonised_spline_no_convergence, - convergence_spline=convergence_spline, + diverge_from=diverge_from, + harmonisee=harmonisee, ) np.testing.assert_equal( - harmonised_spline_no_convergence(x_up_to_harmonisation_time), + diverge_from(x_up_to_harmonisation_time), res(x_up_to_harmonisation_time), ) np.testing.assert_equal( - convergence_spline(x_after_convergence_time), + harmonisee(x_after_convergence_time), res(x_after_convergence_time), ) @@ -218,18 +218,11 @@ def test_get_cosine_decay_harmonised_spline(): np.testing.assert_equal( np.array( [ - 0.5 - * (1.0 + np.cos(np.pi * 0.5 / 6.0)) - * harmonised_spline_no_convergence(3.0) - + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 0.5 / 6.0))) - * convergence_spline(3.0), - 0.5 * harmonised_spline_no_convergence(5.5) - + 0.5 * convergence_spline(5.5), - 0.5 - * (1.0 + np.cos(np.pi * 3.5 / 6.0)) - * harmonised_spline_no_convergence(6.0) - + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 3.5 / 6.0))) - * convergence_spline(6.0), + 0.5 * (1.0 + np.cos(np.pi * 0.5 / 6.0)) * diverge_from(3.0) + + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 0.5 / 6.0))) * harmonisee(3.0), + 0.5 * diverge_from(5.5) + 0.5 * harmonisee(5.5), + 0.5 * (1.0 + np.cos(np.pi * 3.5 / 6.0)) * diverge_from(6.0) + + (1.0 - 0.5 * (1.0 + np.cos(np.pi * 3.5 / 6.0))) * harmonisee(6.0), ] ), res(np.array([3.0, 5.5, 6.0])), diff --git a/tests/integration/test_harmonise_splines_cosine_weight_decay.py b/tests/integration/test_harmonise_splines_cosine_weight_decay.py index b61d816..d7a76ce 100644 --- a/tests/integration/test_harmonise_splines_cosine_weight_decay.py +++ b/tests/integration/test_harmonise_splines_cosine_weight_decay.py @@ -60,7 +60,7 @@ def test_harmonisation_convergence_times(harmonisation_time, convergence_time): res = get_cosine_decay_harmonised_spline( diverge_from=diverge_from, - convergence_spline=harmonisee, + harmonisee=harmonisee, harmonisation_time=harmonisation_time, convergence_time=convergence_time, ) diff --git a/tests/unit/test_get_cosine_decay_spline.py b/tests/unit/test_get_cosine_decay_spline.py index 089fb20..1cc3ca1 100644 --- a/tests/unit/test_get_cosine_decay_spline.py +++ b/tests/unit/test_get_cosine_decay_spline.py @@ -62,7 +62,7 @@ def test_get_cosine_decay(harmonisation_time, convergence_time): harmonisation_time=harmonisation_time, convergence_time=convergence_time, diverge_from=diverge_from, - convergence_spline=harmonisee_spline, + harmonisee=harmonisee_spline, ) np.testing.assert_allclose( From 0d75aa7bd0630ee3aa2111e1ed398ac79caaa506 Mon Sep 17 00:00:00 2001 From: Florence Bockting <48919471+florence-bockting@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:17:01 +0200 Subject: [PATCH 3/6] add test for harmonisation time > convergence time --- docs/how-to-guides/cosine_decay.py | 274 ++++++++++++++++++ ...t_harmonise_splines_cosine_weight_decay.py | 44 +++ 2 files changed, 318 insertions(+) create mode 100644 docs/how-to-guides/cosine_decay.py diff --git a/docs/how-to-guides/cosine_decay.py b/docs/how-to-guides/cosine_decay.py new file mode 100644 index 0000000..ef82e20 --- /dev/null +++ b/docs/how-to-guides/cosine_decay.py @@ -0,0 +1,274 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.16.6 +# kernelspec: +# display_name: Python 3 (ipykernel) +# language: python +# name: python3 +# --- + +# %% [markdown] +# # How to use a cubic-spline as harmonisation of two functions? +# In this tutorial, we present use cases for applying a cubic-spline +# to harmonise two functions which we will call in the following +# `diverge_from` and `harmonisee`. +# The `cubic-spline` interpolates between `diverge_from` and `harmonisee`. + + +# %% +# import relevant libraries +from __future__ import annotations + +import matplotlib.pyplot as plt +import numpy as np +import scipy.interpolate + +from gradient_aware_harmonisation.add_cubic import ( + harmonise_splines_add_cubic, +) +from gradient_aware_harmonisation.spline import SplineScipy + +# %% [markdown] + +# We start by defining the spline `diverge_from` as a linear +# function with intercept=1.0 and slope=2.5. + +# %% +diverge_from_gradient = 2.5 +diverge_from_y_intercept = 1.0 + +diverge_from = SplineScipy( + scipy.interpolate.PPoly( + c=[[diverge_from_gradient], [diverge_from_y_intercept]], + x=[0, 1e8], + ) +) + +# %% [markdown] +# ## Scenarios +# ### Harmonisation time < convergence time +# In the following, we consider nine scenarios in which the +# `harmonisee` spline differs from the `diverge_from` spline +# due to varying shifts in the intercept ([0.0, -1.2, 1.2]) +# and slope ([1.0, 0.7, 1.4]). +# In all of these scenarios we consider harmonisation time +# (=0) < convergence time (=3.2). + +# %% +harmonisation_time = 0.0 +convergence_time = 3.2 + + +# %% +def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 + ax.plot( + x, + spline(x), + label=label, + ) + + if gradient: + ax.set_title("Gradient") + else: + ax.set_title("Function") + + +# %% +i = 0 +for y_intercept_shift in [0.0, -1.2, 1.2]: + for gradient_factor in [1.0, 0.7, 1.4]: + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[ + [diverge_from_gradient * gradient_factor], + [diverge_from_y_intercept + y_intercept_shift], + ], + x=[0, 1e8], + ) + ) + + res = harmonise_splines_add_cubic( + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + fig, axes = plt.subplots(ncols=2, figsize=(12, 4)) + + plot_spline( + diverge_from, np.linspace(-1.0, 3.0, 101), ax=axes[0], label="diverge_from" + ) + plot_spline( + harmonisee, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="harmonisee", + ) + plot_spline( + res, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="res", + ) + + plot_spline( + diverge_from.derivative(), + np.linspace(-1.0, 3.0, 101), + ax=axes[1], + label="diverge_from", + gradient=True, + ) + plot_spline( + harmonisee.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="harmonisee_gradien", + gradient=True, + ) + plot_spline( + res.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="cubic-spline", + gradient=True, + ) + + for ax in axes: + ax.axvline( + harmonisation_time, + label="harmonisation_time", + color="gray", + linestyle=":", + ) + ax.axvline( + convergence_time, label="convergence_time", color="gray", linestyle="--" + ) + for ax in axes[1::2]: + ax.legend(handlelength=1.1, loc="center right", fontsize="small") + + fig.suptitle( + f"Scenario {i+1} (intercept shift: {y_intercept_shift}," + + f" slope factor: {gradient_factor})" + ) + plt.show() + i = i + 1 + +# %% +diverge_from_gradient = 2.5 +diverge_from_y_intercept = 1.0 + +# TODO: from left-edge or something here +diverge_from = SplineScipy( + scipy.interpolate.PPoly( + c=[ + [diverge_from_gradient], + [diverge_from_y_intercept - 10.0 * diverge_from_gradient], + ], + x=[-10.0, 10.0], + ) +) + +# %% [markdown] +# ### Harmonisation time > convergence time +# In the following, we consider the same nine scenarios as +# above in which the `harmonisee` spline differs +# from the `diverge_from` spline due to varying shifts in the +# intercept ([0.0, -1.2, 1.2]) and slope ([1.0, 0.7, 1.4]). +# However, this time we consider in all upcoming scenarios +# harmonisation time (=1.0) > convergence time (=-1.0). + +# %% +harmonisation_time = 1.0 +convergence_time = -1.0 + +# %% +# Backwards along x harmonisation +i = 0 +for y_intercept_shift in [0.0, -1.2, 1.2]: + for gradient_factor in [1.0, 0.7, 1.4]: + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[ + [diverge_from_gradient * gradient_factor], + [ + diverge_from_y_intercept + - 10.0 * diverge_from_gradient + + y_intercept_shift + ], + ], + x=[-10.0, 10.0], + ) + ) + + res = harmonise_splines_add_cubic( + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + fig, axes = plt.subplots(ncols=2, figsize=(12, 4)) + + plot_spline( + diverge_from, np.linspace(-1.0, 3.0, 101), ax=axes[0], label="diverge_from" + ) + plot_spline( + harmonisee, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="harmonisee", + ) + plot_spline( + res, + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[0], + label="res", + ) + + plot_spline( + diverge_from.derivative(), + np.linspace(-1.0, 3.0, 101), + ax=axes[1], + label="diverge_from", + gradient=True, + ) + plot_spline( + harmonisee.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="harmonisee", + gradient=True, + ) + plot_spline( + res.derivative(), + np.linspace(harmonisation_time, 2 * convergence_time, 101), + ax=axes[1], + label="cubic-spline", + gradient=True, + ) + + for ax in axes: + ax.axvline( + harmonisation_time, + label="harmonisation_time", + color="gray", + linestyle=":", + ) + ax.axvline( + convergence_time, label="convergence_time", color="gray", linestyle="--" + ) + for ax in axes[1::2]: + ax.legend(handlelength=1.1, loc="center right", fontsize="small") + + fig.suptitle( + f"Scenario {i+1} (intercept shift: {y_intercept_shift}," + + f" slope factor: {gradient_factor})" + ) + plt.show() + i = i + 1 diff --git a/tests/integration/test_harmonise_splines_cosine_weight_decay.py b/tests/integration/test_harmonise_splines_cosine_weight_decay.py index d7a76ce..1e5760a 100644 --- a/tests/integration/test_harmonise_splines_cosine_weight_decay.py +++ b/tests/integration/test_harmonise_splines_cosine_weight_decay.py @@ -72,3 +72,47 @@ def test_harmonisation_convergence_times(harmonisation_time, convergence_time): harmonisation_time=harmonisation_time, convergence_time=convergence_time, ) + + +def test_harmonisation_time_greater_than_convergence_time(): + scipy = pytest.importorskip("scipy") + + harmonisation_time = 1.0 + convergence_time = -1.0 + + # y = x + # TODO: from left-edge or something here + diverge_from = SplineScipy( + scipy.interpolate.PPoly( + # These are the constants you need given how PPoly is defined + # (it's basically y = f(x - x_le), + # where x_le is the left-edge of the boundary) + c=[[1.0], [-10.0]], + x=[-10.0, 10.0], + ) + ) + assert diverge_from(harmonisation_time) == 1.0 + + # y = 0.5x - 1 + harmonisee = SplineScipy( + scipy.interpolate.PPoly( + c=[[0.5], [-6.0]], + x=[-10.0, 10.0], + ) + ) + assert harmonisee(convergence_time) == -1.5 + + res = get_cosine_decay_harmonised_spline( + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) + + check_expected_continuity( + solution=res, + diverge_from=diverge_from, + harmonisee=harmonisee, + harmonisation_time=harmonisation_time, + convergence_time=convergence_time, + ) From 248577edf4d96a2eb2d2a8467637483ff5d50a8d Mon Sep 17 00:00:00 2001 From: Florence Bockting <48919471+florence-bockting@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:18:18 +0200 Subject: [PATCH 4/6] adapt Spline-weight-decay for dealing with harmonisation time > convergence time --- .../convergence.py | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/src/gradient_aware_harmonisation/convergence.py b/src/gradient_aware_harmonisation/convergence.py index 3eff7f6..37dc819 100644 --- a/src/gradient_aware_harmonisation/convergence.py +++ b/src/gradient_aware_harmonisation/convergence.py @@ -84,7 +84,9 @@ def calc_gamma( """Get cosine-decay derivative""" # compute weight (here: gamma) according to a cosine-decay angle = ( - np.pi * (x - self.initial_time) / (self.final_time - self.initial_time) + np.pi + * (x - self.initial_time) + / abs(self.final_time - self.initial_time) ) gamma_decaying = 0.5 * (1 + np.cos(angle)) @@ -92,9 +94,16 @@ def calc_gamma( return gamma_decaying if not isinstance(x, np.ndarray): - if x <= self.initial_time: + if self.initial_time <= self.final_time: + if x <= self.initial_time: + gamma: float | NP_FLOAT_OR_INT | NP_ARRAY_OF_FLOAT_OR_INT = 1.0 + elif x >= self.final_time: + gamma = 0.0 + else: + gamma = calc_gamma(x) + elif x >= self.initial_time: gamma: float | NP_FLOAT_OR_INT | NP_ARRAY_OF_FLOAT_OR_INT = 1.0 - elif x >= self.final_time: + elif x <= self.final_time: gamma = 0.0 else: gamma = calc_gamma(x) @@ -109,10 +118,14 @@ def calc_gamma( return gamma - # apply decay function only to values that lie between harmonisation - # time and convergence-time - x_gte_final_time = np.where(x >= self.final_time) - x_decay = np.logical_and(x >= self.initial_time, x < self.final_time) + # apply decay function only to values that lie between + # harmonisation time and convergence-time + if self.initial_time <= self.final_time: + x_gte_final_time = np.where(x >= self.final_time) + x_decay = np.logical_and(x >= self.initial_time, x < self.final_time) + else: + x_gte_final_time = np.where(x <= self.final_time) + x_decay = np.logical_and(x <= self.initial_time, x > self.final_time) gamma = np.ones_like(x, dtype=np.floating) gamma[x_gte_final_time] = 0.0 gamma[x_decay] = calc_gamma(x[x_decay]) @@ -210,14 +223,19 @@ def calc_gamma_rising_derivative( """Get cosine-decay derivative""" # compute derivative of gamma according to a cosine-decay angle = ( - np.pi * (x - self.initial_time) / (self.final_time - self.initial_time) + np.pi + * (x - self.initial_time) + / abs(self.final_time - self.initial_time) ) gamma_decaying_derivative = -0.5 * np.sin(angle) return gamma_decaying_derivative if not isinstance(x, np.ndarray): - if x <= self.initial_time or x >= self.final_time: + if self.initial_time <= self.final_time: + if x <= self.initial_time or x >= self.final_time: + return 0.0 + elif x >= self.initial_time or x <= self.final_time: return 0.0 gamma_rising_derivative = calc_gamma_rising_derivative(x) @@ -234,7 +252,15 @@ def calc_gamma_rising_derivative( # apply decay function only to values that lie between harmonisation # time and convergence-time - x_decay = np.where(np.logical_and(x > self.initial_time, x < self.final_time)) + if self.initial_time <= self.final_time: + x_decay = np.where( + np.logical_and(x > self.initial_time, x < self.final_time) + ) + else: + x_decay = np.where( + np.logical_and(x < self.initial_time, x > self.final_time) + ) + gamma_rising_derivative = np.zeros_like(x, dtype=np.floating) gamma_rising_derivative[x_decay] = calc_gamma_rising_derivative(x[x_decay]) @@ -311,6 +337,7 @@ def get_cosine_decay_harmonised_spline( # weight * diverge_from + (1-weight) * harmonisee # With weights decaying from 1 to 0 whereby the decay trajectory # is determined by the cosine decay. + return SumOfSplines( ProductOfSplines( CosineDecaySplineHelper( From 7e60360159ac66a57d61c4c4c9deaf4b168a4d95 Mon Sep 17 00:00:00 2001 From: Florence Bockting <48919471+florence-bockting@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:18:53 +0200 Subject: [PATCH 5/6] add how-to guide for cosine weight-decay --- docs/NAVIGATION.md | 1 + docs/how-to-guides/cosine_decay.py | 38 ++++++++++++++---------------- docs/how-to-guides/index.md | 1 + 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/docs/NAVIGATION.md b/docs/NAVIGATION.md index b263140..ff07bc7 100644 --- a/docs/NAVIGATION.md +++ b/docs/NAVIGATION.md @@ -9,6 +9,7 @@ See https://oprypin.github.io/mkdocs-literate-nav/ - [How-to guides](how-to-guides/index.md) - [Do a basic calculation](how-to-guides/basic-calculation.md) - [Use a cubic spline for harmonisation](how-to-guides/cubic_spline.py) + - [Use a cosine-weight decay for harmonisation](how-to-guides/cosine_decay.py) - [Tutorials](tutorials/index.md) - [Getting Started](tutorials/tutorial.py) - [Further background](further-background/index.md) diff --git a/docs/how-to-guides/cosine_decay.py b/docs/how-to-guides/cosine_decay.py index ef82e20..979768d 100644 --- a/docs/how-to-guides/cosine_decay.py +++ b/docs/how-to-guides/cosine_decay.py @@ -13,11 +13,11 @@ # --- # %% [markdown] -# # How to use a cubic-spline as harmonisation of two functions? -# In this tutorial, we present use cases for applying a cubic-spline +# # How to use cosine weight-decay as harmonisation of two functions? +# In this tutorial, we present use cases for applying cosine-weight decay # to harmonise two functions which we will call in the following # `diverge_from` and `harmonisee`. -# The `cubic-spline` interpolates between `diverge_from` and `harmonisee`. +# The `cosine-weight-decay` interpolates between `diverge_from` and `harmonisee`. # %% @@ -28,9 +28,7 @@ import numpy as np import scipy.interpolate -from gradient_aware_harmonisation.add_cubic import ( - harmonise_splines_add_cubic, -) +from gradient_aware_harmonisation.convergence import get_cosine_decay_harmonised_spline from gradient_aware_harmonisation.spline import SplineScipy # %% [markdown] @@ -92,7 +90,7 @@ def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 ) ) - res = harmonise_splines_add_cubic( + res = get_cosine_decay_harmonised_spline( diverge_from=diverge_from, harmonisee=harmonisee, harmonisation_time=harmonisation_time, @@ -128,14 +126,14 @@ def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 harmonisee.derivative(), np.linspace(harmonisation_time, 2 * convergence_time, 101), ax=axes[1], - label="harmonisee_gradien", + label="harmonisee", gradient=True, ) plot_spline( res.derivative(), np.linspace(harmonisation_time, 2 * convergence_time, 101), ax=axes[1], - label="cubic-spline", + label="cosine_weight_decay", gradient=True, ) @@ -159,6 +157,15 @@ def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 plt.show() i = i + 1 +# %% [markdown] +# ### Harmonisation time > convergence time +# In the following, we consider the same nine scenarios as +# above in which the `harmonisee` spline differs +# from the `diverge_from` spline due to varying shifts in the +# intercept ([0.0, -1.2, 1.2]) and slope ([1.0, 0.7, 1.4]). +# However, this time we consider in all upcoming scenarios +# harmonisation time (=1.0) > convergence time (=-1.0). + # %% diverge_from_gradient = 2.5 diverge_from_y_intercept = 1.0 @@ -174,15 +181,6 @@ def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 ) ) -# %% [markdown] -# ### Harmonisation time > convergence time -# In the following, we consider the same nine scenarios as -# above in which the `harmonisee` spline differs -# from the `diverge_from` spline due to varying shifts in the -# intercept ([0.0, -1.2, 1.2]) and slope ([1.0, 0.7, 1.4]). -# However, this time we consider in all upcoming scenarios -# harmonisation time (=1.0) > convergence time (=-1.0). - # %% harmonisation_time = 1.0 convergence_time = -1.0 @@ -206,7 +204,7 @@ def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 ) ) - res = harmonise_splines_add_cubic( + res = get_cosine_decay_harmonised_spline( diverge_from=diverge_from, harmonisee=harmonisee, harmonisation_time=harmonisation_time, @@ -249,7 +247,7 @@ def plot_spline(spline, x, ax, label, gradient=False): # noqa: D103 res.derivative(), np.linspace(harmonisation_time, 2 * convergence_time, 101), ax=axes[1], - label="cubic-spline", + label="cosine_weight_decay", gradient=True, ) diff --git a/docs/how-to-guides/index.md b/docs/how-to-guides/index.md index b0e9469..e76823c 100644 --- a/docs/how-to-guides/index.md +++ b/docs/how-to-guides/index.md @@ -5,3 +5,4 @@ focuses on a **problem-oriented** approach. We'll go over how to solve common tasks. + [How can I use a cubic spline for harmonisation?](cubic_spline) ++ [How can I use a cosine-weight-decay for harmonisation?](cosine_decay) From 8c3971e7b99923bda384bcc8195879cb04844556 Mon Sep 17 00:00:00 2001 From: Florence Bockting <48919471+florence-bockting@users.noreply.github.com> Date: Thu, 24 Apr 2025 13:26:19 +0200 Subject: [PATCH 6/6] add changelogs --- changelog/21.docs.md | 2 ++ changelog/21.improvement.md | 3 +++ 2 files changed, 5 insertions(+) create mode 100644 changelog/21.docs.md create mode 100644 changelog/21.improvement.md diff --git a/changelog/21.docs.md b/changelog/21.docs.md new file mode 100644 index 0000000..d23de37 --- /dev/null +++ b/changelog/21.docs.md @@ -0,0 +1,2 @@ ++ add how-to-guide for cosine weight-decay in documentation ++ guide presents scenarios for both cases: `harmonisation-time < convergence-time` and vice versa. diff --git a/changelog/21.improvement.md b/changelog/21.improvement.md new file mode 100644 index 0000000..a81b9f8 --- /dev/null +++ b/changelog/21.improvement.md @@ -0,0 +1,3 @@ ++ add tests for cosine weight-decay that explicitly checks the function value and gradient value at the harmonisation time and convergence time ++ add test for cosine weight-decay that checks the case `harmonisation time > convergence time` ++ adapt CosineDecaySplineHelper to support the case `harmonisation-time > convergence-time`