Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified examples/figures/radial_tree.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
224 changes: 117 additions & 107 deletions examples/figures/radial_tree.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/figures/vertical_tree.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
154 changes: 77 additions & 77 deletions examples/figures/vertical_tree.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
979 changes: 266 additions & 713 deletions notebooks/00.VerticalTrees.ipynb

Large diffs are not rendered by default.

297 changes: 160 additions & 137 deletions notebooks/01.RadialTrees.ipynb

Large diffs are not rendered by default.

930 changes: 200 additions & 730 deletions src/phylustrator/drawing/base.py

Large diffs are not rendered by default.

1,020 changes: 501 additions & 519 deletions src/phylustrator/drawing/radial.py

Large diffs are not rendered by default.

1,362 changes: 543 additions & 819 deletions src/phylustrator/drawing/vertical.py

Large diffs are not rendered by default.

47 changes: 38 additions & 9 deletions src/phylustrator/utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,53 @@
from __future__ import annotations

import ete3
import math
import random
import string

def generate_id(prefix: str = "id", length: int = 6) -> str:
"""Generates a unique ID for SVG elements like gradients."""
suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=length))
return f"{prefix}_{suffix}"

def add_origin_if_root_has_dist(tree: ete3.Tree, origin_name: str = "Origin") -> ete3.Tree:
"""If `tree.dist` is non-zero, interpret it as a stem and add an explicit origin node.
def to_rgb(color_str: str) -> tuple[int, int, int]:
"""Parses hex, common names, or RGB tuples into a standard RGB tuple."""
color_str = str(color_str).strip().lower()
if color_str.startswith("#"):
h = color_str.lstrip("#")
if len(h) == 3: h = "".join([c*2 for c in h])
return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
common_names = {
"white": (255, 255, 255), "black": (0, 0, 0), "red": (255, 0, 0),
"green": (0, 128, 0), "blue": (0, 0, 255), "orange": (255, 165, 0),
"purple": (128, 0, 128), "yellow": (255, 255, 0), "gray": (128, 128, 128)
}
return common_names.get(color_str, (0, 0, 0))

def to_hex(rgb: tuple[int, int, int]) -> str:
"""Converts an RGB tuple to a hex string."""
return "#{:02x}{:02x}{:02x}".format(*[int(max(0, min(255, x))) for x in rgb])

def lerp_color(low_hex: str, high_hex: str, t: float) -> str:
"""Interpolates between two colors."""
t = max(0.0, min(1.0, t))
c1 = to_rgb(low_hex)
c2 = to_rgb(high_hex)
return to_hex(tuple(c1[i] + (c2[i] - c1[i]) * t for i in range(3)))

This avoids layout shifts when a rooted Newick encodes a stem length as the root's `dist`.
def polar_to_cartesian(degree: float, radius: float, rotation: float = 0) -> tuple[float, float]:
"""Converts polar coordinates (degree, radius) to (x, y)."""
theta = math.radians(degree + rotation)
return radius * math.cos(theta), radius * math.sin(theta)

Returns the (possibly new) root tree.
"""
stem = float(getattr(tree, "dist", 0.0) or 0.0)
def add_origin_if_root_has_dist(tree: ete3.Tree, origin_name: str = "Origin") -> ete3.Tree:
"""Standardizes trees by adding an explicit origin node if the root has a distance."""
stem = float(tree.dist or 0.0)
if stem <= 0.0:
tree.dist = 0.0
return tree

origin = ete3.Tree()
origin.name = origin_name
origin.dist = 0.0

tree.dist = stem
origin.add_child(tree)
return origin
116 changes: 116 additions & 0 deletions tests/tests_drawings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import pytest
import ete3
from phylustrator.drawing import VerticalTreeDrawer, RadialTreeDrawer, TreeStyle

@pytest.fixture
def simple_tree():
# Simple tree: (A:1, B:1);
return ete3.Tree("(A:1, B:1);")

@pytest.fixture
def transfer_data():
return [{"from": "A", "to": "B", "freq": 1.0}]

@pytest.fixture
def trait_data():
return {"A": 1.0, "B": 2.0}

def test_vertical_layout_init(simple_tree):
style = TreeStyle(width=500, height=500)
drawer = VerticalTreeDrawer(simple_tree, style=style)

