diff --git a/examples/figures/radial_tree.png b/examples/figures/radial_tree.png
index 5063932..de62964 100644
Binary files a/examples/figures/radial_tree.png and b/examples/figures/radial_tree.png differ
diff --git a/examples/figures/radial_tree.svg b/examples/figures/radial_tree.svg
index b293bd4..3c26233 100644
--- a/examples/figures/radial_tree.svg
+++ b/examples/figures/radial_tree.svg
@@ -2,164 +2,153 @@
\ No newline at end of file
diff --git a/examples/figures/vertical_tree.png b/examples/figures/vertical_tree.png
index 488e0e9..3ff0d1a 100644
Binary files a/examples/figures/vertical_tree.png and b/examples/figures/vertical_tree.png differ
diff --git a/examples/figures/vertical_tree.svg b/examples/figures/vertical_tree.svg
index 240ed13..218e23d 100644
--- a/examples/figures/vertical_tree.svg
+++ b/examples/figures/vertical_tree.svg
@@ -2,11 +2,11 @@
\ No newline at end of file
diff --git a/notebooks/00.VerticalTrees.ipynb b/notebooks/00.VerticalTrees.ipynb
index c8029b0..95f02f8 100644
--- a/notebooks/00.VerticalTrees.ipynb
+++ b/notebooks/00.VerticalTrees.ipynb
@@ -27,127 +27,127 @@
"\n",
"\n",
"\n",
- "A\n",
- "B\n",
- "C\n",
- "D\n",
- "E\n",
- "F\n",
- "G\n",
- "H\n",
- "I\n",
- "J\n",
- "K\n",
- "L\n",
- "M\n",
- "N\n",
- "O\n",
- "P\n",
- "Q\n",
- "R\n",
- "S\n",
- "T\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
- "\n",
+ "A\n",
+ "B\n",
+ "C\n",
+ "D\n",
+ "E\n",
+ "F\n",
+ "G\n",
+ "H\n",
+ "I\n",
+ "J\n",
+ "K\n",
+ "L\n",
+ "M\n",
+ "N\n",
+ "O\n",
+ "P\n",
+ "Q\n",
+ "R\n",
+ "S\n",
+ "T\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
""
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 2,
@@ -162,9 +162,6 @@
"my_style = ph.TreeStyle(\n",
" width=600,\n",
" height=600,\n",
- " leaf_size=0,\n",
- " node_size=0,\n",
- " branch_size=2,\n",
" branch_color=\"black\",\n",
" font_size=12,\n",
" font_family=\"Arial\",\n",
@@ -178,7 +175,7 @@
},
{
"cell_type": "markdown",
- "id": "cc8bffea-6d35-426a-9efd-47a984fce9f2",
+ "id": "f0beb8e8-ee65-43fa-a1a8-b650b8862437",
"metadata": {},
"source": [
"# Decorating a tree"
@@ -187,7 +184,7 @@
{
"cell_type": "code",
"execution_count": 3,
- "id": "51a5a7c7-6040-4a56-897a-d86d821dfadb",
+ "id": "7fe4bc28-04cc-4b69-b58f-b4127bc1292d",
"metadata": {},
"outputs": [
{
@@ -197,209 +194,189 @@
""
],
"text/plain": [
- ""
+ ""
]
},
"execution_count": 3,
@@ -411,13 +388,12 @@
"with open(\"../examples/data/basic/tree.nwk\") as f:\n",
" t = ete3.Tree(f.readline(), format=1)\n",
" \n",
- " \n",
"my_style = ph.TreeStyle(\n",
" width=600,\n",
" height=600,\n",
- " leaf_size=0,\n",
- " node_size=0,\n",
- " branch_size=2,\n",
+ " leaf_r=0, \n",
+ " node_r=0, \n",
+ " branch_stroke_width=2, \n",
" branch_color=\"black\",\n",
" font_size=12,\n",
" font_family=\"Arial\",\n",
@@ -435,21 +411,21 @@
"v.add_leaf_names()\n",
"\n",
"# Adding shapes\n",
- "\n",
- "v.add_leaf_shapes(leaves=[\"A\", \"B\", \"C\", \"D\"],\n",
+ "v.add_leaf_shapes(\n",
+ " leaves=[\"A\", \"B\", \"C\", \"D\"],\n",
" shape=\"triangle\",\n",
" fill=\"blue\",\n",
- " size=10,\n",
+ " r=5, # Changed size=10 -> r=5\n",
" stroke=\"black\",\n",
" stroke_width=1,\n",
- " offset=35, # distance from the leaf tip\n",
+ " offset=35, \n",
")\n",
"\n",
"v.add_leaf_shapes(\n",
" leaves=[\"J\", \"M\"],\n",
" shape=\"square\",\n",
" fill=\"orange\",\n",
- " size=8,\n",
+ " r=4, # Changed size=8 -> r=4\n",
" stroke=\"black\",\n",
" stroke_width=1,\n",
" offset=35,\n",
@@ -457,9 +433,9 @@
")\n",
"\n",
"events = [\n",
- " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"size\": 14},\n",
- " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\",\"size\": 14},\n",
- " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"size\": 14},\n",
+ " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"r\": 7}, # size=14 -> r=7\n",
+ " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\", \"r\": 7}, # size=14 -> r=7\n",
+ " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"r\": 7}, # size=14 -> r=7\n",
"]\n",
"\n",
"v.add_branch_shapes(events, orient=None, offset=0)\n",
@@ -467,20 +443,20 @@
"target = t.get_common_ancestor(\"A\", \"D\") \n",
"\n",
"events = [\n",
- " {\"branch\": target, \"where\": 0.7, \"shape\": \"circle\", \"fill\": \"purple\", \"size\": 14},\n",
+ " {\"branch\": target, \"where\": 0.7, \"shape\": \"circle\", \"fill\": \"purple\", \"r\": 7}, # size=14 -> r=7\n",
"]\n",
"\n",
"v.add_branch_shapes(events, orient=None, offset=0)\n",
"\n",
- "target = t.get_common_ancestor(\"E\", \"G\") # To select inner nodes\n",
+ "target = t.get_common_ancestor(\"E\", \"G\") \n",
"v.highlight_clade(target, color=\"orange\", opacity=0.4)\n",
"\n",
"target = t.get_common_ancestor(\"P\", \"Q\") \n",
- "v.highlight_branch(target, color=\"blue\", size=5)\n",
+ "v.highlight_branch(target, color=\"blue\", stroke_width=5) # size -> stroke_width\n",
"\n",
"\n",
"target = t.get_common_ancestor(\"H\", \"J\") \n",
- "v.gradient_branch(target, colors=(\"purple\", \"red\"), size=6)\n",
+ "v.gradient_branch(target, colors=(\"purple\", \"red\"), stroke_width=6) # size -> stroke_width\n",
"\n",
"transfer_data = [\n",
" {\"from\": \"E\", \"to\": \"A\", \"freq\": 1.0},\n",
@@ -488,82 +464,75 @@
"\n",
"v.plot_transfers(\n",
" transfer_data,\n",
- " curve_type=\"C\", \n",
- " stroke_width=3,\n",
+ " curve_type=\"C\", \n",
+ " stroke_width=3, # already used stroke_width\n",
" opacity=0.6,\n",
" gradient_colors=(\"purple\", \"orange\") \n",
")\n",
"\n",
- "\n",
"v.add_time_axis(\n",
" ticks=[0, 0.5, 1.0, 1.5, 2.0, 2.5], \n",
" label=\"Time\", \n",
" y_offset=20 \n",
")\n",
"\n",
- "\n",
- "# --- Column 1: \"Expression\" (Blue) ---\n",
- "# Create fake data: {LeafName: Value}\n",
+ "# Heatmaps use width and border_width (standard)\n",
"data_col1 = {leaf.name: random.uniform(0, 1) for leaf in t.get_leaves()}\n",
- "\n",
"v.add_heatmap(\n",
" data_col1,\n",
" width=15,\n",
- " offset=50, # Distance from tree tips (make room for labels)\n",
+ " offset=50,\n",
" low_color=\"white\",\n",
" high_color=\"blue\",\n",
- " border_color=\"black\", # optional border\n",
+ " border_color=\"black\",\n",
" border_width=0.5\n",
")\n",
"\n",
- "# --- Column 2: \"Enrichment\" (Red) ---\n",
"data_col2 = {leaf.name: random.uniform(0, 100) for leaf in t.get_leaves()}\n",
- "\n",
- "# Calculate new offset: Previous Offset (50) + Previous Width (25) + Gap (5)\n",
"v.add_heatmap(\n",
" data_col2,\n",
" width=15,\n",
- " offset=70, # Placed to the right of the first column\n",
- " low_color=\"#fff5f0\", # light reddish tint\n",
- " high_color=\"#67000d\", # dark red\n",
+ " offset=70,\n",
+ " low_color=\"#fff5f0\",\n",
+ " high_color=\"#67000d\",\n",
" border_color=\"black\",\n",
" border_width=0.5\n",
")\n",
"\n",
"target1 = t.get_common_ancestor(\"A\", \"D\") \n",
"target2 = t.get_common_ancestor(\"H\", \"R\") \n",
- "v.add_node_shapes([target1, target2], shape=\"circle\", fill=\"red\", size=10, stroke=\"white\", stroke_width=1)\n",
+ "v.add_node_shapes([target1, target2], shape=\"circle\", fill=\"red\", r=5, stroke=\"white\", stroke_width=1) # size=10 -> r=5\n",
"\n",
- "v.d\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "0ee4c452-b76b-4422-8138-68e604827c9a",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "target1.name = ''\n",
- "target2.name = ''\n"
- ]
- }
- ],
- "source": [
- "target1 = t.get_common_ancestor(\"E\", \"T\")\n",
- "target2 = t.get_common_ancestor(\"H\", \"J\")\n",
+ "# Explaining the branch/node colors\n",
+ "v.add_categorical_legend(\n",
+ " palette={\"Ancestral\": \"blue\", \"Target Clade\": \"green\"}, \n",
+ " title=\"Lineage Status\",\n",
+ " x=-280, y=-280 # Top-left area\n",
+ ")\n",
"\n",
- "print(\"target1.name =\", repr(target1.name))\n",
- "print(\"target2.name =\", repr(target2.name))\n"
+ "# Explaining the purple-to-orange transfers\n",
+ "v.add_transfer_legend(\n",
+ " colors=(\"purple\", \"orange\"),\n",
+ " x=-280, y=-180 # Positioned below the categorical legend\n",
+ ")\n",
+ "\n",
+ "# Color bar for the first heatmap (Expression)\n",
+ "v.add_color_bar(\n",
+ " low_color=\"white\", \n",
+ " high_color=\"blue\", \n",
+ " vmin=0, vmax=1, \n",
+ " title=\"Expression\",\n",
+ " x=200, # Bottom-left area\n",
+ ")\n",
+ "\n",
+ "\n",
+ "\n",
+ "v.d"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 13,
"id": "733505c7-a793-4694-bbf6-878117be5256",
"metadata": {},
"outputs": [],
@@ -574,422 +543,6 @@
"# PNG requires cairosvg (install: pip install \"phylustrator[export]\")\n",
"v.save_png(\"../examples/figures/vertical_tree.png\", scale=3.0)\n"
]
- },
- {
- "cell_type": "markdown",
- "id": "48cf7f0a-0814-44a6-a759-c23cbe8eb2cb",
- "metadata": {},
- "source": [
- "# Adding tree with transitions"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "9c4d584b-59f1-423c-99c3-2716555342b2",
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "2b769b7e-29ae-4e57-8535-b24b23e937ba",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "image/svg+xml": [
- "\n",
- ""
- ],
- "text/plain": [
- ""
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "with open(\"../examples/data/basic/tree2.nwk\") as f:\n",
- " t = ete3.Tree(f.readline(), format=1)\n",
- " \n",
- "my_style = ph.TreeStyle(\n",
- " width=800,\n",
- " height=800,\n",
- " leaf_size=0,\n",
- " node_size=0,\n",
- " branch_size=4,\n",
- " branch_color=\"black\",\n",
- ")\n",
- "\n",
- "data = pd.read_csv(\"../examples/data/basic/traits.tsv\", sep=\"\\t\")\n",
- "\n",
- "v = ph.VerticalTreeDrawer(t, style=my_style)\n",
- "\n",
- "v.plot_categorical_trait(\n",
- " data=data, \n",
- " value_col=\"X\", \n",
- " node_col=\"Node\", \n",
- " size=4 \n",
- ")\n",
- "\n",
- "v.add_leaf_shapes(\n",
- " leaves=[\"J\", \"M\"],\n",
- " shape=\"square\",\n",
- " fill=\"orange\",\n",
- " size=8,\n",
- " stroke=\"black\",\n",
- " stroke_width=1,\n",
- " offset=35,\n",
- " rotation=45,\n",
- ")\n",
- "\n",
- "v.d"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "1b23fd17-041c-495f-be21-6a0b0b1619ce",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/notebooks/01.RadialTrees.ipynb b/notebooks/01.RadialTrees.ipynb
index d868e29..834f2d0 100644
--- a/notebooks/01.RadialTrees.ipynb
+++ b/notebooks/01.RadialTrees.ipynb
@@ -1,9 +1,17 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "id": "b827bf3b-e23d-42e5-adcb-50b541b835a0",
+ "metadata": {},
+ "source": [
+ "# Radial Tree"
+ ]
+ },
{
"cell_type": "code",
- "execution_count": 12,
- "id": "fb785990-6ccc-47ff-8648-5b260c23b90b",
+ "execution_count": 1,
+ "id": "9876fdef-6062-43e1-b509-71ea8c07f82f",
"metadata": {},
"outputs": [
{
@@ -13,164 +21,153 @@
""
],
"text/plain": [
- ""
+ ""
]
},
- "execution_count": 12,
+ "execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "# Cell 1: Imports\n",
"import phylustrator as ph\n",
"import ete3\n",
- "\n",
+ "import random\n",
"\n",
"with open(\"../examples/data/basic/tree.nwk\") as f:\n",
" t = ete3.Tree(f.readline())\n",
@@ -208,9 +225,9 @@
"my_style = ph.TreeStyle(\n",
" radius=250,\n",
" degrees=360,\n",
- " leaf_size=0,\n",
- " node_size=0,\n",
- " branch_size=2,\n",
+ " leaf_r=0, # Changed leaf_size -> leaf_r\n",
+ " node_r=0, # Changed node_size -> node_r\n",
+ " branch_stroke_width=2, # Changed branch_size -> branch_stroke_width\n",
" branch_color=\"black\",\n",
" font_size=12,\n",
" font_family=\"Arial\",\n",
@@ -228,51 +245,48 @@
"r.draw(branch2color=node_colors)\n",
"r.add_leaf_names()\n",
"\n",
- "\n",
+ "# Adding shapes\n",
"r.add_leaf_shapes(\n",
" leaves=[\"A\", \"B\", \"C\", \"D\"],\n",
" shape=\"triangle\",\n",
" fill=\"blue\",\n",
- " size=10,\n",
+ " r=5, # size=10 -> r=5 (radius)\n",
" stroke=\"black\",\n",
" stroke_width=1,\n",
" offset=35, \n",
- " orient=True\n",
")\n",
"\n",
"r.add_leaf_shapes(\n",
" leaves=[\"J\", \"M\"],\n",
" shape=\"square\",\n",
" fill=\"orange\",\n",
- " size=8,\n",
+ " r=4, # size=8 -> r=4 (radius)\n",
" stroke=\"black\",\n",
" stroke_width=1,\n",
" offset=35,\n",
- " orient=True\n",
- "\n",
+ " rotation=45,\n",
")\n",
"\n",
"events = [\n",
- " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"size\": 14},\n",
- " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\",\"size\": 14},\n",
- " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"size\": 14},\n",
+ " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"r\": 7}, # size=14 -> r=7\n",
+ " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\",\"r\": 7}, # size=14 -> r=7\n",
+ " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"r\": 7}, # size=14 -> r=7\n",
"]\n",
"r.add_branch_shapes(events)\n",
"\n",
- "import random\n",
"heatmap_vals = {leaf.name: random.uniform(0, 1) for leaf in t.get_leaves()}\n",
"\n",
- "r.add_ring_heatmap(\n",
+ "r.add_heatmap( # add_ring_heatmap -> add_heatmap\n",
" heatmap_vals,\n",
" width=20,\n",
- " padding=80, # Push it out further than the previous ring\n",
+ " offset=80, # padding -> offset\n",
" low_color=\"#e0ecf4\",\n",
" high_color=\"#8856a7\"\n",
")\n",
"\n",
"transfer_data = [\n",
" {\"from\": \"E\", \"to\": \"A\", \"freq\": 1.0},\n",
- " {\"from\": \"P\", \"to\": \"F\", \"freq\": 0.5},\n",
+ " {\"from\": \"R\", \"to\": \"F\", \"freq\": 0.5},\n",
"]\n",
"\n",
"r.plot_transfers(\n",
@@ -281,7 +295,7 @@
" stroke_width=3,\n",
" opacity=0.6,\n",
" gradient_colors=(\"purple\", \"orange\"),\n",
- " arc_intensity=80 # Higher value = deeper arc towards center\n",
+ " arc_intensity=80 \n",
")\n",
"\n",
"r.add_time_axis(\n",
@@ -289,28 +303,37 @@
" label=\"\",\n",
" stroke=\"gray\",\n",
" stroke_dasharray=\"3,3\",\n",
- "\n",
")\n",
"\n",
"target_clade = t.get_common_ancestor(\"E\", \"G\")\n",
"r.highlight_clade(target_clade, color=\"orange\", opacity=0.3)\n",
"\n",
+ "# Testing newly implemented parity functions\n",
+ "target_node = t.get_common_ancestor(\"H\", \"J\")\n",
+ "r.highlight_branch(target_node, color=\"red\", stroke_width=4)\n",
+ "r.gradient_branch(target_node, colors=(\"yellow\", \"red\"), stroke_width=6)\n",
+ "r.add_node_names(color=\"blue\", padding=20)\n",
"\n",
- "r.d"
+ "# To explain categorical traits\n",
+ "r.add_categorical_legend({\"Gain\": \"blue\", \"Loss\": \"red\", \"Duplication\":\"green\"}, title=\"Event\")\n",
+ "\n",
+ "# To explain your transfers\n",
+ "r.add_transfer_legend(colors=(\"purple\", \"orange\"), x=250)\n",
+ "\n",
+ "# To explain your heatmap\n",
+ "r.add_color_bar(\"#e0ecf4\", \"#8856a7\", vmin=0, vmax=1, title=\"Confidence\")\n",
+ "\n",
+ "r.d\n"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 2,
"id": "d4237e6e-b776-4f90-8a5d-18530e4298f6",
"metadata": {},
"outputs": [],
"source": [
- "# Cell 5: Export\n",
- "# SVG always works\n",
"r.save_svg(\"../examples/figures/radial_tree.svg\")\n",
- "\n",
- "# PNG requires cairosvg\n",
"r.save_png(\"../examples/figures/radial_tree.png\", scale=3.0)"
]
}
diff --git a/src/phylustrator/drawing/base.py b/src/phylustrator/drawing/base.py
index e35e13d..c42f74b 100644
--- a/src/phylustrator/drawing/base.py
+++ b/src/phylustrator/drawing/base.py
@@ -2,8 +2,10 @@
from dataclasses import dataclass
from pathlib import Path
import math
-import io
import re
+import os
+import base64
+from ..utils import to_hex, to_rgb, lerp_color, generate_id
try:
import cairosvg
@@ -12,775 +14,243 @@
@dataclass
class TreeStyle:
+ """
+ Configuration object for visual styling parameters of the phylogenetic tree.
+
+ Attributes:
+ width (int): Total width of the drawing canvas in pixels. Defaults to 1000.
+ height (int): Total height of the drawing canvas in pixels. Defaults to 1000.
+ radius (int): Radius of the tree layout (for radial trees). Defaults to 400.
+ degrees (int): Total angular span of the tree in degrees (for radial trees). Defaults to 360.
+ rotation (int): Global rotation offset in degrees. Defaults to -90 (starting at 12 o'clock).
+ margin (float): Margin padding around the tree in pixels. Defaults to 100.0.
+ root_stub_length (float): Length of the root node's "stub" branch. Defaults to 20.0.
+ leaf_r (float): Radius of the circle drawn at leaf tips. Defaults to 5.0.
+ leaf_color (str): CSS color string for leaf nodes. Defaults to "black".
+ branch_stroke_width (float): Thickness of branch lines. Defaults to 2.0.
+ branch_color (str): CSS color string for branches. Defaults to "black".
+ node_r (float): Radius of internal nodes. Defaults to 2.0.
+ font_size (int): Base font size for text elements. Defaults to 12.
+ font_family (str): Font family for text elements. Defaults to "Arial".
+ """
width: int = 1000
height: int = 1000
radius: int = 400
degrees: int = 360
rotation: int = -90
- leaf_size: int = 5
+ margin: float = 100.0
+ root_stub_length: float = 20.0
+ leaf_r: float = 5.0
leaf_color: str = "black"
- branch_size: int = 2
+ branch_stroke_width: float = 2.0
branch_color: str = "black"
- node_size: int = 2
+ node_r: float = 2.0
font_size: int = 12
font_family: str = "Arial"
class BaseDrawer:
+ """
+ Abstract base class providing shared UI, rendering primitives, and export functionality.
+ """
def __init__(self, tree, style=None):
+ """
+ Initialize the drawer with a tree structure and style configuration.
+
+ Args:
+ tree (ete3.TreeNode): The tree object to be visualized.
+ style (TreeStyle, optional): Custom style configuration. If None, default style is used.
+ """
self.t = tree
self.style = style if style else TreeStyle()
- self.d = draw.Drawing(self.style.width, self.style.height, origin='center')
- self.d.append(draw.Rectangle(-self.style.width/2, -self.style.height/2,
- self.style.width, self.style.height, fill="white"))
+ self.drawing = draw.Drawing(self.style.width, self.style.height, origin='center')
+ self.drawing.append(draw.Rectangle(-self.style.width/2, -self.style.height/2,
+ self.style.width, self.style.height, fill="white"))
+ self.d = self.drawing
self.total_tree_depth = 0
self.sf = 1.0
+ self._layout_calculated = False
+
+ def _pre_flight_check(self):
+ """Ensure layout calculations are performed before drawing operations."""
+ if not self._layout_calculated:
+ self._calculate_layout()
+ self._layout_calculated = True
+ def _calculate_layout(self):
+ """
+ Abstract method to calculate node coordinates.
+ Must be implemented by subclasses.
+ """
+ raise NotImplementedError
- def _draw_shape_at(
- self,
- x: float,
- y: float,
- shape: str,
- fill: str,
- size: float,
- stroke: str | None,
- stroke_width: float,
- rotation: float = 0.0,
- opacity: float = 1.0,
- ) -> None:
+ def _draw_shape_at(self, x, y, shape, fill, r, stroke=None, stroke_width=1.0, rotation=0, opacity=1.0):
+ """Internal helper to render geometric shapes at specific coordinates."""
common = {"fill": fill, "opacity": opacity}
- if stroke is not None:
+ if stroke:
common["stroke"] = stroke
common["stroke_width"] = stroke_width
-
shp = str(shape).lower()
- rot = float(rotation)
-
- # drawsvg uses SVG transforms; rotate around (x,y)
- transform = None if rot == 0 else f"rotate({rot},{x},{y})"
-
+ transform = f"rotate({rotation},{x},{y})" if rotation != 0 else None
if shp == "circle":
- self.d.append(draw.Circle(x, y, float(size) / 2.0, **common))
- return
-
- if shp == "square":
- s = float(size)
- self.d.append(
- draw.Rectangle(x - s / 2.0, y - s / 2.0, s, s, transform=transform, **common)
- )
- return
-
- if shp == "triangle":
- s = float(size)
- h = s * math.sqrt(3) / 2.0
- p1 = (x, y - (2.0 / 3.0) * h)
- p2 = (x - s / 2.0, y + (1.0 / 3.0) * h)
- p3 = (x + s / 2.0, y + (1.0 / 3.0) * h)
-
- path = draw.Path(transform=transform, **common)
- path.M(*p1).L(*p2).L(*p3).Z()
- self.d.append(path)
- return
-
- raise ValueError(f"Unknown shape: {shape!r}. Use circle/square/triangle.")
-
-
- def _get_rotated_svg(self, rotation: float) -> str:
- """
- Internal helper: Wraps the current drawing in a new SVG
- transformed to handle rotation and resize the canvas.
- """
- original_svg = self.d.as_svg()
-
- if rotation == 0:
- return original_svg
-
- # 1. CLEANUP: Strip XML declaration and DOCTYPE from the inner SVG
- # because they are only allowed at the very start of a file.
- original_svg = re.sub(r'<\?xml.*?\?>', '', original_svg)
- original_svg = re.sub(r'', '', original_svg)
-
- # 2. Calculate new canvas dimensions to prevent clipping
- w, h = self.style.width, self.style.height
- rad = math.radians(rotation)
- new_w = abs(w * math.cos(rad)) + abs(h * math.sin(rad))
- new_h = abs(w * math.sin(rad)) + abs(h * math.cos(rad))
-
- # 3. Wrap cleaned SVG in a group with center rotation
- return (
- f''
- )
-
- def save_svg(self, outpath: str | Path, rotation: float = 0.0) -> None:
- outpath = Path(outpath)
- outpath.parent.mkdir(parents=True, exist_ok=True)
-
- svg_content = self._get_rotated_svg(rotation)
-
- with open(outpath, "w", encoding="utf-8") as f:
- f.write(svg_content)
-
- def save_png(self, outpath: str | Path, scale: float = 1.0, rotation: float = 0.0) -> None:
- if cairosvg is None:
- raise ImportError(
- "PNG export requires 'cairosvg'. Please install it via pip."
- )
-
- outpath = Path(outpath)
- outpath.parent.mkdir(parents=True, exist_ok=True)
-
- svg_content = self._get_rotated_svg(rotation)
-
- cairosvg.svg2png(
- bytestring=svg_content.encode("utf-8"),
- write_to=str(outpath),
- scale=scale
- )
-
- def add_legend(
- self,
- title: str,
- mapping: dict,
- position="top-left",
- symbol: str = "circle",
- text_size: int | None = None,
- padding: int = 20,
- box_padding: int = 10,
- box_fill: str = "white",
- box_opacity: float = 0.9,
- box_stroke: str = "black",
- box_stroke_width: float = 1.0,
- symbol_size: int = 10,
- row_gap: int = 6,
- ):
- """Add a simple categorical legend.
-
- Coordinates follow the library convention: origin at canvas center.
-
- Parameters
- ----------
- title:
- Legend title.
- mapping:
- Dict of {label: color}.
- position:
- "top-left", "top-right", "bottom-left", "bottom-right", or (x, y) tuple
- specifying the *top-left* corner of the legend box.
- symbol:
- "circle", "square", or "line".
+ self.drawing.append(draw.Circle(x, y, float(r), **common))
+ elif shp == "square":
+ side = float(r) * 2.0
+ self.drawing.append(draw.Rectangle(x - r, y - r, side, side, transform=transform, **common))
+ elif shp == "triangle":
+ side = float(r) * 2.0
+ h = side * math.sqrt(3) / 2.0
+ p1, p2, p3 = (x, y - h*2/3), (x - side/2, y + h/3), (x + side/2, y + h/3)
+ path = draw.Path(transform=transform, **common).M(*p1).L(*p2).L(*p3).Z()
+ self.drawing.append(path)
+
+ def add_title(self, text, font_size=24, position="top", pad=40.0, color="black", weight="bold"):
+ """
+ Adds a title text to the drawing at a fixed cardinal position.
+
+ Args:
+ text (str): The title text to display.
+ font_size (int, optional): Font size in pixels. Defaults to 24.
+ position (str, optional): Position anchor ("top", "bottom", "left", "right"). Defaults to "top".
+ pad (float, optional): Padding distance from the edge. Defaults to 40.0.
+ color (str, optional): Text color. Defaults to "black".
+ weight (str, optional): Font weight (e.g., "bold", "normal"). Defaults to "bold".
"""
- if not mapping:
- return
-
- font_size = int(text_size) if text_size is not None else int(self.style.font_size)
- font_family = getattr(self.style, "font_family", "Arial")
-
- # --- Size estimation (SVG has no font metrics here, so we approximate) ---
- # Typical monospace-ish heuristic: average character ~0.6*font_size
- def est_width(s: str) -> float:
- return 0.6 * font_size * len(str(s))
-
- max_label_w = max([est_width(title)] + [est_width(k) for k in mapping.keys()])
- content_w = symbol_size + 8 + max_label_w
- n_rows = len(mapping)
- title_h = font_size + 4
- row_h = font_size + row_gap
- content_h = title_h + (n_rows * row_h)
- box_w = content_w + 2 * box_padding
- box_h = content_h + 2 * box_padding
-
w, h = self.style.width, self.style.height
+ tx, ty = 0, 0
+ if position == "top": ty = -h/2 + pad
+ elif position == "bottom": ty = h/2 - pad
+ elif position == "left": tx = -w/2 + pad
+ elif position == "right": tx = w/2 - pad
+ self.drawing.append(draw.Text(
+ text, font_size, tx, ty, fill=color, font_weight=weight,
+ font_family=self.style.font_family, text_anchor="middle", dominant_baseline="middle"
+ ))
- # --- Anchor (top-left of legend box) ---
- if isinstance(position, tuple) and len(position) == 2:
- x0, y0 = position
- elif position == "top-left":
- x0, y0 = -w / 2 + padding, -h / 2 + padding
- elif position == "top-right":
- x0, y0 = w / 2 - padding - box_w, -h / 2 + padding
- elif position == "bottom-left":
- x0, y0 = -w / 2 + padding, h / 2 - padding - box_h
- elif position == "bottom-right":
- x0, y0 = w / 2 - padding - box_w, h / 2 - padding - box_h
- else:
- x0, y0 = -w / 2 + padding, -h / 2 + padding
-
- # Background box
- self.d.append(
- draw.Rectangle(
- x0,
- y0,
- box_w,
- box_h,
- fill=box_fill,
- opacity=box_opacity,
- stroke=box_stroke,
- stroke_width=box_stroke_width,
- )
- )
-
- # Title
- tx = x0 + box_padding
- ty = y0 + box_padding + font_size
- self.d.append(
- draw.Text(
- title,
- font_size + 1,
- tx,
- ty,
- font_family=font_family,
- font_weight="bold",
- fill="black",
- )
- )
-
- # Rows
- y = ty + (font_size + row_gap)
- sym_x = x0 + box_padding + symbol_size / 2
- text_x = x0 + box_padding + symbol_size + 8
- for label, color in mapping.items():
- if symbol == "circle":
- self.d.append(
- draw.Circle(sym_x, y - font_size * 0.35, symbol_size / 2, fill=color, stroke="black", stroke_width=0.5)
- )
- elif symbol == "square":
- self.d.append(
- draw.Rectangle(
- sym_x - symbol_size / 2,
- y - font_size * 0.85,
- symbol_size,
- symbol_size,
- fill=color,
- stroke="black",
- stroke_width=0.5,
- )
- )
- elif symbol == "line":
- self.d.append(
- draw.Line(
- sym_x - symbol_size / 2,
- y - font_size * 0.5,
- sym_x + symbol_size / 2,
- y - font_size * 0.5,
- stroke=color,
- stroke_width=2,
- )
- )
-
- self.d.append(draw.Text(str(label), font_size, text_x, y, font_family=font_family, fill="black"))
- y += row_h
-
-
- def add_scale_bar(
- self,
- length: float,
- label: str | None = None,
- x: float | None = None,
- y: float | None = None,
- stroke: str = "black",
- stroke_width: float = 2.0,
- tick_size: float = 6.0,
- font_size: int | None = None,
- font_family: str | None = None,
- padding: float = 10.0,
- ) -> None:
- """Add a simple scale bar (tree-length legend).
-
- Parameters
- ----------
- length
- Length in *tree units* (same units used to scale branches).
- label
- Text label. If None, uses the numeric length.
- x, y
- Anchor position (left end) in drawing coordinates. If omitted, places the
- bar near the bottom-left with padding.
- """
- px = float(length) * float(self.sf)
- if label is None:
- label = str(length)
-
- if x is None:
- x = -float(self.style.width) / 2.0 + float(padding)
- if y is None:
- y = float(self.style.height) / 2.0 - float(padding)
-
- fs = int(font_size) if font_size is not None else int(self.style.font_size)
- ff = font_family if font_family is not None else self.style.font_family
-
- # main bar
- self.d.append(draw.Line(x, y, x + px, y, stroke=stroke, stroke_width=stroke_width))
- # end ticks
- self.d.append(draw.Line(x, y - tick_size / 2.0, x, y + tick_size / 2.0, stroke=stroke, stroke_width=stroke_width))
- self.d.append(draw.Line(x + px, y - tick_size / 2.0, x + px, y + tick_size / 2.0, stroke=stroke, stroke_width=stroke_width))
-
- # label above the bar
- self.d.append(draw.Text(label, fs, x + px / 2.0, y - tick_size - 2, center=True, font_family=ff))
-
- def _leaf_xy(self, leaf, offset: float = 0.0) -> tuple[float, float]:
- """Return (x, y) coordinates for a leaf.
-
- Subclasses must implement this. The ``offset`` parameter is in pixels and
- should move the position away from the leaf tip (e.g., to the right for
- vertical trees, outward for radial trees).
- """
- raise NotImplementedError
-
-
- def add_leaf_shapes(
- self,
- leaves,
- shape: str = "circle",
- fill: str = "blue",
- size: float = 10,
- stroke: str | None = None,
- stroke_width: float = 1,
- offset: float = 0.0,
- rotation: float = 0.0,
- orient: str | None = None, # NEW: "radial" or "tangent" (mainly for radial)
- opacity: float = 1.0,
- ):
- if leaves is None:
- return
-
- leaf_nodes = []
- for item in leaves:
- if isinstance(item, str):
- try:
- leaf_nodes.append(self.t & item) # ete3 lookup by name
- except Exception:
- continue
- else:
- leaf_nodes.append(item)
-
- # ensure layout exists once if needed
- if leaf_nodes:
- if not hasattr(leaf_nodes[0], "rad") and not hasattr(leaf_nodes[0], "coordinates"):
- if hasattr(self, "_calculate_layout"):
- self._calculate_layout()
-
- for leaf in leaf_nodes:
- x, y = self._leaf_xy(leaf, offset=float(offset))
-
- rot = float(rotation)
- if orient is not None:
- o = str(orient).lower().strip()
-
- # Use the actual rendered vector (x,y) in the drawing coordinate system.
- # SVG has y pointing DOWN, so atan2(y, x) yields a clockwise angle.
- a = math.degrees(math.atan2(y, x))
-
- # Our triangle path is "pointing up" when rotation=0,
- # so to point along direction angle 'a' we add +90 degrees.
- if o == "radial":
- rot = a + 90.0
- elif o == "tangent":
- rot = a + 180.0
-
- self._draw_shape_at(
- x=x, y=y,
- shape=shape,
- fill=fill,
- size=size,
- stroke=stroke,
- stroke_width=stroke_width,
- rotation=rot,
- opacity=opacity,
- )
-
- def _node_xy(self, node) -> tuple[float, float]:
- """Subclasses must implement: x,y coordinates of any node."""
- raise NotImplementedError
-
- def _edge_point(self, child, where: float) -> tuple[float, float, float]:
+ def add_text(self, text, x, y, font_size=12, color="black", weight="normal", text_anchor="start", dominant_baseline="middle", rotation=0):
"""
- Default edge interpolation: straight line from parent to child.
- Returns (x,y,angle_degrees_along_edge).
- Subclasses can override (vertical should).
+ Adds arbitrary text at specific Cartesian (x, y) coordinates.
+
+ Args:
+ text (str): The text content.
+ x (float): X-coordinate.
+ y (float): Y-coordinate.
+ font_size (int, optional): Font size. Defaults to 12.
+ color (str, optional): CSS color string. Defaults to "black".
+ weight (str, optional): Font weight. Defaults to "normal".
+ text_anchor (str, optional): Horizontal alignment ("start", "middle", "end"). Defaults to "start".
+ dominant_baseline (str, optional): Vertical alignment. Defaults to "middle".
+ rotation (float, optional): Rotation angle in degrees. Defaults to 0.
"""
- parent = child.up
- x0, y0 = self._node_xy(parent)
- x1, y1 = self._node_xy(child)
-
- t = max(0.0, min(1.0, float(where)))
- x = x0 + (x1 - x0) * t
- y = y0 + (y1 - y0) * t
-
- ang = math.degrees(math.atan2(y1 - y0, x1 - x0))
- return x, y, ang
+ transform = f"rotate({rotation}, {x}, {y})" if rotation != 0 else None
+ self.drawing.append(draw.Text(
+ text, font_size, x, y, fill=color, font_weight=weight,
+ font_family=self.style.font_family, text_anchor=text_anchor,
+ dominant_baseline=dominant_baseline, transform=transform
+ ))
- def _where_from_time(self, node, t: float) -> float:
+ def save_svg(self, outpath, rotation=0):
"""
- Map absolute event time t to a [0,1] position along the incoming edge of `node`.
- Requires node.up.time_from_origin and node.time_from_origin (Zombi parser provides these).
- """
- parent = node.up
- if parent is None:
- return 0.0
-
- t0 = float(getattr(parent, "time_from_origin", 0.0))
- t1 = float(getattr(node, "time_from_origin", t0))
- denom = (t1 - t0) if abs(t1 - t0) > 1e-12 else 1.0
+ Exports the current drawing to an SVG file.
- w = (float(t) - t0) / denom
- if w < 0.0:
- return 0.0
- if w > 1.0:
- return 1.0
- return w
-
- def add_title(
- self,
- text: str,
- fontsize: int = 24,
- position: str = "top",
- pad: float = 40.0,
- rotation: float = 0.0,
- color: str = "black",
- font_weight: str = "bold"
- ) -> None:
+ Args:
+ outpath (str or Path): The destination file path.
+ rotation (float, optional): Global rotation to apply to the entire canvas output. Defaults to 0.
"""
- Adds a title to the canvas relative to the edges.
+ self._pre_flight_check()
+ outpath = Path(outpath)
+ outpath.parent.mkdir(parents=True, exist_ok=True)
+ svg_content = self._get_rotated_svg_content(rotation)
+ with open(outpath, "w", encoding="utf-8") as f:
+ f.write(svg_content)
- Parameters
- ----------
- text: The string to display.
- fontsize: Size of the font.
- position: "top", "bottom", "left", or "right".
- pad: Padding from the edge of the canvas.
- rotation: Degrees to rotate the text.
- color: Text color.
- font_weight: "normal" or "bold".
+ def save_png(self, outpath, dpi=300, scale=None, rotation=0):
"""
- w, h = self.style.width, self.style.height
+ Exports the current drawing to a PNG file.
- # Default center of canvas
- tx, ty = 0, 0
+ Requires the `cairosvg` library to be installed.
- # Calculate anchor based on position
- if position == "top":
- ty = -h / 2 + pad
- elif position == "bottom":
- ty = h / 2 - pad
- elif position == "left":
- tx = -w / 2 + pad
- elif position == "right":
- tx = w / 2 - pad
+ Args:
+ outpath (str or Path): The destination file path.
+ dpi (int, optional): Dots per inch for resolution. Defaults to 300.
+ scale (float, optional): Scaling factor. If None, calculated from DPI. Defaults to None.
+ rotation (float, optional): Global rotation to apply to the output. Defaults to 0.
- # Create the text element
- # Note: drawsvg origin='center' is used here
- title_obj = draw.Text(
- text,
- fontsize,
- tx,
- ty,
- fill=color,
- font_family=getattr(self.style, "font_family", "Arial"),
- font_weight=font_weight,
- text_anchor="middle",
- dominant_baseline="middle",
- transform=f"rotate({rotation},{tx},{ty})" if rotation != 0 else None
+ Raises:
+ ImportError: If `cairosvg` is not installed.
+ """
+ if cairosvg is None:
+ raise ImportError("PNG export requires 'cairosvg'.")
+ self._pre_flight_check()
+ outpath = Path(outpath)
+ outpath.parent.mkdir(parents=True, exist_ok=True)
+ if scale is None: scale = dpi / 72.0
+ svg_content = self._get_rotated_svg_content(rotation)
+ cairosvg.svg2png(bytestring=svg_content.encode("utf-8"), write_to=str(outpath), scale=scale)
+
+ def _get_rotated_svg_content(self, rotation):
+ """Generates the SVG content string, applying a global rotation if needed."""
+ if rotation == 0: return self.drawing.as_svg()
+ original_svg = self.drawing.as_svg()
+ original_svg = re.sub(r'<\?xml.*?\?>|', '', original_svg)
+ w, h = self.style.width, self.style.height
+ rad = math.radians(rotation)
+ new_w = abs(w * math.cos(rad)) + abs(h * math.sin(rad))
+ new_h = abs(w * math.sin(rad)) + abs(h * math.cos(rad))
+ return (
+ f''
)
-
- self.d.append(title_obj)
-
- def add_branch_shapes(
- self,
- specs,
- default_where: float = 0.5,
- orient: str | None = None,
- offset: float = 0.0,
- **kwargs # <--- Added to handle extra arguments like stroke_color
- ) -> None:
+ def add_categorical_legend(self, palette, title="Legend", x=None, y=None, font_size=14, r=6):
"""
- Add shapes on branches.
-
- Each spec dict can include:
- branch (str|node), where, shape, fill, size, stroke, stroke_width, rotation, opacity
+ Adds a categorical legend (key-value pairs of colors) to the canvas.
+
+ Args:
+ palette (dict): Mapping of {label: color_string}.
+ title (str, optional): Title for the legend. Defaults to "Legend".
+ x (float, optional): Top-left X coordinate. Defaults to calculated position.
+ y (float, optional): Top-left Y coordinate. Defaults to calculated position.
+ font_size (int, optional): Text size. Defaults to 14.
+ r (float, optional): Radius of the colored marker dots. Defaults to 6.
"""
- # Accept either list[dict] or a DataFrame-like object (e.g. pandas.DataFrame)
- if hasattr(specs, "to_dict") and hasattr(specs, "columns"):
- specs = specs.to_dict(orient="records")
-
- for s in specs:
- br = s.get("branch", None)
- if br is None:
- continue
-
- # resolve node
- if isinstance(br, str):
- try:
- child = self.t & br
- except Exception:
- continue
- else:
- child = br
-
- if child.up is None:
- continue # no parent edge
-
- # Determine position along the branch
- if "where" in s and s.get("where") is not None:
- where = float(s.get("where"))
- elif "time" in s and s.get("time") is not None and hasattr(child, "time_from_origin") and hasattr(child.up, "time_from_origin"):
- where = float(self._where_from_time(child, float(s.get("time"))))
- else:
- where = float(default_where)
-
- x, y, edge_ang = self._edge_point(child, where=where)
-
- # optional perpendicular offset
- if offset != 0.0:
- perp = edge_ang + 90.0
- x += float(offset) * math.cos(math.radians(perp))
- y += float(offset) * math.sin(math.radians(perp))
-
- # Handle rotation
- rot = float(s.get("rotation", 0.0))
- if orient is not None:
- o = str(orient).lower().strip()
- if o == "along":
- rot = edge_ang
- elif o == "perp":
- rot = edge_ang + 90.0
-
- # Draw the shape using internal helper
- self._draw_shape_at(
- x=x, y=y,
- shape=s.get("shape", "circle"),
- fill=s.get("fill", "blue"),
- size=float(s.get("size", 10)),
- stroke=s.get("stroke", None),
- stroke_width=float(s.get("stroke_width", 1)),
- rotation=rot,
- opacity=float(s.get("opacity", 1.0)),
- )
-
- def add_colorbar(
- self,
- vmin: float,
- vmax: float,
- low_color: str = "#f7fbff",
- high_color: str = "#08306b",
- label: str | None = None,
- ticks: list[float] | None = None,
- n_steps: int = 60,
- bar_width: float = 18.0,
- bar_height: float = 180.0,
- margin: float = 18.0,
- x_offset: float = 0.0,
- y_offset: float = 0.0,
- label_pad: float = 10.0,
- tick_pad: float = 12.0,
- font_size: int | None = None,
- font_family: str | None = None,
- stroke: str = "black",
- stroke_width: float = 1.0,
- ):
- """Add a vertical colorbar to the drawing (top-right by default).
-
- You can reposition with x_offset / y_offset (in px), relative to top-right anchor.
- Coordinates assume drawsvg Drawing(origin='center').
+ if x is None: x = -self.style.width / 2 + 30
+ if y is None: y = -self.style.height / 2 + 30
+ self.drawing.append(draw.Text(title, font_size + 2, x, y, font_weight="bold",
+ font_family=self.style.font_family, text_anchor="start"))
+ curr_y = y + font_size * 1.5
+ for label, color in palette.items():
+ self.drawing.append(draw.Circle(x + r, curr_y, r, fill=color))
+ self.drawing.append(draw.Text(str(label), font_size, x + r*2.5, curr_y,
+ font_family=self.style.font_family, text_anchor="start", dominant_baseline="middle"))
+ curr_y += font_size * 1.4
+
+ def add_color_bar(self, low_color, high_color, vmin, vmax, title="", x=None, y=None, width=100, height=15, font_size=12):
"""
- def _hex_to_rgb(h: str) -> tuple[int, int, int]:
- h = h.lstrip("#")
- return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
-
- def _rgb_to_hex(rgb: tuple[int, int, int]) -> str:
- return "#{:02x}{:02x}{:02x}".format(*rgb)
-
- def _lerp(a: float, b: float, t: float) -> float:
- return a + (b - a) * t
-
- c0 = _hex_to_rgb(low_color)
- c1 = _hex_to_rgb(high_color)
-
- def _lerp_color(t: float) -> str:
- t = 0.0 if t < 0.0 else (1.0 if t > 1.0 else t)
- r = int(_lerp(c0[0], c1[0], t))
- g = int(_lerp(c0[1], c1[1], t))
- b = int(_lerp(c0[2], c1[2], t))
- return _rgb_to_hex((r, g, b))
-
- fs = int(font_size) if font_size is not None else int(self.style.font_size)
- ff = font_family if font_family is not None else self.style.font_family
-
- # anchor top-right (origin is center)
- x0 = float(self.style.width) / 2.0 - float(margin) - float(bar_width) + float(x_offset)
- y_top = -float(self.style.height) / 2.0 + float(margin) + float(y_offset)
-
- # outline box
- self.d.append(draw.Rectangle(
- x0, y_top, bar_width, bar_height,
- fill="none", stroke=stroke, stroke_width=stroke_width
- ))
-
- # gradient fill as stacked rectangles
- steps = max(2, int(n_steps))
- step_h = float(bar_height) / steps
- for i in range(steps):
- t = i / (steps - 1)
- fill = _lerp_color(1.0 - t) # high at top
- y = y_top + i * step_h
- self.d.append(draw.Rectangle(x0, y, bar_width, step_h + 0.5, fill=fill, stroke="none"))
-
- # default ticks
- if ticks is None:
- ticks = [vmin, (vmin + vmax) / 2.0, vmax]
-
- # tick marks + labels (right side)
- for tv in ticks:
- frac = 0.0 if vmax == vmin else (float(tv) - float(vmin)) / (float(vmax) - float(vmin))
- frac = 0.0 if frac < 0.0 else (1.0 if frac > 1.0 else frac)
-
- y = y_top + (1.0 - frac) * float(bar_height)
- self.d.append(draw.Line(
- x0 + bar_width, y,
- x0 + bar_width + 6, y,
- stroke=stroke, stroke_width=stroke_width
- ))
- self.d.append(draw.Text(
- str(tv), fs,
- x0 + bar_width + float(tick_pad), y + fs * 0.35,
- font_family=ff
- ))
-
- # label (above)
- if label:
- self.d.append(draw.Text(
- label, fs,
- x0 + bar_width / 2.0, y_top - float(label_pad),
- center=True, font_family=ff
- ))
-
- def add_node_shapes(
- self,
- nodes,
- shape: str = "circle",
- fill: str = "red",
- size: float = 10.0,
- stroke: str | None = None,
- stroke_width: float = 1.0,
- rotation: float = 0.0,
- dx: float = 0.0,
- dy: float = 0.0,
- missing: str = "ignore", # "ignore" or "raise"
- ) -> None:
- """Draw shapes centered on (ancestral) nodes.
-
- Parameters
- ----------
- nodes
- Either:
- - list of node names (str) or ete3 Node objects
- - OR a list of dict specs with keys:
- {"node": , "shape": ..., "fill": ..., "size": ..., "stroke": ..., "stroke_width": ..., "rotation": ..., "dx": ..., "dy": ...}
- shape
- "circle", "square", or "triangle" (ignored when using dict specs per-node).
- rotation
- Degrees. For triangles/squares; circle ignores it.
- dx, dy
- Pixel offsets to nudge the marker.
- missing
- What to do if a node name is not found: "ignore" or "raise".
+ Adds a continuous linear gradient color bar to the canvas.
+
+ Args:
+ low_color (str): Color corresponding to the minimum value.
+ high_color (str): Color corresponding to the maximum value.
+ vmin (float): The numeric value corresponding to low_color.
+ vmax (float): The numeric value corresponding to high_color.
+ title (str, optional): Title text above the bar. Defaults to "".
+ x (float, optional): Top-left X coordinate. Defaults to calculated position.
+ y (float, optional): Top-left Y coordinate. Defaults to calculated position.
+ width (float, optional): Width of the color bar. Defaults to 100.
+ height (float, optional): Height of the color bar. Defaults to 15.
+ font_size (int, optional): Font size for labels. Defaults to 12.
"""
- # allow list[dict] specs
- if isinstance(nodes, list) and nodes and isinstance(nodes[0], dict):
- for s in nodes:
- n = s.get("node", None)
- if n is None:
- continue
- self._add_one_node_shape(
- node=n,
- shape=str(s.get("shape", shape)),
- fill=str(s.get("fill", fill)),
- size=float(s.get("size", size)),
- stroke=s.get("stroke", stroke),
- stroke_width=float(s.get("stroke_width", stroke_width)),
- rotation=float(s.get("rotation", rotation)),
- dx=float(s.get("dx", dx)),
- dy=float(s.get("dy", dy)),
- missing=missing,
- )
- return
-
- # uniform style for all nodes
- for n in nodes:
- self._add_one_node_shape(
- node=n,
- shape=shape,
- fill=fill,
- size=size,
- stroke=stroke,
- stroke_width=stroke_width,
- rotation=rotation,
- dx=dx,
- dy=dy,
- missing=missing,
- )
-
-
- def _add_one_node_shape(
- self,
- node,
- shape: str,
- fill: str,
- size: float,
- stroke: str | None,
- stroke_width: float,
- rotation: float,
- dx: float,
- dy: float,
- missing: str,
- ) -> None:
- """Internal helper: draw one node marker centered at node coordinates."""
- # resolve node
- n = node
- if isinstance(node, str):
- hits = self.t.search_nodes(name=node)
- if not hits:
- if missing == "raise":
- raise ValueError(f"Node '{node}' not found in tree.")
- return
- n = hits[0]
-
- x, y = self._node_xy(n)
- x += float(dx)
- y += float(dy)
-
- shp = str(shape).lower()
- half = float(size) / 2.0
-
- common = {"fill": fill}
- if stroke is not None:
- common["stroke"] = stroke
- common["stroke_width"] = float(stroke_width)
-
- if shp == "circle":
- self.d.append(draw.Circle(x, y, half, **common))
- return
-
- if shp == "square":
- rect = draw.Rectangle(x - half, y - half, float(size), float(size), **common)
- if rotation:
- rect.args["transform"] = f"rotate({float(rotation)},{x},{y})"
- self.d.append(rect)
- return
-
- if shp == "triangle":
- # Up-pointing triangle centered at (x,y)
- p = draw.Path(**common)
- p.M(x, y - half).L(x - half, y + half).L(x + half, y + half).Z()
- if rotation:
- p.args["transform"] = f"rotate({float(rotation)},{x},{y})"
- self.d.append(p)
- return
-
- raise ValueError(f"Unknown shape '{shape}'. Use: circle, square, triangle.")
+ if x is None: x = -self.style.width / 2 + 30
+ if y is None: y = self.style.height / 2 - 60
+ gid = generate_id("cb_grad")
+ grad = draw.LinearGradient(x, y, x + width, y, id=gid)
+ grad.add_stop(0, low_color); grad.add_stop(1, high_color)
+ self.drawing.append(grad)
+ if title:
+ self.drawing.append(draw.Text(title, font_size, x, y - 10, font_weight="bold", text_anchor="start"))
+ self.drawing.append(draw.Rectangle(x, y, width, height, fill=grad, stroke="black", stroke_width=0.5))
+ self.drawing.append(draw.Text(f"{vmin:.2g}", font_size - 2, x, y + height + 12, text_anchor="start"))
+ self.drawing.append(draw.Text(f"{vmax:.2g}", font_size - 2, x + width, y + height + 12, text_anchor="end"))
diff --git a/src/phylustrator/drawing/radial.py b/src/phylustrator/drawing/radial.py
index 7971da8..022f362 100644
--- a/src/phylustrator/drawing/radial.py
+++ b/src/phylustrator/drawing/radial.py
@@ -1,574 +1,556 @@
-from platform import node
import drawsvg as draw
import math
-import random
+import os
+import base64
from .base import BaseDrawer
-
-def radial_converter(degree, radius, rotation=0):
- theta = math.radians(degree + rotation)
- return radius * math.cos(theta), radius * math.sin(theta)
+from ..utils import polar_to_cartesian, generate_id, lerp_color
class RadialTreeDrawer(BaseDrawer):
+ """
+ Drawer class for rendering phylogenetic trees in a circular (radial) layout.
+
+ Inherits from BaseDrawer and implements radial-specific geometry calculations
+ where nodes are positioned by angle and radius.
+ """
def __init__(self, tree, style=None):
+ """
+ Initializes the RadialTreeDrawer and performs the initial layout calculation.
+
+ Args:
+ tree (ete3.TreeNode): The tree object to be visualized.
+ style (TreeStyle, optional): Custom style configuration.
+ """
super().__init__(tree, style)
self._calculate_layout()
def _rot_ang(self, ang_deg: float) -> float:
- # layout angle + style rotation (degrees)
- return float(ang_deg) + float(getattr(self.style, "rotation", 0.0))
-
+ """Applies the global rotation offset to a given angle."""
+ return float(ang_deg) + float(self.style.rotation)
- def _node_xy(self, node):
+ def _node_xy(self, node) -> tuple[float, float]:
+ """Calculates Cartesian (x, y) coordinates for a node in polar space."""
if not (hasattr(node, "rad") and hasattr(node, "angle")):
self._calculate_layout()
- r = float(node.rad)
- ang = self._rot_ang(node.angle)
- th = math.radians(ang)
- return (r * math.cos(th), r * math.sin(th))
+ return polar_to_cartesian(self._rot_ang(node.angle), node.rad)
-
- def _leaf_xy(self, leaf, offset: float = 0.0):
+ def _leaf_xy(self, leaf, offset: float = 0.0) -> tuple[float, float]:
+ """Calculates Cartesian coordinates for a leaf tip with a radial offset."""
if not (hasattr(leaf, "rad") and hasattr(leaf, "angle")):
self._calculate_layout()
- r = float(leaf.rad) + float(offset)
- ang = self._rot_ang(leaf.angle)
- th = math.radians(ang)
- return (r * math.cos(th), r * math.sin(th))
-
+ return polar_to_cartesian(self._rot_ang(leaf.angle), leaf.rad + float(offset))
- def _edge_point(self, child, where: float):
- """
- Place a point along the drawn branch centerline (radial segment at CHILD angle).
- where=0 => (r=parent.rad, angle=child.angle)
- where=1 => (r=child.rad, angle=child.angle)
- """
+ def _edge_point(self, child, where: float) -> tuple[float, float, float]:
+ """Finds a point along the radial branch leading to a child node."""
parent = child.up
if parent is None:
- x, y = self._node_xy(child)
- return x, y, 0.0
-
- if not (hasattr(parent, "rad") and hasattr(child, "rad") and hasattr(child, "angle")):
- self._calculate_layout()
-
- r0 = float(parent.rad)
- r1 = float(child.rad)
+ return (*self._node_xy(child), self._rot_ang(child.angle))
+ r_p, r_c = float(parent.rad), float(child.rad)
t = max(0.0, min(1.0, float(where)))
- r = r0 + (r1 - r0) * t
-
+ r = r_p + (r_c - r_p) * t
ang = self._rot_ang(child.angle)
- th = math.radians(ang)
- x = r * math.cos(th)
- y = r * math.sin(th)
-
- # orientation "along" the branch direction
- edge_ang = ang if (r1 - r0) >= 0 else (ang + 180.0)
- return x, y, edge_ang
+ x, y = polar_to_cartesian(ang, r)
+ return x, y, ang
def _calculate_layout(self):
- max_dist = 0
+ """
+ Computes polar coordinates (radius and angle) for all nodes in the tree.
+
+ This method is called automatically during initialization or before drawing.
+ """
+ # 1. Radial Scaling (Distances)
+ max_dist = 0.0
for n in self.t.traverse("preorder"):
n.dist_to_root = n.up.dist_to_root + n.dist if not n.is_root() else getattr(n, "dist", 0.0)
max_dist = max(max_dist, n.dist_to_root)
self.total_tree_depth = max_dist
- self.sf = self.style.radius / max_dist if max_dist > 0 else 1
+ self.sf = float(self.style.radius) / max_dist if max_dist > 0 else 1.0
+
+ for n in self.t.traverse():
+ n.rad = n.dist_to_root * self.sf
+
+ # 2. Angular Scaling (Leaves)
leaves = self.t.get_leaves()
- self.angle_step = self.style.degrees / len(leaves)
+ span = float(self.style.degrees)
+ angle_step = span / max(len(leaves), 1)
for i, leaf in enumerate(leaves):
- leaf.angle = i * self.angle_step
+ leaf.angle = i * angle_step
+
+ # 3. Internal Centerings
for n in self.t.traverse("postorder"):
if not n.is_leaf():
- n.angle = (n.children[0].angle + n.children[-1].angle) / 2
- n.rad = n.dist_to_root * self.sf
- n.xy = radial_converter(n.angle, n.rad, self.style.rotation)
+ if n.children:
+ n.angle = sum(c.angle for c in n.children) / len(n.children)
+ else:
+ n.angle = 0.0
+ self._layout_calculated = True
- def draw(self, branch2color=None, hide_radial_lines=None):
- """
- Standard radial drawing loop.
- :param branch2color: Dictionary {node_object: color_string}
- :param hide_radial_lines: List of nodes to skip parent connections
+ def draw(self, branch2color=None):
"""
- hidden_set = set(hide_radial_lines) if hide_radial_lines else set()
+ Draws the main tree skeleton.
+ Connectors are drawn as circular arcs and branches as radial lines.
+
+ Args:
+ branch2color (dict, optional): Dictionary mapping `ete3.TreeNode` objects to
+ CSS color strings. Used to color specific branches.
+ """
+ self._pre_flight_check()
for n in self.t.traverse("postorder"):
- # Resolve Color
- b_color = self.style.branch_color
- if branch2color and n in branch2color:
- b_color = branch2color[n]
- if b_color == "None": continue
+ color = branch2color.get(n, self.style.branch_color) if branch2color else self.style.branch_color
+ x, y = self._node_xy(n)
+ if not n.is_root():
+ parent = n.up
+ px, py = self._node_xy(parent)
+ start_ang, end_ang = self._rot_ang(parent.angle), self._rot_ang(n.angle)
+ if abs(start_ang - end_ang) > 0.001:
+ ax, ay = polar_to_cartesian(end_ang, parent.rad)
+ path = draw.Path(stroke=color, stroke_width=self.style.branch_stroke_width, fill="none")
+ sweep = 1 if end_ang > start_ang else 0
+ path.M(px, py).A(parent.rad, parent.rad, 0, 0, sweep, ax, ay)
+ self.drawing.append(path)
+ self.drawing.append(draw.Line(ax, ay, x, y, stroke=color, stroke_width=self.style.branch_stroke_width))
+ else:
+ self.drawing.append(draw.Line(px, py, x, y, stroke=color, stroke_width=self.style.branch_stroke_width))
+
+ r_v = self.style.leaf_r if n.is_leaf() else self.style.node_r
+ if r_v > 0:
+ fill = self.style.leaf_color if n.is_leaf() else color
+ self.drawing.append(draw.Circle(x, y, r_v, fill=fill))
- x, y = n.xy
+ def highlight_clade(self, node, color="lightblue", opacity=0.3, padding=10):
+ """
+ Draws a shaded "donut wedge" background behind a specific clade.
+
+ Args:
+ node (ete3.TreeNode): The root node of the clade to highlight.
+ color (str, optional): Fill color. Defaults to "lightblue".
+ opacity (float, optional): Fill opacity (0.0 to 1.0). Defaults to 0.3.
+ padding (float, optional): Radial padding in pixels. Defaults to 10.
+ """
+ self._pre_flight_check()
+ leaves = node.get_leaves()
+ angles = [l.angle for l in leaves]
+ start_ang, end_ang = self._rot_ang(min(angles)), self._rot_ang(max(angles))
+ r_i, r_o = node.rad - padding, max(l.rad for l in leaves) + padding
+
+ p = draw.Path(fill=color, fill_opacity=opacity)
+ p1x, p1y = polar_to_cartesian(start_ang, r_i)
+ p2x, p2y = polar_to_cartesian(end_ang, r_i)
+ p3x, p3y = polar_to_cartesian(end_ang, r_o)
+ p4x, p4y = polar_to_cartesian(start_ang, r_o)
+ sweep = 1 if end_ang > start_ang else 0
+ p.M(p1x, p1y).A(r_i, r_i, 0, 0, sweep, p2x, p2y).L(p3x, p3y).A(r_o, r_o, 0, 0, 0 if sweep == 1 else 1, p4x, p4y).Z()
+ self.drawing.append(p)
+
+ def highlight_branch(self, node, color="red", stroke_width=None):
+ """
+ Overlays a thicker or colored arc/line on the branch leading to the specific node.
- # 1. Radial Line to Parent
- if not n.is_root() and n not in hidden_set:
- px, py = radial_converter(n.angle, n.up.rad, self.style.rotation)
- self.d.append(draw.Line(px, py, x, y, stroke=b_color, stroke_width=self.style.branch_size))
+ Args:
+ node (ete3.TreeNode): The target node.
+ color (str, optional): CSS color string. Defaults to "red".
+ stroke_width (float, optional): Thickness of the highlight. Defaults to 2x standard width.
+ """
+ if node.is_root(): return
+ sw = stroke_width or self.style.branch_stroke_width * 2
+ px, py = self._node_xy(node.up)
+ x, y = self._node_xy(node)
+ start_ang, end_ang = self._rot_ang(node.up.angle), self._rot_ang(node.angle)
+
+ if abs(start_ang - end_ang) > 0.001:
+ ax, ay = polar_to_cartesian(end_ang, node.up.rad)
+ path = draw.Path(stroke=color, stroke_width=sw, fill="none", stroke_linecap="round")
+ sweep = 1 if end_ang > start_ang else 0
+ path.M(px, py).A(node.up.rad, node.up.rad, 0, 0, sweep, ax, ay)
+ self.drawing.append(path)
+ self.drawing.append(draw.Line(ax, ay, x, y, stroke=color, stroke_width=sw, stroke_linecap="round"))
+ else:
+ self.drawing.append(draw.Line(px, py, x, y, stroke=color, stroke_width=sw, stroke_linecap="round"))
+
+ def gradient_branch(self, node, colors=("red", "blue"), stroke_width=None):
+ """
+ Applies a linear color gradient along a radial branch segment.
- # 2. Connector Arc and Nodes
- if not n.is_leaf():
- a1, a2 = n.children[0].angle, n.children[-1].angle
- p = draw.Path(stroke=b_color, stroke_width=self.style.branch_size, fill="none")
- sx, sy = radial_converter(a1, n.rad, self.style.rotation)
- ex, ey = radial_converter(a2, n.rad, self.style.rotation)
- p.M(sx, sy).A(n.rad, n.rad, 0, 0, 1, ex, ey)
- self.d.append(p)
-
- if not n.is_root():
- self.d.append(draw.Circle(x, y, self.style.node_size, fill=b_color))
- else:
- self.d.append(draw.Circle(x, y, self.style.leaf_size, fill=self.style.leaf_color))
+ Args:
+ node (ete3.TreeNode): The target node.
+ colors (tuple, optional): Tuple of (start_color, end_color). Defaults to ("red", "blue").
+ stroke_width (float, optional): Thickness of the branch. Defaults to style default.
+ """
+ if node.is_root(): return
+ sw = stroke_width or self.style.branch_stroke_width
+ px, py = self._node_xy(node.up)
+ x, y = self._node_xy(node)
+ gid = generate_id("grad")
+ grad = draw.LinearGradient(px, py, x, y, id=gid)
+ grad.add_stop(0, colors[0])
+ grad.add_stop(1, colors[1])
+ self.drawing.append(grad)
+ self.highlight_branch(node, color=grad, stroke_width=sw)
+
+ def add_leaf_names(self, font_size=None, color="black", padding=10):
+ """
+ Adds text labels to leaf tips, automatically rotated to match the radial angle.
- def add_ring(self, mapping, width=20, padding=10):
- r_in = self.style.radius + padding
- r_out = r_in + width
- for l in self.t.get_leaves():
- if l.name not in mapping: continue
- a1, a2 = l.angle - self.angle_step/2, l.angle + self.angle_step/2
- p = draw.Path(fill=mapping[l.name])
- s_i_x, s_i_y = radial_converter(a1, r_in, self.style.rotation)
- e_i_x, e_i_y = radial_converter(a2, r_in, self.style.rotation)
- s_o_x, s_o_y = radial_converter(a1, r_out, self.style.rotation)
- e_o_x, e_o_y = radial_converter(a2, r_out, self.style.rotation)
- p.M(s_o_x, s_o_y).A(r_out, r_out, 0, 0, 1, e_o_x, e_o_y).L(e_i_x, e_i_y).A(r_in, r_in, 0, 0, 0, s_i_x, s_i_y).Z()
- self.d.append(p)
-
- def add_leaf_names(self, padding=15):
+ Args:
+ font_size (int, optional): Font size in pixels. Defaults to style default.
+ color (str, optional): Text color. Defaults to "black".
+ padding (float, optional): Distance from leaf tip to text start. Defaults to 10.
+ """
+ fs = font_size or self.style.font_size
for l in self.t.get_leaves():
- angle = l.angle + self.style.rotation
- ang_mod = angle % 360.0
-
- # SVG y-axis is downward, so the "right side" includes angles near 360 as well
- right_side = (ang_mod <= 90.0) or (ang_mod >= 270.0)
-
- # place at leaf tip radius, not at a fixed style.radius
- x, y = radial_converter(l.angle, float(l.rad) + float(padding), self.style.rotation)
-
- rot = angle if right_side else (angle - 180.0)
- anchor = "start" if right_side else "end"
-
- self.d.append(
- draw.Text(
- l.name,
- self.style.font_size,
- x,
- y,
- transform=f"rotate({rot},{x},{y})",
- text_anchor=anchor,
- dominant_baseline="middle",
- font_family=self.style.font_family,
- )
- )
-
-
- def plot_transfers(
- self,
- transfers,
- mode="midpoint",
- curve_type="C",
- filter_below=0.1,
- use_gradient=True,
- gradient_colors=("purple", "orange"),
- color="orange",
- use_thickness=True,
- stroke_width=5,
- arc_intensity=40,
- opacity=0.6,
- ):
- """
- Plot transfers on a radial tree.
-
- mode="time" places endpoints at the correct event time along each endpoint branch
- using node.time_from_origin (from Zombi parser). Otherwise uses midpoint logic.
- """
- # Accept either list[dict] or a DataFrame-like object (e.g. pandas.DataFrame)
- if hasattr(transfers, "to_dict") and hasattr(transfers, "columns"):
- transfers = transfers.to_dict(orient="records")
-
- # Ensure layout exists
- any_node = next(self.t.traverse("preorder"))
- if not hasattr(any_node, "rad") or not hasattr(any_node, "angle"):
- self._calculate_layout()
+ ang = (self._rot_ang(l.angle)) % 360
+ x, y = polar_to_cartesian(ang, l.rad + padding)
+ text_rot, anchor = ang, "start"
+ if 90 < ang < 270:
+ text_rot += 180
+ anchor = "end"
+ self.drawing.append(draw.Text(l.name, fs, x, y, fill=color, font_family=self.style.font_family,
+ transform=f"rotate({text_rot}, {x}, {y})", text_anchor=anchor, dominant_baseline="middle"))
+
+ def add_node_names(self, font_size=None, color="gray", padding=5):
+ """
+ Adds text labels to internal node positions.
- name_to_node = {n.name: n for n in self.t.traverse()}
-
- def where_from_time(node, tt: float) -> float:
- parent = node.up
- if parent is None:
- return 0.0
- t0 = float(getattr(parent, "time_from_origin", 0.0))
- t1 = float(getattr(node, "time_from_origin", t0))
- denom = (t1 - t0) if abs(t1 - t0) > 1e-12 else 1.0
- w = (float(tt) - t0) / denom
- if w < 0.0:
- return 0.0
- if w > 1.0:
- return 1.0
- return w
+ Args:
+ font_size (int, optional): Font size. Defaults to style default * 0.8.
+ color (str, optional): Text color. Defaults to "gray".
+ padding (float, optional): Radial offset. Defaults to 5.
+ """
+ fs = font_size or self.style.font_size * 0.8
+ for n in self.t.traverse():
+ if not n.is_leaf() and n.name:
+ ang = (self._rot_ang(n.angle)) % 360
+ x, y = polar_to_cartesian(ang, n.rad + padding)
+ text_rot = ang + (180 if 90 < ang < 270 else 0)
+ self.drawing.append(draw.Text(n.name, fs, x, y, fill=color, font_family=self.style.font_family,
+ transform=f"rotate({text_rot}, {x}, {y})", text_anchor="middle", dominant_baseline="middle"))
+
+ def add_leaf_shapes(self, leaves, shape="circle", fill="blue", r=5.0, stroke=None, stroke_width=1.0, offset=0.0, rotation=0.0, opacity=1.0, orient=False):
+ """
+ Adds geometric markers next to specific leaf tips.
+
+ Args:
+ leaves (list): List of node names (str) or objects to mark.
+ shape (str, optional): Shape type ("circle", "square", "triangle"). Defaults to "circle".
+ fill (str, optional): Fill color. Defaults to "blue".
+ r (float, optional): Radius/Size of the shape. Defaults to 5.0.
+ stroke (str, optional): Border color. Defaults to None.
+ stroke_width (float, optional): Border width. Defaults to 1.0.
+ offset (float, optional): Radial distance offset from the leaf tip. Defaults to 0.0.
+ rotation (float, optional): Additional rotation for the shape. Defaults to 0.0.
+ opacity (float, optional): Opacity (0.0 to 1.0). Defaults to 1.0.
+ orient (bool, optional): If True, rotates shape to match the branch angle. Defaults to False.
+ """
+ self._pre_flight_check()
+ for item in leaves:
+ try:
+ node = self.t & item if isinstance(item, str) else item
+ ang = self._rot_ang(node.angle)
+ x, y = polar_to_cartesian(ang, node.rad + float(offset))
+ rot = rotation + (ang if orient else 0.0)
+ self._draw_shape_at(x, y, shape, fill, r, stroke, stroke_width, rot, opacity)
+ except: continue
+
+ def add_node_shapes(self, nodes, shape="circle", fill="red", r=5.0, stroke=None, stroke_width=1.0, rotation=0, dx=0, dy=0, orient=False):
+ """
+ Adds geometric markers centered on specific node positions.
+
+ Args:
+ nodes (list): List of node names/objects OR list of dicts with specific style per node.
+ shape (str, optional): Default shape. Defaults to "circle".
+ fill (str, optional): Default fill color. Defaults to "red".
+ r (float, optional): Default radius. Defaults to 5.0.
+ stroke (str, optional): Default stroke color. Defaults to None.
+ stroke_width (float, optional): Default stroke width. Defaults to 1.0.
+ rotation (float, optional): Default rotation. Defaults to 0.
+ dx (float, optional): Cartesian X offset. Defaults to 0.
+ dy (float, optional): Cartesian Y offset. Defaults to 0.
+ orient (bool, optional): If True, rotates shape to match the branch angle. Defaults to False.
+ """
+ self._pre_flight_check()
+ if isinstance(nodes, list) and nodes and isinstance(nodes[0], dict):
+ for s in nodes:
+ self.add_node_shapes([s.get("node")], s.get("shape", shape), s.get("fill", fill), s.get("r", r),
+ s.get("stroke", stroke), s.get("stroke_width", stroke_width), s.get("rotation", rotation), orient=orient)
+ return
+ for item in nodes:
+ try:
+ node = self.t.search_nodes(name=item)[0] if isinstance(item, str) else item
+ ang = self._rot_ang(node.angle)
+ x, y = polar_to_cartesian(ang, node.rad)
+ rot = rotation + (ang if orient else 0.0)
+ self._draw_shape_at(x + dx, y + dy, shape, fill, r, stroke, stroke_width, rot)
+ except: continue
+
+ def add_branch_shapes(self, specs, default_where=0.5, offset=0.0, orient=None):
+ """
+ Adds geometric markers along branches (e.g., to visualize events).
+ Args:
+ specs (list or DataFrame): Data definitions. Must contain 'branch' key/column.
+ default_where (float, optional): Position along branch (0.0 to 1.0). Defaults to 0.5.
+ offset (float, optional): Perpendicular offset from the branch line. Defaults to 0.0.
+ orient (str, optional): Rotation mode: "along" (matches branch), "perp" (90 deg to branch), or None.
+ """
+ self._pre_flight_check()
+ if hasattr(specs, "to_dict"): specs = specs.to_dict(orient="records")
+ for s in specs:
+ br = s.get("branch")
+ if not br: continue
+ try:
+ node = self.t & br if isinstance(br, str) else br
+ where = s.get("where", default_where)
+ x, y, ang = self._edge_point(node, where=where)
+ if offset != 0:
+ perp = math.radians(ang + 90)
+ x += offset * math.cos(perp)
+ y += offset * math.sin(perp)
+ rot = s.get("rotation", 0.0)
+ if orient == "along": rot = ang
+ elif orient == "perp": rot = ang + 90
+ r_val = float(s.get("r", s.get("size", 10.0) / 2.0))
+ self._draw_shape_at(x, y, s.get("shape", "circle"), s.get("fill", "blue"), r_val,
+ s.get("stroke"), s.get("stroke_width", 1.0), rot, s.get("opacity", 1.0))
+ except: continue
+
+ def plot_transfers(self, transfers, mode="midpoint", curve_type="C", filter_below=0.0, use_gradient=True,
+ gradient_colors=("purple", "orange"), color="orange", stroke_width=5.0, arc_intensity=50.0, opacity=0.6):
+ """
+ Plots curved HGT (Horizontal Gene Transfer) links between lineages in radial space.
+
+ Args:
+ transfers (list or DataFrame): List of dicts with 'from', 'to', 'freq' keys.
+ mode (str, optional): "time" (uses 'time' key for position) or "midpoint". Defaults to "midpoint".
+ curve_type (str, optional): Bezier curve type ("C" or "S"). Defaults to "C".
+ filter_below (float, optional): Minimum frequency to plot. Defaults to 0.0.
+ use_gradient (bool, optional): If True, colors fade from source to dest. Defaults to True.
+ gradient_colors (tuple, optional): (Start color, End color). Defaults to ("purple", "orange").
+ color (str, optional): Solid color if gradients are disabled. Defaults to "orange".
+ stroke_width (float, optional): Base thickness of the transfer lines. Defaults to 5.0.
+ arc_intensity (float, optional): Curvature strength (control point distance). Defaults to 50.0.
+ opacity (float, optional): Opacity of the transfer lines. Defaults to 0.6.
+ """
+ if hasattr(transfers, "to_dict"): transfers = transfers.to_dict(orient="records")
+ name2node = {n.name: n for n in self.t.traverse()}
+ self._pre_flight_check()
+ def get_where(node, t_ev):
+ p = node.up
+ if not p: return 0.0
+ t0, t1 = float(getattr(p, "time_from_origin", 0.0)), float(getattr(node, "time_from_origin", 0.0))
+ return max(0.0, min(1.0, (float(t_ev) - t0) / (t1 - t0))) if abs(t1 - t0) > 1e-12 else 0.5
+
for tr in transfers:
freq = float(tr.get("freq", 1.0))
- if freq < filter_below:
- continue
-
- src = name_to_node.get(tr.get("from"))
- dst = name_to_node.get(tr.get("to"))
- if src is None or dst is None:
- continue
-
- src_angle = float(src.angle)
- dst_angle = float(dst.angle)
-
- # Compute endpoints
- if (
- mode == "time"
- and tr.get("time") is not None
- and hasattr(src, "time_from_origin")
- and hasattr(dst, "time_from_origin")
- ):
- tt = float(tr["time"])
- w_src = where_from_time(src, tt)
- w_dst = where_from_time(dst, tt)
- sx, sy, _ = self._edge_point(src, w_src)
- ex, ey, _ = self._edge_point(dst, w_dst)
-
- # radii for control points
- src_r_mid = (sx * sx + sy * sy) ** 0.5
- dst_r_mid = (ex * ex + ey * ey) ** 0.5
-
+ if freq < filter_below: continue
+ src, dst = name2node.get(tr.get("from")), name2node.get(tr.get("to"))
+ if not src or not dst: continue
+ if mode == "time" and "time" in tr:
+ x_s, y_s, _ = self._edge_point(src, get_where(src, tr["time"]))
+ x_e, y_e, _ = self._edge_point(dst, get_where(dst, tr["time"]))
else:
- # midpoint fallback (old behavior)
- src_r = float(src.rad)
- dst_r = float(dst.rad)
- src_parent_r = float(src.up.rad) if src.up else (src_r - 20.0)
- dst_parent_r = float(dst.up.rad) if dst.up else (dst_r - 20.0)
-
- src_r_mid = (src_r + src_parent_r) / 2.0
- dst_r_mid = (dst_r + dst_parent_r) / 2.0
-
- sx, sy = radial_converter(src_angle, src_r_mid, self.style.rotation)
- ex, ey = radial_converter(dst_angle, dst_r_mid, self.style.rotation)
-
- width = (stroke_width * freq) if use_thickness else stroke_width
- path = draw.Path(stroke_width=width, fill="none", stroke_opacity=opacity)
-
+ x_s, y_s, _ = self._edge_point(src, 0.5); x_e, y_e, _ = self._edge_point(dst, 0.5)
+
+ path = draw.Path(stroke_width=stroke_width * freq, fill="none", stroke_opacity=opacity)
if use_gradient:
- grad_id = f"tr_grad_{random.randint(0, 999999)}"
- grad = draw.LinearGradient(sx, sy, ex, ey, id=grad_id)
- grad.add_stop(0, gradient_colors[0])
- grad.add_stop(1, gradient_colors[1])
- self.d.append(grad)
- path.args["stroke"] = grad
- else:
- path.args["stroke"] = color
-
- path.M(sx, sy)
-
- if curve_type.upper() == "S":
- c1x, c1y = radial_converter(src_angle, src_r_mid + arc_intensity, self.style.rotation)
- c2x, c2y = radial_converter(dst_angle, dst_r_mid - arc_intensity, self.style.rotation)
- path.C(c1x, c1y, c2x, c2y, ex, ey)
- else:
- c1x, c1y = radial_converter(src_angle, src_r_mid - arc_intensity, self.style.rotation)
- c2x, c2y = radial_converter(dst_angle, dst_r_mid - arc_intensity, self.style.rotation)
- path.C(c1x, c1y, c2x, c2y, ex, ey)
-
- self.d.append(path)
-
-
- def add_transfer_legend(
- self,
- title="Transfer Frequency",
- colors=("purple", "orange"),
- low=0.1,
- high=1.0,
- source_label="Source",
- arrival_label="Arrival",
- show_frequency=False,
- show_direction=True,
- margin=20,
- ):
- """Add a transfer legend.
-
- By default this shows *direction* (two solid colors): a "Source" swatch and an
- "Arrival" swatch. If you also want a frequency scale, set
- ``show_frequency=True``.
-
- Parameters
- ----------
- colors:
- Tuple (source_color, arrival_color). These match the gradient endpoints
- used by ``plot_transfers(..., gradient_colors=...)``.
- show_frequency:
- Draws a gradient bar + numeric low/high labels.
- show_direction:
- Draws two solid color swatches labelled source/arrival.
- """
- if not (show_frequency or show_direction):
- return
-
- font_size = 11
- num_font_size = 9
- sw = 14
- gap = 6
- pad_x = 10
- top_pad = 10
- bottom_pad = 10
- row_h = 18
- bar_h = 12
- bar_w = 110
-
- # Estimate legend width from label lengths (drawsvg doesn't expose text metrics).
- max_label_len = 0
- if show_direction:
- max_label_len = max(len(str(source_label)), len(str(arrival_label)))
- if show_frequency:
- max_label_len = max(max_label_len, len(str(title)))
- est_text_w = max_label_len * font_size * 0.60
- w = int(pad_x + sw + gap + est_text_w + pad_x)
- if show_frequency:
- w = max(w, pad_x + bar_w + pad_x)
-
- # Height: SVG y-axis increases downward.
- content_h = top_pad
- if show_frequency:
- # title + bar + low/high labels + spacing
- content_h += (font_size + 4) + bar_h + (num_font_size + 10) + 6
- if show_direction:
- content_h += (2 * row_h)
- content_h += bottom_pad
- box_h = content_h
-
- x = -self.style.width / 2 + 30
- y = self.style.height / 2 - margin - box_h
-
- # Background
- self.d.append(
- draw.Rectangle(x, y, w, box_h, fill="white", stroke="black", stroke_width=1, opacity=0.9)
- )
-
- cursor_y = y + top_pad + 2
-
- # Optional frequency scale
- if show_frequency:
- self.d.append(draw.Text(title, font_size, x + 10, cursor_y, font_family="sans-serif", font_weight="bold"))
- cursor_y += 10
-
- grad_id = f"legend_transfer_grad_{random.randint(0, 999999)}"
- grad = draw.LinearGradient(x + 10, cursor_y + bar_h / 2, x + 10 + bar_w, cursor_y + bar_h / 2, id=grad_id)
- grad.add_stop(0, colors[0])
- grad.add_stop(1, colors[1])
- self.d.append(grad)
- self.d.append(draw.Rectangle(x + 10, cursor_y, bar_w, bar_h, fill=grad))
-
- # low/high numeric labels
- self.d.append(draw.Text(f"{low}", num_font_size, x + 10, cursor_y + bar_h + 12, font_family="sans-serif"))
- self.d.append(draw.Text(f"{high}", num_font_size, x + 10 + bar_w - 15, cursor_y + bar_h + 12, font_family="sans-serif"))
- cursor_y += bar_h + 24
-
- # Direction swatches
- if show_direction:
- sw = 14
- self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[0]))
- self.d.append(draw.Text(source_label, font_size, x + 30, cursor_y + 11, font_family="sans-serif"))
- cursor_y += row_h
-
- self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[1]))
- self.d.append(draw.Text(arrival_label, font_size, x + 30, cursor_y + 11, font_family="sans-serif"))
-
- def add_ring_heatmap(
- self,
- values,
- width: float = 20.0,
- padding: float = 10.0,
- vmin: float | None = None,
- vmax: float | None = None,
- low_color: str = "#f7fbff",
- high_color: str = "#08306b",
- missing_color: str = "white",
- ):
- """Add a numeric heatmap ring around the radial tree.
-
- Parameters
- ----------
- values:
- Mapping leaf_name -> numeric value, OR a pandas Series, OR a DataFrame with columns
- ["leaf", "value"] (any column names are fine if you pass a dict instead).
- width:
- Ring thickness in px.
- padding:
- Distance outward from the tree radius.
- vmin, vmax:
- Normalization bounds. If None, computed from provided values.
- low_color, high_color:
- Gradient endpoints (hex colors).
- missing_color:
- Fill used if a leaf is missing from the values.
- """
- # ---- normalize inputs to dict[str, float] ----
- if hasattr(values, "to_dict") and not isinstance(values, dict):
- # pandas Series
- values = values.to_dict()
-
- if hasattr(values, "columns") and hasattr(values, "to_dict") and not isinstance(values, dict):
- # DataFrame-like: try common formats
- cols = list(values.columns)
- if len(cols) >= 2:
- leaf_col, val_col = cols[0], cols[1]
- values = dict(zip(values[leaf_col].astype(str), values[val_col].astype(float)))
- else:
- values = {}
-
- if not isinstance(values, dict):
- raise TypeError("values must be a dict, pandas Series, or a DataFrame with 2 columns.")
-
- # ---- helpers for color interpolation ----
- def _hex_to_rgb(h: str) -> tuple[int, int, int]:
- h = h.lstrip("#")
- return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16)
-
- def _rgb_to_hex(rgb: tuple[int, int, int]) -> str:
- return "#{:02x}{:02x}{:02x}".format(*rgb)
-
- def _lerp(a: float, b: float, t: float) -> float:
- return a + (b - a) * t
-
- c0 = _hex_to_rgb(low_color)
- c1 = _hex_to_rgb(high_color)
-
- def _lerp_color(t: float) -> str:
- t = 0.0 if t < 0.0 else (1.0 if t > 1.0 else t)
- r = int(_lerp(c0[0], c1[0], t))
- g = int(_lerp(c0[1], c1[1], t))
- b = int(_lerp(c0[2], c1[2], t))
- return _rgb_to_hex((r, g, b))
+ gid = generate_id("tr_grad")
+ grad = draw.LinearGradient(x_s, y_s, x_e, y_e, id=gid)
+ grad.add_stop(0, gradient_colors[0]); grad.add_stop(1, gradient_colors[1])
+ self.drawing.append(grad); path.args["stroke"] = grad
+ else: path.args["stroke"] = color
+ path.M(x_s, y_s)
+ mx, my = (x_s + x_e) / 2.0, (y_s + y_e) / 2.0
+ dist = math.sqrt(mx**2 + my**2)
+ pull = max(0, dist - arc_intensity) if dist > 0 else 0
+ cx = (mx/dist)*pull if dist > 0 else 0
+ cy = (my/dist)*pull if dist > 0 else 0
+ path.Q(cx, cy, x_e, y_e); self.drawing.append(path)
+
+ def add_time_axis(self, ticks, label="", tick_labels=None, stroke="gray", stroke_width=1.0, stroke_dasharray="3,3", font_size=10, label_angle=90):
+ """
+ Adds concentric rings representing evolutionary time or distance steps.
+
+ Args:
+ ticks (list): List of radial distances to draw rings at.
+ label (str, optional): Unused in radial currently, but kept for interface consistency.
+ tick_labels (list, optional): Custom text for each tick. Defaults to numerical values.
+ stroke (str, optional): Color of rings. Defaults to "gray".
+ stroke_width (float, optional): Thickness of rings. Defaults to 1.0.
+ stroke_dasharray (str, optional): Dash pattern. Defaults to "3,3".
+ font_size (int, optional): Size of tick labels. Defaults to 10.
+ label_angle (float, optional): Angle (degrees) to place labels at. Defaults to 90.
+ """
+ self._pre_flight_check()
+ for i, t in enumerate(ticks):
+ r = t * self.sf
+ self.drawing.append(draw.Circle(0, 0, r, fill="none", stroke=stroke, stroke_width=stroke_width, stroke_dasharray=stroke_dasharray))
+ lx, ly = polar_to_cartesian(label_angle, r)
+
+ # Use custom label if provided
+ display_text = str(tick_labels[i]) if tick_labels is not None and i < len(tick_labels) else str(t)
+ self.drawing.append(draw.Text(display_text, font_size, lx, ly, fill="black", stroke="white", stroke_width=0.5,
+ paint_order="stroke", text_anchor="middle", dominant_baseline="middle", font_family="Arial"))
- # ---- determine vmin/vmax from provided numeric values ----
- vals = []
- for k, v in values.items():
+ def add_heatmap(self, values, width=15.0, offset=10.0, low_color="#f7fbff", high_color="#08306b", border_color="none", border_width=0.5):
+ """
+ Adds a circular heatmap ring surrounding the leaf tips.
+
+ Args:
+ values (dict): Mapping of {node_name: numeric_value}.
+ width (float, optional): Thickness of the heatmap ring. Defaults to 15.0.
+ offset (float, optional): Radial offset from leaf tips. Defaults to 10.0.
+ low_color (str, optional): Color for minimum value. Defaults to "#f7fbff".
+ high_color (str, optional): Color for maximum value. Defaults to "#08306b".
+ border_color (str, optional): Border color for cells. Defaults to "none".
+ border_width (float, optional): Border thickness. Defaults to 0.5.
+ """
+ if hasattr(values, "to_dict"): values = values.to_dict()
+ vals = [float(v) for v in values.values() if isinstance(v, (int, float))]
+ if not vals: return
+ vmin, vmax = min(vals), max(vals) + 1e-12
+ angle_span = float(self.style.degrees) / len(self.t.get_leaves())
+ for l in self.t.get_leaves():
+ val = values.get(l.name)
+ fill = lerp_color(low_color, high_color, (float(val) - vmin) / (vmax - vmin)) if val is not None else "white"
+ start_ang, end_ang = self._rot_ang(l.angle - angle_span/2), self._rot_ang(l.angle + angle_span/2)
+ r_i, r_o = l.rad + offset, l.rad + offset + width
+ p = draw.Path(fill=fill, stroke=border_color, stroke_width=border_width)
+ p1x, p1y = polar_to_cartesian(start_ang, r_i); p2x, p2y = polar_to_cartesian(end_ang, r_i)
+ p3x, p3y = polar_to_cartesian(end_ang, r_o); p4x, p4y = polar_to_cartesian(start_ang, r_o)
+ sweep = 1 if end_ang > start_ang else 0
+ p.M(p1x, p1y).A(r_i, r_i, 0, 0, sweep, p2x, p2y).L(p3x, p3y).A(r_o, r_o, 0, 0, 0 if sweep == 1 else 1, p4x, p4y).Z()
+ self.drawing.append(p)
+
+ def add_clade_labels(self, labels, offset=40.0, stroke_width=1.5, color="black", font_size=None):
+ """
+ Adds circular arcs outside the tree to group and label specific clades.
+
+ Args:
+ labels (dict): Mapping of {node_name_or_object: label_text}.
+ offset (float, optional): Radial offset from the outermost leaf. Defaults to 40.0.
+ stroke_width (float, optional): Thickness of the bracket arc. Defaults to 1.5.
+ color (str, optional): Color of the arc and text. Defaults to "black".
+ font_size (int, optional): Font size. Defaults to style default.
+ """
+ self._pre_flight_check()
+ fs = font_size or self.style.font_size
+ max_rad = max(l.rad for l in self.t.get_leaves())
+ arc_rad = max_rad + offset
+ for target, text in labels.items():
try:
- vals.append(float(v))
- except Exception:
- continue
-
- if not vals:
- return
-
- if vmin is None:
- vmin = min(vals)
- if vmax is None:
- vmax = max(vals)
- if float(vmax) == float(vmin):
- vmax = float(vmin) + 1e-12
-
- vmin = float(vmin)
- vmax = float(vmax)
-
- # ---- ring geometry (same as add_ring) ----
- r_in = float(self.style.radius) + float(padding)
- r_out = r_in + float(width)
+ node = self.t.search_nodes(name=target)[0] if isinstance(target, str) else target
+ leaves = node.get_leaves()
+ angles = [l.angle for l in leaves]
+ start_ang, end_ang = self._rot_ang(min(angles)), self._rot_ang(max(angles))
+ mid_ang = (start_ang + end_ang) / 2.0
+ p1x, p1y = polar_to_cartesian(start_ang, arc_rad); p2x, p2y = polar_to_cartesian(end_ang, arc_rad)
+ sweep = 1 if end_ang > start_ang else 0
+ p = draw.Path(stroke=color, stroke_width=stroke_width, fill="none")
+ p.M(p1x, p1y).A(arc_rad, arc_rad, 0, 0, sweep, p2x, p2y)
+ self.drawing.append(p)
+ tx, ty = polar_to_cartesian(mid_ang, arc_rad + 8)
+ text_rot, anchor = mid_ang, "start"
+ if 90 < (mid_ang % 360) < 270:
+ text_rot += 180; anchor = "end"
+ self.drawing.append(draw.Text(text, fs, tx, ty, fill=color, font_family=self.style.font_family,
+ transform=f"rotate({text_rot}, {tx}, {ty})", text_anchor=anchor, dominant_baseline="middle"))
+ except: continue
+
+ def plot_continuous_variable(self, node_to_rgb, stroke_width=None):
+ """
+ Maps a continuous trait to branch colors using a gradient interpolation.
- for l in self.t.get_leaves():
- name = str(l.name)
- if name not in values:
- fill = missing_color
- else:
- try:
- v = float(values[name])
- tnorm = (v - vmin) / (vmax - vmin)
- fill = _lerp_color(tnorm)
- except Exception:
- fill = missing_color
+ Args:
+ node_to_rgb (dict): Mapping of {node: (r, g, b)} tuples (values 0-255).
+ stroke_width (float, optional): Thickness of the branches. Defaults to style default.
+ """
+ def _to_hex(rgb): return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2]))
+ sw = stroke_width or self.style.branch_stroke_width
+ for node in self.t.traverse():
+ if node.is_root(): continue
+ c_c = node_to_rgb.get(node) or node_to_rgb.get(node.name)
+ c_p = node_to_rgb.get(node.up) or node_to_rgb.get(node.up.name)
+ if c_c and c_p: self.gradient_branch(node, stroke_width=sw, colors=(_to_hex(c_p), _to_hex(c_c)))
+ else: self.highlight_branch(node, stroke_width=sw)
+ for n in self.t.traverse("postorder"):
+ if not n.is_leaf() and (col := (node_to_rgb.get(n) or node_to_rgb.get(n.name))):
+ x, y = self._node_xy(n); self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=_to_hex(col)))
- a1, a2 = l.angle - self.angle_step / 2.0, l.angle + self.angle_step / 2.0
- p = draw.Path(fill=fill, stroke="none")
+ def plot_categorical_trait(self, data, value_col, node_col="Node", palette=None, stroke_width=None, default_color="black"):
+ """
+ Maps categorical traits to branch and node colors.
- s_i_x, s_i_y = radial_converter(a1, r_in, self.style.rotation)
- e_i_x, e_i_y = radial_converter(a2, r_in, self.style.rotation)
- s_o_x, s_o_y = radial_converter(a1, r_out, self.style.rotation)
- e_o_x, e_o_y = radial_converter(a2, r_out, self.style.rotation)
+ If a parent and child have different categories, the branch color transitions (gradient).
- p.M(s_o_x, s_o_y)\
- .A(r_out, r_out, 0, 0, 1, e_o_x, e_o_y)\
- .L(e_i_x, e_i_y)\
- .A(r_in, r_in, 0, 0, 0, s_i_x, s_i_y)\
- .Z()
+ Args:
+ data (DataFrame or dict): Data containing trait values per node.
+ value_col (str): Column name for the trait values (if DataFrame).
+ node_col (str, optional): Column name for node names. Defaults to "Node".
+ palette (dict, optional): Color map {value: color}. Defaults to auto-generated.
+ stroke_width (float, optional): Branch thickness. Defaults to style default.
+ default_color (str, optional): Fallback color. Defaults to "black".
+ """
+ if hasattr(data, "to_dict"): mapping = dict(zip(data[node_col].astype(str), data[value_col]))
+ else: mapping = data
+ if palette is None:
+ unique_vals = sorted(list(set(mapping.values())))
+ defaults = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33"]
+ palette = {val: defaults[i % len(defaults)] for i, val in enumerate(unique_vals)}
+ sw = stroke_width or self.style.branch_stroke_width
+ def get_color(n): return palette.get(mapping.get(n.name), default_color)
+ for node in self.t.traverse():
+ if node.is_root(): continue
+ c_n, c_p = get_color(node), get_color(node.up)
+ if c_n != c_p: self.gradient_branch(node, colors=(c_p, c_n), stroke_width=sw)
+ else: self.highlight_branch(node, color=c_n, stroke_width=sw)
+ for n in self.t.traverse("postorder"):
+ if not n.is_leaf():
+ x, y = self._node_xy(n); self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=get_color(n)))
- self.d.append(p)
+ def add_transfer_legend(self, colors=("purple", "orange"), labels=("Departure", "Arrival"), x=None, y=None, font_size=14):
+ """
+ Adds a legend specifically for Horizontal Gene Transfers.
+ """
+ palette = {labels[0]: colors[0], labels[1]: colors[1]}
+ self.add_categorical_legend(palette, title="Transfer Event", x=x, y=y, font_size=font_size)
+ def add_leaf_images(self, image_dir, extension=".png", width=40, height=40, offset=10):
+ """
+ Places images next to leaf tips in the radial layout.
+ """
+ self._pre_flight_check()
+ for leaf in self.t.get_leaves():
+ lx, ly = self._leaf_xy(leaf, offset=offset)
+ path = os.path.join(image_dir, f"{leaf.name}{extension}")
+ if os.path.exists(path):
+ with open(path, "rb") as f:
+ uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}"
+ self.drawing.append(draw.Image(lx - width/2, ly - height/2, width, height, path=uri))
+
+ def add_ancestral_images(self, image_dir, extension=".png", width=40, height=40, omit=None):
+ """
+ Places images at internal node positions in the radial layout.
+ """
+ self._pre_flight_check()
+ for node in self.t.traverse():
+ if not node.is_leaf():
+ if omit and node.name in omit: continue
+
+ nx, ny = self._node_xy(node)
+ path = os.path.join(image_dir, f"{node.name}{extension}")
+ if os.path.exists(path):
+ with open(path, "rb") as f:
+ uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}"
+ self.drawing.append(draw.Image(nx - width/2, ny - height/2, width, height, path=uri))
+
+ def add_scale_bar(self, length, label=None, x=None, y=None, stroke="black", stroke_width=2.0):
+ """
+ Adds a physical scale bar representing a distance value.
+ """
+ self._pre_flight_check()
+ px = float(length) * self.sf
+ label = label or str(length)
+ x = x if x is not None else -self.style.width/2 + 20
+ y = y if y is not None else self.style.height/2 - 20
+ self.drawing.append(draw.Line(x, y, x + px, y, stroke=stroke, stroke_width=stroke_width))
+ self.drawing.append(draw.Text(label, self.style.font_size, x + px/2, y - 8, center=True))
- def highlight_clade(self, node, color="lightblue", opacity=0.3, padding=10):
- """Draws a shaded sector (pie slice) behind a specific clade."""
- if node.is_leaf(): return
-
- leaves = node.get_leaves()
- angles = [l.angle for l in leaves]
-
- # Define Angular bounds
- # We extend by half a step to cover the "gap" between this clade and neighbors
- min_ang = min(angles) - (self.angle_step / 2.0)
- max_ang = max(angles) + (self.angle_step / 2.0)
-
- # Define Radial bounds
- # Inner: from the node itself (or root 0 if you want full pie)
- # Outer: furthest leaf + padding
- r_inner = float(node.rad)
- r_outer = max(float(l.rad) for l in leaves) + padding
-
- # Convert to SVG Arc Path
- # Move to Inner Start -> Line to Outer Start -> Arc to Outer End -> Line to Inner End -> Arc to Inner Start
- sx_in, sy_in = radial_converter(min_ang, r_inner, self.style.rotation)
- ex_in, ey_in = radial_converter(max_ang, r_inner, self.style.rotation)
-
- sx_out, sy_out = radial_converter(min_ang, r_outer, self.style.rotation)
- ex_out, ey_out = radial_converter(max_ang, r_outer, self.style.rotation)
-
- # Large arc flag (0 or 1) depending on if angle > 180
- angle_diff = max_ang - min_ang
- large_arc = 1 if angle_diff > 180 else 0
-
- path = draw.Path(fill=color, fill_opacity=opacity, stroke="none")
- path.M(sx_in, sy_in)\
- .L(sx_out, sy_out)\
- .A(r_outer, r_outer, 0, large_arc, 1, ex_out, ey_out)\
- .L(ex_in, ey_in)\
- .A(r_inner, r_inner, 0, large_arc, 0, sx_in, sy_in)\
- .Z()
-
- self.d.append(path)
-
-
- def add_time_axis(
- self,
- ticks: list[float],
- label: str = "Time",
- tick_size: float = 0, # Usually 0 for radial rings
- label_angle: float = 90, # Angle where text labels appear
- stroke: str = "#ccc",
- stroke_width: float = 1.0,
- stroke_dasharray: str = "4,2", # Dashed lines look better for grids
- font_size: int = 10,
- ):
- """Add concentric rings representing time/distance."""
-
- for t in ticks:
- r = t * self.sf
-
- # 1. Draw the ring (circle)
- # If tree is full 360, use Circle. If partial, use Arc (omitted for brevity, assuming 360 usually)
- self.d.append(draw.Circle(0, 0, r, fill="none", stroke=stroke,
- stroke_width=stroke_width, stroke_dasharray=stroke_dasharray))
-
- # 2. Draw the label
- # We place the label at 'label_angle'
- lx, ly = radial_converter(label_angle, r, 0) # rotation=0 so angle is absolute
-
- # Simple background rect for readability? (Optional)
- self.d.append(draw.Text(str(t), font_size, lx, ly,
- fill="black", stroke="white", stroke_width=0.5, paint_order="stroke",
- text_anchor="middle", dominant_baseline="middle", font_family="Arial"))
-
- # Axis Title? (Optional, maybe placed at the outermost tick)
- if ticks:
- max_r = max(ticks) * self.sf
- lx, ly = radial_converter(label_angle, max_r + 20, 0)
- self.d.append(draw.Text(label, font_size + 2, lx, ly,
- font_weight="bold", text_anchor="middle", dominant_baseline="middle"))
\ No newline at end of file
diff --git a/src/phylustrator/drawing/vertical.py b/src/phylustrator/drawing/vertical.py
index 8a36830..082e2ef 100644
--- a/src/phylustrator/drawing/vertical.py
+++ b/src/phylustrator/drawing/vertical.py
@@ -1,937 +1,661 @@
import drawsvg as draw
from .base import BaseDrawer
+from ..utils import to_hex, lerp_color, generate_id
+import math
import random
+import os
+import base64
class VerticalTreeDrawer(BaseDrawer):
+ """
+ Drawer class for rendering phylogenetic trees in a rectangular vertical layout.
+
+ Nodes are positioned using Cartesian coordinates where Y represents vertical
+ positioning (tips) and X represents distance/time.
+ """
def __init__(self, tree, style=None):
+ """
+ Initializes the VerticalTreeDrawer and calculates the rectangular layout.
+
+ Args:
+ tree (ete3.TreeNode): The tree object to be visualized.
+ style (TreeStyle, optional): Custom style configuration.
+ """
super().__init__(tree, style)
self._calculate_layout()
- def _node_xy(self, node):
+ def _node_xy(self, node) -> tuple[float, float]:
+ """Calculates Cartesian (x, y) coordinates for a node based on current scaling."""
if not hasattr(node, "coordinates"):
self._calculate_layout()
- x, y = node.coordinates # coordinates is (x,y)
+ x, y = node.coordinates
return float(x), float(y)
- def _leaf_xy(self, leaf, offset: float = 0.0):
+ def _leaf_xy(self, leaf, offset: float = 0.0) -> tuple[float, float]:
+ """Calculates coordinates for a leaf tip, with an optional horizontal offset."""
x, y = self._node_xy(leaf)
return (x + float(offset), y)
- def _edge_point(self, child, where: float):
- """
- Place markers along the *horizontal* part of the rectangular edge:
- (x_parent, y_child) -> (x_child, y_child)
-
- This keeps y constant, so where=0 is at the elbow (NOT on the shared vertical),
- and where=1 is at the child tip.
- """
+ def _edge_point(self, child, where: float) -> tuple[float, float, float]:
+ """Finds a point along the horizontal branch segment leading to a child."""
parent = child.up
if parent is None:
- x, y = self._node_xy(child)
- return x, y, 0.0
-
- # Ensure layout exists
- if not hasattr(parent, "coordinates") or not hasattr(child, "coordinates"):
- self._calculate_layout()
-
- x_parent, _y_parent = self._node_xy(parent)
- x_child, y_child = self._node_xy(child)
-
+ return (*self._node_xy(child), 0.0)
+ x_p, _ = self._node_xy(parent)
+ x_c, y_c = self._node_xy(child)
t = max(0.0, min(1.0, float(where)))
- x = x_parent + (x_child - x_parent) * t
- y = y_child
+ return x_p + (x_c - x_p) * t, y_c, (0.0 if (x_c - x_p) >= 0 else 180.0)
- # Horizontal direction only (no weird rotations)
- edge_ang = 0.0 if (x_child - x_parent) >= 0 else 180.0
- return x, y, edge_ang
-
- def _calculate_layout(self, max_width=None):
+ def _calculate_layout(self, max_width: float | None = None):
"""
- Calculates tree coordinates.
- :param max_width: If provided, scales the tree to this width instead of the style width.
+ Computes Cartesian coordinates for all nodes in rectangular space.
+
+ Args:
+ max_width (float, optional): Force a specific drawing width.
"""
- # 1. Calculate Distances
- max_dist = 0
+ # 1. Horizontal Scaling (Distances)
+ max_dist = 0.0
for n in self.t.traverse("preorder"):
n.dist_to_root = n.up.dist_to_root + n.dist if not n.is_root() else getattr(n, "dist", 0.0)
- if n.dist_to_root > max_dist: max_dist = n.dist_to_root
+ max_dist = max(max_dist, n.dist_to_root)
self.total_tree_depth = max_dist
-
- # 2. Handle Scaling and Margins
- horizontal_padding = 100
- target_width = max_width if max_width is not None else self.style.width
- self.sf = (target_width - (horizontal_padding * 2)) / max_dist if max_dist > 0 else 1
-
- # Center the root X based on padding
- self.root_x = -self.style.width / 2 + horizontal_padding
+ pad = self.style.margin
+ target_w = max_width if max_width is not None else self.style.width
+ self.sf = (target_w - (pad * 2)) / max_dist if max_dist > 0 else 1.0
+ self.root_x = -self.style.width / 2 + pad
- # 3. Calculate Vertical Positions
+ # 2. Vertical Scaling (Leaves)
leaves = self.t.get_leaves()
- y_padding = 100
- y_step = (self.style.height - (y_padding * 2)) / max(len(leaves)-1, 1)
- start_y = -self.style.height / 2 + y_padding
+ y_step = (self.style.height - (pad * 2)) / max(len(leaves)-1, 1)
+ start_y = -self.style.height / 2 + pad
for i, leaf in enumerate(leaves):
leaf.y_coord = start_y + (i * y_step)
leaf.coordinates = (self.root_x + (leaf.dist_to_root * self.sf), leaf.y_coord)
+ # 3. Internal Centering
for n in self.t.traverse("postorder"):
if not n.is_leaf():
n.y_coord = sum(c.y_coord for c in n.children) / len(n.children)
n.coordinates = (self.root_x + (n.dist_to_root * self.sf), n.y_coord)
+ self._layout_calculated = True
def draw(self, branch2color=None, right_margin=0):
"""
- Draws the tree while reserving space on the right for images.
+ Draws the tree skeleton using rectangular elbows.
+
+ Args:
+ branch2color (dict, optional): Mapping of ete3.TreeNode to color strings.
+ right_margin (float, optional): Extra space (in pixels) to reserve on the right edge
+ (e.g., for heatmaps or labels). Defaults to 0.
"""
- # 1. Re-run layout calculation with the reserved right margin
- self._calculate_layout(max_width=self.style.width - right_margin)
+ self._pre_flight_check()
+ if right_margin > 0:
+ self._calculate_layout(max_width=self.style.width - right_margin)
for n in self.t.traverse("postorder"):
x, y = n.coordinates
-
- # Resolve Color
- color = self.style.branch_color
- if branch2color and n in branch2color:
- color = branch2color[n]
+ color = branch2color.get(n, self.style.branch_color) if branch2color else self.style.branch_color
- # 2. Horizontal branch
if not n.is_root():
px, py = n.up.coordinates
- self.d.append(draw.Line(px, y, x, y, stroke=color,
- stroke_width=self.style.branch_size, stroke_linecap="round"))
+ # Horizontal segment
+ self.drawing.append(draw.Line(px, y, x, y, stroke=color,
+ stroke_width=self.style.branch_stroke_width, stroke_linecap="round"))
else:
- # Root "handle"
- self.d.append(draw.Line(x - 20, y, x, y, stroke=color, stroke_width=self.style.branch_size))
+ # Root stub
+ self.drawing.append(draw.Line(x - self.style.root_stub_length, y, x, y,
+ stroke=color, stroke_width=self.style.branch_stroke_width))
- # 3. Vertical connector
if not n.is_leaf():
+ # Vertical connector (elbow)
y_min = min(c.y_coord for c in n.children)
y_max = max(c.y_coord for c in n.children)
- self.d.append(draw.Line(x, y_min, x, y_max, stroke=color,
- stroke_width=self.style.branch_size, stroke_linecap="round"))
- self.d.append(draw.Circle(x, y, self.style.node_size, fill=color))
- else:
- self.d.append(draw.Circle(x, y, self.style.leaf_size, fill=self.style.leaf_color))
+ self.drawing.append(draw.Line(x, y_min, x, y_max, stroke=color,
+ stroke_width=self.style.branch_stroke_width, stroke_linecap="round"))
+ if self.style.node_r > 0:
+ self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=color))
+ elif self.style.leaf_r > 0:
+ self.drawing.append(draw.Circle(x, y, self.style.leaf_r, fill=self.style.leaf_color))
def highlight_clade(self, node, color="lightblue", opacity=0.3, padding=10):
"""
- Draws a shaded rectangle behind a specific clade.
+ Draws a shaded rectangular background behind a specific clade.
+
+ Args:
+ node (ete3.TreeNode): The root of the clade.
+ color (str, optional): Fill color. Defaults to "lightblue".
+ opacity (float, optional): Fill opacity. Defaults to 0.3.
+ padding (float, optional): Padding around the clade box. Defaults to 10.
"""
+ self._pre_flight_check()
leaves = node.get_leaves()
-
- # Calculate bounds based on node and leaf coordinates
x_start, _ = node.coordinates
-
- # X range: From node to the furthest tip in this clade
x_max = max(l.coordinates[0] for l in leaves)
-
- # Y range: Spanning all leaves in the clade
y_min = min(l.y_coord for l in leaves)
y_max = max(l.y_coord for l in leaves)
-
- # Geometry:
- # width is (x_max - x_start) + a small padding for the tips
- # height is (y_max - y_min) + padding on top/bottom
- rect_x = x_start - (padding / 2)
- rect_y = y_min - padding
- rect_w = (x_max - x_start) + padding # Reduced the "60" extension to just padding
- rect_h = (y_max - y_min) + (2 * padding)
-
- self.d.append(draw.Rectangle(
- rect_x, rect_y, rect_w, rect_h,
- fill=color,
- fill_opacity=opacity,
- stroke="none"
+ self.drawing.append(draw.Rectangle(
+ x_start - (padding / 2), y_min - padding,
+ (x_max - x_start) + padding, (y_max - y_min) + (2 * padding),
+ fill=color, fill_opacity=opacity, stroke="none"
))
- def highlight_branch(self, node, color="red", size=None):
+ def highlight_branch(self, node, color="red", stroke_width=None):
+ """
+ Overlays a thicker or colored line on a specific branch.
+
+ Args:
+ node (ete3.TreeNode): The target node.
+ color (str, optional): CSS color string. Defaults to "red".
+ stroke_width (float, optional): Thickness. Defaults to 2x style default.
+ """
if node.is_root(): return
- s_width = size if size else self.style.branch_size * 2
+ sw = stroke_width if stroke_width is not None else self.style.branch_stroke_width * 2
x, y = node.coordinates
- px, py = node.up.coordinates
- self.d.append(draw.Line(px, y, x, y, stroke=color, stroke_width=s_width, stroke_linecap="round"))
-
- def gradient_branch(self, node, colors=("red", "blue"), size=None):
- if node.is_root():
- return
-
- s_width = size if size else self.style.branch_size
- x, y = node.coordinates
- px, _ = node.up.coordinates
-
- # FIX: Use the node's memory address (id(node)) to ensure a
- # truly unique ID for this specific branch segment.
- grad_id = f"grad_{id(node)}"
-
- grad = draw.LinearGradient(px, y, x, y, id=grad_id)
- grad.add_stop(0, colors[0])
- grad.add_stop(1, colors[1])
-
- self.d.append(grad)
- self.d.append(draw.Line(px, y, x, y, stroke=grad, stroke_width=s_width))
+ px, _ = node.up.coordinates
+ self.drawing.append(draw.Line(px, y, x, y, stroke=color, stroke_width=sw, stroke_linecap="round"))
+ def gradient_branch(self, node, colors=("red", "blue"), stroke_width=None):
+ """
+ Applies a linear color gradient along a branch segment.
- def add_leaf_names(
- self,
- font_size=None,
- color=None,
- font_family=None,
- rotation=0,
- padding=10
- ):
+ Args:
+ node (ete3.TreeNode): The target node.
+ colors (tuple, optional): (start_color, end_color). Defaults to ("red", "blue").
+ stroke_width (float, optional): Thickness. Defaults to style default.
"""
- Adds leaf labels where the center of the text is aligned with the leaf tip.
+ if node.is_root(): return
+ sw = stroke_width if stroke_width is not None else self.style.branch_stroke_width
+ x, y, (px, _) = *node.coordinates, node.up.coordinates
+ gid = generate_id("grad")
+ grad = draw.LinearGradient(px, y, x, y, id=gid)
+ grad.add_stop(0, colors[0])
+ grad.add_stop(1, colors[1])
+ self.drawing.append(grad)
+ self.drawing.append(draw.Line(px, y, x, y, stroke=grad, stroke_width=sw))
+
+ def add_leaf_names(self, font_size=None, color="black", rotation=0, padding=10):
"""
- fs = font_size if font_size is not None else self.style.font_size
- ff = font_family if font_family is not None else self.style.font_family
- text_color = color if color is not None else "black"
+ Adds text labels to the leaf tips.
+ Args:
+ font_size (int, optional): Font size. Defaults to style default.
+ color (str, optional): Text color. Defaults to "black".
+ rotation (float, optional): Rotation angle. Defaults to 0.
+ padding (float, optional): Horizontal padding. Defaults to 10.
+ """
+ fs = font_size or self.style.font_size
for l in self.t.get_leaves():
x, y = l.coordinates
-
- # The anchor point is the leaf tip + padding
- tx = x + padding
- ty = y
-
- # Rotate around the center (tx, ty)
+ tx, ty = x + padding, y
transform = f"rotate({rotation}, {tx}, {ty})" if rotation != 0 else ""
+ self.drawing.append(draw.Text(l.name, fs, tx, ty, fill=color, font_family=self.style.font_family,
+ transform=transform, text_anchor="start", dominant_baseline="middle"))
- self.d.append(draw.Text(
- l.name,
- fs,
- tx,
- ty,
- fill=text_color,
- font_family=ff,
- transform=transform,
- text_anchor="middle", # Centers horizontally
- dominant_baseline="middle" # Centers vertically
- ))
-
- def add_node_names(
- self,
- font_size=None,
- color="gray",
- font_family=None,
- x_offset=-15,
- y_offset=-10,
- rotation=0
- ):
+ def add_node_names(self, font_size=None, color="gray", x_offset=-15, y_offset=-10, rotation=0):
"""
- Adds labels to internal nodes centered on the offset coordinate.
+ Adds text labels to internal nodes.
+
+ Args:
+ font_size (int, optional): Font size. Defaults to style default * 0.8.
+ color (str, optional): Text color. Defaults to "gray".
+ x_offset (float, optional): Horizontal offset. Defaults to -15.
+ y_offset (float, optional): Vertical offset. Defaults to -10.
+ rotation (float, optional): Rotation angle. Defaults to 0.
"""
- fs = font_size if font_size is not None else self.style.font_size * 0.8
- ff = font_family if font_family is not None else self.style.font_family
-
+ fs = font_size or self.style.font_size * 0.8
for n in self.t.traverse():
- if not n.is_leaf():
- if not n.name:
- continue
-
+ if not n.is_leaf() and n.name:
x, y = n.coordinates
-
- # Apply offsets to the center point
- tx = x + x_offset
- ty = y + y_offset
-
+ tx, ty = x + x_offset, y + y_offset
transform = f"rotate({rotation}, {tx}, {ty})" if rotation != 0 else ""
-
- self.d.append(draw.Text(
- n.name,
- fs,
- tx,
- ty,
- fill=color,
- font_family=ff,
- transform=transform,
- text_anchor="middle", # Centers horizontally
- dominant_baseline="middle" # Centers vertically
+ self.drawing.append(draw.Text(
+ n.name, fs, tx, ty, fill=color, font_family=self.style.font_family,
+ transform=transform, text_anchor="middle", dominant_baseline="middle"
))
-
- def add_time_axis(
- self,
- ticks: list[float],
- label: str = "Time",
- tick_size: float = 6.0,
- padding: float = 20.0,
- y_offset: float = 0.0,
- root_stub: float = 20.0,
- stroke: str = "black",
- stroke_width: float = 2.0,
- font_size: int | None = None,
- font_family: str | None = None,
- # --- grid options ---
- grid: bool = False,
- grid_stroke: str = "#cccccc",
- grid_stroke_width: float = 1.0,
- grid_opacity: float = 0.5,
- # --- NEW: custom tick labels ---
- tick_labels: dict[float, str] | None = None,
- ) -> None:
- """
- Draw a horizontal time axis. Optionally draw vertical grid lines at each tick
- across the tree.
-
- tick_labels lets you display labels that differ from the numeric tick positions,
- e.g. show "-2" at tick position 2.0 (backwards time).
- """
- if not ticks:
+ def add_leaf_shapes(self, leaves, shape="circle", fill="blue", r=5.0, stroke=None, stroke_width=1.0, offset=0.0, rotation=0.0, opacity=1.0):
+ """
+ Adds geometric markers next to leaf tips.
+
+ Args:
+ leaves (list): List of node names or objects.
+ shape (str, optional): Shape type. Defaults to "circle".
+ fill (str, optional): Fill color. Defaults to "blue".
+ r (float, optional): Radius/size. Defaults to 5.0.
+ stroke (str, optional): Stroke color. Defaults to None.
+ stroke_width (float, optional): Stroke width. Defaults to 1.0.
+ offset (float, optional): Horizontal offset. Defaults to 0.0.
+ rotation (float, optional): Rotation angle. Defaults to 0.0.
+ opacity (float, optional): Opacity. Defaults to 1.0.
+ """
+ self._pre_flight_check()
+ for item in leaves:
+ try:
+ node = self.t & item if isinstance(item, str) else item
+ x, y = self._leaf_xy(node, offset=float(offset))
+ self._draw_shape_at(x, y, shape, fill, r, stroke, stroke_width, rotation, opacity)
+ except: continue
+
+ def add_node_shapes(self, nodes, shape="circle", fill="red", r=5.0, stroke=None, stroke_width=1.0, rotation=0, dx=0, dy=0):
+ """
+ Adds geometric markers at node positions.
+
+ Args:
+ nodes (list): List of node names/objects or style dicts.
+ shape (str, optional): Default shape. Defaults to "circle".
+ fill (str, optional): Default fill color. Defaults to "red".
+ r (float, optional): Default radius. Defaults to 5.0.
+ stroke (str, optional): Default stroke color. Defaults to None.
+ stroke_width (float, optional): Default stroke width. Defaults to 1.0.
+ rotation (float, optional): Default rotation. Defaults to 0.
+ dx (float, optional): X offset. Defaults to 0.
+ dy (float, optional): Y offset. Defaults to 0.
+ """
+ self._pre_flight_check()
+ if isinstance(nodes, list) and nodes and isinstance(nodes[0], dict):
+ for s in nodes:
+ self.add_node_shapes([s.get("node")], s.get("shape", shape), s.get("fill", fill), s.get("r", r),
+ s.get("stroke", stroke), s.get("stroke_width", stroke_width), s.get("rotation", rotation))
return
+ for item in nodes:
+ try:
+ node = self.t.search_nodes(name=item)[0] if isinstance(item, str) else item
+ x, y = self._node_xy(node)
+ self._draw_shape_at(x + dx, y + dy, shape, fill, r, stroke, stroke_width, rotation)
+ except: continue
+
+ def add_branch_shapes(self, specs, default_where=0.5, offset=0.0, orient=None):
+ """
+ Adds geometric markers along branches (useful for modeling events).
- # Ensure layout exists
- any_node = next(self.t.traverse("preorder"))
- if not hasattr(any_node, "coordinates") or not hasattr(any_node, "y_coord"):
- self._calculate_layout()
-
- sf = float(self.sf)
-
- # Font defaults
- fs = font_size if font_size is not None else self.style.font_size
- ff = font_family if font_family is not None else self.style.font_family
-
- # Tree vertical extent
+ Args:
+ specs (list): List of dicts describing shapes and positions.
+ default_where (float, optional): Position along branch (0-1). Defaults to 0.5.
+ offset (float, optional): Perpendicular offset. Defaults to 0.0.
+ orient (str, optional): "along", "perp", or None.
+ """
+ self._pre_flight_check()
+ if hasattr(specs, "to_dict"): specs = specs.to_dict(orient="records")
+ for s in specs:
+ br = s.get("branch")
+ if not br: continue
+ try:
+ node = self.t & br if isinstance(br, str) else br
+ where = s.get("where", default_where)
+ x, y, ang = self._edge_point(node, where=where)
+ if offset != 0:
+ perp = math.radians(ang + 90)
+ x += offset * math.cos(perp)
+ y += offset * math.sin(perp)
+ rot = s.get("rotation", 0.0)
+ if orient == "along": rot = ang
+ elif orient == "perp": rot = ang + 90
+ r_val = float(s.get("r", s.get("size", 10.0) / 2.0))
+ self._draw_shape_at(x, y, s.get("shape", "circle"), s.get("fill", "blue"), r_val,
+ s.get("stroke"), s.get("stroke_width", 1.0), rot, s.get("opacity", 1.0))
+ except: continue
+
+ def add_time_axis(self, ticks, label="Time", tick_labels=None, tick_size=6.0, padding=20.0, y_offset=0.0, stroke_width=2.0, grid=False):
+ """
+ Adds a linear axis at the bottom to represent time or distance.
+
+ Args:
+ ticks (list[float]): Numerical positions for the ticks.
+ label (str): Label for the axis itself. Defaults to "Time".
+ tick_labels (list[str], optional): Custom strings for the tick values.
+ tick_size (float, optional): Length of the tick marks. Defaults to 6.0.
+ padding (float, optional): Distance from the tree tips. Defaults to 20.0.
+ y_offset (float, optional): Extra manual offset on the Y axis. Defaults to 0.0.
+ stroke_width (float, optional): Thickness of the axis line. Defaults to 2.0.
+ grid (bool, optional): Whether to draw vertical grid lines. Defaults to False.
+ """
+ self._pre_flight_check()
leaves = self.t.get_leaves()
- min_y = min(float(l.y_coord) for l in leaves)
- max_y = max(float(l.y_coord) for l in leaves)
-
- # Place axis below the tree
- y_axis = max_y + float(padding) + float(y_offset)
-
- # Axis X range
- x_left = float(self.root_x) - float(root_stub)
- x_right = float(self.root_x) + (max(float(t) for t in ticks) * sf)
-
- # Baseline
- self.d.append(draw.Line(
- x_left, y_axis,
- x_right, y_axis,
- stroke=stroke,
- stroke_width=stroke_width,
- ))
-
- tick_labels = tick_labels or {}
-
- # Ticks (+ optional vertical grid lines)
- for tt in ticks:
- tt = float(tt)
- x = float(self.root_x) + tt * sf
-
+ min_y, max_y = min(l.y_coord for l in leaves), max(l.y_coord for l in leaves)
+ y_axis = max_y + padding + y_offset
+ x_left, x_right = self.root_x - self.style.root_stub_length, self.root_x + (max(ticks) * self.sf)
+ self.drawing.append(draw.Line(x_left, y_axis, x_right, y_axis, stroke="black", stroke_width=stroke_width))
+ for i, tt in enumerate(ticks):
+ x = self.root_x + tt * self.sf
if grid:
- self.d.append(draw.Line(
- x, min_y,
- x, max_y,
- stroke=grid_stroke,
- stroke_width=grid_stroke_width,
- stroke_opacity=grid_opacity,
- ))
-
- self.d.append(draw.Line(
- x, y_axis,
- x, y_axis + float(tick_size),
- stroke=stroke,
- stroke_width=stroke_width,
- ))
-
- text = tick_labels.get(tt, str(tt))
- self.d.append(draw.Text(
- text,
- fs,
- x,
- y_axis + float(tick_size) + fs,
- center=True,
- font_family=ff,
- ))
-
- # Axis label
- self.d.append(draw.Text(
- label,
- fs,
- (x_left + x_right) / 2.0,
- y_axis + float(tick_size) + fs * 2.2,
- center=True,
- font_family=ff,
- ))
-
-
-
- def plot_transfers(
- self,
- transfers,
- mode="midpoint",
- curve_type="C",
- filter_below=0.0,
- use_gradient=True,
- gradient_colors=("purple", "orange"),
- color="orange",
- use_thickness=True,
- stroke_width=5,
- arc_intensity=40,
- opacity=0.6,
- ):
- """
- Plot horizontal gene transfers as curved lines.
-
- Parameters
- ----------
- transfers
- Either a list of dicts OR a pandas.DataFrame with at least:
- 'from', 'to', 'time' (optional), 'freq' (optional)
- mode
- "midpoint" (default) or "time".
- - midpoint: attach curves to mid-branch positions (old behavior)
- - time: attach curves at the event time along each endpoint branch using
- node.time_from_origin (Zombi parser provides this)
- curve_type
- "C" (default) or "S"
- """
- # Accept either list[dict] or a DataFrame-like object (e.g. pandas.DataFrame)
+ self.drawing.append(draw.Line(x, min_y, x, max_y, stroke="#ccc", stroke_width=1.0, stroke_opacity=0.5))
+ self.drawing.append(draw.Line(x, y_axis, x, y_axis + tick_size, stroke="black", stroke_width=stroke_width))
+
+ # Use custom label if provided, else use the numerical value
+ display_text = str(tick_labels[i]) if tick_labels is not None and i < len(tick_labels) else str(tt)
+ self.drawing.append(draw.Text(display_text, self.style.font_size, x, y_axis + tick_size + self.style.font_size,
+ text_anchor="middle", font_family=self.style.font_family))
+ self.drawing.append(draw.Text(label, self.style.font_size, (x_left + x_right) / 2.0, y_axis + tick_size + self.style.font_size * 2.5,
+ text_anchor="middle", font_family=self.style.font_family))
+
+ def plot_transfers(self, transfers, mode="midpoint", curve_type="C", filter_below=0.0, use_gradient=True,
+ gradient_colors=("purple", "orange"), color="orange", stroke_width=5.0, arc_intensity=40.0, opacity=0.6):
+ """
+ Plots curved arrows representing Horizontal Gene Transfer (HGT) events.
+
+ Args:
+ transfers (list): List of transfer event dicts or DataFrame.
+ mode (str, optional): "midpoint" or "time". Defaults to "midpoint".
+ curve_type (str, optional): "C" for C-shape, "S" for S-shape. Defaults to "C".
+ filter_below (float, optional): Minimum frequency filter. Defaults to 0.0.
+ use_gradient (bool, optional): Use color gradient. Defaults to True.
+ gradient_colors (tuple, optional): Gradient colors. Defaults to ("purple", "orange").
+ color (str, optional): Fallback color. Defaults to "orange".
+ stroke_width (float, optional): Stroke width. Defaults to 5.0.
+ arc_intensity (float, optional): Curve height. Defaults to 40.0.
+ opacity (float, optional): Opacity. Defaults to 0.6.
+ """
if hasattr(transfers, "to_dict") and hasattr(transfers, "columns"):
transfers = transfers.to_dict(orient="records")
-
name2node = {n.name: n for n in self.t.traverse()}
+ self._pre_flight_check()
- # Ensure layout exists
- any_node = next(self.t.traverse("preorder"))
- if not hasattr(any_node, "coordinates") or not hasattr(any_node, "y_coord"):
- self._calculate_layout()
-
- def where_from_time(node, tt: float) -> float:
- parent = node.up
- if parent is None:
- return 0.0
- t0 = float(getattr(parent, "time_from_origin", 0.0))
- t1 = float(getattr(node, "time_from_origin", t0))
- denom = (t1 - t0) if abs(t1 - t0) > 1e-12 else 1.0
- w = (float(tt) - t0) / denom
- if w < 0.0:
- return 0.0
- if w > 1.0:
- return 1.0
- return w
+ def get_where(node, t_ev):
+ p = node.up
+ if not p: return 0.0
+ t0, t1 = float(getattr(p, "time_from_origin", 0.0)), float(getattr(node, "time_from_origin", 0.0))
+ if abs(t1 - t0) > 1e-12:
+ return max(0.0, min(1.0, (float(t_ev) - t0) / (t1 - t0)))
+ return 0.5
for tr in transfers:
freq = float(tr.get("freq", 1.0))
- if freq < filter_below:
- continue
-
- src = name2node.get(tr.get("from"))
- dst = name2node.get(tr.get("to"))
- if src is None or dst is None:
- continue
-
- # Compute endpoints
- if (
- mode == "time"
- and tr.get("time") is not None
- and hasattr(src, "time_from_origin")
- and hasattr(dst, "time_from_origin")
- ):
- tt = float(tr["time"])
- w_src = where_from_time(src, tt)
- w_dst = where_from_time(dst, tt)
- x_start, y_start, _ = self._edge_point(src, w_src)
- x_end, y_end, _ = self._edge_point(dst, w_dst)
+ if freq < filter_below: continue
+ src, dst = name2node.get(tr.get("from")), name2node.get(tr.get("to"))
+ if not src or not dst: continue
+ if mode == "time" and "time" in tr:
+ x_s, y_s, _ = self._edge_point(src, get_where(src, tr["time"]))
+ x_e, y_e, _ = self._edge_point(dst, get_where(dst, tr["time"]))
else:
- # midpoint fallback (old behavior)
- y_start, y_end = float(src.y_coord), float(dst.y_coord)
- src_px = src.up.coordinates[0] if src.up else (src.coordinates[0] - 20)
- dst_px = dst.up.coordinates[0] if dst.up else (dst.coordinates[0] - 20)
- x_start = (float(src_px) + float(src.coordinates[0])) / 2.0
- x_end = (float(dst_px) + float(dst.coordinates[0])) / 2.0
-
- # Styling
- width = (stroke_width * freq) if use_thickness else stroke_width
- path = draw.Path(stroke_width=width, fill="none", stroke_opacity=opacity)
-
+ x_s, y_s, _ = self._edge_point(src, 0.5)
+ x_e, y_e, _ = self._edge_point(dst, 0.5)
+
+ path = draw.Path(stroke_width=stroke_width * freq, fill="none", stroke_opacity=opacity)
if use_gradient:
- grad_id = f"tr_grad_{random.randint(0, 999999)}"
- grad = draw.LinearGradient(x_start, y_start, x_end, y_end, id=grad_id)
+ gid = generate_id("tr_grad")
+ grad = draw.LinearGradient(x_s, y_s, x_e, y_e, id=gid)
grad.add_stop(0, gradient_colors[0])
grad.add_stop(1, gradient_colors[1])
- self.d.append(grad)
+ self.drawing.append(grad)
path.args["stroke"] = grad
else:
path.args["stroke"] = color
-
- # Geometry
- path.M(x_start, y_start)
-
- if curve_type.upper() == "S":
- dx = x_end - x_start
- sgn = 1 if dx >= 0 else -1
- arc = abs(arc_intensity)
- cp1x = x_start + (sgn * arc)
- cp2x = x_end - (sgn * arc)
- path.C(cp1x, y_start, cp2x, y_end, x_end, y_end)
- else:
- # C-curve
- path.C(
- x_start - arc_intensity, y_start,
- x_end - arc_intensity, y_end,
- x_end, y_end
- )
-
- self.d.append(path)
-
-
- def add_transfer_legend(
- self,
- title="Transfer Frequency",
- colors=("purple", "orange"),
- low=0.1,
- high=1.0,
- source_label="Source",
- arrival_label="Arrival",
- show_frequency=False,
- show_direction=True,
- margin=20,
- ):
- """Add a transfer legend.
-
- By default this shows *direction* (two solid colors): a "Source" swatch and an
- "Arrival" swatch. If you also want a frequency scale, set
- ``show_frequency=True``.
-
- Parameters
- ----------
- colors:
- Tuple (source_color, arrival_color). These match the gradient endpoints
- used by ``plot_transfers(..., gradient_colors=...)``.
- show_frequency:
- Draws a gradient bar + numeric low/high labels.
- show_direction:
- Draws two solid color swatches labelled source/arrival.
- """
- if not (show_frequency or show_direction):
- return
-
- font_size = 11
- num_font_size = 9
- sw = 14
- gap = 6
- pad_x = 10
- top_pad = 10
- bottom_pad = 10
- row_h = 18
- bar_h = 12
- bar_w = 110
-
- # Estimate legend width from label lengths (drawsvg doesn't expose text metrics).
- max_label_len = 0
- if show_direction:
- max_label_len = max(len(str(source_label)), len(str(arrival_label)))
- if show_frequency:
- max_label_len = max(max_label_len, len(str(title)))
- est_text_w = max_label_len * font_size * 0.60
- w = int(pad_x + sw + gap + est_text_w + pad_x)
- if show_frequency:
- w = max(w, pad_x + bar_w + pad_x)
-
- # Height: SVG y-axis increases downward.
- content_h = top_pad
- if show_frequency:
- # title + bar + low/high labels + spacing
- content_h += (font_size + 4) + bar_h + (num_font_size + 10) + 6
- if show_direction:
- content_h += (2 * row_h)
- content_h += bottom_pad
- box_h = content_h
-
- x = -self.style.width / 2 + 30
- y = self.style.height / 2 - margin - box_h
-
- self.d.append(draw.Rectangle(x, y, w, box_h, fill="white", stroke="black", stroke_width=1, opacity=0.9))
-
- cursor_y = y + top_pad + 2
-
- if show_frequency:
- self.d.append(draw.Text(title, font_size, x + 10, cursor_y, font_family="sans-serif", font_weight="bold"))
- cursor_y += 10
-
- grad_id = f"legend_transfer_grad_{random.randint(0, 999999)}"
- grad = draw.LinearGradient(x + 10, cursor_y + bar_h / 2, x + 10 + bar_w, cursor_y + bar_h / 2, id=grad_id)
- grad.add_stop(0, colors[0])
- grad.add_stop(1, colors[1])
- self.d.append(grad)
- self.d.append(draw.Rectangle(x + 10, cursor_y, bar_w, bar_h, fill=grad))
-
- self.d.append(draw.Text(f"{low}", num_font_size, x + 10, cursor_y + bar_h + 12, font_family="sans-serif"))
- self.d.append(draw.Text(f"{high}", num_font_size, x + 10 + bar_w - 15, cursor_y + bar_h + 12, font_family="sans-serif"))
- cursor_y += bar_h + 24
-
- if show_direction:
- sw = 14
- self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[0]))
- self.d.append(draw.Text(source_label, font_size, x + 30, cursor_y + 11, font_family="sans-serif"))
- cursor_y += row_h
-
- self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[1]))
- self.d.append(draw.Text(arrival_label, 11, x + 30, cursor_y + 11, font_family="sans-serif"))
-
-
- def add_heatmap(
- self,
- values,
- width: float = 20.0,
- offset: float = 10.0,
- vmin: float | None = None,
- vmax: float | None = None,
- low_color: str = "#f7fbff",
- high_color: str = "#08306b",
- missing_color: str = "white",
- border_color: str = "none",
- border_width: float = 0.5
- ):
- """Add a vertical heatmap strip next to the tree tips."""
-
- # 1. Normalize values to dict
- if hasattr(values, "to_dict") and not isinstance(values, dict):
- values = values.to_dict()
- if hasattr(values, "columns") and hasattr(values, "to_dict") and not isinstance(values, dict):
- cols = list(values.columns)
- if len(cols) >= 2:
- values = dict(zip(values[cols[0]].astype(str), values[cols[1]].astype(float)))
- else:
- values = {}
-
- # 2. Helper: Parse Color (Hex or Name) to RGB Tuple
- def _to_rgb(color_str):
- color_str = str(color_str).strip()
-
- # Handle Hex (e.g., #FF0000 or #F00)
- if color_str.startswith("#"):
- h = color_str.lstrip("#")
- if len(h) == 3: # Expand shorthand #fff -> #ffffff
- h = "".join([c*2 for c in h])
- return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
-
- # Handle Basic Names
- # (Simple lookup to avoid dependencies like matplotlib)
- common_names = {
- "white": (255, 255, 255), "black": (0, 0, 0),
- "red": (255, 0, 0), "green": (0, 128, 0),
- "blue": (0, 0, 255), "orange": (255, 165, 0),
- "purple": (128, 0, 128), "yellow": (255, 255, 0),
- "gray": (128, 128, 128), "grey": (128, 128, 128),
- "cyan": (0, 255, 255), "magenta": (255, 0, 255),
- "lime": (0, 255, 0), "darkgreen": (0, 100, 0),
- "navy": (0, 0, 128), "teal": (0, 128, 128)
- }
- if color_str.lower() in common_names:
- return common_names[color_str.lower()]
-
- # Fallback if unknown name
- raise ValueError(f"Heatmap interpolation needs Hex codes (e.g. '#ff0000'). Unknown name: '{color_str}'")
-
- def _rgb_to_hex(rgb):
- return "#{:02x}{:02x}{:02x}".format(*rgb)
-
- # 3. Prepare data
- vals = [float(v) for v in values.values() if isinstance(v, (int, float))]
- if not vals: return
-
- vmin = vmin if vmin is not None else min(vals)
- vmax = vmax if vmax is not None else max(vals)
- if vmax == vmin: vmax = vmin + 1e-12
- # Convert start/end colors to RGB for math
- c0 = _to_rgb(low_color)
- c1 = _to_rgb(high_color)
-
- # 4. Draw Rectangles
- # Calculate X position
- max_x = max(l.coordinates[0] for l in self.t.get_leaves())
- start_x = max_x + offset
-
- leaves = self.t.get_leaves()
- # Calculate height of each block (distance between leaves)
- if len(leaves) > 1:
- y_coords = sorted([l.y_coord for l in leaves])
- # Estimate step size from the smallest difference (handles irregular trees better)
- diffs = [y_coords[i+1] - y_coords[i] for i in range(len(y_coords)-1)]
- y_step = min(diffs) if diffs else 20
- # If step is tiny (overlap), use average
- if y_step < 1:
- y_step = abs(leaves[1].y_coord - leaves[0].y_coord)
+ path.M(x_s, y_s)
+ if curve_type.upper() == "S":
+ sgn = 1 if (x_e - x_s) >= 0 else -1
+ path.C(x_s + (sgn * arc_intensity), y_s, x_e - (sgn * arc_intensity), y_e, x_e, y_e)
else:
- y_step = 20
-
- for l in leaves:
- name = str(l.name)
- fill = missing_color
-
- if name in values:
- try:
- val = float(values[name])
- # Interpolate
- t = (val - vmin) / (vmax - vmin)
- t = max(0.0, min(1.0, t))
-
- r = int(c0[0] + (c1[0]-c0[0]) * t)
- g = int(c0[1] + (c1[1]-c0[1]) * t)
- b = int(c0[2] + (c1[2]-c0[2]) * t)
- fill = _rgb_to_hex((r,g,b))
- except Exception:
- pass # Keep missing_color
-
- # Center rectangle on leaf Y
- self.d.append(draw.Rectangle(
- start_x,
- l.y_coord - y_step/2,
- width,
- y_step,
- fill=fill,
- stroke=border_color,
- stroke_width=border_width
- ))
-
+ path.C(x_s - arc_intensity, y_s, x_e - arc_intensity, y_e, x_e, y_e)
+ self.drawing.append(path)
- def add_leaf_images(self, image_dir, extension=".png", width=40, height=40, offset=10, rotation=0):
+ def add_heatmap(self, values, width=15.0, offset=10.0, low_color="#f7fbff", high_color="#08306b", border_color="none", border_width=0.5):
"""
- Adds PNG images to the right of each leaf node with a rotation option.
- :param rotation: Degrees to rotate the image (clockwise).
+ Adds a column heatmap strip next to the leaf names.
+
+ Args:
+ values (dict): Mapping of {node_name: numeric_value}.
+ width (float, optional): Width of each heatmap cell. Defaults to 15.0.
+ offset (float, optional): Horizontal offset from tree. Defaults to 10.0.
+ low_color (str, optional): Color for min value. Defaults to "#f7fbff".
+ high_color (str, optional): Color for max value. Defaults to "#08306b".
+ border_color (str, optional): Cell border color. Defaults to "none".
+ border_width (float, optional): Cell border width. Defaults to 0.5.
"""
- import os
- import base64
-
- # 1. Coordinate check: Ensure we have the latest positions
- if not hasattr(self.t, 'coordinates'):
- self._calculate_layout()
+ if hasattr(values, "to_dict"): values = values.to_dict()
+ vals = [float(v) for v in values.values() if isinstance(v, (int, float))]
+ if not vals: return
+ vmin, vmax = min(vals), max(vals) + 1e-12
+ max_x = max(l.coordinates[0] for l in self.t.get_leaves())
+ y_step = (self.style.height - (self.style.margin * 2)) / max(len(self.t.get_leaves())-1, 1)
+ for l in self.t.get_leaves():
+ val = values.get(l.name)
+ fill = lerp_color(low_color, high_color, (float(val) - vmin) / (vmax - vmin)) if val is not None else "white"
+ self.drawing.append(draw.Rectangle(max_x + offset, l.y_coord - y_step/2, width, y_step,
+ fill=fill, stroke=border_color, stroke_width=border_width))
- for leaf in self.t.get_leaves():
- lx, ly = leaf.coordinates
-
- # 2. Coordinate Math
- img_x = lx + offset
- img_y = ly - (height / 2)
+ def add_clade_labels(self, labels, offset=40.0, stroke_width=1.5, color="black", font_size=None):
+ """
+ Adds square brackets '[' to group lineages with text labels.
+
+ Args:
+ labels (dict): Mapping {node: label_text}.
+ offset (float, optional): Horizontal offset. Defaults to 40.0.
+ stroke_width (float, optional): Bracket thickness. Defaults to 1.5.
+ color (str, optional): Color. Defaults to "black".
+ font_size (int, optional): Font size. Defaults to style default.
+ """
+ self._pre_flight_check()
+ fs = font_size or self.style.font_size
+ max_x = max(l.coordinates[0] for l in self.t.get_leaves())
+ bracket_x = max_x + offset
+ for target, text in labels.items():
+ try:
+ node = self.t.search_nodes(name=target)[0] if isinstance(target, str) else target
+ leaves = node.get_leaves()
+ y_min = min(l.y_coord for l in leaves)
+ y_max = max(l.y_coord for l in leaves)
+ p = draw.Path(stroke=color, stroke_width=stroke_width, fill="none")
+ p.M(bracket_x - 5, y_min).L(bracket_x, y_min).L(bracket_x, y_max).L(bracket_x - 5, y_max)
+ self.drawing.append(p)
+ self.drawing.append(draw.Text(text, fs, bracket_x + 8, (y_min + y_max) / 2,
+ fill=color, font_family=self.style.font_family,
+ text_anchor="start", dominant_baseline="middle"))
+ except: continue
+
+ def plot_continuous_variable(self, node_to_rgb, stroke_width=None):
+ """
+ Colors branches based on RGB mappings of a continuous trait.
- img_path = os.path.join(image_dir, f"{leaf.name}{extension}")
-
- if os.path.exists(img_path):
- # 3. Manual Embedding
- with open(img_path, "rb") as img_file:
- encoded_string = base64.b64encode(img_file.read()).decode('utf-8')
- data_uri = f"data:image/png;base64,{encoded_string}"
-
- # 4. Rotation Logic
- # We rotate around the center of the image (img_x + width/2, img_y + height/2)
- transform_str = ""
- if rotation != 0:
- center_x = img_x + (width / 2)
- center_y = img_y + (height / 2)
- transform_str = f"rotate({rotation}, {center_x}, {center_y})"
-
- # Create the image object
- img_obj = draw.Image(
- img_x, img_y,
- width, height,
- path=data_uri,
- transform=transform_str # Apply the rotation transform
- )
-
- self.d.append(img_obj)
- else:
- print(f"MISSING IMAGE: No file at {img_path}")
-
-
- def add_ancestral_images(self, image_dir, extension=".png", width=40, height=40, rotation=0, omit=None):
- """
- Adds PNG images centered on ancestral (internal) nodes.
- :param rotation: Degrees to rotate the image (clockwise).
- """
- import os
- import base64
-
- # 1. Coordinate check: Ensure we have the latest positions
- if not hasattr(self.t, 'coordinates'):
- self._calculate_layout()
-
- # Traverse all nodes but filter for internal ones
- for node in self.t.traverse():
- if not node.is_leaf():
-
- if omit and node.name in omit:
- continue # Skip omitted names
-
- nx, ny = node.coordinates #
-
- # 2. Centering Math
- # To make the center of the image coincide with the node,
- # we subtract half the width/height from the top-left anchor.
- img_x = nx - (width / 2)
- img_y = ny - (height / 2)
-
- img_path = os.path.join(image_dir, f"{node.name}{extension}")
-
- if os.path.exists(img_path):
- # 3. Manual Embedding
- with open(img_path, "rb") as img_file:
- encoded_string = base64.b64encode(img_file.read()).decode('utf-8')
- data_uri = f"data:image/png;base64,{encoded_string}"
-
- # 4. Rotation Logic (Centered)
- transform_str = ""
- if rotation != 0:
- transform_str = f"rotate({rotation}, {nx}, {ny})"
-
- # Create the image object centered on (nx, ny)
- img_obj = draw.Image(
- img_x, img_y,
- width, height,
- path=data_uri,
- transform=transform_str
- )
-
- self.d.append(img_obj)
- else:
- # Optional: Print warning for missing ancestral names
- # print(f"MISSING ANCESTRAL IMAGE: {img_path}")
- pass
-
- def plot_continuous_variable(self, node_to_rgb, size=None):
- """
- Colors branches using a gradient based on RGB values at each node.
- Also recolors vertical connectors so they are not left black.
- node_to_rgb can map node.name OR node objects to (r,g,b).
- """
- def _to_hex(rgb):
- return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2]))
-
- s_width = size if size is not None else self.style.branch_size
-
- # 1) Gradient on horizontal segments (parent -> child)
+ Args:
+ node_to_rgb (dict): Mapping {node: (r,g,b)}.
+ stroke_width (float, optional): Branch thickness. Defaults to style default.
+ """
+ def _to_hex(rgb): return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2]))
+ sw = stroke_width or self.style.branch_stroke_width
for node in self.t.traverse():
- if node.is_root():
- continue
-
- parent = node.up
- c_child = node_to_rgb.get(node) or node_to_rgb.get(node.name)
- c_parent = node_to_rgb.get(parent) or node_to_rgb.get(parent.name)
-
- if (c_child is not None) and (c_parent is not None):
- self.gradient_branch(
- node,
- colors=(_to_hex(c_parent), _to_hex(c_child)),
- size=s_width,
- )
+ if node.is_root(): continue
+ c_c = node_to_rgb.get(node) or node_to_rgb.get(node.name)
+ c_p = node_to_rgb.get(node.up) or node_to_rgb.get(node.up.name)
+ if c_c and c_p:
+ self.gradient_branch(node, stroke_width=sw, colors=(_to_hex(c_p), _to_hex(c_c)))
else:
- self.highlight_branch(node, color=self.style.branch_color, size=s_width)
-
- # 2) Re-draw vertical connectors using the internal node's color
- # This covers the "elbows" that are currently appearing as black lines
+ self.highlight_branch(node, stroke_width=sw)
for n in self.t.traverse("postorder"):
- if n.is_leaf():
- continue
-
- # Get the color for the internal node
- col = node_to_rgb.get(n) or node_to_rgb.get(n.name)
- if col is None:
- continue
-
- x, y = n.coordinates
- # Vertical span of all children
- y_min = min(c.y_coord for c in n.children)
- y_max = max(c.y_coord for c in n.children)
-
- self.d.append(draw.Line(
- x, y_min, x, y_max,
- stroke=_to_hex(col),
- stroke_width=s_width,
- stroke_linecap="round"
- ))
-
- # Optional: Overwrite the internal-node dot with the same color
- self.d.append(draw.Circle(x, y, self.style.node_size, fill=_to_hex(col)))
-
- def plot_categorical_trait(
- self,
- data,
- value_col,
- node_col="Node",
- palette=None,
- size=None,
- default_color="black"
- ):
- """
- Colors branches based on categorical data using GRADIENTS for transitions.
-
- :param data: pandas DataFrame containing node names and values.
- :param value_col: The name of the column with the category (e.g., "X").
- :param node_col: The name of the column with node names (default "Node").
- :param palette: Dictionary mapping values -> hex colors.
- Example: {0: "red", 1: "blue"}
- :param size: Thickness of the branches.
- """
- # 1. Parse Data into a dictionary: { "NodeName": value }
- if hasattr(data, "to_dict"):
- # Map Node Name -> Category Value
+ if not n.is_leaf():
+ col = node_to_rgb.get(n) or node_to_rgb.get(n.name)
+ if col:
+ x, y = n.coordinates
+ y_min = min(c.y_coord for c in n.children)
+ y_max = max(c.y_coord for c in n.children)
+ self.drawing.append(draw.Line(x, y_min, x, y_max, stroke=_to_hex(col), stroke_width=sw, stroke_linecap="round"))
+ self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=_to_hex(col)))
+
+ def plot_categorical_trait(self, data, value_col, node_col="Node", palette=None, stroke_width=None, default_color="black"):
+ """
+ Colors branches and nodes based on categorical lineage traits.
+
+ Args:
+ data (DataFrame or dict): Trait data.
+ value_col (str): Column for values.
+ node_col (str, optional): Column for node names. Defaults to "Node".
+ palette (dict, optional): Color map. Defaults to auto.
+ stroke_width (float, optional): Branch thickness. Defaults to style default.
+ default_color (str, optional): Fallback color. Defaults to "black".
+ """
+ if hasattr(data, "to_dict"):
mapping = dict(zip(data[node_col].astype(str), data[value_col]))
- else:
- mapping = data
-
- # 2. Define Default Palette if none provided
+ else: mapping = data
if palette is None:
- defaults = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33"]
unique_vals = sorted(list(set(mapping.values())))
+ defaults = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33"]
palette = {val: defaults[i % len(defaults)] for i, val in enumerate(unique_vals)}
-
- s_width = size if size is not None else self.style.branch_size
-
- # Helper to safely get color hex string
- def get_color(n):
- val = mapping.get(n.name)
- return palette.get(val, default_color)
-
- # 3. Draw Branches
+ sw = stroke_width or self.style.branch_stroke_width
+ def get_color(n): return palette.get(mapping.get(n.name), default_color)
for node in self.t.traverse():
- if node.is_root():
- continue
-
- c_node = get_color(node)
- c_parent = get_color(node.up)
-
- # --- HORIZONTAL BRANCH ---
- if c_node != c_parent:
- # Different colors? Use Gradient!
- # We reuse your existing gradient_branch function
- self.gradient_branch(
- node,
- colors=(c_parent, c_node),
- size=s_width
- )
+ if node.is_root(): continue
+ c_n, c_p = get_color(node), get_color(node.up)
+ if c_n != c_p:
+ self.gradient_branch(node, colors=(c_p, c_n), stroke_width=sw)
else:
- # Same color? Solid line.
x, y = node.coordinates
- px, _ = node.up.coordinates # parent X, child Y (rectangular elbow)
- self.d.append(draw.Line(
- px, y, x, y,
- stroke=c_node,
- stroke_width=s_width,
- stroke_linecap="round"
- ))
-
- # 4. Draw Vertical Connectors (Elbows) & Nodes
- # These take the color of the internal node itself
+ px, _ = node.up.coordinates
+ self.drawing.append(draw.Line(px, y, x, y, stroke=c_n, stroke_width=sw, stroke_linecap="round"))
for n in self.t.traverse("postorder"):
- if n.is_leaf():
- continue
+ if not n.is_leaf():
+ color = get_color(n)
+ x, y = n.coordinates
+ y_min = min(c.y_coord for c in n.children)
+ y_max = max(c.y_coord for c in n.children)
+ self.drawing.append(draw.Line(x, y_min, x, y_max, stroke=color, stroke_width=sw, stroke_linecap="round"))
+ self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=color))
- color = get_color(n)
-
- x, y = n.coordinates
- y_min = min(c.y_coord for c in n.children)
- y_max = max(c.y_coord for c in n.children)
-
- # Vertical Line
- self.d.append(draw.Line(
- x, y_min, x, y_max,
- stroke=color,
- stroke_width=s_width,
- stroke_linecap="round"
- ))
-
- # Node Circle (covers the join)
- self.d.append(draw.Circle(x, y, self.style.node_size, fill=color))
\ No newline at end of file
+ def add_categorical_legend(self, palette, title="Legend", x=None, y=None, font_size=14, r=6):
+ """
+ Adds a categorical legend (colored circles with labels).
+
+ Args:
+ palette (dict): Map of {label: color}.
+ title (str, optional): Legend title. Defaults to "Legend".
+ x (float, optional): X position. Defaults to auto.
+ y (float, optional): Y position. Defaults to auto.
+ font_size (int, optional): Font size. Defaults to 14.
+ r (float, optional): Marker radius. Defaults to 6.
+ """
+ if x is None: x = -self.style.width / 2 + 30
+ if y is None: y = -self.style.height / 2 + 30
+ self.drawing.append(draw.Text(title, font_size + 2, x, y, font_weight="bold",
+ font_family=self.style.font_family, text_anchor="start"))
+ curr_y = y + font_size * 1.5
+ for label, color in palette.items():
+ self.drawing.append(draw.Circle(x + r, curr_y, r, fill=color))
+ self.drawing.append(draw.Text(str(label), font_size, x + r*2.5, curr_y,
+ font_family=self.style.font_family, text_anchor="start", dominant_baseline="middle"))
+ curr_y += font_size * 1.4
+
+ def add_transfer_legend(self, colors=("purple", "orange"), labels=("Departure", "Arrival"), x=None, y=None, font_size=14):
+ """
+ Adds a legend specifically for Horizontal Gene Transfers.
+
+ Args:
+ colors (tuple, optional): (start_color, end_color). Defaults to ("purple", "orange").
+ labels (tuple, optional): (start_label, end_label). Defaults to ("Departure", "Arrival").
+ x (float, optional): X position.
+ y (float, optional): Y position.
+ font_size (int, optional): Font size. Defaults to 14.
+ """
+ palette = {labels[0]: colors[0], labels[1]: colors[1]}
+ self.add_categorical_legend(palette, title="Transfer Event", x=x, y=y, font_size=font_size)
+
+ def add_color_bar(self, low_color, high_color, vmin, vmax, title="", x=None, y=None, width=100, height=15, font_size=12):
+ """
+ Adds a continuous color bar gradient legend.
+
+ Args:
+ low_color (str): Color for min value.
+ high_color (str): Color for max value.
+ vmin (float): Min value.
+ vmax (float): Max value.
+ title (str, optional): Title.
+ x (float, optional): X position.
+ y (float, optional): Y position.
+ width (float, optional): Bar width. Defaults to 100.
+ height (float, optional): Bar height. Defaults to 15.
+ font_size (int, optional): Font size. Defaults to 12.
+ """
+ if x is None: x = -self.style.width / 2 + 30
+ if y is None: y = self.style.height / 2 - 60
+ gid = generate_id("cb_grad")
+ grad = draw.LinearGradient(x, y, x + width, y, id=gid)
+ grad.add_stop(0, low_color); grad.add_stop(1, high_color)
+ self.drawing.append(grad)
+ if title:
+ self.drawing.append(draw.Text(title, font_size, x, y - 10, font_weight="bold", text_anchor="start"))
+ self.drawing.append(draw.Rectangle(x, y, width, height, fill=grad, stroke="black", stroke_width=0.5))
+ self.drawing.append(draw.Text(f"{vmin:.2g}", font_size - 2, x, y + height + 12, text_anchor="start"))
+ self.drawing.append(draw.Text(f"{vmax:.2g}", font_size - 2, x + width, y + height + 12, text_anchor="end"))
+
+ def add_leaf_images(self, image_dir, extension=".png", width=40, height=40, offset=10):
+ """
+ Places images next to leaf tips.
+
+ Args:
+ image_dir (str): Directory containing image files.
+ extension (str, optional): File extension. Defaults to ".png".
+ width (float, optional): Image width. Defaults to 40.
+ height (float, optional): Image height. Defaults to 40.
+ offset (float, optional): Horizontal offset. Defaults to 10.
+ """
+ for leaf in self.t.get_leaves():
+ lx, ly = self._leaf_xy(leaf, offset=offset)
+ path = os.path.join(image_dir, f"{leaf.name}{extension}")
+ if os.path.exists(path):
+ with open(path, "rb") as f:
+ uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}"
+ self.drawing.append(draw.Image(lx - width/2, ly - height/2, width, height, path=uri))
+
+ def add_ancestral_images(self, image_dir, extension=".png", width=40, height=40, omit=None):
+ """
+ Places images at internal node positions.
+
+ Args:
+ image_dir (str): Directory containing image files.
+ extension (str, optional): File extension. Defaults to ".png".
+ width (float, optional): Image width. Defaults to 40.
+ height (float, optional): Image height. Defaults to 40.
+ omit (list, optional): List of node names to skip. Defaults to None.
+ """
+ for node in self.t.traverse():
+ if not node.is_leaf():
+ if omit and node.name in omit: continue
+ nx, ny = self._node_xy(node)
+ path = os.path.join(image_dir, f"{node.name}{extension}")
+ if os.path.exists(path):
+ with open(path, "rb") as f:
+ uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}"
+ self.drawing.append(draw.Image(nx - width/2, ny - height/2, width, height, path=uri))
+
+ def add_title(self, text, font_size=24, position="top", pad=40.0, color="black", weight="bold"):
+ """
+ Adds a title text to the drawing.
+
+ Args:
+ text (str): Title text.
+ font_size (int, optional): Font size. Defaults to 24.
+ position (str, optional): "top", "bottom", "left", "right". Defaults to "top".
+ pad (float, optional): Padding. Defaults to 40.0.
+ color (str, optional): Color. Defaults to "black".
+ weight (str, optional): Font weight. Defaults to "bold".
+ """
+ w, h = self.style.width, self.style.height
+ tx, ty = 0, 0
+ if position == "top": ty = -h/2 + pad
+ elif position == "bottom": ty = h/2 - pad
+ elif position == "left": tx = -w/2 + pad
+ elif position == "right": tx = w/2 - pad
+ self.drawing.append(draw.Text(
+ text, font_size, tx, ty, fill=color, font_weight=weight,
+ font_family=self.style.font_family, text_anchor="middle", dominant_baseline="middle"
+ ))
+
+ def add_scale_bar(self, length, label=None, x=None, y=None, stroke="black", stroke_width=2.0):
+ """
+ Adds a physical scale bar representing a distance value.
+
+ Args:
+ length (float): The distance value the bar represents.
+ label (str, optional): Text label. Defaults to str(length).
+ x (float, optional): X position. Defaults to auto.
+ y (float, optional): Y position. Defaults to auto.
+ stroke (str, optional): Color. Defaults to "black".
+ stroke_width (float, optional): Thickness. Defaults to 2.0.
+ """
+ self._pre_flight_check()
+ px = float(length) * self.sf
+ label = label or str(length)
+ x = x if x is not None else -self.style.width/2 + 20
+ y = y if y is not None else self.style.height/2 - 20
+ self.drawing.append(draw.Line(x, y, x + px, y, stroke=stroke, stroke_width=stroke_width))
+ self.drawing.append(draw.Text(label, self.style.font_size, x + px/2, y - 8, center=True))
diff --git a/src/phylustrator/utils.py b/src/phylustrator/utils.py
index c596ce6..bae88b7 100644
--- a/src/phylustrator/utils.py
+++ b/src/phylustrator/utils.py
@@ -1,24 +1,53 @@
from __future__ import annotations
-
import ete3
+import math
+import random
+import string
+def generate_id(prefix: str = "id", length: int = 6) -> str:
+ """Generates a unique ID for SVG elements like gradients."""
+ suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=length))
+ return f"{prefix}_{suffix}"
-def add_origin_if_root_has_dist(tree: ete3.Tree, origin_name: str = "Origin") -> ete3.Tree:
- """If `tree.dist` is non-zero, interpret it as a stem and add an explicit origin node.
+def to_rgb(color_str: str) -> tuple[int, int, int]:
+ """Parses hex, common names, or RGB tuples into a standard RGB tuple."""
+ color_str = str(color_str).strip().lower()
+ if color_str.startswith("#"):
+ h = color_str.lstrip("#")
+ if len(h) == 3: h = "".join([c*2 for c in h])
+ return tuple(int(h[i:i+2], 16) for i in (0, 2, 4))
+ common_names = {
+ "white": (255, 255, 255), "black": (0, 0, 0), "red": (255, 0, 0),
+ "green": (0, 128, 0), "blue": (0, 0, 255), "orange": (255, 165, 0),
+ "purple": (128, 0, 128), "yellow": (255, 255, 0), "gray": (128, 128, 128)
+ }
+ return common_names.get(color_str, (0, 0, 0))
+
+def to_hex(rgb: tuple[int, int, int]) -> str:
+ """Converts an RGB tuple to a hex string."""
+ return "#{:02x}{:02x}{:02x}".format(*[int(max(0, min(255, x))) for x in rgb])
+
+def lerp_color(low_hex: str, high_hex: str, t: float) -> str:
+ """Interpolates between two colors."""
+ t = max(0.0, min(1.0, t))
+ c1 = to_rgb(low_hex)
+ c2 = to_rgb(high_hex)
+ return to_hex(tuple(c1[i] + (c2[i] - c1[i]) * t for i in range(3)))
- This avoids layout shifts when a rooted Newick encodes a stem length as the root's `dist`.
+def polar_to_cartesian(degree: float, radius: float, rotation: float = 0) -> tuple[float, float]:
+ """Converts polar coordinates (degree, radius) to (x, y)."""
+ theta = math.radians(degree + rotation)
+ return radius * math.cos(theta), radius * math.sin(theta)
- Returns the (possibly new) root tree.
- """
- stem = float(getattr(tree, "dist", 0.0) or 0.0)
+def add_origin_if_root_has_dist(tree: ete3.Tree, origin_name: str = "Origin") -> ete3.Tree:
+ """Standardizes trees by adding an explicit origin node if the root has a distance."""
+ stem = float(tree.dist or 0.0)
if stem <= 0.0:
tree.dist = 0.0
return tree
-
origin = ete3.Tree()
origin.name = origin_name
origin.dist = 0.0
-
tree.dist = stem
origin.add_child(tree)
return origin
diff --git a/tests/tests_drawings.py b/tests/tests_drawings.py
new file mode 100644
index 0000000..0938772
--- /dev/null
+++ b/tests/tests_drawings.py
@@ -0,0 +1,116 @@
+import pytest
+import ete3
+from phylustrator.drawing import VerticalTreeDrawer, RadialTreeDrawer, TreeStyle
+
+@pytest.fixture
+def simple_tree():
+ # Simple tree: (A:1, B:1);
+ return ete3.Tree("(A:1, B:1);")
+
+@pytest.fixture
+def transfer_data():
+ return [{"from": "A", "to": "B", "freq": 1.0}]
+
+@pytest.fixture
+def trait_data():
+ return {"A": 1.0, "B": 2.0}
+
+def test_vertical_layout_init(simple_tree):
+ style = TreeStyle(width=500, height=500)
+ drawer = VerticalTreeDrawer(simple_tree, style=style)
+
+ # Check if layout was calculated
+ for node in simple_tree.traverse():
+ assert hasattr(node, "coordinates")
+ assert len(node.coordinates) == 2
+
+def test_radial_layout_init(simple_tree):
+ style = TreeStyle(radius=200)
+ drawer = RadialTreeDrawer(simple_tree, style=style)
+
+ # Check if layout was calculated
+ for node in simple_tree.traverse():
+ assert hasattr(node, "rad")
+ assert hasattr(node, "angle")
+
+ # Check bounds
+ root = simple_tree.get_tree_root()
+ assert root.rad == 0
+ leaf_a = simple_tree.search_nodes(name="A")[0]
+ assert leaf_a.rad == 200
+
+def test_pre_flight_check(simple_tree):
+ drawer = VerticalTreeDrawer(simple_tree)
+ # Reset flag manually
+ drawer._layout_calculated = False
+ # Calling a method that triggers check
+ drawer.add_title("Test")
+ assert drawer._layout_calculated is True
+
+@pytest.mark.parametrize("drawer_class", [VerticalTreeDrawer, RadialTreeDrawer])
+def test_method_existence(drawer_class, simple_tree):
+ """
+ Comprehensive existence check for all public API methods to prevent
+ regressions during refactoring.
+ """
+ drawer = drawer_class(simple_tree)
+
+ required_methods = [
+ "draw",
+ "highlight_clade",
+ "highlight_branch",
+ "gradient_branch",
+ "add_leaf_names",
+ "add_node_names",
+ "add_leaf_shapes",
+ "add_node_shapes",
+ "add_branch_shapes",
+ "plot_transfers",
+ "add_time_axis",
+ "add_heatmap",
+ "add_clade_labels",
+ "plot_continuous_variable",
+ "plot_categorical_trait",
+ "add_categorical_legend",
+ "add_transfer_legend",
+ "add_color_bar",
+ "add_leaf_images",
+ "add_ancestral_images",
+ "add_title",
+ "add_scale_bar",
+ "save_svg",
+ "save_png"
+ ]
+
+ for method in required_methods:
+ assert hasattr(drawer, method), f"{drawer_class.__name__} is missing required method: {method}"
+
+@pytest.mark.parametrize("drawer_class", [VerticalTreeDrawer, RadialTreeDrawer])
+def test_smoke_execution(drawer_class, simple_tree, transfer_data, trait_data):
+ """
+ Executes core methods with dummy data to ensure no internal crashes/SyntaxErrors.
+ """
+ drawer = drawer_class(simple_tree)
+
+ # Core Drawing
+ drawer.draw()
+
+ # Overlays
+ drawer.highlight_clade(simple_tree, color="red")
+ drawer.add_leaf_names()
+ drawer.add_leaf_shapes(["A"], r=5)
+ drawer.add_branch_shapes([{"branch": "A", "where": 0.5, "shape": "circle"}])
+ drawer.plot_transfers(transfer_data)
+
+ # Labels & Legends
+ drawer.add_clade_labels({"A": "Label"})
+ drawer.add_categorical_legend({"Trait": "blue"})
+ drawer.add_color_bar("white", "blue", 0, 1)
+
+ # Traits
+ drawer.plot_categorical_trait(trait_data, value_col="trait")
+
+ # Title
+ drawer.add_title("Smoke Test")
+
+ assert drawer.drawing is not None
diff --git a/tests/tests_utils.py b/tests/tests_utils.py
new file mode 100644
index 0000000..11d7ad6
--- /dev/null
+++ b/tests/tests_utils.py
@@ -0,0 +1,33 @@
+import pytest
+from phylustrator.utils import to_rgb, to_hex, lerp_color, polar_to_cartesian
+import math
+
+def test_to_rgb():
+ assert to_rgb("#ffffff") == (255, 255, 255)
+ assert to_rgb("#000") == (0, 0, 0)
+ assert to_rgb("red") == (255, 0, 0)
+ assert to_rgb("invalid") == (0, 0, 0)
+
+def test_to_hex():
+ assert to_hex((255, 255, 255)) == "#ffffff"
+ assert to_hex((0, 0, 0)) == "#000000"
+ # Test clipping
+ assert to_hex((300, -10, 100)) == "#ff0064"
+
+def test_lerp_color():
+ # Midpoint between black and white
+ assert lerp_color("#000000", "#ffffff", 0.5) == "#7f7f7f"
+ # Bound checks
+ assert lerp_color("#000000", "#ffffff", -1) == "#000000"
+ assert lerp_color("#000000", "#ffffff", 2) == "#ffffff"
+
+def test_polar_to_cartesian():
+ # 0 degrees, radius 100 should be (100, 0)
+ x, y = polar_to_cartesian(0, 100)
+ assert pytest.approx(x) == 100
+ assert pytest.approx(y) == 0
+
+ # 90 degrees, radius 100 should be (0, 100)
+ x, y = polar_to_cartesian(90, 100)
+ assert pytest.approx(x) == 0
+ assert pytest.approx(y) == 100