Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ are conflicting, and averaging them gives an update direction that is detrimenta
objective. Note that in this picture, the dual cone, represented in green, is the set of vectors
that have a non-negative inner product with both $g_1$ and $g_2$.

![image](docs/source/_static/direction_upgrad_mean.svg)
![image](docs/source/_static/gradients_cone_projections_upgrad_mean.svg)

With Jacobian descent, $g_1$ and $g_2$ are computed individually and carefully aggregated using an
aggregator $\mathcal A$. In this example, the aggregator is the Unconflicting Projection of
Expand Down
1 change: 0 additions & 1 deletion docs/source/_static/direction_upgrad_mean.svg

This file was deleted.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,8 @@ test = [
]

plot = [
"plotly>=5.19.0", # Recent version to avoid problems, could be relaxed
"plotly[kaleido]>=5.19.0", # Recent version to avoid problems, could be relaxed
"dash>=2.16.0", # Recent version to avoid problems, could be relaxed
"kaleido==0.2.1", # Only works with locked version
"matplotlib>=3.10.0", # Recent version to avoid problems, could be relaxed
]
# Dependency group allowing to easily resolve version of the core dependencies to the lower bound.
Expand Down
32 changes: 32 additions & 0 deletions tests/plots/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,35 @@ def angle_to_coord(angle: float, r: float = 1.0) -> tuple[float, float]:
x = r * np.cos(angle)
y = r * np.sin(angle)
return x, y


def compute_2d_non_conflicting_cone(matrix: np.ndarray) -> tuple[float, float]:
"""
Computes the frontier of the non-conflicting cone from a matrix of 2-dimensional rows.
Returns the result as an angle in [0, 2pi[ corresponding to the start of the cone, and an
opening angle, that is <= pi and that can be negative if the cone is empty.

This method currently does not handle the case where the cone is a straight line passing by the
origin (when matrix is for instance [[1, 0],[-1, 0]]).

:param matrix: Any real-valued [m, 2] matrix.
"""

row_angles = [coord_to_angle(*row)[0] for row in matrix]

# Compute the start of the non-conflicting half-space of each individual row.
start_angles = [(angle - np.pi / 2) % (2 * np.pi) for angle in row_angles]

# Combine these non-conflicting half-spaces to obtain the global non-conflicting cone.
cone_start_angle = start_angles[0]
opening = np.pi
for hs_start_angle in start_angles[1:]:
cone_start_angle, opening = combine_bounds(cone_start_angle, opening, hs_start_angle)

return cone_start_angle, opening


def project(vector: torch.Tensor, onto: torch.Tensor) -> torch.Tensor:
onto_normalized = onto / torch.linalg.norm(onto)
projection = vector @ onto_normalized * onto_normalized
return projection
294 changes: 294 additions & 0 deletions tests/plots/static_plotter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import os

import torch
from plotly import graph_objects as go

from plots._utils import (
angle_to_coord,
compute_2d_non_conflicting_cone,
make_cone_scatter,
make_polygon_scatter,
make_right_angle,
make_segment_scatter,
make_vector_scatter,
project,
)
from torchjd.aggregation import (
MGDA,
DualProj,
Mean,
UPGrad,
)

RIGHT_ANGLE_SIZE = 0.07


def main(
*,
gradients=False,
cone=False,
projections=False,
upgrad=False,
mean=False,
dual_proj=False,
mgda=False,
):
angle1 = 2.6
angle2 = 0.3277
norm1 = 0.9
norm2 = 2.8
g1 = torch.tensor(angle_to_coord(angle1, norm1))
g2 = torch.tensor(angle_to_coord(angle2, norm2))
matrix = torch.stack([g1, g2])
g1_proj = g1 - project(g1, onto=g2)
g2_proj = g2 - project(g2, onto=g1)
filename = ""

aggregators = {
"UPGrad": UPGrad(),
"Mean": Mean(),
"DualProj": DualProj(),
"MGDA": MGDA(),
}
results = {name: aggregator(matrix) for name, aggregator in aggregators.items()}

fig = go.Figure()
aggregation_labels = [] # Collect aggregator names to add labels as text elements at the end

if gradients:
filename += "gradients"
for i in range(len(matrix)):
label = r"$\huge{" + f"g_{i + 1}" + r"}$"

gradient_scatter = make_vector_scatter(
matrix[i],
color="rgb(40, 40, 40)",
label=label,
showlegend=False,
dash=False,
textposition="bottom center",
text_size=32,
marker_size=22,
line_width=4,
)
fig.add_trace(gradient_scatter)

if cone:
filename += "_cone"
start_angle, opening = compute_2d_non_conflicting_cone(matrix.numpy())
cone = make_cone_scatter(
start_angle,
opening,
label="Non-conflicting cone",
printable=False,
)
fig.add_trace(cone)

if projections:
filename += "_projections"
g1_proj_segment = make_segment_scatter(g1, g1_proj)
g2_proj_segment = make_segment_scatter(g2, g2_proj)
origin_g1_proj_vector = make_vector_scatter(
g1_proj,
color="rgb(100, 100, 100)",
label=r"$\huge{" + r"\pi_J(g_1)" + r"}$",
line_width=3,
marker_size=16,
textposition="top left",
)
origin_g2_proj_vector = make_vector_scatter(
g2_proj,
color="rgb(100, 100, 100)",
label=r"$\huge{" + r"\pi_J(g_2)" + r"}$",
line_width=3,
marker_size=16,
textposition="top right",
)