# Check if layout was calculated
for node in simple_tree.traverse():
assert hasattr(node, "coordinates")
assert len(node.coordinates) == 2

def test_radial_layout_init(simple_tree):
style = TreeStyle(radius=200)
drawer = RadialTreeDrawer(simple_tree, style=style)

# Check if layout was calculated
for node in simple_tree.traverse():
assert hasattr(node, "rad")
assert hasattr(node, "angle")

# Check bounds
root = simple_tree.get_tree_root()
assert root.rad == 0
leaf_a = simple_tree.search_nodes(name="A")[0]
assert leaf_a.rad == 200

def test_pre_flight_check(simple_tree):
drawer = VerticalTreeDrawer(simple_tree)
# Reset flag manually
drawer._layout_calculated = False
# Calling a method that triggers check
drawer.add_title("Test")
assert drawer._layout_calculated is True

@pytest.mark.parametrize("drawer_class", [VerticalTreeDrawer, RadialTreeDrawer])
def test_method_existence(drawer_class, simple_tree):
"""
Comprehensive existence check for all public API methods to prevent
regressions during refactoring.
"""
drawer = drawer_class(simple_tree)

required_methods = [
"draw",
"highlight_clade",
"highlight_branch",
"gradient_branch",
"add_leaf_names",
"add_node_names",
"add_leaf_shapes",
"add_node_shapes",
"add_branch_shapes",
"plot_transfers",
"add_time_axis",
"add_heatmap",
"add_clade_labels",
"plot_continuous_variable",
"plot_categorical_trait",
"add_categorical_legend",
"add_transfer_legend",
"add_color_bar",
"add_leaf_images",
"add_ancestral_images",
"add_title",
"add_scale_bar",
"save_svg",
"save_png"
]

for method in required_methods:
assert hasattr(drawer, method), f"{drawer_class.__name__} is missing required method: {method}"

@pytest.mark.parametrize("drawer_class", [VerticalTreeDrawer, RadialTreeDrawer])
def test_smoke_execution(drawer_class, simple_tree, transfer_data, trait_data):
"""
Executes core methods with dummy data to ensure no internal crashes/SyntaxErrors.
"""
drawer = drawer_class(simple_tree)

# Core Drawing
drawer.draw()

# Overlays
drawer.highlight_clade(simple_tree, color="red")
drawer.add_leaf_names()
drawer.add_leaf_shapes(["A"], r=5)
drawer.add_branch_shapes([{"branch": "A", "where": 0.5, "shape": "circle"}])
drawer.plot_transfers(transfer_data)

# Labels & Legends
drawer.add_clade_labels({"A": "Label"})
drawer.add_categorical_legend({"Trait": "blue"})
drawer.add_color_bar("white", "blue", 0, 1)

# Traits
drawer.plot_categorical_trait(trait_data, value_col="trait")

# Title
drawer.add_title("Smoke Test")

assert drawer.drawing is not None
33 changes: 33 additions & 0 deletions tests/tests_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from phylustrator.utils import to_rgb, to_hex, lerp_color, polar_to_cartesian
import math

def test_to_rgb():
assert to_rgb("#ffffff") == (255, 255, 255)
assert to_rgb("#000") == (0, 0, 0)
assert to_rgb("red") == (255, 0, 0)
assert to_rgb("invalid") == (0, 0, 0)

def test_to_hex():
assert to_hex((255, 255, 255)) == "#ffffff"
assert to_hex((0, 0, 0)) == "#000000"
# Test clipping
assert to_hex((300, -10, 100)) == "#ff0064"

def test_lerp_color():
# Midpoint between black and white
assert lerp_color("#000000", "#ffffff", 0.5) == "#7f7f7f"
# Bound checks
assert lerp_color("#000000", "#ffffff", -1) == "#000000"
assert lerp_color("#000000", "#ffffff", 2) == "#ffffff"

def test_polar_to_cartesian():
# 0 degrees, radius 100 should be (100, 0)
x, y = polar_to_cartesian(0, 100)
assert pytest.approx(x) == 100
assert pytest.approx(y) == 0

# 90 degrees, radius 100 should be (0, 100)
x, y = polar_to_cartesian(90, 100)
assert pytest.approx(x) == 0
assert pytest.approx(y) == 100