g1_proj_right_angle = make_polygon_scatter(
make_right_angle(g1_proj, size=RIGHT_ANGLE_SIZE, positive_para=False),
)
g2_proj_right_angle = make_polygon_scatter(
make_right_angle(
g2_proj,
size=RIGHT_ANGLE_SIZE,
positive_orth=False,
positive_para=False,
),
)

fig.add_trace(g1_proj_segment)
fig.add_trace(g2_proj_segment)

fig.add_trace(g1_proj_right_angle)
fig.add_trace(g2_proj_right_angle)

fig.add_trace(origin_g1_proj_vector)
fig.add_trace(origin_g2_proj_vector)

if upgrad:
filename += "_upgrad"
g1_proj_upgrad_segment = make_segment_scatter(g1_proj, results["UPGrad"])
g2_proj_upgrad_segment = make_segment_scatter(g2_proj, results["UPGrad"])

fig.add_trace(g1_proj_upgrad_segment)
fig.add_trace(g2_proj_upgrad_segment)

name = "UPGrad"
result = results[name]
aggregation_scatter = make_vector_scatter(
result,
color="rgb(0, 0, 215)",
label="", # Label will be added as text element at the end
textposition="top center",
showlegend=False,
dash=False,
text_size=32,
marker_size=22,
line_width=4,
)
fig.add_trace(aggregation_scatter)
aggregation_labels.append(name)

if mean:
filename += "_mean"
g1_g2_segment = make_segment_scatter(g1, g2)

fig.add_trace(g1_g2_segment)

name = "Mean"
result = results[name]
aggregation_scatter = make_vector_scatter(
result,
color="rgb(0, 0, 215)",
label="", # Label will be added as text element at the end
textposition="top center",
showlegend=False,
dash=False,
text_size=32,
marker_size=22,
line_width=4,
)
fig.add_trace(aggregation_scatter)
aggregation_labels.append(name)

if dual_proj:
filename += "_dual_proj"
dual_proj_segment = make_segment_scatter(results["Mean"], results["DualProj"])

dual_proj_right_angle = make_polygon_scatter(
make_right_angle(
results["DualProj"],
size=RIGHT_ANGLE_SIZE,
positive_orth=False,
positive_para=False,
),
)

fig.add_trace(dual_proj_segment)
fig.add_trace(dual_proj_right_angle)

name = "DualProj"
result = results[name]
aggregation_scatter = make_vector_scatter(
result,
color="rgb(0, 0, 215)",
label="", # Label will be added as text element at the end
textposition="top center",
showlegend=False,
dash=False,
text_size=32,
marker_size=22,
line_width=4,
)
fig.add_trace(aggregation_scatter)
aggregation_labels.append(name)

if mgda:
filename += "_mgda"
if not mean: # Otherwise the segment between g1 and g2 is already plotted
g1_g2_segment = make_segment_scatter(g1, g2)
fig.add_trace(g1_g2_segment)

mgda_right_angle = make_polygon_scatter(
make_right_angle(
results["MGDA"],
size=RIGHT_ANGLE_SIZE,
positive_para=False,
positive_orth=False,
),
)
fig.add_trace(mgda_right_angle)

name = "MGDA"
result = results[name]
aggregation_scatter = make_vector_scatter(
result,
color="rgb(0, 0, 215)",
label="", # Label will be added as text element at the end
textposition="top center",
showlegend=False,
dash=False,
text_size=32,
marker_size=22,
line_width=4,
)
fig.add_trace(aggregation_scatter)
aggregation_labels.append(name)

# Add aggregation labels as text elements at the end so they appear on top
for name in aggregation_labels:
result = results[name]
label_text = r"$\huge{\mathcal{A}_{\mathrm{" + name + r"}}(J)}$"
fig.add_annotation(
x=result[0].item(),
y=result[1].item(),
text=label_text,
showarrow=False,
font={"size": 32, "color": "rgb(0, 0, 215)"},
yanchor="bottom",
xanchor="center",
)

fig.update_layout(
hovermode=False,
width=912,
height=528,
plot_bgcolor="white",
showlegend=False,
margin={"l": 0, "r": 0, "t": 0, "b": 0},
)
fig.update_xaxes(
scaleanchor="y",
scaleratio=1,
range=[-0.95, 2.85],
showgrid=False,
zeroline=False,
visible=False,
)
fig.update_yaxes(range=[-0.1, 2.1], showgrid=False, zeroline=False, visible=False)

os.makedirs("images/", exist_ok=True)
fig.write_image(f"images/{filename}.pdf")
# Alternative: use .svg here and then convert to pdf using rsvg-convert. Install
# [rsvg-convert](https://manpages.ubuntu.com/manpages/bionic/man1/rsvg-convert.1.html) and run:
# `rsvg-convert -f pdf -o filename.pdf filename.svg`
# To do that on all files at ones, run:
# ```
# for file in images/*.svg; do rsvg-convert -f pdf -o "${file%.svg}.pdf" "$file"; done
# ```


if __name__ == "__main__":
# Step-by-step construction of UPGrad for the presentation
main(gradients=True)
main(gradients=True, mean=True)
main(gradients=True, mean=True, cone=True)
main(gradients=True, mean=True, cone=True, projections=True)
main(gradients=True, mean=True, cone=True, projections=True, upgrad=True)

# Plot with UPGrad only
main(gradients=True, cone=True, projections=True, upgrad=True)

# Plot with Mean, DualProj and MGDA
main(gradients=True, mean=True, cone=True, dual_proj=True, mgda=True)