diff --git a/examples/figures/radial_tree.png b/examples/figures/radial_tree.png index 5063932..de62964 100644 Binary files a/examples/figures/radial_tree.png and b/examples/figures/radial_tree.png differ diff --git a/examples/figures/radial_tree.svg b/examples/figures/radial_tree.svg index b293bd4..3c26233 100644 --- a/examples/figures/radial_tree.svg +++ b/examples/figures/radial_tree.svg @@ -2,164 +2,153 @@ - + - + + + + + + + + + + - + - + - - + - + - + - - + - - + - + - + - + - - + - - + - + - + - + - - + - - + - + - - + - + - + - + - - + - - + - + - - + - + - + - - + - - + - + - - + - - + - - + - + - + - - + - - - -A -B -C -D -E -F -G -H -I -J -K -L -M -N -O -P -Q -R -S -T +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T - - + + - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + 0 @@ -172,6 +161,27 @@ 2.0 2.5 - - + + + + + + +Event + +Gain + +Loss + +Duplication +Transfer Event + +Departure + +Arrival + +Confidence + +0 +1 \ No newline at end of file diff --git a/examples/figures/vertical_tree.png b/examples/figures/vertical_tree.png index 488e0e9..3ff0d1a 100644 Binary files a/examples/figures/vertical_tree.png and b/examples/figures/vertical_tree.png differ diff --git a/examples/figures/vertical_tree.svg b/examples/figures/vertical_tree.svg index 240ed13..218e23d 100644 --- a/examples/figures/vertical_tree.svg +++ b/examples/figures/vertical_tree.svg @@ -2,11 +2,11 @@ - + - + @@ -109,96 +109,96 @@ -A -B -C -D -E -F -G -H -I -J -K -L -M -N -O -P -Q -R -S -T +A +B +C +D +E +F +G +H +I +J +K +L +M +N +O +P +Q +R +S +T - - + + - - - - + + + + -0.0 +0 -0.5 +0.5 -1.0 +1.0 -1.5 +1.5 -2.0 +2.0 -2.5 -Time - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +2.5 +Time + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/notebooks/00.VerticalTrees.ipynb b/notebooks/00.VerticalTrees.ipynb index c8029b0..95f02f8 100644 --- a/notebooks/00.VerticalTrees.ipynb +++ b/notebooks/00.VerticalTrees.ipynb @@ -27,127 +27,127 @@ "\n", "\n", "\n", - "A\n", - "B\n", - "C\n", - "D\n", - "E\n", - "F\n", - "G\n", - "H\n", - "I\n", - "J\n", - "K\n", - "L\n", - "M\n", - "N\n", - "O\n", - "P\n", - "Q\n", - "R\n", - "S\n", - "T\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", + "A\n", + "B\n", + "C\n", + "D\n", + "E\n", + "F\n", + "G\n", + "H\n", + "I\n", + "J\n", + "K\n", + "L\n", + "M\n", + "N\n", + "O\n", + "P\n", + "Q\n", + "R\n", + "S\n", + "T\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "" ], "text/plain": [ - "" + "" ] }, "execution_count": 2, @@ -162,9 +162,6 @@ "my_style = ph.TreeStyle(\n", " width=600,\n", " height=600,\n", - " leaf_size=0,\n", - " node_size=0,\n", - " branch_size=2,\n", " branch_color=\"black\",\n", " font_size=12,\n", " font_family=\"Arial\",\n", @@ -178,7 +175,7 @@ }, { "cell_type": "markdown", - "id": "cc8bffea-6d35-426a-9efd-47a984fce9f2", + "id": "f0beb8e8-ee65-43fa-a1a8-b650b8862437", "metadata": {}, "source": [ "# Decorating a tree" @@ -187,7 +184,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "51a5a7c7-6040-4a56-897a-d86d821dfadb", + "id": "7fe4bc28-04cc-4b69-b58f-b4127bc1292d", "metadata": {}, "outputs": [ { @@ -197,209 +194,189 @@ "\n", "\n", - "\n", + "\n", "\n", "\n", "\n", - "\n", + "\n", "\n", "\n", "\n", + "\n", + "\n", + "\n", + "\n", "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", - "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", "\n", "\n", - "\n", - "A\n", - "B\n", - "C\n", - "D\n", - "E\n", - "F\n", - "G\n", - "H\n", - "I\n", - "J\n", - "K\n", - "L\n", - "M\n", - "N\n", - "O\n", - "P\n", - "Q\n", - "R\n", - "S\n", - "T\n", + "A\n", + "B\n", + "C\n", + "D\n", + "E\n", + "F\n", + "G\n", + "H\n", + "I\n", + "J\n", + "K\n", + "L\n", + "M\n", + "N\n", + "O\n", + "P\n", + "Q\n", + "R\n", + "S\n", + "T\n", "\n", "\n", "\n", "\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "\n", "\n", "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", "\n", "\n", - "0.0\n", + "0\n", "\n", - "0.5\n", + "0.5\n", "\n", - "1.0\n", + "1.0\n", "\n", - "1.5\n", + "1.5\n", "\n", - "2.0\n", + "2.0\n", "\n", - "2.5\n", - "Time\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", + "2.5\n", + "Time\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Lineage Status\n", + "\n", + "Ancestral\n", + "\n", + "Target Clade\n", + "Transfer Event\n", + "\n", + "Departure\n", + "\n", + "Arrival\n", + "\n", + "Expression\n", + "\n", + "0\n", + "1\n", "" ], "text/plain": [ - "" + "" ] }, "execution_count": 3, @@ -411,13 +388,12 @@ "with open(\"../examples/data/basic/tree.nwk\") as f:\n", " t = ete3.Tree(f.readline(), format=1)\n", " \n", - " \n", "my_style = ph.TreeStyle(\n", " width=600,\n", " height=600,\n", - " leaf_size=0,\n", - " node_size=0,\n", - " branch_size=2,\n", + " leaf_r=0, \n", + " node_r=0, \n", + " branch_stroke_width=2, \n", " branch_color=\"black\",\n", " font_size=12,\n", " font_family=\"Arial\",\n", @@ -435,21 +411,21 @@ "v.add_leaf_names()\n", "\n", "# Adding shapes\n", - "\n", - "v.add_leaf_shapes(leaves=[\"A\", \"B\", \"C\", \"D\"],\n", + "v.add_leaf_shapes(\n", + " leaves=[\"A\", \"B\", \"C\", \"D\"],\n", " shape=\"triangle\",\n", " fill=\"blue\",\n", - " size=10,\n", + " r=5, # Changed size=10 -> r=5\n", " stroke=\"black\",\n", " stroke_width=1,\n", - " offset=35, # distance from the leaf tip\n", + " offset=35, \n", ")\n", "\n", "v.add_leaf_shapes(\n", " leaves=[\"J\", \"M\"],\n", " shape=\"square\",\n", " fill=\"orange\",\n", - " size=8,\n", + " r=4, # Changed size=8 -> r=4\n", " stroke=\"black\",\n", " stroke_width=1,\n", " offset=35,\n", @@ -457,9 +433,9 @@ ")\n", "\n", "events = [\n", - " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"size\": 14},\n", - " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\",\"size\": 14},\n", - " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"size\": 14},\n", + " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"r\": 7}, # size=14 -> r=7\n", + " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\", \"r\": 7}, # size=14 -> r=7\n", + " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"r\": 7}, # size=14 -> r=7\n", "]\n", "\n", "v.add_branch_shapes(events, orient=None, offset=0)\n", @@ -467,20 +443,20 @@ "target = t.get_common_ancestor(\"A\", \"D\") \n", "\n", "events = [\n", - " {\"branch\": target, \"where\": 0.7, \"shape\": \"circle\", \"fill\": \"purple\", \"size\": 14},\n", + " {\"branch\": target, \"where\": 0.7, \"shape\": \"circle\", \"fill\": \"purple\", \"r\": 7}, # size=14 -> r=7\n", "]\n", "\n", "v.add_branch_shapes(events, orient=None, offset=0)\n", "\n", - "target = t.get_common_ancestor(\"E\", \"G\") # To select inner nodes\n", + "target = t.get_common_ancestor(\"E\", \"G\") \n", "v.highlight_clade(target, color=\"orange\", opacity=0.4)\n", "\n", "target = t.get_common_ancestor(\"P\", \"Q\") \n", - "v.highlight_branch(target, color=\"blue\", size=5)\n", + "v.highlight_branch(target, color=\"blue\", stroke_width=5) # size -> stroke_width\n", "\n", "\n", "target = t.get_common_ancestor(\"H\", \"J\") \n", - "v.gradient_branch(target, colors=(\"purple\", \"red\"), size=6)\n", + "v.gradient_branch(target, colors=(\"purple\", \"red\"), stroke_width=6) # size -> stroke_width\n", "\n", "transfer_data = [\n", " {\"from\": \"E\", \"to\": \"A\", \"freq\": 1.0},\n", @@ -488,82 +464,75 @@ "\n", "v.plot_transfers(\n", " transfer_data,\n", - " curve_type=\"C\", \n", - " stroke_width=3,\n", + " curve_type=\"C\", \n", + " stroke_width=3, # already used stroke_width\n", " opacity=0.6,\n", " gradient_colors=(\"purple\", \"orange\") \n", ")\n", "\n", - "\n", "v.add_time_axis(\n", " ticks=[0, 0.5, 1.0, 1.5, 2.0, 2.5], \n", " label=\"Time\", \n", " y_offset=20 \n", ")\n", "\n", - "\n", - "# --- Column 1: \"Expression\" (Blue) ---\n", - "# Create fake data: {LeafName: Value}\n", + "# Heatmaps use width and border_width (standard)\n", "data_col1 = {leaf.name: random.uniform(0, 1) for leaf in t.get_leaves()}\n", - "\n", "v.add_heatmap(\n", " data_col1,\n", " width=15,\n", - " offset=50, # Distance from tree tips (make room for labels)\n", + " offset=50,\n", " low_color=\"white\",\n", " high_color=\"blue\",\n", - " border_color=\"black\", # optional border\n", + " border_color=\"black\",\n", " border_width=0.5\n", ")\n", "\n", - "# --- Column 2: \"Enrichment\" (Red) ---\n", "data_col2 = {leaf.name: random.uniform(0, 100) for leaf in t.get_leaves()}\n", - "\n", - "# Calculate new offset: Previous Offset (50) + Previous Width (25) + Gap (5)\n", "v.add_heatmap(\n", " data_col2,\n", " width=15,\n", - " offset=70, # Placed to the right of the first column\n", - " low_color=\"#fff5f0\", # light reddish tint\n", - " high_color=\"#67000d\", # dark red\n", + " offset=70,\n", + " low_color=\"#fff5f0\",\n", + " high_color=\"#67000d\",\n", " border_color=\"black\",\n", " border_width=0.5\n", ")\n", "\n", "target1 = t.get_common_ancestor(\"A\", \"D\") \n", "target2 = t.get_common_ancestor(\"H\", \"R\") \n", - "v.add_node_shapes([target1, target2], shape=\"circle\", fill=\"red\", size=10, stroke=\"white\", stroke_width=1)\n", + "v.add_node_shapes([target1, target2], shape=\"circle\", fill=\"red\", r=5, stroke=\"white\", stroke_width=1) # size=10 -> r=5\n", "\n", - "v.d\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0ee4c452-b76b-4422-8138-68e604827c9a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "target1.name = ''\n", - "target2.name = ''\n" - ] - } - ], - "source": [ - "target1 = t.get_common_ancestor(\"E\", \"T\")\n", - "target2 = t.get_common_ancestor(\"H\", \"J\")\n", + "# Explaining the branch/node colors\n", + "v.add_categorical_legend(\n", + " palette={\"Ancestral\": \"blue\", \"Target Clade\": \"green\"}, \n", + " title=\"Lineage Status\",\n", + " x=-280, y=-280 # Top-left area\n", + ")\n", "\n", - "print(\"target1.name =\", repr(target1.name))\n", - "print(\"target2.name =\", repr(target2.name))\n" + "# Explaining the purple-to-orange transfers\n", + "v.add_transfer_legend(\n", + " colors=(\"purple\", \"orange\"),\n", + " x=-280, y=-180 # Positioned below the categorical legend\n", + ")\n", + "\n", + "# Color bar for the first heatmap (Expression)\n", + "v.add_color_bar(\n", + " low_color=\"white\", \n", + " high_color=\"blue\", \n", + " vmin=0, vmax=1, \n", + " title=\"Expression\",\n", + " x=200, # Bottom-left area\n", + ")\n", + "\n", + "\n", + "\n", + "v.d" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 13, "id": "733505c7-a793-4694-bbf6-878117be5256", "metadata": {}, "outputs": [], @@ -574,422 +543,6 @@ "# PNG requires cairosvg (install: pip install \"phylustrator[export]\")\n", "v.save_png(\"../examples/figures/vertical_tree.png\", scale=3.0)\n" ] - }, - { - "cell_type": "markdown", - "id": "48cf7f0a-0814-44a6-a759-c23cbe8eb2cb", - "metadata": {}, - "source": [ - "# Adding tree with transitions" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "9c4d584b-59f1-423c-99c3-2716555342b2", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "2b769b7e-29ae-4e57-8535-b24b23e937ba", - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with open(\"../examples/data/basic/tree2.nwk\") as f:\n", - " t = ete3.Tree(f.readline(), format=1)\n", - " \n", - "my_style = ph.TreeStyle(\n", - " width=800,\n", - " height=800,\n", - " leaf_size=0,\n", - " node_size=0,\n", - " branch_size=4,\n", - " branch_color=\"black\",\n", - ")\n", - "\n", - "data = pd.read_csv(\"../examples/data/basic/traits.tsv\", sep=\"\\t\")\n", - "\n", - "v = ph.VerticalTreeDrawer(t, style=my_style)\n", - "\n", - "v.plot_categorical_trait(\n", - " data=data, \n", - " value_col=\"X\", \n", - " node_col=\"Node\", \n", - " size=4 \n", - ")\n", - "\n", - "v.add_leaf_shapes(\n", - " leaves=[\"J\", \"M\"],\n", - " shape=\"square\",\n", - " fill=\"orange\",\n", - " size=8,\n", - " stroke=\"black\",\n", - " stroke_width=1,\n", - " offset=35,\n", - " rotation=45,\n", - ")\n", - "\n", - "v.d" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1b23fd17-041c-495f-be21-6a0b0b1619ce", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/notebooks/01.RadialTrees.ipynb b/notebooks/01.RadialTrees.ipynb index d868e29..834f2d0 100644 --- a/notebooks/01.RadialTrees.ipynb +++ b/notebooks/01.RadialTrees.ipynb @@ -1,9 +1,17 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "b827bf3b-e23d-42e5-adcb-50b541b835a0", + "metadata": {}, + "source": [ + "# Radial Tree" + ] + }, { "cell_type": "code", - "execution_count": 12, - "id": "fb785990-6ccc-47ff-8648-5b260c23b90b", + "execution_count": 1, + "id": "9876fdef-6062-43e1-b509-71ea8c07f82f", "metadata": {}, "outputs": [ { @@ -13,164 +21,153 @@ "\n", "\n", - "\n", + "\n", "\n", "\n", "\n", - "\n", + "\n", "\n", "\n", "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "\n", "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", + "\n", "\n", - "\n", - "\n", + "\n", "\n", - "\n", - "\n", - "\n", - "A\n", - "B\n", - "C\n", - "D\n", - "E\n", - "F\n", - "G\n", - "H\n", - "I\n", - "J\n", - "K\n", - "L\n", - "M\n", - "N\n", - "O\n", - "P\n", - "Q\n", - "R\n", - "S\n", - "T\n", + "A\n", + "B\n", + "C\n", + "D\n", + "E\n", + "F\n", + "G\n", + "H\n", + "I\n", + "J\n", + "K\n", + "L\n", + "M\n", + "N\n", + "O\n", + "P\n", + "Q\n", + "R\n", + "S\n", + "T\n", "\n", "\n", "\n", "\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "\n", "0\n", "\n", @@ -183,24 +180,44 @@ "2.0\n", "\n", "2.5\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Event\n", + "\n", + "Gain\n", + "\n", + "Loss\n", + "\n", + "Duplication\n", + "Transfer Event\n", + "\n", + "Departure\n", + "\n", + "Arrival\n", + "\n", + "Confidence\n", + "\n", + "0\n", + "1\n", "" ], "text/plain": [ - "" + "" ] }, - "execution_count": 12, + "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# Cell 1: Imports\n", "import phylustrator as ph\n", "import ete3\n", - "\n", + "import random\n", "\n", "with open(\"../examples/data/basic/tree.nwk\") as f:\n", " t = ete3.Tree(f.readline())\n", @@ -208,9 +225,9 @@ "my_style = ph.TreeStyle(\n", " radius=250,\n", " degrees=360,\n", - " leaf_size=0,\n", - " node_size=0,\n", - " branch_size=2,\n", + " leaf_r=0, # Changed leaf_size -> leaf_r\n", + " node_r=0, # Changed node_size -> node_r\n", + " branch_stroke_width=2, # Changed branch_size -> branch_stroke_width\n", " branch_color=\"black\",\n", " font_size=12,\n", " font_family=\"Arial\",\n", @@ -228,51 +245,48 @@ "r.draw(branch2color=node_colors)\n", "r.add_leaf_names()\n", "\n", - "\n", + "# Adding shapes\n", "r.add_leaf_shapes(\n", " leaves=[\"A\", \"B\", \"C\", \"D\"],\n", " shape=\"triangle\",\n", " fill=\"blue\",\n", - " size=10,\n", + " r=5, # size=10 -> r=5 (radius)\n", " stroke=\"black\",\n", " stroke_width=1,\n", " offset=35, \n", - " orient=True\n", ")\n", "\n", "r.add_leaf_shapes(\n", " leaves=[\"J\", \"M\"],\n", " shape=\"square\",\n", " fill=\"orange\",\n", - " size=8,\n", + " r=4, # size=8 -> r=4 (radius)\n", " stroke=\"black\",\n", " stroke_width=1,\n", " offset=35,\n", - " orient=True\n", - "\n", + " rotation=45,\n", ")\n", "\n", "events = [\n", - " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"size\": 14},\n", - " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\",\"size\": 14},\n", - " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"size\": 14},\n", + " {\"branch\": \"K\", \"where\": 0.2, \"shape\": \"circle\", \"fill\": \"red\", \"r\": 7}, # size=14 -> r=7\n", + " {\"branch\": \"K\", \"where\": 0.5, \"shape\": \"circle\", \"fill\": \"orange\",\"r\": 7}, # size=14 -> r=7\n", + " {\"branch\": \"K\", \"where\": 0.8, \"shape\": \"circle\", \"fill\": \"darkgreen\", \"r\": 7}, # size=14 -> r=7\n", "]\n", "r.add_branch_shapes(events)\n", "\n", - "import random\n", "heatmap_vals = {leaf.name: random.uniform(0, 1) for leaf in t.get_leaves()}\n", "\n", - "r.add_ring_heatmap(\n", + "r.add_heatmap( # add_ring_heatmap -> add_heatmap\n", " heatmap_vals,\n", " width=20,\n", - " padding=80, # Push it out further than the previous ring\n", + " offset=80, # padding -> offset\n", " low_color=\"#e0ecf4\",\n", " high_color=\"#8856a7\"\n", ")\n", "\n", "transfer_data = [\n", " {\"from\": \"E\", \"to\": \"A\", \"freq\": 1.0},\n", - " {\"from\": \"P\", \"to\": \"F\", \"freq\": 0.5},\n", + " {\"from\": \"R\", \"to\": \"F\", \"freq\": 0.5},\n", "]\n", "\n", "r.plot_transfers(\n", @@ -281,7 +295,7 @@ " stroke_width=3,\n", " opacity=0.6,\n", " gradient_colors=(\"purple\", \"orange\"),\n", - " arc_intensity=80 # Higher value = deeper arc towards center\n", + " arc_intensity=80 \n", ")\n", "\n", "r.add_time_axis(\n", @@ -289,28 +303,37 @@ " label=\"\",\n", " stroke=\"gray\",\n", " stroke_dasharray=\"3,3\",\n", - "\n", ")\n", "\n", "target_clade = t.get_common_ancestor(\"E\", \"G\")\n", "r.highlight_clade(target_clade, color=\"orange\", opacity=0.3)\n", "\n", + "# Testing newly implemented parity functions\n", + "target_node = t.get_common_ancestor(\"H\", \"J\")\n", + "r.highlight_branch(target_node, color=\"red\", stroke_width=4)\n", + "r.gradient_branch(target_node, colors=(\"yellow\", \"red\"), stroke_width=6)\n", + "r.add_node_names(color=\"blue\", padding=20)\n", "\n", - "r.d" + "# To explain categorical traits\n", + "r.add_categorical_legend({\"Gain\": \"blue\", \"Loss\": \"red\", \"Duplication\":\"green\"}, title=\"Event\")\n", + "\n", + "# To explain your transfers\n", + "r.add_transfer_legend(colors=(\"purple\", \"orange\"), x=250)\n", + "\n", + "# To explain your heatmap\n", + "r.add_color_bar(\"#e0ecf4\", \"#8856a7\", vmin=0, vmax=1, title=\"Confidence\")\n", + "\n", + "r.d\n" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 2, "id": "d4237e6e-b776-4f90-8a5d-18530e4298f6", "metadata": {}, "outputs": [], "source": [ - "# Cell 5: Export\n", - "# SVG always works\n", "r.save_svg(\"../examples/figures/radial_tree.svg\")\n", - "\n", - "# PNG requires cairosvg\n", "r.save_png(\"../examples/figures/radial_tree.png\", scale=3.0)" ] } diff --git a/src/phylustrator/drawing/base.py b/src/phylustrator/drawing/base.py index e35e13d..c42f74b 100644 --- a/src/phylustrator/drawing/base.py +++ b/src/phylustrator/drawing/base.py @@ -2,8 +2,10 @@ from dataclasses import dataclass from pathlib import Path import math -import io import re +import os +import base64 +from ..utils import to_hex, to_rgb, lerp_color, generate_id try: import cairosvg @@ -12,775 +14,243 @@ @dataclass class TreeStyle: + """ + Configuration object for visual styling parameters of the phylogenetic tree. + + Attributes: + width (int): Total width of the drawing canvas in pixels. Defaults to 1000. + height (int): Total height of the drawing canvas in pixels. Defaults to 1000. + radius (int): Radius of the tree layout (for radial trees). Defaults to 400. + degrees (int): Total angular span of the tree in degrees (for radial trees). Defaults to 360. + rotation (int): Global rotation offset in degrees. Defaults to -90 (starting at 12 o'clock). + margin (float): Margin padding around the tree in pixels. Defaults to 100.0. + root_stub_length (float): Length of the root node's "stub" branch. Defaults to 20.0. + leaf_r (float): Radius of the circle drawn at leaf tips. Defaults to 5.0. + leaf_color (str): CSS color string for leaf nodes. Defaults to "black". + branch_stroke_width (float): Thickness of branch lines. Defaults to 2.0. + branch_color (str): CSS color string for branches. Defaults to "black". + node_r (float): Radius of internal nodes. Defaults to 2.0. + font_size (int): Base font size for text elements. Defaults to 12. + font_family (str): Font family for text elements. Defaults to "Arial". + """ width: int = 1000 height: int = 1000 radius: int = 400 degrees: int = 360 rotation: int = -90 - leaf_size: int = 5 + margin: float = 100.0 + root_stub_length: float = 20.0 + leaf_r: float = 5.0 leaf_color: str = "black" - branch_size: int = 2 + branch_stroke_width: float = 2.0 branch_color: str = "black" - node_size: int = 2 + node_r: float = 2.0 font_size: int = 12 font_family: str = "Arial" class BaseDrawer: + """ + Abstract base class providing shared UI, rendering primitives, and export functionality. + """ def __init__(self, tree, style=None): + """ + Initialize the drawer with a tree structure and style configuration. + + Args: + tree (ete3.TreeNode): The tree object to be visualized. + style (TreeStyle, optional): Custom style configuration. If None, default style is used. + """ self.t = tree self.style = style if style else TreeStyle() - self.d = draw.Drawing(self.style.width, self.style.height, origin='center') - self.d.append(draw.Rectangle(-self.style.width/2, -self.style.height/2, - self.style.width, self.style.height, fill="white")) + self.drawing = draw.Drawing(self.style.width, self.style.height, origin='center') + self.drawing.append(draw.Rectangle(-self.style.width/2, -self.style.height/2, + self.style.width, self.style.height, fill="white")) + self.d = self.drawing self.total_tree_depth = 0 self.sf = 1.0 + self._layout_calculated = False + + def _pre_flight_check(self): + """Ensure layout calculations are performed before drawing operations.""" + if not self._layout_calculated: + self._calculate_layout() + self._layout_calculated = True + def _calculate_layout(self): + """ + Abstract method to calculate node coordinates. + Must be implemented by subclasses. + """ + raise NotImplementedError - def _draw_shape_at( - self, - x: float, - y: float, - shape: str, - fill: str, - size: float, - stroke: str | None, - stroke_width: float, - rotation: float = 0.0, - opacity: float = 1.0, - ) -> None: + def _draw_shape_at(self, x, y, shape, fill, r, stroke=None, stroke_width=1.0, rotation=0, opacity=1.0): + """Internal helper to render geometric shapes at specific coordinates.""" common = {"fill": fill, "opacity": opacity} - if stroke is not None: + if stroke: common["stroke"] = stroke common["stroke_width"] = stroke_width - shp = str(shape).lower() - rot = float(rotation) - - # drawsvg uses SVG transforms; rotate around (x,y) - transform = None if rot == 0 else f"rotate({rot},{x},{y})" - + transform = f"rotate({rotation},{x},{y})" if rotation != 0 else None if shp == "circle": - self.d.append(draw.Circle(x, y, float(size) / 2.0, **common)) - return - - if shp == "square": - s = float(size) - self.d.append( - draw.Rectangle(x - s / 2.0, y - s / 2.0, s, s, transform=transform, **common) - ) - return - - if shp == "triangle": - s = float(size) - h = s * math.sqrt(3) / 2.0 - p1 = (x, y - (2.0 / 3.0) * h) - p2 = (x - s / 2.0, y + (1.0 / 3.0) * h) - p3 = (x + s / 2.0, y + (1.0 / 3.0) * h) - - path = draw.Path(transform=transform, **common) - path.M(*p1).L(*p2).L(*p3).Z() - self.d.append(path) - return - - raise ValueError(f"Unknown shape: {shape!r}. Use circle/square/triangle.") - - - def _get_rotated_svg(self, rotation: float) -> str: - """ - Internal helper: Wraps the current drawing in a new SVG - transformed to handle rotation and resize the canvas. - """ - original_svg = self.d.as_svg() - - if rotation == 0: - return original_svg - - # 1. CLEANUP: Strip XML declaration and DOCTYPE from the inner SVG - # because they are only allowed at the very start of a file. - original_svg = re.sub(r'<\?xml.*?\?>', '', original_svg) - original_svg = re.sub(r'', '', original_svg) - - # 2. Calculate new canvas dimensions to prevent clipping - w, h = self.style.width, self.style.height - rad = math.radians(rotation) - new_w = abs(w * math.cos(rad)) + abs(h * math.sin(rad)) - new_h = abs(w * math.sin(rad)) + abs(h * math.cos(rad)) - - # 3. Wrap cleaned SVG in a group with center rotation - return ( - f'\n' - f' \n' - f' {original_svg}\n' - f' \n' - f'' - ) - - def save_svg(self, outpath: str | Path, rotation: float = 0.0) -> None: - outpath = Path(outpath) - outpath.parent.mkdir(parents=True, exist_ok=True) - - svg_content = self._get_rotated_svg(rotation) - - with open(outpath, "w", encoding="utf-8") as f: - f.write(svg_content) - - def save_png(self, outpath: str | Path, scale: float = 1.0, rotation: float = 0.0) -> None: - if cairosvg is None: - raise ImportError( - "PNG export requires 'cairosvg'. Please install it via pip." - ) - - outpath = Path(outpath) - outpath.parent.mkdir(parents=True, exist_ok=True) - - svg_content = self._get_rotated_svg(rotation) - - cairosvg.svg2png( - bytestring=svg_content.encode("utf-8"), - write_to=str(outpath), - scale=scale - ) - - def add_legend( - self, - title: str, - mapping: dict, - position="top-left", - symbol: str = "circle", - text_size: int | None = None, - padding: int = 20, - box_padding: int = 10, - box_fill: str = "white", - box_opacity: float = 0.9, - box_stroke: str = "black", - box_stroke_width: float = 1.0, - symbol_size: int = 10, - row_gap: int = 6, - ): - """Add a simple categorical legend. - - Coordinates follow the library convention: origin at canvas center. - - Parameters - ---------- - title: - Legend title. - mapping: - Dict of {label: color}. - position: - "top-left", "top-right", "bottom-left", "bottom-right", or (x, y) tuple - specifying the *top-left* corner of the legend box. - symbol: - "circle", "square", or "line". + self.drawing.append(draw.Circle(x, y, float(r), **common)) + elif shp == "square": + side = float(r) * 2.0 + self.drawing.append(draw.Rectangle(x - r, y - r, side, side, transform=transform, **common)) + elif shp == "triangle": + side = float(r) * 2.0 + h = side * math.sqrt(3) / 2.0 + p1, p2, p3 = (x, y - h*2/3), (x - side/2, y + h/3), (x + side/2, y + h/3) + path = draw.Path(transform=transform, **common).M(*p1).L(*p2).L(*p3).Z() + self.drawing.append(path) + + def add_title(self, text, font_size=24, position="top", pad=40.0, color="black", weight="bold"): + """ + Adds a title text to the drawing at a fixed cardinal position. + + Args: + text (str): The title text to display. + font_size (int, optional): Font size in pixels. Defaults to 24. + position (str, optional): Position anchor ("top", "bottom", "left", "right"). Defaults to "top". + pad (float, optional): Padding distance from the edge. Defaults to 40.0. + color (str, optional): Text color. Defaults to "black". + weight (str, optional): Font weight (e.g., "bold", "normal"). Defaults to "bold". """ - if not mapping: - return - - font_size = int(text_size) if text_size is not None else int(self.style.font_size) - font_family = getattr(self.style, "font_family", "Arial") - - # --- Size estimation (SVG has no font metrics here, so we approximate) --- - # Typical monospace-ish heuristic: average character ~0.6*font_size - def est_width(s: str) -> float: - return 0.6 * font_size * len(str(s)) - - max_label_w = max([est_width(title)] + [est_width(k) for k in mapping.keys()]) - content_w = symbol_size + 8 + max_label_w - n_rows = len(mapping) - title_h = font_size + 4 - row_h = font_size + row_gap - content_h = title_h + (n_rows * row_h) - box_w = content_w + 2 * box_padding - box_h = content_h + 2 * box_padding - w, h = self.style.width, self.style.height + tx, ty = 0, 0 + if position == "top": ty = -h/2 + pad + elif position == "bottom": ty = h/2 - pad + elif position == "left": tx = -w/2 + pad + elif position == "right": tx = w/2 - pad + self.drawing.append(draw.Text( + text, font_size, tx, ty, fill=color, font_weight=weight, + font_family=self.style.font_family, text_anchor="middle", dominant_baseline="middle" + )) - # --- Anchor (top-left of legend box) --- - if isinstance(position, tuple) and len(position) == 2: - x0, y0 = position - elif position == "top-left": - x0, y0 = -w / 2 + padding, -h / 2 + padding - elif position == "top-right": - x0, y0 = w / 2 - padding - box_w, -h / 2 + padding - elif position == "bottom-left": - x0, y0 = -w / 2 + padding, h / 2 - padding - box_h - elif position == "bottom-right": - x0, y0 = w / 2 - padding - box_w, h / 2 - padding - box_h - else: - x0, y0 = -w / 2 + padding, -h / 2 + padding - - # Background box - self.d.append( - draw.Rectangle( - x0, - y0, - box_w, - box_h, - fill=box_fill, - opacity=box_opacity, - stroke=box_stroke, - stroke_width=box_stroke_width, - ) - ) - - # Title - tx = x0 + box_padding - ty = y0 + box_padding + font_size - self.d.append( - draw.Text( - title, - font_size + 1, - tx, - ty, - font_family=font_family, - font_weight="bold", - fill="black", - ) - ) - - # Rows - y = ty + (font_size + row_gap) - sym_x = x0 + box_padding + symbol_size / 2 - text_x = x0 + box_padding + symbol_size + 8 - for label, color in mapping.items(): - if symbol == "circle": - self.d.append( - draw.Circle(sym_x, y - font_size * 0.35, symbol_size / 2, fill=color, stroke="black", stroke_width=0.5) - ) - elif symbol == "square": - self.d.append( - draw.Rectangle( - sym_x - symbol_size / 2, - y - font_size * 0.85, - symbol_size, - symbol_size, - fill=color, - stroke="black", - stroke_width=0.5, - ) - ) - elif symbol == "line": - self.d.append( - draw.Line( - sym_x - symbol_size / 2, - y - font_size * 0.5, - sym_x + symbol_size / 2, - y - font_size * 0.5, - stroke=color, - stroke_width=2, - ) - ) - - self.d.append(draw.Text(str(label), font_size, text_x, y, font_family=font_family, fill="black")) - y += row_h - - - def add_scale_bar( - self, - length: float, - label: str | None = None, - x: float | None = None, - y: float | None = None, - stroke: str = "black", - stroke_width: float = 2.0, - tick_size: float = 6.0, - font_size: int | None = None, - font_family: str | None = None, - padding: float = 10.0, - ) -> None: - """Add a simple scale bar (tree-length legend). - - Parameters - ---------- - length - Length in *tree units* (same units used to scale branches). - label - Text label. If None, uses the numeric length. - x, y - Anchor position (left end) in drawing coordinates. If omitted, places the - bar near the bottom-left with padding. - """ - px = float(length) * float(self.sf) - if label is None: - label = str(length) - - if x is None: - x = -float(self.style.width) / 2.0 + float(padding) - if y is None: - y = float(self.style.height) / 2.0 - float(padding) - - fs = int(font_size) if font_size is not None else int(self.style.font_size) - ff = font_family if font_family is not None else self.style.font_family - - # main bar - self.d.append(draw.Line(x, y, x + px, y, stroke=stroke, stroke_width=stroke_width)) - # end ticks - self.d.append(draw.Line(x, y - tick_size / 2.0, x, y + tick_size / 2.0, stroke=stroke, stroke_width=stroke_width)) - self.d.append(draw.Line(x + px, y - tick_size / 2.0, x + px, y + tick_size / 2.0, stroke=stroke, stroke_width=stroke_width)) - - # label above the bar - self.d.append(draw.Text(label, fs, x + px / 2.0, y - tick_size - 2, center=True, font_family=ff)) - - def _leaf_xy(self, leaf, offset: float = 0.0) -> tuple[float, float]: - """Return (x, y) coordinates for a leaf. - - Subclasses must implement this. The ``offset`` parameter is in pixels and - should move the position away from the leaf tip (e.g., to the right for - vertical trees, outward for radial trees). - """ - raise NotImplementedError - - - def add_leaf_shapes( - self, - leaves, - shape: str = "circle", - fill: str = "blue", - size: float = 10, - stroke: str | None = None, - stroke_width: float = 1, - offset: float = 0.0, - rotation: float = 0.0, - orient: str | None = None, # NEW: "radial" or "tangent" (mainly for radial) - opacity: float = 1.0, - ): - if leaves is None: - return - - leaf_nodes = [] - for item in leaves: - if isinstance(item, str): - try: - leaf_nodes.append(self.t & item) # ete3 lookup by name - except Exception: - continue - else: - leaf_nodes.append(item) - - # ensure layout exists once if needed - if leaf_nodes: - if not hasattr(leaf_nodes[0], "rad") and not hasattr(leaf_nodes[0], "coordinates"): - if hasattr(self, "_calculate_layout"): - self._calculate_layout() - - for leaf in leaf_nodes: - x, y = self._leaf_xy(leaf, offset=float(offset)) - - rot = float(rotation) - if orient is not None: - o = str(orient).lower().strip() - - # Use the actual rendered vector (x,y) in the drawing coordinate system. - # SVG has y pointing DOWN, so atan2(y, x) yields a clockwise angle. - a = math.degrees(math.atan2(y, x)) - - # Our triangle path is "pointing up" when rotation=0, - # so to point along direction angle 'a' we add +90 degrees. - if o == "radial": - rot = a + 90.0 - elif o == "tangent": - rot = a + 180.0 - - self._draw_shape_at( - x=x, y=y, - shape=shape, - fill=fill, - size=size, - stroke=stroke, - stroke_width=stroke_width, - rotation=rot, - opacity=opacity, - ) - - def _node_xy(self, node) -> tuple[float, float]: - """Subclasses must implement: x,y coordinates of any node.""" - raise NotImplementedError - - def _edge_point(self, child, where: float) -> tuple[float, float, float]: + def add_text(self, text, x, y, font_size=12, color="black", weight="normal", text_anchor="start", dominant_baseline="middle", rotation=0): """ - Default edge interpolation: straight line from parent to child. - Returns (x,y,angle_degrees_along_edge). - Subclasses can override (vertical should). + Adds arbitrary text at specific Cartesian (x, y) coordinates. + + Args: + text (str): The text content. + x (float): X-coordinate. + y (float): Y-coordinate. + font_size (int, optional): Font size. Defaults to 12. + color (str, optional): CSS color string. Defaults to "black". + weight (str, optional): Font weight. Defaults to "normal". + text_anchor (str, optional): Horizontal alignment ("start", "middle", "end"). Defaults to "start". + dominant_baseline (str, optional): Vertical alignment. Defaults to "middle". + rotation (float, optional): Rotation angle in degrees. Defaults to 0. """ - parent = child.up - x0, y0 = self._node_xy(parent) - x1, y1 = self._node_xy(child) - - t = max(0.0, min(1.0, float(where))) - x = x0 + (x1 - x0) * t - y = y0 + (y1 - y0) * t - - ang = math.degrees(math.atan2(y1 - y0, x1 - x0)) - return x, y, ang + transform = f"rotate({rotation}, {x}, {y})" if rotation != 0 else None + self.drawing.append(draw.Text( + text, font_size, x, y, fill=color, font_weight=weight, + font_family=self.style.font_family, text_anchor=text_anchor, + dominant_baseline=dominant_baseline, transform=transform + )) - def _where_from_time(self, node, t: float) -> float: + def save_svg(self, outpath, rotation=0): """ - Map absolute event time t to a [0,1] position along the incoming edge of `node`. - Requires node.up.time_from_origin and node.time_from_origin (Zombi parser provides these). - """ - parent = node.up - if parent is None: - return 0.0 - - t0 = float(getattr(parent, "time_from_origin", 0.0)) - t1 = float(getattr(node, "time_from_origin", t0)) - denom = (t1 - t0) if abs(t1 - t0) > 1e-12 else 1.0 + Exports the current drawing to an SVG file. - w = (float(t) - t0) / denom - if w < 0.0: - return 0.0 - if w > 1.0: - return 1.0 - return w - - def add_title( - self, - text: str, - fontsize: int = 24, - position: str = "top", - pad: float = 40.0, - rotation: float = 0.0, - color: str = "black", - font_weight: str = "bold" - ) -> None: + Args: + outpath (str or Path): The destination file path. + rotation (float, optional): Global rotation to apply to the entire canvas output. Defaults to 0. """ - Adds a title to the canvas relative to the edges. + self._pre_flight_check() + outpath = Path(outpath) + outpath.parent.mkdir(parents=True, exist_ok=True) + svg_content = self._get_rotated_svg_content(rotation) + with open(outpath, "w", encoding="utf-8") as f: + f.write(svg_content) - Parameters - ---------- - text: The string to display. - fontsize: Size of the font. - position: "top", "bottom", "left", or "right". - pad: Padding from the edge of the canvas. - rotation: Degrees to rotate the text. - color: Text color. - font_weight: "normal" or "bold". + def save_png(self, outpath, dpi=300, scale=None, rotation=0): """ - w, h = self.style.width, self.style.height + Exports the current drawing to a PNG file. - # Default center of canvas - tx, ty = 0, 0 + Requires the `cairosvg` library to be installed. - # Calculate anchor based on position - if position == "top": - ty = -h / 2 + pad - elif position == "bottom": - ty = h / 2 - pad - elif position == "left": - tx = -w / 2 + pad - elif position == "right": - tx = w / 2 - pad + Args: + outpath (str or Path): The destination file path. + dpi (int, optional): Dots per inch for resolution. Defaults to 300. + scale (float, optional): Scaling factor. If None, calculated from DPI. Defaults to None. + rotation (float, optional): Global rotation to apply to the output. Defaults to 0. - # Create the text element - # Note: drawsvg origin='center' is used here - title_obj = draw.Text( - text, - fontsize, - tx, - ty, - fill=color, - font_family=getattr(self.style, "font_family", "Arial"), - font_weight=font_weight, - text_anchor="middle", - dominant_baseline="middle", - transform=f"rotate({rotation},{tx},{ty})" if rotation != 0 else None + Raises: + ImportError: If `cairosvg` is not installed. + """ + if cairosvg is None: + raise ImportError("PNG export requires 'cairosvg'.") + self._pre_flight_check() + outpath = Path(outpath) + outpath.parent.mkdir(parents=True, exist_ok=True) + if scale is None: scale = dpi / 72.0 + svg_content = self._get_rotated_svg_content(rotation) + cairosvg.svg2png(bytestring=svg_content.encode("utf-8"), write_to=str(outpath), scale=scale) + + def _get_rotated_svg_content(self, rotation): + """Generates the SVG content string, applying a global rotation if needed.""" + if rotation == 0: return self.drawing.as_svg() + original_svg = self.drawing.as_svg() + original_svg = re.sub(r'<\?xml.*?\?>|', '', original_svg) + w, h = self.style.width, self.style.height + rad = math.radians(rotation) + new_w = abs(w * math.cos(rad)) + abs(h * math.sin(rad)) + new_h = abs(w * math.sin(rad)) + abs(h * math.cos(rad)) + return ( + f'\n' + f' \n' + f' {original_svg}\n' + f' \n' + f'' ) - - self.d.append(title_obj) - - def add_branch_shapes( - self, - specs, - default_where: float = 0.5, - orient: str | None = None, - offset: float = 0.0, - **kwargs # <--- Added to handle extra arguments like stroke_color - ) -> None: + def add_categorical_legend(self, palette, title="Legend", x=None, y=None, font_size=14, r=6): """ - Add shapes on branches. - - Each spec dict can include: - branch (str|node), where, shape, fill, size, stroke, stroke_width, rotation, opacity + Adds a categorical legend (key-value pairs of colors) to the canvas. + + Args: + palette (dict): Mapping of {label: color_string}. + title (str, optional): Title for the legend. Defaults to "Legend". + x (float, optional): Top-left X coordinate. Defaults to calculated position. + y (float, optional): Top-left Y coordinate. Defaults to calculated position. + font_size (int, optional): Text size. Defaults to 14. + r (float, optional): Radius of the colored marker dots. Defaults to 6. """ - # Accept either list[dict] or a DataFrame-like object (e.g. pandas.DataFrame) - if hasattr(specs, "to_dict") and hasattr(specs, "columns"): - specs = specs.to_dict(orient="records") - - for s in specs: - br = s.get("branch", None) - if br is None: - continue - - # resolve node - if isinstance(br, str): - try: - child = self.t & br - except Exception: - continue - else: - child = br - - if child.up is None: - continue # no parent edge - - # Determine position along the branch - if "where" in s and s.get("where") is not None: - where = float(s.get("where")) - elif "time" in s and s.get("time") is not None and hasattr(child, "time_from_origin") and hasattr(child.up, "time_from_origin"): - where = float(self._where_from_time(child, float(s.get("time")))) - else: - where = float(default_where) - - x, y, edge_ang = self._edge_point(child, where=where) - - # optional perpendicular offset - if offset != 0.0: - perp = edge_ang + 90.0 - x += float(offset) * math.cos(math.radians(perp)) - y += float(offset) * math.sin(math.radians(perp)) - - # Handle rotation - rot = float(s.get("rotation", 0.0)) - if orient is not None: - o = str(orient).lower().strip() - if o == "along": - rot = edge_ang - elif o == "perp": - rot = edge_ang + 90.0 - - # Draw the shape using internal helper - self._draw_shape_at( - x=x, y=y, - shape=s.get("shape", "circle"), - fill=s.get("fill", "blue"), - size=float(s.get("size", 10)), - stroke=s.get("stroke", None), - stroke_width=float(s.get("stroke_width", 1)), - rotation=rot, - opacity=float(s.get("opacity", 1.0)), - ) - - def add_colorbar( - self, - vmin: float, - vmax: float, - low_color: str = "#f7fbff", - high_color: str = "#08306b", - label: str | None = None, - ticks: list[float] | None = None, - n_steps: int = 60, - bar_width: float = 18.0, - bar_height: float = 180.0, - margin: float = 18.0, - x_offset: float = 0.0, - y_offset: float = 0.0, - label_pad: float = 10.0, - tick_pad: float = 12.0, - font_size: int | None = None, - font_family: str | None = None, - stroke: str = "black", - stroke_width: float = 1.0, - ): - """Add a vertical colorbar to the drawing (top-right by default). - - You can reposition with x_offset / y_offset (in px), relative to top-right anchor. - Coordinates assume drawsvg Drawing(origin='center'). + if x is None: x = -self.style.width / 2 + 30 + if y is None: y = -self.style.height / 2 + 30 + self.drawing.append(draw.Text(title, font_size + 2, x, y, font_weight="bold", + font_family=self.style.font_family, text_anchor="start")) + curr_y = y + font_size * 1.5 + for label, color in palette.items(): + self.drawing.append(draw.Circle(x + r, curr_y, r, fill=color)) + self.drawing.append(draw.Text(str(label), font_size, x + r*2.5, curr_y, + font_family=self.style.font_family, text_anchor="start", dominant_baseline="middle")) + curr_y += font_size * 1.4 + + def add_color_bar(self, low_color, high_color, vmin, vmax, title="", x=None, y=None, width=100, height=15, font_size=12): """ - def _hex_to_rgb(h: str) -> tuple[int, int, int]: - h = h.lstrip("#") - return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) - - def _rgb_to_hex(rgb: tuple[int, int, int]) -> str: - return "#{:02x}{:02x}{:02x}".format(*rgb) - - def _lerp(a: float, b: float, t: float) -> float: - return a + (b - a) * t - - c0 = _hex_to_rgb(low_color) - c1 = _hex_to_rgb(high_color) - - def _lerp_color(t: float) -> str: - t = 0.0 if t < 0.0 else (1.0 if t > 1.0 else t) - r = int(_lerp(c0[0], c1[0], t)) - g = int(_lerp(c0[1], c1[1], t)) - b = int(_lerp(c0[2], c1[2], t)) - return _rgb_to_hex((r, g, b)) - - fs = int(font_size) if font_size is not None else int(self.style.font_size) - ff = font_family if font_family is not None else self.style.font_family - - # anchor top-right (origin is center) - x0 = float(self.style.width) / 2.0 - float(margin) - float(bar_width) + float(x_offset) - y_top = -float(self.style.height) / 2.0 + float(margin) + float(y_offset) - - # outline box - self.d.append(draw.Rectangle( - x0, y_top, bar_width, bar_height, - fill="none", stroke=stroke, stroke_width=stroke_width - )) - - # gradient fill as stacked rectangles - steps = max(2, int(n_steps)) - step_h = float(bar_height) / steps - for i in range(steps): - t = i / (steps - 1) - fill = _lerp_color(1.0 - t) # high at top - y = y_top + i * step_h - self.d.append(draw.Rectangle(x0, y, bar_width, step_h + 0.5, fill=fill, stroke="none")) - - # default ticks - if ticks is None: - ticks = [vmin, (vmin + vmax) / 2.0, vmax] - - # tick marks + labels (right side) - for tv in ticks: - frac = 0.0 if vmax == vmin else (float(tv) - float(vmin)) / (float(vmax) - float(vmin)) - frac = 0.0 if frac < 0.0 else (1.0 if frac > 1.0 else frac) - - y = y_top + (1.0 - frac) * float(bar_height) - self.d.append(draw.Line( - x0 + bar_width, y, - x0 + bar_width + 6, y, - stroke=stroke, stroke_width=stroke_width - )) - self.d.append(draw.Text( - str(tv), fs, - x0 + bar_width + float(tick_pad), y + fs * 0.35, - font_family=ff - )) - - # label (above) - if label: - self.d.append(draw.Text( - label, fs, - x0 + bar_width / 2.0, y_top - float(label_pad), - center=True, font_family=ff - )) - - def add_node_shapes( - self, - nodes, - shape: str = "circle", - fill: str = "red", - size: float = 10.0, - stroke: str | None = None, - stroke_width: float = 1.0, - rotation: float = 0.0, - dx: float = 0.0, - dy: float = 0.0, - missing: str = "ignore", # "ignore" or "raise" - ) -> None: - """Draw shapes centered on (ancestral) nodes. - - Parameters - ---------- - nodes - Either: - - list of node names (str) or ete3 Node objects - - OR a list of dict specs with keys: - {"node": , "shape": ..., "fill": ..., "size": ..., "stroke": ..., "stroke_width": ..., "rotation": ..., "dx": ..., "dy": ...} - shape - "circle", "square", or "triangle" (ignored when using dict specs per-node). - rotation - Degrees. For triangles/squares; circle ignores it. - dx, dy - Pixel offsets to nudge the marker. - missing - What to do if a node name is not found: "ignore" or "raise". + Adds a continuous linear gradient color bar to the canvas. + + Args: + low_color (str): Color corresponding to the minimum value. + high_color (str): Color corresponding to the maximum value. + vmin (float): The numeric value corresponding to low_color. + vmax (float): The numeric value corresponding to high_color. + title (str, optional): Title text above the bar. Defaults to "". + x (float, optional): Top-left X coordinate. Defaults to calculated position. + y (float, optional): Top-left Y coordinate. Defaults to calculated position. + width (float, optional): Width of the color bar. Defaults to 100. + height (float, optional): Height of the color bar. Defaults to 15. + font_size (int, optional): Font size for labels. Defaults to 12. """ - # allow list[dict] specs - if isinstance(nodes, list) and nodes and isinstance(nodes[0], dict): - for s in nodes: - n = s.get("node", None) - if n is None: - continue - self._add_one_node_shape( - node=n, - shape=str(s.get("shape", shape)), - fill=str(s.get("fill", fill)), - size=float(s.get("size", size)), - stroke=s.get("stroke", stroke), - stroke_width=float(s.get("stroke_width", stroke_width)), - rotation=float(s.get("rotation", rotation)), - dx=float(s.get("dx", dx)), - dy=float(s.get("dy", dy)), - missing=missing, - ) - return - - # uniform style for all nodes - for n in nodes: - self._add_one_node_shape( - node=n, - shape=shape, - fill=fill, - size=size, - stroke=stroke, - stroke_width=stroke_width, - rotation=rotation, - dx=dx, - dy=dy, - missing=missing, - ) - - - def _add_one_node_shape( - self, - node, - shape: str, - fill: str, - size: float, - stroke: str | None, - stroke_width: float, - rotation: float, - dx: float, - dy: float, - missing: str, - ) -> None: - """Internal helper: draw one node marker centered at node coordinates.""" - # resolve node - n = node - if isinstance(node, str): - hits = self.t.search_nodes(name=node) - if not hits: - if missing == "raise": - raise ValueError(f"Node '{node}' not found in tree.") - return - n = hits[0] - - x, y = self._node_xy(n) - x += float(dx) - y += float(dy) - - shp = str(shape).lower() - half = float(size) / 2.0 - - common = {"fill": fill} - if stroke is not None: - common["stroke"] = stroke - common["stroke_width"] = float(stroke_width) - - if shp == "circle": - self.d.append(draw.Circle(x, y, half, **common)) - return - - if shp == "square": - rect = draw.Rectangle(x - half, y - half, float(size), float(size), **common) - if rotation: - rect.args["transform"] = f"rotate({float(rotation)},{x},{y})" - self.d.append(rect) - return - - if shp == "triangle": - # Up-pointing triangle centered at (x,y) - p = draw.Path(**common) - p.M(x, y - half).L(x - half, y + half).L(x + half, y + half).Z() - if rotation: - p.args["transform"] = f"rotate({float(rotation)},{x},{y})" - self.d.append(p) - return - - raise ValueError(f"Unknown shape '{shape}'. Use: circle, square, triangle.") + if x is None: x = -self.style.width / 2 + 30 + if y is None: y = self.style.height / 2 - 60 + gid = generate_id("cb_grad") + grad = draw.LinearGradient(x, y, x + width, y, id=gid) + grad.add_stop(0, low_color); grad.add_stop(1, high_color) + self.drawing.append(grad) + if title: + self.drawing.append(draw.Text(title, font_size, x, y - 10, font_weight="bold", text_anchor="start")) + self.drawing.append(draw.Rectangle(x, y, width, height, fill=grad, stroke="black", stroke_width=0.5)) + self.drawing.append(draw.Text(f"{vmin:.2g}", font_size - 2, x, y + height + 12, text_anchor="start")) + self.drawing.append(draw.Text(f"{vmax:.2g}", font_size - 2, x + width, y + height + 12, text_anchor="end")) diff --git a/src/phylustrator/drawing/radial.py b/src/phylustrator/drawing/radial.py index 7971da8..022f362 100644 --- a/src/phylustrator/drawing/radial.py +++ b/src/phylustrator/drawing/radial.py @@ -1,574 +1,556 @@ -from platform import node import drawsvg as draw import math -import random +import os +import base64 from .base import BaseDrawer - -def radial_converter(degree, radius, rotation=0): - theta = math.radians(degree + rotation) - return radius * math.cos(theta), radius * math.sin(theta) +from ..utils import polar_to_cartesian, generate_id, lerp_color class RadialTreeDrawer(BaseDrawer): + """ + Drawer class for rendering phylogenetic trees in a circular (radial) layout. + + Inherits from BaseDrawer and implements radial-specific geometry calculations + where nodes are positioned by angle and radius. + """ def __init__(self, tree, style=None): + """ + Initializes the RadialTreeDrawer and performs the initial layout calculation. + + Args: + tree (ete3.TreeNode): The tree object to be visualized. + style (TreeStyle, optional): Custom style configuration. + """ super().__init__(tree, style) self._calculate_layout() def _rot_ang(self, ang_deg: float) -> float: - # layout angle + style rotation (degrees) - return float(ang_deg) + float(getattr(self.style, "rotation", 0.0)) - + """Applies the global rotation offset to a given angle.""" + return float(ang_deg) + float(self.style.rotation) - def _node_xy(self, node): + def _node_xy(self, node) -> tuple[float, float]: + """Calculates Cartesian (x, y) coordinates for a node in polar space.""" if not (hasattr(node, "rad") and hasattr(node, "angle")): self._calculate_layout() - r = float(node.rad) - ang = self._rot_ang(node.angle) - th = math.radians(ang) - return (r * math.cos(th), r * math.sin(th)) + return polar_to_cartesian(self._rot_ang(node.angle), node.rad) - - def _leaf_xy(self, leaf, offset: float = 0.0): + def _leaf_xy(self, leaf, offset: float = 0.0) -> tuple[float, float]: + """Calculates Cartesian coordinates for a leaf tip with a radial offset.""" if not (hasattr(leaf, "rad") and hasattr(leaf, "angle")): self._calculate_layout() - r = float(leaf.rad) + float(offset) - ang = self._rot_ang(leaf.angle) - th = math.radians(ang) - return (r * math.cos(th), r * math.sin(th)) - + return polar_to_cartesian(self._rot_ang(leaf.angle), leaf.rad + float(offset)) - def _edge_point(self, child, where: float): - """ - Place a point along the drawn branch centerline (radial segment at CHILD angle). - where=0 => (r=parent.rad, angle=child.angle) - where=1 => (r=child.rad, angle=child.angle) - """ + def _edge_point(self, child, where: float) -> tuple[float, float, float]: + """Finds a point along the radial branch leading to a child node.""" parent = child.up if parent is None: - x, y = self._node_xy(child) - return x, y, 0.0 - - if not (hasattr(parent, "rad") and hasattr(child, "rad") and hasattr(child, "angle")): - self._calculate_layout() - - r0 = float(parent.rad) - r1 = float(child.rad) + return (*self._node_xy(child), self._rot_ang(child.angle)) + r_p, r_c = float(parent.rad), float(child.rad) t = max(0.0, min(1.0, float(where))) - r = r0 + (r1 - r0) * t - + r = r_p + (r_c - r_p) * t ang = self._rot_ang(child.angle) - th = math.radians(ang) - x = r * math.cos(th) - y = r * math.sin(th) - - # orientation "along" the branch direction - edge_ang = ang if (r1 - r0) >= 0 else (ang + 180.0) - return x, y, edge_ang + x, y = polar_to_cartesian(ang, r) + return x, y, ang def _calculate_layout(self): - max_dist = 0 + """ + Computes polar coordinates (radius and angle) for all nodes in the tree. + + This method is called automatically during initialization or before drawing. + """ + # 1. Radial Scaling (Distances) + max_dist = 0.0 for n in self.t.traverse("preorder"): n.dist_to_root = n.up.dist_to_root + n.dist if not n.is_root() else getattr(n, "dist", 0.0) max_dist = max(max_dist, n.dist_to_root) self.total_tree_depth = max_dist - self.sf = self.style.radius / max_dist if max_dist > 0 else 1 + self.sf = float(self.style.radius) / max_dist if max_dist > 0 else 1.0 + + for n in self.t.traverse(): + n.rad = n.dist_to_root * self.sf + + # 2. Angular Scaling (Leaves) leaves = self.t.get_leaves() - self.angle_step = self.style.degrees / len(leaves) + span = float(self.style.degrees) + angle_step = span / max(len(leaves), 1) for i, leaf in enumerate(leaves): - leaf.angle = i * self.angle_step + leaf.angle = i * angle_step + + # 3. Internal Centerings for n in self.t.traverse("postorder"): if not n.is_leaf(): - n.angle = (n.children[0].angle + n.children[-1].angle) / 2 - n.rad = n.dist_to_root * self.sf - n.xy = radial_converter(n.angle, n.rad, self.style.rotation) + if n.children: + n.angle = sum(c.angle for c in n.children) / len(n.children) + else: + n.angle = 0.0 + self._layout_calculated = True - def draw(self, branch2color=None, hide_radial_lines=None): - """ - Standard radial drawing loop. - :param branch2color: Dictionary {node_object: color_string} - :param hide_radial_lines: List of nodes to skip parent connections + def draw(self, branch2color=None): """ - hidden_set = set(hide_radial_lines) if hide_radial_lines else set() + Draws the main tree skeleton. + Connectors are drawn as circular arcs and branches as radial lines. + + Args: + branch2color (dict, optional): Dictionary mapping `ete3.TreeNode` objects to + CSS color strings. Used to color specific branches. + """ + self._pre_flight_check() for n in self.t.traverse("postorder"): - # Resolve Color - b_color = self.style.branch_color - if branch2color and n in branch2color: - b_color = branch2color[n] - if b_color == "None": continue + color = branch2color.get(n, self.style.branch_color) if branch2color else self.style.branch_color + x, y = self._node_xy(n) + if not n.is_root(): + parent = n.up + px, py = self._node_xy(parent) + start_ang, end_ang = self._rot_ang(parent.angle), self._rot_ang(n.angle) + if abs(start_ang - end_ang) > 0.001: + ax, ay = polar_to_cartesian(end_ang, parent.rad) + path = draw.Path(stroke=color, stroke_width=self.style.branch_stroke_width, fill="none") + sweep = 1 if end_ang > start_ang else 0 + path.M(px, py).A(parent.rad, parent.rad, 0, 0, sweep, ax, ay) + self.drawing.append(path) + self.drawing.append(draw.Line(ax, ay, x, y, stroke=color, stroke_width=self.style.branch_stroke_width)) + else: + self.drawing.append(draw.Line(px, py, x, y, stroke=color, stroke_width=self.style.branch_stroke_width)) + + r_v = self.style.leaf_r if n.is_leaf() else self.style.node_r + if r_v > 0: + fill = self.style.leaf_color if n.is_leaf() else color + self.drawing.append(draw.Circle(x, y, r_v, fill=fill)) - x, y = n.xy + def highlight_clade(self, node, color="lightblue", opacity=0.3, padding=10): + """ + Draws a shaded "donut wedge" background behind a specific clade. + + Args: + node (ete3.TreeNode): The root node of the clade to highlight. + color (str, optional): Fill color. Defaults to "lightblue". + opacity (float, optional): Fill opacity (0.0 to 1.0). Defaults to 0.3. + padding (float, optional): Radial padding in pixels. Defaults to 10. + """ + self._pre_flight_check() + leaves = node.get_leaves() + angles = [l.angle for l in leaves] + start_ang, end_ang = self._rot_ang(min(angles)), self._rot_ang(max(angles)) + r_i, r_o = node.rad - padding, max(l.rad for l in leaves) + padding + + p = draw.Path(fill=color, fill_opacity=opacity) + p1x, p1y = polar_to_cartesian(start_ang, r_i) + p2x, p2y = polar_to_cartesian(end_ang, r_i) + p3x, p3y = polar_to_cartesian(end_ang, r_o) + p4x, p4y = polar_to_cartesian(start_ang, r_o) + sweep = 1 if end_ang > start_ang else 0 + p.M(p1x, p1y).A(r_i, r_i, 0, 0, sweep, p2x, p2y).L(p3x, p3y).A(r_o, r_o, 0, 0, 0 if sweep == 1 else 1, p4x, p4y).Z() + self.drawing.append(p) + + def highlight_branch(self, node, color="red", stroke_width=None): + """ + Overlays a thicker or colored arc/line on the branch leading to the specific node. - # 1. Radial Line to Parent - if not n.is_root() and n not in hidden_set: - px, py = radial_converter(n.angle, n.up.rad, self.style.rotation) - self.d.append(draw.Line(px, py, x, y, stroke=b_color, stroke_width=self.style.branch_size)) + Args: + node (ete3.TreeNode): The target node. + color (str, optional): CSS color string. Defaults to "red". + stroke_width (float, optional): Thickness of the highlight. Defaults to 2x standard width. + """ + if node.is_root(): return + sw = stroke_width or self.style.branch_stroke_width * 2 + px, py = self._node_xy(node.up) + x, y = self._node_xy(node) + start_ang, end_ang = self._rot_ang(node.up.angle), self._rot_ang(node.angle) + + if abs(start_ang - end_ang) > 0.001: + ax, ay = polar_to_cartesian(end_ang, node.up.rad) + path = draw.Path(stroke=color, stroke_width=sw, fill="none", stroke_linecap="round") + sweep = 1 if end_ang > start_ang else 0 + path.M(px, py).A(node.up.rad, node.up.rad, 0, 0, sweep, ax, ay) + self.drawing.append(path) + self.drawing.append(draw.Line(ax, ay, x, y, stroke=color, stroke_width=sw, stroke_linecap="round")) + else: + self.drawing.append(draw.Line(px, py, x, y, stroke=color, stroke_width=sw, stroke_linecap="round")) + + def gradient_branch(self, node, colors=("red", "blue"), stroke_width=None): + """ + Applies a linear color gradient along a radial branch segment. - # 2. Connector Arc and Nodes - if not n.is_leaf(): - a1, a2 = n.children[0].angle, n.children[-1].angle - p = draw.Path(stroke=b_color, stroke_width=self.style.branch_size, fill="none") - sx, sy = radial_converter(a1, n.rad, self.style.rotation) - ex, ey = radial_converter(a2, n.rad, self.style.rotation) - p.M(sx, sy).A(n.rad, n.rad, 0, 0, 1, ex, ey) - self.d.append(p) - - if not n.is_root(): - self.d.append(draw.Circle(x, y, self.style.node_size, fill=b_color)) - else: - self.d.append(draw.Circle(x, y, self.style.leaf_size, fill=self.style.leaf_color)) + Args: + node (ete3.TreeNode): The target node. + colors (tuple, optional): Tuple of (start_color, end_color). Defaults to ("red", "blue"). + stroke_width (float, optional): Thickness of the branch. Defaults to style default. + """ + if node.is_root(): return + sw = stroke_width or self.style.branch_stroke_width + px, py = self._node_xy(node.up) + x, y = self._node_xy(node) + gid = generate_id("grad") + grad = draw.LinearGradient(px, py, x, y, id=gid) + grad.add_stop(0, colors[0]) + grad.add_stop(1, colors[1]) + self.drawing.append(grad) + self.highlight_branch(node, color=grad, stroke_width=sw) + + def add_leaf_names(self, font_size=None, color="black", padding=10): + """ + Adds text labels to leaf tips, automatically rotated to match the radial angle. - def add_ring(self, mapping, width=20, padding=10): - r_in = self.style.radius + padding - r_out = r_in + width - for l in self.t.get_leaves(): - if l.name not in mapping: continue - a1, a2 = l.angle - self.angle_step/2, l.angle + self.angle_step/2 - p = draw.Path(fill=mapping[l.name]) - s_i_x, s_i_y = radial_converter(a1, r_in, self.style.rotation) - e_i_x, e_i_y = radial_converter(a2, r_in, self.style.rotation) - s_o_x, s_o_y = radial_converter(a1, r_out, self.style.rotation) - e_o_x, e_o_y = radial_converter(a2, r_out, self.style.rotation) - p.M(s_o_x, s_o_y).A(r_out, r_out, 0, 0, 1, e_o_x, e_o_y).L(e_i_x, e_i_y).A(r_in, r_in, 0, 0, 0, s_i_x, s_i_y).Z() - self.d.append(p) - - def add_leaf_names(self, padding=15): + Args: + font_size (int, optional): Font size in pixels. Defaults to style default. + color (str, optional): Text color. Defaults to "black". + padding (float, optional): Distance from leaf tip to text start. Defaults to 10. + """ + fs = font_size or self.style.font_size for l in self.t.get_leaves(): - angle = l.angle + self.style.rotation - ang_mod = angle % 360.0 - - # SVG y-axis is downward, so the "right side" includes angles near 360 as well - right_side = (ang_mod <= 90.0) or (ang_mod >= 270.0) - - # place at leaf tip radius, not at a fixed style.radius - x, y = radial_converter(l.angle, float(l.rad) + float(padding), self.style.rotation) - - rot = angle if right_side else (angle - 180.0) - anchor = "start" if right_side else "end" - - self.d.append( - draw.Text( - l.name, - self.style.font_size, - x, - y, - transform=f"rotate({rot},{x},{y})", - text_anchor=anchor, - dominant_baseline="middle", - font_family=self.style.font_family, - ) - ) - - - def plot_transfers( - self, - transfers, - mode="midpoint", - curve_type="C", - filter_below=0.1, - use_gradient=True, - gradient_colors=("purple", "orange"), - color="orange", - use_thickness=True, - stroke_width=5, - arc_intensity=40, - opacity=0.6, - ): - """ - Plot transfers on a radial tree. - - mode="time" places endpoints at the correct event time along each endpoint branch - using node.time_from_origin (from Zombi parser). Otherwise uses midpoint logic. - """ - # Accept either list[dict] or a DataFrame-like object (e.g. pandas.DataFrame) - if hasattr(transfers, "to_dict") and hasattr(transfers, "columns"): - transfers = transfers.to_dict(orient="records") - - # Ensure layout exists - any_node = next(self.t.traverse("preorder")) - if not hasattr(any_node, "rad") or not hasattr(any_node, "angle"): - self._calculate_layout() + ang = (self._rot_ang(l.angle)) % 360 + x, y = polar_to_cartesian(ang, l.rad + padding) + text_rot, anchor = ang, "start" + if 90 < ang < 270: + text_rot += 180 + anchor = "end" + self.drawing.append(draw.Text(l.name, fs, x, y, fill=color, font_family=self.style.font_family, + transform=f"rotate({text_rot}, {x}, {y})", text_anchor=anchor, dominant_baseline="middle")) + + def add_node_names(self, font_size=None, color="gray", padding=5): + """ + Adds text labels to internal node positions. - name_to_node = {n.name: n for n in self.t.traverse()} - - def where_from_time(node, tt: float) -> float: - parent = node.up - if parent is None: - return 0.0 - t0 = float(getattr(parent, "time_from_origin", 0.0)) - t1 = float(getattr(node, "time_from_origin", t0)) - denom = (t1 - t0) if abs(t1 - t0) > 1e-12 else 1.0 - w = (float(tt) - t0) / denom - if w < 0.0: - return 0.0 - if w > 1.0: - return 1.0 - return w + Args: + font_size (int, optional): Font size. Defaults to style default * 0.8. + color (str, optional): Text color. Defaults to "gray". + padding (float, optional): Radial offset. Defaults to 5. + """ + fs = font_size or self.style.font_size * 0.8 + for n in self.t.traverse(): + if not n.is_leaf() and n.name: + ang = (self._rot_ang(n.angle)) % 360 + x, y = polar_to_cartesian(ang, n.rad + padding) + text_rot = ang + (180 if 90 < ang < 270 else 0) + self.drawing.append(draw.Text(n.name, fs, x, y, fill=color, font_family=self.style.font_family, + transform=f"rotate({text_rot}, {x}, {y})", text_anchor="middle", dominant_baseline="middle")) + + def add_leaf_shapes(self, leaves, shape="circle", fill="blue", r=5.0, stroke=None, stroke_width=1.0, offset=0.0, rotation=0.0, opacity=1.0, orient=False): + """ + Adds geometric markers next to specific leaf tips. + + Args: + leaves (list): List of node names (str) or objects to mark. + shape (str, optional): Shape type ("circle", "square", "triangle"). Defaults to "circle". + fill (str, optional): Fill color. Defaults to "blue". + r (float, optional): Radius/Size of the shape. Defaults to 5.0. + stroke (str, optional): Border color. Defaults to None. + stroke_width (float, optional): Border width. Defaults to 1.0. + offset (float, optional): Radial distance offset from the leaf tip. Defaults to 0.0. + rotation (float, optional): Additional rotation for the shape. Defaults to 0.0. + opacity (float, optional): Opacity (0.0 to 1.0). Defaults to 1.0. + orient (bool, optional): If True, rotates shape to match the branch angle. Defaults to False. + """ + self._pre_flight_check() + for item in leaves: + try: + node = self.t & item if isinstance(item, str) else item + ang = self._rot_ang(node.angle) + x, y = polar_to_cartesian(ang, node.rad + float(offset)) + rot = rotation + (ang if orient else 0.0) + self._draw_shape_at(x, y, shape, fill, r, stroke, stroke_width, rot, opacity) + except: continue + + def add_node_shapes(self, nodes, shape="circle", fill="red", r=5.0, stroke=None, stroke_width=1.0, rotation=0, dx=0, dy=0, orient=False): + """ + Adds geometric markers centered on specific node positions. + + Args: + nodes (list): List of node names/objects OR list of dicts with specific style per node. + shape (str, optional): Default shape. Defaults to "circle". + fill (str, optional): Default fill color. Defaults to "red". + r (float, optional): Default radius. Defaults to 5.0. + stroke (str, optional): Default stroke color. Defaults to None. + stroke_width (float, optional): Default stroke width. Defaults to 1.0. + rotation (float, optional): Default rotation. Defaults to 0. + dx (float, optional): Cartesian X offset. Defaults to 0. + dy (float, optional): Cartesian Y offset. Defaults to 0. + orient (bool, optional): If True, rotates shape to match the branch angle. Defaults to False. + """ + self._pre_flight_check() + if isinstance(nodes, list) and nodes and isinstance(nodes[0], dict): + for s in nodes: + self.add_node_shapes([s.get("node")], s.get("shape", shape), s.get("fill", fill), s.get("r", r), + s.get("stroke", stroke), s.get("stroke_width", stroke_width), s.get("rotation", rotation), orient=orient) + return + for item in nodes: + try: + node = self.t.search_nodes(name=item)[0] if isinstance(item, str) else item + ang = self._rot_ang(node.angle) + x, y = polar_to_cartesian(ang, node.rad) + rot = rotation + (ang if orient else 0.0) + self._draw_shape_at(x + dx, y + dy, shape, fill, r, stroke, stroke_width, rot) + except: continue + + def add_branch_shapes(self, specs, default_where=0.5, offset=0.0, orient=None): + """ + Adds geometric markers along branches (e.g., to visualize events). + Args: + specs (list or DataFrame): Data definitions. Must contain 'branch' key/column. + default_where (float, optional): Position along branch (0.0 to 1.0). Defaults to 0.5. + offset (float, optional): Perpendicular offset from the branch line. Defaults to 0.0. + orient (str, optional): Rotation mode: "along" (matches branch), "perp" (90 deg to branch), or None. + """ + self._pre_flight_check() + if hasattr(specs, "to_dict"): specs = specs.to_dict(orient="records") + for s in specs: + br = s.get("branch") + if not br: continue + try: + node = self.t & br if isinstance(br, str) else br + where = s.get("where", default_where) + x, y, ang = self._edge_point(node, where=where) + if offset != 0: + perp = math.radians(ang + 90) + x += offset * math.cos(perp) + y += offset * math.sin(perp) + rot = s.get("rotation", 0.0) + if orient == "along": rot = ang + elif orient == "perp": rot = ang + 90 + r_val = float(s.get("r", s.get("size", 10.0) / 2.0)) + self._draw_shape_at(x, y, s.get("shape", "circle"), s.get("fill", "blue"), r_val, + s.get("stroke"), s.get("stroke_width", 1.0), rot, s.get("opacity", 1.0)) + except: continue + + def plot_transfers(self, transfers, mode="midpoint", curve_type="C", filter_below=0.0, use_gradient=True, + gradient_colors=("purple", "orange"), color="orange", stroke_width=5.0, arc_intensity=50.0, opacity=0.6): + """ + Plots curved HGT (Horizontal Gene Transfer) links between lineages in radial space. + + Args: + transfers (list or DataFrame): List of dicts with 'from', 'to', 'freq' keys. + mode (str, optional): "time" (uses 'time' key for position) or "midpoint". Defaults to "midpoint". + curve_type (str, optional): Bezier curve type ("C" or "S"). Defaults to "C". + filter_below (float, optional): Minimum frequency to plot. Defaults to 0.0. + use_gradient (bool, optional): If True, colors fade from source to dest. Defaults to True. + gradient_colors (tuple, optional): (Start color, End color). Defaults to ("purple", "orange"). + color (str, optional): Solid color if gradients are disabled. Defaults to "orange". + stroke_width (float, optional): Base thickness of the transfer lines. Defaults to 5.0. + arc_intensity (float, optional): Curvature strength (control point distance). Defaults to 50.0. + opacity (float, optional): Opacity of the transfer lines. Defaults to 0.6. + """ + if hasattr(transfers, "to_dict"): transfers = transfers.to_dict(orient="records") + name2node = {n.name: n for n in self.t.traverse()} + self._pre_flight_check() + def get_where(node, t_ev): + p = node.up + if not p: return 0.0 + t0, t1 = float(getattr(p, "time_from_origin", 0.0)), float(getattr(node, "time_from_origin", 0.0)) + return max(0.0, min(1.0, (float(t_ev) - t0) / (t1 - t0))) if abs(t1 - t0) > 1e-12 else 0.5 + for tr in transfers: freq = float(tr.get("freq", 1.0)) - if freq < filter_below: - continue - - src = name_to_node.get(tr.get("from")) - dst = name_to_node.get(tr.get("to")) - if src is None or dst is None: - continue - - src_angle = float(src.angle) - dst_angle = float(dst.angle) - - # Compute endpoints - if ( - mode == "time" - and tr.get("time") is not None - and hasattr(src, "time_from_origin") - and hasattr(dst, "time_from_origin") - ): - tt = float(tr["time"]) - w_src = where_from_time(src, tt) - w_dst = where_from_time(dst, tt) - sx, sy, _ = self._edge_point(src, w_src) - ex, ey, _ = self._edge_point(dst, w_dst) - - # radii for control points - src_r_mid = (sx * sx + sy * sy) ** 0.5 - dst_r_mid = (ex * ex + ey * ey) ** 0.5 - + if freq < filter_below: continue + src, dst = name2node.get(tr.get("from")), name2node.get(tr.get("to")) + if not src or not dst: continue + if mode == "time" and "time" in tr: + x_s, y_s, _ = self._edge_point(src, get_where(src, tr["time"])) + x_e, y_e, _ = self._edge_point(dst, get_where(dst, tr["time"])) else: - # midpoint fallback (old behavior) - src_r = float(src.rad) - dst_r = float(dst.rad) - src_parent_r = float(src.up.rad) if src.up else (src_r - 20.0) - dst_parent_r = float(dst.up.rad) if dst.up else (dst_r - 20.0) - - src_r_mid = (src_r + src_parent_r) / 2.0 - dst_r_mid = (dst_r + dst_parent_r) / 2.0 - - sx, sy = radial_converter(src_angle, src_r_mid, self.style.rotation) - ex, ey = radial_converter(dst_angle, dst_r_mid, self.style.rotation) - - width = (stroke_width * freq) if use_thickness else stroke_width - path = draw.Path(stroke_width=width, fill="none", stroke_opacity=opacity) - + x_s, y_s, _ = self._edge_point(src, 0.5); x_e, y_e, _ = self._edge_point(dst, 0.5) + + path = draw.Path(stroke_width=stroke_width * freq, fill="none", stroke_opacity=opacity) if use_gradient: - grad_id = f"tr_grad_{random.randint(0, 999999)}" - grad = draw.LinearGradient(sx, sy, ex, ey, id=grad_id) - grad.add_stop(0, gradient_colors[0]) - grad.add_stop(1, gradient_colors[1]) - self.d.append(grad) - path.args["stroke"] = grad - else: - path.args["stroke"] = color - - path.M(sx, sy) - - if curve_type.upper() == "S": - c1x, c1y = radial_converter(src_angle, src_r_mid + arc_intensity, self.style.rotation) - c2x, c2y = radial_converter(dst_angle, dst_r_mid - arc_intensity, self.style.rotation) - path.C(c1x, c1y, c2x, c2y, ex, ey) - else: - c1x, c1y = radial_converter(src_angle, src_r_mid - arc_intensity, self.style.rotation) - c2x, c2y = radial_converter(dst_angle, dst_r_mid - arc_intensity, self.style.rotation) - path.C(c1x, c1y, c2x, c2y, ex, ey) - - self.d.append(path) - - - def add_transfer_legend( - self, - title="Transfer Frequency", - colors=("purple", "orange"), - low=0.1, - high=1.0, - source_label="Source", - arrival_label="Arrival", - show_frequency=False, - show_direction=True, - margin=20, - ): - """Add a transfer legend. - - By default this shows *direction* (two solid colors): a "Source" swatch and an - "Arrival" swatch. If you also want a frequency scale, set - ``show_frequency=True``. - - Parameters - ---------- - colors: - Tuple (source_color, arrival_color). These match the gradient endpoints - used by ``plot_transfers(..., gradient_colors=...)``. - show_frequency: - Draws a gradient bar + numeric low/high labels. - show_direction: - Draws two solid color swatches labelled source/arrival. - """ - if not (show_frequency or show_direction): - return - - font_size = 11 - num_font_size = 9 - sw = 14 - gap = 6 - pad_x = 10 - top_pad = 10 - bottom_pad = 10 - row_h = 18 - bar_h = 12 - bar_w = 110 - - # Estimate legend width from label lengths (drawsvg doesn't expose text metrics). - max_label_len = 0 - if show_direction: - max_label_len = max(len(str(source_label)), len(str(arrival_label))) - if show_frequency: - max_label_len = max(max_label_len, len(str(title))) - est_text_w = max_label_len * font_size * 0.60 - w = int(pad_x + sw + gap + est_text_w + pad_x) - if show_frequency: - w = max(w, pad_x + bar_w + pad_x) - - # Height: SVG y-axis increases downward. - content_h = top_pad - if show_frequency: - # title + bar + low/high labels + spacing - content_h += (font_size + 4) + bar_h + (num_font_size + 10) + 6 - if show_direction: - content_h += (2 * row_h) - content_h += bottom_pad - box_h = content_h - - x = -self.style.width / 2 + 30 - y = self.style.height / 2 - margin - box_h - - # Background - self.d.append( - draw.Rectangle(x, y, w, box_h, fill="white", stroke="black", stroke_width=1, opacity=0.9) - ) - - cursor_y = y + top_pad + 2 - - # Optional frequency scale - if show_frequency: - self.d.append(draw.Text(title, font_size, x + 10, cursor_y, font_family="sans-serif", font_weight="bold")) - cursor_y += 10 - - grad_id = f"legend_transfer_grad_{random.randint(0, 999999)}" - grad = draw.LinearGradient(x + 10, cursor_y + bar_h / 2, x + 10 + bar_w, cursor_y + bar_h / 2, id=grad_id) - grad.add_stop(0, colors[0]) - grad.add_stop(1, colors[1]) - self.d.append(grad) - self.d.append(draw.Rectangle(x + 10, cursor_y, bar_w, bar_h, fill=grad)) - - # low/high numeric labels - self.d.append(draw.Text(f"{low}", num_font_size, x + 10, cursor_y + bar_h + 12, font_family="sans-serif")) - self.d.append(draw.Text(f"{high}", num_font_size, x + 10 + bar_w - 15, cursor_y + bar_h + 12, font_family="sans-serif")) - cursor_y += bar_h + 24 - - # Direction swatches - if show_direction: - sw = 14 - self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[0])) - self.d.append(draw.Text(source_label, font_size, x + 30, cursor_y + 11, font_family="sans-serif")) - cursor_y += row_h - - self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[1])) - self.d.append(draw.Text(arrival_label, font_size, x + 30, cursor_y + 11, font_family="sans-serif")) - - def add_ring_heatmap( - self, - values, - width: float = 20.0, - padding: float = 10.0, - vmin: float | None = None, - vmax: float | None = None, - low_color: str = "#f7fbff", - high_color: str = "#08306b", - missing_color: str = "white", - ): - """Add a numeric heatmap ring around the radial tree. - - Parameters - ---------- - values: - Mapping leaf_name -> numeric value, OR a pandas Series, OR a DataFrame with columns - ["leaf", "value"] (any column names are fine if you pass a dict instead). - width: - Ring thickness in px. - padding: - Distance outward from the tree radius. - vmin, vmax: - Normalization bounds. If None, computed from provided values. - low_color, high_color: - Gradient endpoints (hex colors). - missing_color: - Fill used if a leaf is missing from the values. - """ - # ---- normalize inputs to dict[str, float] ---- - if hasattr(values, "to_dict") and not isinstance(values, dict): - # pandas Series - values = values.to_dict() - - if hasattr(values, "columns") and hasattr(values, "to_dict") and not isinstance(values, dict): - # DataFrame-like: try common formats - cols = list(values.columns) - if len(cols) >= 2: - leaf_col, val_col = cols[0], cols[1] - values = dict(zip(values[leaf_col].astype(str), values[val_col].astype(float))) - else: - values = {} - - if not isinstance(values, dict): - raise TypeError("values must be a dict, pandas Series, or a DataFrame with 2 columns.") - - # ---- helpers for color interpolation ---- - def _hex_to_rgb(h: str) -> tuple[int, int, int]: - h = h.lstrip("#") - return int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16) - - def _rgb_to_hex(rgb: tuple[int, int, int]) -> str: - return "#{:02x}{:02x}{:02x}".format(*rgb) - - def _lerp(a: float, b: float, t: float) -> float: - return a + (b - a) * t - - c0 = _hex_to_rgb(low_color) - c1 = _hex_to_rgb(high_color) - - def _lerp_color(t: float) -> str: - t = 0.0 if t < 0.0 else (1.0 if t > 1.0 else t) - r = int(_lerp(c0[0], c1[0], t)) - g = int(_lerp(c0[1], c1[1], t)) - b = int(_lerp(c0[2], c1[2], t)) - return _rgb_to_hex((r, g, b)) + gid = generate_id("tr_grad") + grad = draw.LinearGradient(x_s, y_s, x_e, y_e, id=gid) + grad.add_stop(0, gradient_colors[0]); grad.add_stop(1, gradient_colors[1]) + self.drawing.append(grad); path.args["stroke"] = grad + else: path.args["stroke"] = color + path.M(x_s, y_s) + mx, my = (x_s + x_e) / 2.0, (y_s + y_e) / 2.0 + dist = math.sqrt(mx**2 + my**2) + pull = max(0, dist - arc_intensity) if dist > 0 else 0 + cx = (mx/dist)*pull if dist > 0 else 0 + cy = (my/dist)*pull if dist > 0 else 0 + path.Q(cx, cy, x_e, y_e); self.drawing.append(path) + + def add_time_axis(self, ticks, label="", tick_labels=None, stroke="gray", stroke_width=1.0, stroke_dasharray="3,3", font_size=10, label_angle=90): + """ + Adds concentric rings representing evolutionary time or distance steps. + + Args: + ticks (list): List of radial distances to draw rings at. + label (str, optional): Unused in radial currently, but kept for interface consistency. + tick_labels (list, optional): Custom text for each tick. Defaults to numerical values. + stroke (str, optional): Color of rings. Defaults to "gray". + stroke_width (float, optional): Thickness of rings. Defaults to 1.0. + stroke_dasharray (str, optional): Dash pattern. Defaults to "3,3". + font_size (int, optional): Size of tick labels. Defaults to 10. + label_angle (float, optional): Angle (degrees) to place labels at. Defaults to 90. + """ + self._pre_flight_check() + for i, t in enumerate(ticks): + r = t * self.sf + self.drawing.append(draw.Circle(0, 0, r, fill="none", stroke=stroke, stroke_width=stroke_width, stroke_dasharray=stroke_dasharray)) + lx, ly = polar_to_cartesian(label_angle, r) + + # Use custom label if provided + display_text = str(tick_labels[i]) if tick_labels is not None and i < len(tick_labels) else str(t) + self.drawing.append(draw.Text(display_text, font_size, lx, ly, fill="black", stroke="white", stroke_width=0.5, + paint_order="stroke", text_anchor="middle", dominant_baseline="middle", font_family="Arial")) - # ---- determine vmin/vmax from provided numeric values ---- - vals = [] - for k, v in values.items(): + def add_heatmap(self, values, width=15.0, offset=10.0, low_color="#f7fbff", high_color="#08306b", border_color="none", border_width=0.5): + """ + Adds a circular heatmap ring surrounding the leaf tips. + + Args: + values (dict): Mapping of {node_name: numeric_value}. + width (float, optional): Thickness of the heatmap ring. Defaults to 15.0. + offset (float, optional): Radial offset from leaf tips. Defaults to 10.0. + low_color (str, optional): Color for minimum value. Defaults to "#f7fbff". + high_color (str, optional): Color for maximum value. Defaults to "#08306b". + border_color (str, optional): Border color for cells. Defaults to "none". + border_width (float, optional): Border thickness. Defaults to 0.5. + """ + if hasattr(values, "to_dict"): values = values.to_dict() + vals = [float(v) for v in values.values() if isinstance(v, (int, float))] + if not vals: return + vmin, vmax = min(vals), max(vals) + 1e-12 + angle_span = float(self.style.degrees) / len(self.t.get_leaves()) + for l in self.t.get_leaves(): + val = values.get(l.name) + fill = lerp_color(low_color, high_color, (float(val) - vmin) / (vmax - vmin)) if val is not None else "white" + start_ang, end_ang = self._rot_ang(l.angle - angle_span/2), self._rot_ang(l.angle + angle_span/2) + r_i, r_o = l.rad + offset, l.rad + offset + width + p = draw.Path(fill=fill, stroke=border_color, stroke_width=border_width) + p1x, p1y = polar_to_cartesian(start_ang, r_i); p2x, p2y = polar_to_cartesian(end_ang, r_i) + p3x, p3y = polar_to_cartesian(end_ang, r_o); p4x, p4y = polar_to_cartesian(start_ang, r_o) + sweep = 1 if end_ang > start_ang else 0 + p.M(p1x, p1y).A(r_i, r_i, 0, 0, sweep, p2x, p2y).L(p3x, p3y).A(r_o, r_o, 0, 0, 0 if sweep == 1 else 1, p4x, p4y).Z() + self.drawing.append(p) + + def add_clade_labels(self, labels, offset=40.0, stroke_width=1.5, color="black", font_size=None): + """ + Adds circular arcs outside the tree to group and label specific clades. + + Args: + labels (dict): Mapping of {node_name_or_object: label_text}. + offset (float, optional): Radial offset from the outermost leaf. Defaults to 40.0. + stroke_width (float, optional): Thickness of the bracket arc. Defaults to 1.5. + color (str, optional): Color of the arc and text. Defaults to "black". + font_size (int, optional): Font size. Defaults to style default. + """ + self._pre_flight_check() + fs = font_size or self.style.font_size + max_rad = max(l.rad for l in self.t.get_leaves()) + arc_rad = max_rad + offset + for target, text in labels.items(): try: - vals.append(float(v)) - except Exception: - continue - - if not vals: - return - - if vmin is None: - vmin = min(vals) - if vmax is None: - vmax = max(vals) - if float(vmax) == float(vmin): - vmax = float(vmin) + 1e-12 - - vmin = float(vmin) - vmax = float(vmax) - - # ---- ring geometry (same as add_ring) ---- - r_in = float(self.style.radius) + float(padding) - r_out = r_in + float(width) + node = self.t.search_nodes(name=target)[0] if isinstance(target, str) else target + leaves = node.get_leaves() + angles = [l.angle for l in leaves] + start_ang, end_ang = self._rot_ang(min(angles)), self._rot_ang(max(angles)) + mid_ang = (start_ang + end_ang) / 2.0 + p1x, p1y = polar_to_cartesian(start_ang, arc_rad); p2x, p2y = polar_to_cartesian(end_ang, arc_rad) + sweep = 1 if end_ang > start_ang else 0 + p = draw.Path(stroke=color, stroke_width=stroke_width, fill="none") + p.M(p1x, p1y).A(arc_rad, arc_rad, 0, 0, sweep, p2x, p2y) + self.drawing.append(p) + tx, ty = polar_to_cartesian(mid_ang, arc_rad + 8) + text_rot, anchor = mid_ang, "start" + if 90 < (mid_ang % 360) < 270: + text_rot += 180; anchor = "end" + self.drawing.append(draw.Text(text, fs, tx, ty, fill=color, font_family=self.style.font_family, + transform=f"rotate({text_rot}, {tx}, {ty})", text_anchor=anchor, dominant_baseline="middle")) + except: continue + + def plot_continuous_variable(self, node_to_rgb, stroke_width=None): + """ + Maps a continuous trait to branch colors using a gradient interpolation. - for l in self.t.get_leaves(): - name = str(l.name) - if name not in values: - fill = missing_color - else: - try: - v = float(values[name]) - tnorm = (v - vmin) / (vmax - vmin) - fill = _lerp_color(tnorm) - except Exception: - fill = missing_color + Args: + node_to_rgb (dict): Mapping of {node: (r, g, b)} tuples (values 0-255). + stroke_width (float, optional): Thickness of the branches. Defaults to style default. + """ + def _to_hex(rgb): return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2])) + sw = stroke_width or self.style.branch_stroke_width + for node in self.t.traverse(): + if node.is_root(): continue + c_c = node_to_rgb.get(node) or node_to_rgb.get(node.name) + c_p = node_to_rgb.get(node.up) or node_to_rgb.get(node.up.name) + if c_c and c_p: self.gradient_branch(node, stroke_width=sw, colors=(_to_hex(c_p), _to_hex(c_c))) + else: self.highlight_branch(node, stroke_width=sw) + for n in self.t.traverse("postorder"): + if not n.is_leaf() and (col := (node_to_rgb.get(n) or node_to_rgb.get(n.name))): + x, y = self._node_xy(n); self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=_to_hex(col))) - a1, a2 = l.angle - self.angle_step / 2.0, l.angle + self.angle_step / 2.0 - p = draw.Path(fill=fill, stroke="none") + def plot_categorical_trait(self, data, value_col, node_col="Node", palette=None, stroke_width=None, default_color="black"): + """ + Maps categorical traits to branch and node colors. - s_i_x, s_i_y = radial_converter(a1, r_in, self.style.rotation) - e_i_x, e_i_y = radial_converter(a2, r_in, self.style.rotation) - s_o_x, s_o_y = radial_converter(a1, r_out, self.style.rotation) - e_o_x, e_o_y = radial_converter(a2, r_out, self.style.rotation) + If a parent and child have different categories, the branch color transitions (gradient). - p.M(s_o_x, s_o_y)\ - .A(r_out, r_out, 0, 0, 1, e_o_x, e_o_y)\ - .L(e_i_x, e_i_y)\ - .A(r_in, r_in, 0, 0, 0, s_i_x, s_i_y)\ - .Z() + Args: + data (DataFrame or dict): Data containing trait values per node. + value_col (str): Column name for the trait values (if DataFrame). + node_col (str, optional): Column name for node names. Defaults to "Node". + palette (dict, optional): Color map {value: color}. Defaults to auto-generated. + stroke_width (float, optional): Branch thickness. Defaults to style default. + default_color (str, optional): Fallback color. Defaults to "black". + """ + if hasattr(data, "to_dict"): mapping = dict(zip(data[node_col].astype(str), data[value_col])) + else: mapping = data + if palette is None: + unique_vals = sorted(list(set(mapping.values()))) + defaults = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33"] + palette = {val: defaults[i % len(defaults)] for i, val in enumerate(unique_vals)} + sw = stroke_width or self.style.branch_stroke_width + def get_color(n): return palette.get(mapping.get(n.name), default_color) + for node in self.t.traverse(): + if node.is_root(): continue + c_n, c_p = get_color(node), get_color(node.up) + if c_n != c_p: self.gradient_branch(node, colors=(c_p, c_n), stroke_width=sw) + else: self.highlight_branch(node, color=c_n, stroke_width=sw) + for n in self.t.traverse("postorder"): + if not n.is_leaf(): + x, y = self._node_xy(n); self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=get_color(n))) - self.d.append(p) + def add_transfer_legend(self, colors=("purple", "orange"), labels=("Departure", "Arrival"), x=None, y=None, font_size=14): + """ + Adds a legend specifically for Horizontal Gene Transfers. + """ + palette = {labels[0]: colors[0], labels[1]: colors[1]} + self.add_categorical_legend(palette, title="Transfer Event", x=x, y=y, font_size=font_size) + def add_leaf_images(self, image_dir, extension=".png", width=40, height=40, offset=10): + """ + Places images next to leaf tips in the radial layout. + """ + self._pre_flight_check() + for leaf in self.t.get_leaves(): + lx, ly = self._leaf_xy(leaf, offset=offset) + path = os.path.join(image_dir, f"{leaf.name}{extension}") + if os.path.exists(path): + with open(path, "rb") as f: + uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}" + self.drawing.append(draw.Image(lx - width/2, ly - height/2, width, height, path=uri)) + + def add_ancestral_images(self, image_dir, extension=".png", width=40, height=40, omit=None): + """ + Places images at internal node positions in the radial layout. + """ + self._pre_flight_check() + for node in self.t.traverse(): + if not node.is_leaf(): + if omit and node.name in omit: continue + + nx, ny = self._node_xy(node) + path = os.path.join(image_dir, f"{node.name}{extension}") + if os.path.exists(path): + with open(path, "rb") as f: + uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}" + self.drawing.append(draw.Image(nx - width/2, ny - height/2, width, height, path=uri)) + + def add_scale_bar(self, length, label=None, x=None, y=None, stroke="black", stroke_width=2.0): + """ + Adds a physical scale bar representing a distance value. + """ + self._pre_flight_check() + px = float(length) * self.sf + label = label or str(length) + x = x if x is not None else -self.style.width/2 + 20 + y = y if y is not None else self.style.height/2 - 20 + self.drawing.append(draw.Line(x, y, x + px, y, stroke=stroke, stroke_width=stroke_width)) + self.drawing.append(draw.Text(label, self.style.font_size, x + px/2, y - 8, center=True)) - def highlight_clade(self, node, color="lightblue", opacity=0.3, padding=10): - """Draws a shaded sector (pie slice) behind a specific clade.""" - if node.is_leaf(): return - - leaves = node.get_leaves() - angles = [l.angle for l in leaves] - - # Define Angular bounds - # We extend by half a step to cover the "gap" between this clade and neighbors - min_ang = min(angles) - (self.angle_step / 2.0) - max_ang = max(angles) + (self.angle_step / 2.0) - - # Define Radial bounds - # Inner: from the node itself (or root 0 if you want full pie) - # Outer: furthest leaf + padding - r_inner = float(node.rad) - r_outer = max(float(l.rad) for l in leaves) + padding - - # Convert to SVG Arc Path - # Move to Inner Start -> Line to Outer Start -> Arc to Outer End -> Line to Inner End -> Arc to Inner Start - sx_in, sy_in = radial_converter(min_ang, r_inner, self.style.rotation) - ex_in, ey_in = radial_converter(max_ang, r_inner, self.style.rotation) - - sx_out, sy_out = radial_converter(min_ang, r_outer, self.style.rotation) - ex_out, ey_out = radial_converter(max_ang, r_outer, self.style.rotation) - - # Large arc flag (0 or 1) depending on if angle > 180 - angle_diff = max_ang - min_ang - large_arc = 1 if angle_diff > 180 else 0 - - path = draw.Path(fill=color, fill_opacity=opacity, stroke="none") - path.M(sx_in, sy_in)\ - .L(sx_out, sy_out)\ - .A(r_outer, r_outer, 0, large_arc, 1, ex_out, ey_out)\ - .L(ex_in, ey_in)\ - .A(r_inner, r_inner, 0, large_arc, 0, sx_in, sy_in)\ - .Z() - - self.d.append(path) - - - def add_time_axis( - self, - ticks: list[float], - label: str = "Time", - tick_size: float = 0, # Usually 0 for radial rings - label_angle: float = 90, # Angle where text labels appear - stroke: str = "#ccc", - stroke_width: float = 1.0, - stroke_dasharray: str = "4,2", # Dashed lines look better for grids - font_size: int = 10, - ): - """Add concentric rings representing time/distance.""" - - for t in ticks: - r = t * self.sf - - # 1. Draw the ring (circle) - # If tree is full 360, use Circle. If partial, use Arc (omitted for brevity, assuming 360 usually) - self.d.append(draw.Circle(0, 0, r, fill="none", stroke=stroke, - stroke_width=stroke_width, stroke_dasharray=stroke_dasharray)) - - # 2. Draw the label - # We place the label at 'label_angle' - lx, ly = radial_converter(label_angle, r, 0) # rotation=0 so angle is absolute - - # Simple background rect for readability? (Optional) - self.d.append(draw.Text(str(t), font_size, lx, ly, - fill="black", stroke="white", stroke_width=0.5, paint_order="stroke", - text_anchor="middle", dominant_baseline="middle", font_family="Arial")) - - # Axis Title? (Optional, maybe placed at the outermost tick) - if ticks: - max_r = max(ticks) * self.sf - lx, ly = radial_converter(label_angle, max_r + 20, 0) - self.d.append(draw.Text(label, font_size + 2, lx, ly, - font_weight="bold", text_anchor="middle", dominant_baseline="middle")) \ No newline at end of file diff --git a/src/phylustrator/drawing/vertical.py b/src/phylustrator/drawing/vertical.py index 8a36830..082e2ef 100644 --- a/src/phylustrator/drawing/vertical.py +++ b/src/phylustrator/drawing/vertical.py @@ -1,937 +1,661 @@ import drawsvg as draw from .base import BaseDrawer +from ..utils import to_hex, lerp_color, generate_id +import math import random +import os +import base64 class VerticalTreeDrawer(BaseDrawer): + """ + Drawer class for rendering phylogenetic trees in a rectangular vertical layout. + + Nodes are positioned using Cartesian coordinates where Y represents vertical + positioning (tips) and X represents distance/time. + """ def __init__(self, tree, style=None): + """ + Initializes the VerticalTreeDrawer and calculates the rectangular layout. + + Args: + tree (ete3.TreeNode): The tree object to be visualized. + style (TreeStyle, optional): Custom style configuration. + """ super().__init__(tree, style) self._calculate_layout() - def _node_xy(self, node): + def _node_xy(self, node) -> tuple[float, float]: + """Calculates Cartesian (x, y) coordinates for a node based on current scaling.""" if not hasattr(node, "coordinates"): self._calculate_layout() - x, y = node.coordinates # coordinates is (x,y) + x, y = node.coordinates return float(x), float(y) - def _leaf_xy(self, leaf, offset: float = 0.0): + def _leaf_xy(self, leaf, offset: float = 0.0) -> tuple[float, float]: + """Calculates coordinates for a leaf tip, with an optional horizontal offset.""" x, y = self._node_xy(leaf) return (x + float(offset), y) - def _edge_point(self, child, where: float): - """ - Place markers along the *horizontal* part of the rectangular edge: - (x_parent, y_child) -> (x_child, y_child) - - This keeps y constant, so where=0 is at the elbow (NOT on the shared vertical), - and where=1 is at the child tip. - """ + def _edge_point(self, child, where: float) -> tuple[float, float, float]: + """Finds a point along the horizontal branch segment leading to a child.""" parent = child.up if parent is None: - x, y = self._node_xy(child) - return x, y, 0.0 - - # Ensure layout exists - if not hasattr(parent, "coordinates") or not hasattr(child, "coordinates"): - self._calculate_layout() - - x_parent, _y_parent = self._node_xy(parent) - x_child, y_child = self._node_xy(child) - + return (*self._node_xy(child), 0.0) + x_p, _ = self._node_xy(parent) + x_c, y_c = self._node_xy(child) t = max(0.0, min(1.0, float(where))) - x = x_parent + (x_child - x_parent) * t - y = y_child + return x_p + (x_c - x_p) * t, y_c, (0.0 if (x_c - x_p) >= 0 else 180.0) - # Horizontal direction only (no weird rotations) - edge_ang = 0.0 if (x_child - x_parent) >= 0 else 180.0 - return x, y, edge_ang - - def _calculate_layout(self, max_width=None): + def _calculate_layout(self, max_width: float | None = None): """ - Calculates tree coordinates. - :param max_width: If provided, scales the tree to this width instead of the style width. + Computes Cartesian coordinates for all nodes in rectangular space. + + Args: + max_width (float, optional): Force a specific drawing width. """ - # 1. Calculate Distances - max_dist = 0 + # 1. Horizontal Scaling (Distances) + max_dist = 0.0 for n in self.t.traverse("preorder"): n.dist_to_root = n.up.dist_to_root + n.dist if not n.is_root() else getattr(n, "dist", 0.0) - if n.dist_to_root > max_dist: max_dist = n.dist_to_root + max_dist = max(max_dist, n.dist_to_root) self.total_tree_depth = max_dist - - # 2. Handle Scaling and Margins - horizontal_padding = 100 - target_width = max_width if max_width is not None else self.style.width - self.sf = (target_width - (horizontal_padding * 2)) / max_dist if max_dist > 0 else 1 - - # Center the root X based on padding - self.root_x = -self.style.width / 2 + horizontal_padding + pad = self.style.margin + target_w = max_width if max_width is not None else self.style.width + self.sf = (target_w - (pad * 2)) / max_dist if max_dist > 0 else 1.0 + self.root_x = -self.style.width / 2 + pad - # 3. Calculate Vertical Positions + # 2. Vertical Scaling (Leaves) leaves = self.t.get_leaves() - y_padding = 100 - y_step = (self.style.height - (y_padding * 2)) / max(len(leaves)-1, 1) - start_y = -self.style.height / 2 + y_padding + y_step = (self.style.height - (pad * 2)) / max(len(leaves)-1, 1) + start_y = -self.style.height / 2 + pad for i, leaf in enumerate(leaves): leaf.y_coord = start_y + (i * y_step) leaf.coordinates = (self.root_x + (leaf.dist_to_root * self.sf), leaf.y_coord) + # 3. Internal Centering for n in self.t.traverse("postorder"): if not n.is_leaf(): n.y_coord = sum(c.y_coord for c in n.children) / len(n.children) n.coordinates = (self.root_x + (n.dist_to_root * self.sf), n.y_coord) + self._layout_calculated = True def draw(self, branch2color=None, right_margin=0): """ - Draws the tree while reserving space on the right for images. + Draws the tree skeleton using rectangular elbows. + + Args: + branch2color (dict, optional): Mapping of ete3.TreeNode to color strings. + right_margin (float, optional): Extra space (in pixels) to reserve on the right edge + (e.g., for heatmaps or labels). Defaults to 0. """ - # 1. Re-run layout calculation with the reserved right margin - self._calculate_layout(max_width=self.style.width - right_margin) + self._pre_flight_check() + if right_margin > 0: + self._calculate_layout(max_width=self.style.width - right_margin) for n in self.t.traverse("postorder"): x, y = n.coordinates - - # Resolve Color - color = self.style.branch_color - if branch2color and n in branch2color: - color = branch2color[n] + color = branch2color.get(n, self.style.branch_color) if branch2color else self.style.branch_color - # 2. Horizontal branch if not n.is_root(): px, py = n.up.coordinates - self.d.append(draw.Line(px, y, x, y, stroke=color, - stroke_width=self.style.branch_size, stroke_linecap="round")) + # Horizontal segment + self.drawing.append(draw.Line(px, y, x, y, stroke=color, + stroke_width=self.style.branch_stroke_width, stroke_linecap="round")) else: - # Root "handle" - self.d.append(draw.Line(x - 20, y, x, y, stroke=color, stroke_width=self.style.branch_size)) + # Root stub + self.drawing.append(draw.Line(x - self.style.root_stub_length, y, x, y, + stroke=color, stroke_width=self.style.branch_stroke_width)) - # 3. Vertical connector if not n.is_leaf(): + # Vertical connector (elbow) y_min = min(c.y_coord for c in n.children) y_max = max(c.y_coord for c in n.children) - self.d.append(draw.Line(x, y_min, x, y_max, stroke=color, - stroke_width=self.style.branch_size, stroke_linecap="round")) - self.d.append(draw.Circle(x, y, self.style.node_size, fill=color)) - else: - self.d.append(draw.Circle(x, y, self.style.leaf_size, fill=self.style.leaf_color)) + self.drawing.append(draw.Line(x, y_min, x, y_max, stroke=color, + stroke_width=self.style.branch_stroke_width, stroke_linecap="round")) + if self.style.node_r > 0: + self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=color)) + elif self.style.leaf_r > 0: + self.drawing.append(draw.Circle(x, y, self.style.leaf_r, fill=self.style.leaf_color)) def highlight_clade(self, node, color="lightblue", opacity=0.3, padding=10): """ - Draws a shaded rectangle behind a specific clade. + Draws a shaded rectangular background behind a specific clade. + + Args: + node (ete3.TreeNode): The root of the clade. + color (str, optional): Fill color. Defaults to "lightblue". + opacity (float, optional): Fill opacity. Defaults to 0.3. + padding (float, optional): Padding around the clade box. Defaults to 10. """ + self._pre_flight_check() leaves = node.get_leaves() - - # Calculate bounds based on node and leaf coordinates x_start, _ = node.coordinates - - # X range: From node to the furthest tip in this clade x_max = max(l.coordinates[0] for l in leaves) - - # Y range: Spanning all leaves in the clade y_min = min(l.y_coord for l in leaves) y_max = max(l.y_coord for l in leaves) - - # Geometry: - # width is (x_max - x_start) + a small padding for the tips - # height is (y_max - y_min) + padding on top/bottom - rect_x = x_start - (padding / 2) - rect_y = y_min - padding - rect_w = (x_max - x_start) + padding # Reduced the "60" extension to just padding - rect_h = (y_max - y_min) + (2 * padding) - - self.d.append(draw.Rectangle( - rect_x, rect_y, rect_w, rect_h, - fill=color, - fill_opacity=opacity, - stroke="none" + self.drawing.append(draw.Rectangle( + x_start - (padding / 2), y_min - padding, + (x_max - x_start) + padding, (y_max - y_min) + (2 * padding), + fill=color, fill_opacity=opacity, stroke="none" )) - def highlight_branch(self, node, color="red", size=None): + def highlight_branch(self, node, color="red", stroke_width=None): + """ + Overlays a thicker or colored line on a specific branch. + + Args: + node (ete3.TreeNode): The target node. + color (str, optional): CSS color string. Defaults to "red". + stroke_width (float, optional): Thickness. Defaults to 2x style default. + """ if node.is_root(): return - s_width = size if size else self.style.branch_size * 2 + sw = stroke_width if stroke_width is not None else self.style.branch_stroke_width * 2 x, y = node.coordinates - px, py = node.up.coordinates - self.d.append(draw.Line(px, y, x, y, stroke=color, stroke_width=s_width, stroke_linecap="round")) - - def gradient_branch(self, node, colors=("red", "blue"), size=None): - if node.is_root(): - return - - s_width = size if size else self.style.branch_size - x, y = node.coordinates - px, _ = node.up.coordinates - - # FIX: Use the node's memory address (id(node)) to ensure a - # truly unique ID for this specific branch segment. - grad_id = f"grad_{id(node)}" - - grad = draw.LinearGradient(px, y, x, y, id=grad_id) - grad.add_stop(0, colors[0]) - grad.add_stop(1, colors[1]) - - self.d.append(grad) - self.d.append(draw.Line(px, y, x, y, stroke=grad, stroke_width=s_width)) + px, _ = node.up.coordinates + self.drawing.append(draw.Line(px, y, x, y, stroke=color, stroke_width=sw, stroke_linecap="round")) + def gradient_branch(self, node, colors=("red", "blue"), stroke_width=None): + """ + Applies a linear color gradient along a branch segment. - def add_leaf_names( - self, - font_size=None, - color=None, - font_family=None, - rotation=0, - padding=10 - ): + Args: + node (ete3.TreeNode): The target node. + colors (tuple, optional): (start_color, end_color). Defaults to ("red", "blue"). + stroke_width (float, optional): Thickness. Defaults to style default. """ - Adds leaf labels where the center of the text is aligned with the leaf tip. + if node.is_root(): return + sw = stroke_width if stroke_width is not None else self.style.branch_stroke_width + x, y, (px, _) = *node.coordinates, node.up.coordinates + gid = generate_id("grad") + grad = draw.LinearGradient(px, y, x, y, id=gid) + grad.add_stop(0, colors[0]) + grad.add_stop(1, colors[1]) + self.drawing.append(grad) + self.drawing.append(draw.Line(px, y, x, y, stroke=grad, stroke_width=sw)) + + def add_leaf_names(self, font_size=None, color="black", rotation=0, padding=10): """ - fs = font_size if font_size is not None else self.style.font_size - ff = font_family if font_family is not None else self.style.font_family - text_color = color if color is not None else "black" + Adds text labels to the leaf tips. + Args: + font_size (int, optional): Font size. Defaults to style default. + color (str, optional): Text color. Defaults to "black". + rotation (float, optional): Rotation angle. Defaults to 0. + padding (float, optional): Horizontal padding. Defaults to 10. + """ + fs = font_size or self.style.font_size for l in self.t.get_leaves(): x, y = l.coordinates - - # The anchor point is the leaf tip + padding - tx = x + padding - ty = y - - # Rotate around the center (tx, ty) + tx, ty = x + padding, y transform = f"rotate({rotation}, {tx}, {ty})" if rotation != 0 else "" + self.drawing.append(draw.Text(l.name, fs, tx, ty, fill=color, font_family=self.style.font_family, + transform=transform, text_anchor="start", dominant_baseline="middle")) - self.d.append(draw.Text( - l.name, - fs, - tx, - ty, - fill=text_color, - font_family=ff, - transform=transform, - text_anchor="middle", # Centers horizontally - dominant_baseline="middle" # Centers vertically - )) - - def add_node_names( - self, - font_size=None, - color="gray", - font_family=None, - x_offset=-15, - y_offset=-10, - rotation=0 - ): + def add_node_names(self, font_size=None, color="gray", x_offset=-15, y_offset=-10, rotation=0): """ - Adds labels to internal nodes centered on the offset coordinate. + Adds text labels to internal nodes. + + Args: + font_size (int, optional): Font size. Defaults to style default * 0.8. + color (str, optional): Text color. Defaults to "gray". + x_offset (float, optional): Horizontal offset. Defaults to -15. + y_offset (float, optional): Vertical offset. Defaults to -10. + rotation (float, optional): Rotation angle. Defaults to 0. """ - fs = font_size if font_size is not None else self.style.font_size * 0.8 - ff = font_family if font_family is not None else self.style.font_family - + fs = font_size or self.style.font_size * 0.8 for n in self.t.traverse(): - if not n.is_leaf(): - if not n.name: - continue - + if not n.is_leaf() and n.name: x, y = n.coordinates - - # Apply offsets to the center point - tx = x + x_offset - ty = y + y_offset - + tx, ty = x + x_offset, y + y_offset transform = f"rotate({rotation}, {tx}, {ty})" if rotation != 0 else "" - - self.d.append(draw.Text( - n.name, - fs, - tx, - ty, - fill=color, - font_family=ff, - transform=transform, - text_anchor="middle", # Centers horizontally - dominant_baseline="middle" # Centers vertically + self.drawing.append(draw.Text( + n.name, fs, tx, ty, fill=color, font_family=self.style.font_family, + transform=transform, text_anchor="middle", dominant_baseline="middle" )) - - def add_time_axis( - self, - ticks: list[float], - label: str = "Time", - tick_size: float = 6.0, - padding: float = 20.0, - y_offset: float = 0.0, - root_stub: float = 20.0, - stroke: str = "black", - stroke_width: float = 2.0, - font_size: int | None = None, - font_family: str | None = None, - # --- grid options --- - grid: bool = False, - grid_stroke: str = "#cccccc", - grid_stroke_width: float = 1.0, - grid_opacity: float = 0.5, - # --- NEW: custom tick labels --- - tick_labels: dict[float, str] | None = None, - ) -> None: - """ - Draw a horizontal time axis. Optionally draw vertical grid lines at each tick - across the tree. - - tick_labels lets you display labels that differ from the numeric tick positions, - e.g. show "-2" at tick position 2.0 (backwards time). - """ - if not ticks: + def add_leaf_shapes(self, leaves, shape="circle", fill="blue", r=5.0, stroke=None, stroke_width=1.0, offset=0.0, rotation=0.0, opacity=1.0): + """ + Adds geometric markers next to leaf tips. + + Args: + leaves (list): List of node names or objects. + shape (str, optional): Shape type. Defaults to "circle". + fill (str, optional): Fill color. Defaults to "blue". + r (float, optional): Radius/size. Defaults to 5.0. + stroke (str, optional): Stroke color. Defaults to None. + stroke_width (float, optional): Stroke width. Defaults to 1.0. + offset (float, optional): Horizontal offset. Defaults to 0.0. + rotation (float, optional): Rotation angle. Defaults to 0.0. + opacity (float, optional): Opacity. Defaults to 1.0. + """ + self._pre_flight_check() + for item in leaves: + try: + node = self.t & item if isinstance(item, str) else item + x, y = self._leaf_xy(node, offset=float(offset)) + self._draw_shape_at(x, y, shape, fill, r, stroke, stroke_width, rotation, opacity) + except: continue + + def add_node_shapes(self, nodes, shape="circle", fill="red", r=5.0, stroke=None, stroke_width=1.0, rotation=0, dx=0, dy=0): + """ + Adds geometric markers at node positions. + + Args: + nodes (list): List of node names/objects or style dicts. + shape (str, optional): Default shape. Defaults to "circle". + fill (str, optional): Default fill color. Defaults to "red". + r (float, optional): Default radius. Defaults to 5.0. + stroke (str, optional): Default stroke color. Defaults to None. + stroke_width (float, optional): Default stroke width. Defaults to 1.0. + rotation (float, optional): Default rotation. Defaults to 0. + dx (float, optional): X offset. Defaults to 0. + dy (float, optional): Y offset. Defaults to 0. + """ + self._pre_flight_check() + if isinstance(nodes, list) and nodes and isinstance(nodes[0], dict): + for s in nodes: + self.add_node_shapes([s.get("node")], s.get("shape", shape), s.get("fill", fill), s.get("r", r), + s.get("stroke", stroke), s.get("stroke_width", stroke_width), s.get("rotation", rotation)) return + for item in nodes: + try: + node = self.t.search_nodes(name=item)[0] if isinstance(item, str) else item + x, y = self._node_xy(node) + self._draw_shape_at(x + dx, y + dy, shape, fill, r, stroke, stroke_width, rotation) + except: continue + + def add_branch_shapes(self, specs, default_where=0.5, offset=0.0, orient=None): + """ + Adds geometric markers along branches (useful for modeling events). - # Ensure layout exists - any_node = next(self.t.traverse("preorder")) - if not hasattr(any_node, "coordinates") or not hasattr(any_node, "y_coord"): - self._calculate_layout() - - sf = float(self.sf) - - # Font defaults - fs = font_size if font_size is not None else self.style.font_size - ff = font_family if font_family is not None else self.style.font_family - - # Tree vertical extent + Args: + specs (list): List of dicts describing shapes and positions. + default_where (float, optional): Position along branch (0-1). Defaults to 0.5. + offset (float, optional): Perpendicular offset. Defaults to 0.0. + orient (str, optional): "along", "perp", or None. + """ + self._pre_flight_check() + if hasattr(specs, "to_dict"): specs = specs.to_dict(orient="records") + for s in specs: + br = s.get("branch") + if not br: continue + try: + node = self.t & br if isinstance(br, str) else br + where = s.get("where", default_where) + x, y, ang = self._edge_point(node, where=where) + if offset != 0: + perp = math.radians(ang + 90) + x += offset * math.cos(perp) + y += offset * math.sin(perp) + rot = s.get("rotation", 0.0) + if orient == "along": rot = ang + elif orient == "perp": rot = ang + 90 + r_val = float(s.get("r", s.get("size", 10.0) / 2.0)) + self._draw_shape_at(x, y, s.get("shape", "circle"), s.get("fill", "blue"), r_val, + s.get("stroke"), s.get("stroke_width", 1.0), rot, s.get("opacity", 1.0)) + except: continue + + def add_time_axis(self, ticks, label="Time", tick_labels=None, tick_size=6.0, padding=20.0, y_offset=0.0, stroke_width=2.0, grid=False): + """ + Adds a linear axis at the bottom to represent time or distance. + + Args: + ticks (list[float]): Numerical positions for the ticks. + label (str): Label for the axis itself. Defaults to "Time". + tick_labels (list[str], optional): Custom strings for the tick values. + tick_size (float, optional): Length of the tick marks. Defaults to 6.0. + padding (float, optional): Distance from the tree tips. Defaults to 20.0. + y_offset (float, optional): Extra manual offset on the Y axis. Defaults to 0.0. + stroke_width (float, optional): Thickness of the axis line. Defaults to 2.0. + grid (bool, optional): Whether to draw vertical grid lines. Defaults to False. + """ + self._pre_flight_check() leaves = self.t.get_leaves() - min_y = min(float(l.y_coord) for l in leaves) - max_y = max(float(l.y_coord) for l in leaves) - - # Place axis below the tree - y_axis = max_y + float(padding) + float(y_offset) - - # Axis X range - x_left = float(self.root_x) - float(root_stub) - x_right = float(self.root_x) + (max(float(t) for t in ticks) * sf) - - # Baseline - self.d.append(draw.Line( - x_left, y_axis, - x_right, y_axis, - stroke=stroke, - stroke_width=stroke_width, - )) - - tick_labels = tick_labels or {} - - # Ticks (+ optional vertical grid lines) - for tt in ticks: - tt = float(tt) - x = float(self.root_x) + tt * sf - + min_y, max_y = min(l.y_coord for l in leaves), max(l.y_coord for l in leaves) + y_axis = max_y + padding + y_offset + x_left, x_right = self.root_x - self.style.root_stub_length, self.root_x + (max(ticks) * self.sf) + self.drawing.append(draw.Line(x_left, y_axis, x_right, y_axis, stroke="black", stroke_width=stroke_width)) + for i, tt in enumerate(ticks): + x = self.root_x + tt * self.sf if grid: - self.d.append(draw.Line( - x, min_y, - x, max_y, - stroke=grid_stroke, - stroke_width=grid_stroke_width, - stroke_opacity=grid_opacity, - )) - - self.d.append(draw.Line( - x, y_axis, - x, y_axis + float(tick_size), - stroke=stroke, - stroke_width=stroke_width, - )) - - text = tick_labels.get(tt, str(tt)) - self.d.append(draw.Text( - text, - fs, - x, - y_axis + float(tick_size) + fs, - center=True, - font_family=ff, - )) - - # Axis label - self.d.append(draw.Text( - label, - fs, - (x_left + x_right) / 2.0, - y_axis + float(tick_size) + fs * 2.2, - center=True, - font_family=ff, - )) - - - - def plot_transfers( - self, - transfers, - mode="midpoint", - curve_type="C", - filter_below=0.0, - use_gradient=True, - gradient_colors=("purple", "orange"), - color="orange", - use_thickness=True, - stroke_width=5, - arc_intensity=40, - opacity=0.6, - ): - """ - Plot horizontal gene transfers as curved lines. - - Parameters - ---------- - transfers - Either a list of dicts OR a pandas.DataFrame with at least: - 'from', 'to', 'time' (optional), 'freq' (optional) - mode - "midpoint" (default) or "time". - - midpoint: attach curves to mid-branch positions (old behavior) - - time: attach curves at the event time along each endpoint branch using - node.time_from_origin (Zombi parser provides this) - curve_type - "C" (default) or "S" - """ - # Accept either list[dict] or a DataFrame-like object (e.g. pandas.DataFrame) + self.drawing.append(draw.Line(x, min_y, x, max_y, stroke="#ccc", stroke_width=1.0, stroke_opacity=0.5)) + self.drawing.append(draw.Line(x, y_axis, x, y_axis + tick_size, stroke="black", stroke_width=stroke_width)) + + # Use custom label if provided, else use the numerical value + display_text = str(tick_labels[i]) if tick_labels is not None and i < len(tick_labels) else str(tt) + self.drawing.append(draw.Text(display_text, self.style.font_size, x, y_axis + tick_size + self.style.font_size, + text_anchor="middle", font_family=self.style.font_family)) + self.drawing.append(draw.Text(label, self.style.font_size, (x_left + x_right) / 2.0, y_axis + tick_size + self.style.font_size * 2.5, + text_anchor="middle", font_family=self.style.font_family)) + + def plot_transfers(self, transfers, mode="midpoint", curve_type="C", filter_below=0.0, use_gradient=True, + gradient_colors=("purple", "orange"), color="orange", stroke_width=5.0, arc_intensity=40.0, opacity=0.6): + """ + Plots curved arrows representing Horizontal Gene Transfer (HGT) events. + + Args: + transfers (list): List of transfer event dicts or DataFrame. + mode (str, optional): "midpoint" or "time". Defaults to "midpoint". + curve_type (str, optional): "C" for C-shape, "S" for S-shape. Defaults to "C". + filter_below (float, optional): Minimum frequency filter. Defaults to 0.0. + use_gradient (bool, optional): Use color gradient. Defaults to True. + gradient_colors (tuple, optional): Gradient colors. Defaults to ("purple", "orange"). + color (str, optional): Fallback color. Defaults to "orange". + stroke_width (float, optional): Stroke width. Defaults to 5.0. + arc_intensity (float, optional): Curve height. Defaults to 40.0. + opacity (float, optional): Opacity. Defaults to 0.6. + """ if hasattr(transfers, "to_dict") and hasattr(transfers, "columns"): transfers = transfers.to_dict(orient="records") - name2node = {n.name: n for n in self.t.traverse()} + self._pre_flight_check() - # Ensure layout exists - any_node = next(self.t.traverse("preorder")) - if not hasattr(any_node, "coordinates") or not hasattr(any_node, "y_coord"): - self._calculate_layout() - - def where_from_time(node, tt: float) -> float: - parent = node.up - if parent is None: - return 0.0 - t0 = float(getattr(parent, "time_from_origin", 0.0)) - t1 = float(getattr(node, "time_from_origin", t0)) - denom = (t1 - t0) if abs(t1 - t0) > 1e-12 else 1.0 - w = (float(tt) - t0) / denom - if w < 0.0: - return 0.0 - if w > 1.0: - return 1.0 - return w + def get_where(node, t_ev): + p = node.up + if not p: return 0.0 + t0, t1 = float(getattr(p, "time_from_origin", 0.0)), float(getattr(node, "time_from_origin", 0.0)) + if abs(t1 - t0) > 1e-12: + return max(0.0, min(1.0, (float(t_ev) - t0) / (t1 - t0))) + return 0.5 for tr in transfers: freq = float(tr.get("freq", 1.0)) - if freq < filter_below: - continue - - src = name2node.get(tr.get("from")) - dst = name2node.get(tr.get("to")) - if src is None or dst is None: - continue - - # Compute endpoints - if ( - mode == "time" - and tr.get("time") is not None - and hasattr(src, "time_from_origin") - and hasattr(dst, "time_from_origin") - ): - tt = float(tr["time"]) - w_src = where_from_time(src, tt) - w_dst = where_from_time(dst, tt) - x_start, y_start, _ = self._edge_point(src, w_src) - x_end, y_end, _ = self._edge_point(dst, w_dst) + if freq < filter_below: continue + src, dst = name2node.get(tr.get("from")), name2node.get(tr.get("to")) + if not src or not dst: continue + if mode == "time" and "time" in tr: + x_s, y_s, _ = self._edge_point(src, get_where(src, tr["time"])) + x_e, y_e, _ = self._edge_point(dst, get_where(dst, tr["time"])) else: - # midpoint fallback (old behavior) - y_start, y_end = float(src.y_coord), float(dst.y_coord) - src_px = src.up.coordinates[0] if src.up else (src.coordinates[0] - 20) - dst_px = dst.up.coordinates[0] if dst.up else (dst.coordinates[0] - 20) - x_start = (float(src_px) + float(src.coordinates[0])) / 2.0 - x_end = (float(dst_px) + float(dst.coordinates[0])) / 2.0 - - # Styling - width = (stroke_width * freq) if use_thickness else stroke_width - path = draw.Path(stroke_width=width, fill="none", stroke_opacity=opacity) - + x_s, y_s, _ = self._edge_point(src, 0.5) + x_e, y_e, _ = self._edge_point(dst, 0.5) + + path = draw.Path(stroke_width=stroke_width * freq, fill="none", stroke_opacity=opacity) if use_gradient: - grad_id = f"tr_grad_{random.randint(0, 999999)}" - grad = draw.LinearGradient(x_start, y_start, x_end, y_end, id=grad_id) + gid = generate_id("tr_grad") + grad = draw.LinearGradient(x_s, y_s, x_e, y_e, id=gid) grad.add_stop(0, gradient_colors[0]) grad.add_stop(1, gradient_colors[1]) - self.d.append(grad) + self.drawing.append(grad) path.args["stroke"] = grad else: path.args["stroke"] = color - - # Geometry - path.M(x_start, y_start) - - if curve_type.upper() == "S": - dx = x_end - x_start - sgn = 1 if dx >= 0 else -1 - arc = abs(arc_intensity) - cp1x = x_start + (sgn * arc) - cp2x = x_end - (sgn * arc) - path.C(cp1x, y_start, cp2x, y_end, x_end, y_end) - else: - # C-curve - path.C( - x_start - arc_intensity, y_start, - x_end - arc_intensity, y_end, - x_end, y_end - ) - - self.d.append(path) - - - def add_transfer_legend( - self, - title="Transfer Frequency", - colors=("purple", "orange"), - low=0.1, - high=1.0, - source_label="Source", - arrival_label="Arrival", - show_frequency=False, - show_direction=True, - margin=20, - ): - """Add a transfer legend. - - By default this shows *direction* (two solid colors): a "Source" swatch and an - "Arrival" swatch. If you also want a frequency scale, set - ``show_frequency=True``. - - Parameters - ---------- - colors: - Tuple (source_color, arrival_color). These match the gradient endpoints - used by ``plot_transfers(..., gradient_colors=...)``. - show_frequency: - Draws a gradient bar + numeric low/high labels. - show_direction: - Draws two solid color swatches labelled source/arrival. - """ - if not (show_frequency or show_direction): - return - - font_size = 11 - num_font_size = 9 - sw = 14 - gap = 6 - pad_x = 10 - top_pad = 10 - bottom_pad = 10 - row_h = 18 - bar_h = 12 - bar_w = 110 - - # Estimate legend width from label lengths (drawsvg doesn't expose text metrics). - max_label_len = 0 - if show_direction: - max_label_len = max(len(str(source_label)), len(str(arrival_label))) - if show_frequency: - max_label_len = max(max_label_len, len(str(title))) - est_text_w = max_label_len * font_size * 0.60 - w = int(pad_x + sw + gap + est_text_w + pad_x) - if show_frequency: - w = max(w, pad_x + bar_w + pad_x) - - # Height: SVG y-axis increases downward. - content_h = top_pad - if show_frequency: - # title + bar + low/high labels + spacing - content_h += (font_size + 4) + bar_h + (num_font_size + 10) + 6 - if show_direction: - content_h += (2 * row_h) - content_h += bottom_pad - box_h = content_h - - x = -self.style.width / 2 + 30 - y = self.style.height / 2 - margin - box_h - - self.d.append(draw.Rectangle(x, y, w, box_h, fill="white", stroke="black", stroke_width=1, opacity=0.9)) - - cursor_y = y + top_pad + 2 - - if show_frequency: - self.d.append(draw.Text(title, font_size, x + 10, cursor_y, font_family="sans-serif", font_weight="bold")) - cursor_y += 10 - - grad_id = f"legend_transfer_grad_{random.randint(0, 999999)}" - grad = draw.LinearGradient(x + 10, cursor_y + bar_h / 2, x + 10 + bar_w, cursor_y + bar_h / 2, id=grad_id) - grad.add_stop(0, colors[0]) - grad.add_stop(1, colors[1]) - self.d.append(grad) - self.d.append(draw.Rectangle(x + 10, cursor_y, bar_w, bar_h, fill=grad)) - - self.d.append(draw.Text(f"{low}", num_font_size, x + 10, cursor_y + bar_h + 12, font_family="sans-serif")) - self.d.append(draw.Text(f"{high}", num_font_size, x + 10 + bar_w - 15, cursor_y + bar_h + 12, font_family="sans-serif")) - cursor_y += bar_h + 24 - - if show_direction: - sw = 14 - self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[0])) - self.d.append(draw.Text(source_label, font_size, x + 30, cursor_y + 11, font_family="sans-serif")) - cursor_y += row_h - - self.d.append(draw.Rectangle(x + 10, cursor_y, sw, sw, fill=colors[1])) - self.d.append(draw.Text(arrival_label, 11, x + 30, cursor_y + 11, font_family="sans-serif")) - - - def add_heatmap( - self, - values, - width: float = 20.0, - offset: float = 10.0, - vmin: float | None = None, - vmax: float | None = None, - low_color: str = "#f7fbff", - high_color: str = "#08306b", - missing_color: str = "white", - border_color: str = "none", - border_width: float = 0.5 - ): - """Add a vertical heatmap strip next to the tree tips.""" - - # 1. Normalize values to dict - if hasattr(values, "to_dict") and not isinstance(values, dict): - values = values.to_dict() - if hasattr(values, "columns") and hasattr(values, "to_dict") and not isinstance(values, dict): - cols = list(values.columns) - if len(cols) >= 2: - values = dict(zip(values[cols[0]].astype(str), values[cols[1]].astype(float))) - else: - values = {} - - # 2. Helper: Parse Color (Hex or Name) to RGB Tuple - def _to_rgb(color_str): - color_str = str(color_str).strip() - - # Handle Hex (e.g., #FF0000 or #F00) - if color_str.startswith("#"): - h = color_str.lstrip("#") - if len(h) == 3: # Expand shorthand #fff -> #ffffff - h = "".join([c*2 for c in h]) - return tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) - - # Handle Basic Names - # (Simple lookup to avoid dependencies like matplotlib) - common_names = { - "white": (255, 255, 255), "black": (0, 0, 0), - "red": (255, 0, 0), "green": (0, 128, 0), - "blue": (0, 0, 255), "orange": (255, 165, 0), - "purple": (128, 0, 128), "yellow": (255, 255, 0), - "gray": (128, 128, 128), "grey": (128, 128, 128), - "cyan": (0, 255, 255), "magenta": (255, 0, 255), - "lime": (0, 255, 0), "darkgreen": (0, 100, 0), - "navy": (0, 0, 128), "teal": (0, 128, 128) - } - if color_str.lower() in common_names: - return common_names[color_str.lower()] - - # Fallback if unknown name - raise ValueError(f"Heatmap interpolation needs Hex codes (e.g. '#ff0000'). Unknown name: '{color_str}'") - - def _rgb_to_hex(rgb): - return "#{:02x}{:02x}{:02x}".format(*rgb) - - # 3. Prepare data - vals = [float(v) for v in values.values() if isinstance(v, (int, float))] - if not vals: return - - vmin = vmin if vmin is not None else min(vals) - vmax = vmax if vmax is not None else max(vals) - if vmax == vmin: vmax = vmin + 1e-12 - # Convert start/end colors to RGB for math - c0 = _to_rgb(low_color) - c1 = _to_rgb(high_color) - - # 4. Draw Rectangles - # Calculate X position - max_x = max(l.coordinates[0] for l in self.t.get_leaves()) - start_x = max_x + offset - - leaves = self.t.get_leaves() - # Calculate height of each block (distance between leaves) - if len(leaves) > 1: - y_coords = sorted([l.y_coord for l in leaves]) - # Estimate step size from the smallest difference (handles irregular trees better) - diffs = [y_coords[i+1] - y_coords[i] for i in range(len(y_coords)-1)] - y_step = min(diffs) if diffs else 20 - # If step is tiny (overlap), use average - if y_step < 1: - y_step = abs(leaves[1].y_coord - leaves[0].y_coord) + path.M(x_s, y_s) + if curve_type.upper() == "S": + sgn = 1 if (x_e - x_s) >= 0 else -1 + path.C(x_s + (sgn * arc_intensity), y_s, x_e - (sgn * arc_intensity), y_e, x_e, y_e) else: - y_step = 20 - - for l in leaves: - name = str(l.name) - fill = missing_color - - if name in values: - try: - val = float(values[name]) - # Interpolate - t = (val - vmin) / (vmax - vmin) - t = max(0.0, min(1.0, t)) - - r = int(c0[0] + (c1[0]-c0[0]) * t) - g = int(c0[1] + (c1[1]-c0[1]) * t) - b = int(c0[2] + (c1[2]-c0[2]) * t) - fill = _rgb_to_hex((r,g,b)) - except Exception: - pass # Keep missing_color - - # Center rectangle on leaf Y - self.d.append(draw.Rectangle( - start_x, - l.y_coord - y_step/2, - width, - y_step, - fill=fill, - stroke=border_color, - stroke_width=border_width - )) - + path.C(x_s - arc_intensity, y_s, x_e - arc_intensity, y_e, x_e, y_e) + self.drawing.append(path) - def add_leaf_images(self, image_dir, extension=".png", width=40, height=40, offset=10, rotation=0): + def add_heatmap(self, values, width=15.0, offset=10.0, low_color="#f7fbff", high_color="#08306b", border_color="none", border_width=0.5): """ - Adds PNG images to the right of each leaf node with a rotation option. - :param rotation: Degrees to rotate the image (clockwise). + Adds a column heatmap strip next to the leaf names. + + Args: + values (dict): Mapping of {node_name: numeric_value}. + width (float, optional): Width of each heatmap cell. Defaults to 15.0. + offset (float, optional): Horizontal offset from tree. Defaults to 10.0. + low_color (str, optional): Color for min value. Defaults to "#f7fbff". + high_color (str, optional): Color for max value. Defaults to "#08306b". + border_color (str, optional): Cell border color. Defaults to "none". + border_width (float, optional): Cell border width. Defaults to 0.5. """ - import os - import base64 - - # 1. Coordinate check: Ensure we have the latest positions - if not hasattr(self.t, 'coordinates'): - self._calculate_layout() + if hasattr(values, "to_dict"): values = values.to_dict() + vals = [float(v) for v in values.values() if isinstance(v, (int, float))] + if not vals: return + vmin, vmax = min(vals), max(vals) + 1e-12 + max_x = max(l.coordinates[0] for l in self.t.get_leaves()) + y_step = (self.style.height - (self.style.margin * 2)) / max(len(self.t.get_leaves())-1, 1) + for l in self.t.get_leaves(): + val = values.get(l.name) + fill = lerp_color(low_color, high_color, (float(val) - vmin) / (vmax - vmin)) if val is not None else "white" + self.drawing.append(draw.Rectangle(max_x + offset, l.y_coord - y_step/2, width, y_step, + fill=fill, stroke=border_color, stroke_width=border_width)) - for leaf in self.t.get_leaves(): - lx, ly = leaf.coordinates - - # 2. Coordinate Math - img_x = lx + offset - img_y = ly - (height / 2) + def add_clade_labels(self, labels, offset=40.0, stroke_width=1.5, color="black", font_size=None): + """ + Adds square brackets '[' to group lineages with text labels. + + Args: + labels (dict): Mapping {node: label_text}. + offset (float, optional): Horizontal offset. Defaults to 40.0. + stroke_width (float, optional): Bracket thickness. Defaults to 1.5. + color (str, optional): Color. Defaults to "black". + font_size (int, optional): Font size. Defaults to style default. + """ + self._pre_flight_check() + fs = font_size or self.style.font_size + max_x = max(l.coordinates[0] for l in self.t.get_leaves()) + bracket_x = max_x + offset + for target, text in labels.items(): + try: + node = self.t.search_nodes(name=target)[0] if isinstance(target, str) else target + leaves = node.get_leaves() + y_min = min(l.y_coord for l in leaves) + y_max = max(l.y_coord for l in leaves) + p = draw.Path(stroke=color, stroke_width=stroke_width, fill="none") + p.M(bracket_x - 5, y_min).L(bracket_x, y_min).L(bracket_x, y_max).L(bracket_x - 5, y_max) + self.drawing.append(p) + self.drawing.append(draw.Text(text, fs, bracket_x + 8, (y_min + y_max) / 2, + fill=color, font_family=self.style.font_family, + text_anchor="start", dominant_baseline="middle")) + except: continue + + def plot_continuous_variable(self, node_to_rgb, stroke_width=None): + """ + Colors branches based on RGB mappings of a continuous trait. - img_path = os.path.join(image_dir, f"{leaf.name}{extension}") - - if os.path.exists(img_path): - # 3. Manual Embedding - with open(img_path, "rb") as img_file: - encoded_string = base64.b64encode(img_file.read()).decode('utf-8') - data_uri = f"data:image/png;base64,{encoded_string}" - - # 4. Rotation Logic - # We rotate around the center of the image (img_x + width/2, img_y + height/2) - transform_str = "" - if rotation != 0: - center_x = img_x + (width / 2) - center_y = img_y + (height / 2) - transform_str = f"rotate({rotation}, {center_x}, {center_y})" - - # Create the image object - img_obj = draw.Image( - img_x, img_y, - width, height, - path=data_uri, - transform=transform_str # Apply the rotation transform - ) - - self.d.append(img_obj) - else: - print(f"MISSING IMAGE: No file at {img_path}") - - - def add_ancestral_images(self, image_dir, extension=".png", width=40, height=40, rotation=0, omit=None): - """ - Adds PNG images centered on ancestral (internal) nodes. - :param rotation: Degrees to rotate the image (clockwise). - """ - import os - import base64 - - # 1. Coordinate check: Ensure we have the latest positions - if not hasattr(self.t, 'coordinates'): - self._calculate_layout() - - # Traverse all nodes but filter for internal ones - for node in self.t.traverse(): - if not node.is_leaf(): - - if omit and node.name in omit: - continue # Skip omitted names - - nx, ny = node.coordinates # - - # 2. Centering Math - # To make the center of the image coincide with the node, - # we subtract half the width/height from the top-left anchor. - img_x = nx - (width / 2) - img_y = ny - (height / 2) - - img_path = os.path.join(image_dir, f"{node.name}{extension}") - - if os.path.exists(img_path): - # 3. Manual Embedding - with open(img_path, "rb") as img_file: - encoded_string = base64.b64encode(img_file.read()).decode('utf-8') - data_uri = f"data:image/png;base64,{encoded_string}" - - # 4. Rotation Logic (Centered) - transform_str = "" - if rotation != 0: - transform_str = f"rotate({rotation}, {nx}, {ny})" - - # Create the image object centered on (nx, ny) - img_obj = draw.Image( - img_x, img_y, - width, height, - path=data_uri, - transform=transform_str - ) - - self.d.append(img_obj) - else: - # Optional: Print warning for missing ancestral names - # print(f"MISSING ANCESTRAL IMAGE: {img_path}") - pass - - def plot_continuous_variable(self, node_to_rgb, size=None): - """ - Colors branches using a gradient based on RGB values at each node. - Also recolors vertical connectors so they are not left black. - node_to_rgb can map node.name OR node objects to (r,g,b). - """ - def _to_hex(rgb): - return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2])) - - s_width = size if size is not None else self.style.branch_size - - # 1) Gradient on horizontal segments (parent -> child) + Args: + node_to_rgb (dict): Mapping {node: (r,g,b)}. + stroke_width (float, optional): Branch thickness. Defaults to style default. + """ + def _to_hex(rgb): return '#%02x%02x%02x' % (int(rgb[0]), int(rgb[1]), int(rgb[2])) + sw = stroke_width or self.style.branch_stroke_width for node in self.t.traverse(): - if node.is_root(): - continue - - parent = node.up - c_child = node_to_rgb.get(node) or node_to_rgb.get(node.name) - c_parent = node_to_rgb.get(parent) or node_to_rgb.get(parent.name) - - if (c_child is not None) and (c_parent is not None): - self.gradient_branch( - node, - colors=(_to_hex(c_parent), _to_hex(c_child)), - size=s_width, - ) + if node.is_root(): continue + c_c = node_to_rgb.get(node) or node_to_rgb.get(node.name) + c_p = node_to_rgb.get(node.up) or node_to_rgb.get(node.up.name) + if c_c and c_p: + self.gradient_branch(node, stroke_width=sw, colors=(_to_hex(c_p), _to_hex(c_c))) else: - self.highlight_branch(node, color=self.style.branch_color, size=s_width) - - # 2) Re-draw vertical connectors using the internal node's color - # This covers the "elbows" that are currently appearing as black lines + self.highlight_branch(node, stroke_width=sw) for n in self.t.traverse("postorder"): - if n.is_leaf(): - continue - - # Get the color for the internal node - col = node_to_rgb.get(n) or node_to_rgb.get(n.name) - if col is None: - continue - - x, y = n.coordinates - # Vertical span of all children - y_min = min(c.y_coord for c in n.children) - y_max = max(c.y_coord for c in n.children) - - self.d.append(draw.Line( - x, y_min, x, y_max, - stroke=_to_hex(col), - stroke_width=s_width, - stroke_linecap="round" - )) - - # Optional: Overwrite the internal-node dot with the same color - self.d.append(draw.Circle(x, y, self.style.node_size, fill=_to_hex(col))) - - def plot_categorical_trait( - self, - data, - value_col, - node_col="Node", - palette=None, - size=None, - default_color="black" - ): - """ - Colors branches based on categorical data using GRADIENTS for transitions. - - :param data: pandas DataFrame containing node names and values. - :param value_col: The name of the column with the category (e.g., "X"). - :param node_col: The name of the column with node names (default "Node"). - :param palette: Dictionary mapping values -> hex colors. - Example: {0: "red", 1: "blue"} - :param size: Thickness of the branches. - """ - # 1. Parse Data into a dictionary: { "NodeName": value } - if hasattr(data, "to_dict"): - # Map Node Name -> Category Value + if not n.is_leaf(): + col = node_to_rgb.get(n) or node_to_rgb.get(n.name) + if col: + x, y = n.coordinates + y_min = min(c.y_coord for c in n.children) + y_max = max(c.y_coord for c in n.children) + self.drawing.append(draw.Line(x, y_min, x, y_max, stroke=_to_hex(col), stroke_width=sw, stroke_linecap="round")) + self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=_to_hex(col))) + + def plot_categorical_trait(self, data, value_col, node_col="Node", palette=None, stroke_width=None, default_color="black"): + """ + Colors branches and nodes based on categorical lineage traits. + + Args: + data (DataFrame or dict): Trait data. + value_col (str): Column for values. + node_col (str, optional): Column for node names. Defaults to "Node". + palette (dict, optional): Color map. Defaults to auto. + stroke_width (float, optional): Branch thickness. Defaults to style default. + default_color (str, optional): Fallback color. Defaults to "black". + """ + if hasattr(data, "to_dict"): mapping = dict(zip(data[node_col].astype(str), data[value_col])) - else: - mapping = data - - # 2. Define Default Palette if none provided + else: mapping = data if palette is None: - defaults = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33"] unique_vals = sorted(list(set(mapping.values()))) + defaults = ["#E41A1C", "#377EB8", "#4DAF4A", "#984EA3", "#FF7F00", "#FFFF33"] palette = {val: defaults[i % len(defaults)] for i, val in enumerate(unique_vals)} - - s_width = size if size is not None else self.style.branch_size - - # Helper to safely get color hex string - def get_color(n): - val = mapping.get(n.name) - return palette.get(val, default_color) - - # 3. Draw Branches + sw = stroke_width or self.style.branch_stroke_width + def get_color(n): return palette.get(mapping.get(n.name), default_color) for node in self.t.traverse(): - if node.is_root(): - continue - - c_node = get_color(node) - c_parent = get_color(node.up) - - # --- HORIZONTAL BRANCH --- - if c_node != c_parent: - # Different colors? Use Gradient! - # We reuse your existing gradient_branch function - self.gradient_branch( - node, - colors=(c_parent, c_node), - size=s_width - ) + if node.is_root(): continue + c_n, c_p = get_color(node), get_color(node.up) + if c_n != c_p: + self.gradient_branch(node, colors=(c_p, c_n), stroke_width=sw) else: - # Same color? Solid line. x, y = node.coordinates - px, _ = node.up.coordinates # parent X, child Y (rectangular elbow) - self.d.append(draw.Line( - px, y, x, y, - stroke=c_node, - stroke_width=s_width, - stroke_linecap="round" - )) - - # 4. Draw Vertical Connectors (Elbows) & Nodes - # These take the color of the internal node itself + px, _ = node.up.coordinates + self.drawing.append(draw.Line(px, y, x, y, stroke=c_n, stroke_width=sw, stroke_linecap="round")) for n in self.t.traverse("postorder"): - if n.is_leaf(): - continue + if not n.is_leaf(): + color = get_color(n) + x, y = n.coordinates + y_min = min(c.y_coord for c in n.children) + y_max = max(c.y_coord for c in n.children) + self.drawing.append(draw.Line(x, y_min, x, y_max, stroke=color, stroke_width=sw, stroke_linecap="round")) + self.drawing.append(draw.Circle(x, y, self.style.node_r, fill=color)) - color = get_color(n) - - x, y = n.coordinates - y_min = min(c.y_coord for c in n.children) - y_max = max(c.y_coord for c in n.children) - - # Vertical Line - self.d.append(draw.Line( - x, y_min, x, y_max, - stroke=color, - stroke_width=s_width, - stroke_linecap="round" - )) - - # Node Circle (covers the join) - self.d.append(draw.Circle(x, y, self.style.node_size, fill=color)) \ No newline at end of file + def add_categorical_legend(self, palette, title="Legend", x=None, y=None, font_size=14, r=6): + """ + Adds a categorical legend (colored circles with labels). + + Args: + palette (dict): Map of {label: color}. + title (str, optional): Legend title. Defaults to "Legend". + x (float, optional): X position. Defaults to auto. + y (float, optional): Y position. Defaults to auto. + font_size (int, optional): Font size. Defaults to 14. + r (float, optional): Marker radius. Defaults to 6. + """ + if x is None: x = -self.style.width / 2 + 30 + if y is None: y = -self.style.height / 2 + 30 + self.drawing.append(draw.Text(title, font_size + 2, x, y, font_weight="bold", + font_family=self.style.font_family, text_anchor="start")) + curr_y = y + font_size * 1.5 + for label, color in palette.items(): + self.drawing.append(draw.Circle(x + r, curr_y, r, fill=color)) + self.drawing.append(draw.Text(str(label), font_size, x + r*2.5, curr_y, + font_family=self.style.font_family, text_anchor="start", dominant_baseline="middle")) + curr_y += font_size * 1.4 + + def add_transfer_legend(self, colors=("purple", "orange"), labels=("Departure", "Arrival"), x=None, y=None, font_size=14): + """ + Adds a legend specifically for Horizontal Gene Transfers. + + Args: + colors (tuple, optional): (start_color, end_color). Defaults to ("purple", "orange"). + labels (tuple, optional): (start_label, end_label). Defaults to ("Departure", "Arrival"). + x (float, optional): X position. + y (float, optional): Y position. + font_size (int, optional): Font size. Defaults to 14. + """ + palette = {labels[0]: colors[0], labels[1]: colors[1]} + self.add_categorical_legend(palette, title="Transfer Event", x=x, y=y, font_size=font_size) + + def add_color_bar(self, low_color, high_color, vmin, vmax, title="", x=None, y=None, width=100, height=15, font_size=12): + """ + Adds a continuous color bar gradient legend. + + Args: + low_color (str): Color for min value. + high_color (str): Color for max value. + vmin (float): Min value. + vmax (float): Max value. + title (str, optional): Title. + x (float, optional): X position. + y (float, optional): Y position. + width (float, optional): Bar width. Defaults to 100. + height (float, optional): Bar height. Defaults to 15. + font_size (int, optional): Font size. Defaults to 12. + """ + if x is None: x = -self.style.width / 2 + 30 + if y is None: y = self.style.height / 2 - 60 + gid = generate_id("cb_grad") + grad = draw.LinearGradient(x, y, x + width, y, id=gid) + grad.add_stop(0, low_color); grad.add_stop(1, high_color) + self.drawing.append(grad) + if title: + self.drawing.append(draw.Text(title, font_size, x, y - 10, font_weight="bold", text_anchor="start")) + self.drawing.append(draw.Rectangle(x, y, width, height, fill=grad, stroke="black", stroke_width=0.5)) + self.drawing.append(draw.Text(f"{vmin:.2g}", font_size - 2, x, y + height + 12, text_anchor="start")) + self.drawing.append(draw.Text(f"{vmax:.2g}", font_size - 2, x + width, y + height + 12, text_anchor="end")) + + def add_leaf_images(self, image_dir, extension=".png", width=40, height=40, offset=10): + """ + Places images next to leaf tips. + + Args: + image_dir (str): Directory containing image files. + extension (str, optional): File extension. Defaults to ".png". + width (float, optional): Image width. Defaults to 40. + height (float, optional): Image height. Defaults to 40. + offset (float, optional): Horizontal offset. Defaults to 10. + """ + for leaf in self.t.get_leaves(): + lx, ly = self._leaf_xy(leaf, offset=offset) + path = os.path.join(image_dir, f"{leaf.name}{extension}") + if os.path.exists(path): + with open(path, "rb") as f: + uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}" + self.drawing.append(draw.Image(lx - width/2, ly - height/2, width, height, path=uri)) + + def add_ancestral_images(self, image_dir, extension=".png", width=40, height=40, omit=None): + """ + Places images at internal node positions. + + Args: + image_dir (str): Directory containing image files. + extension (str, optional): File extension. Defaults to ".png". + width (float, optional): Image width. Defaults to 40. + height (float, optional): Image height. Defaults to 40. + omit (list, optional): List of node names to skip. Defaults to None. + """ + for node in self.t.traverse(): + if not node.is_leaf(): + if omit and node.name in omit: continue + nx, ny = self._node_xy(node) + path = os.path.join(image_dir, f"{node.name}{extension}") + if os.path.exists(path): + with open(path, "rb") as f: + uri = f"data:image/png;base64,{base64.b64encode(f.read()).decode()}" + self.drawing.append(draw.Image(nx - width/2, ny - height/2, width, height, path=uri)) + + def add_title(self, text, font_size=24, position="top", pad=40.0, color="black", weight="bold"): + """ + Adds a title text to the drawing. + + Args: + text (str): Title text. + font_size (int, optional): Font size. Defaults to 24. + position (str, optional): "top", "bottom", "left", "right". Defaults to "top". + pad (float, optional): Padding. Defaults to 40.0. + color (str, optional): Color. Defaults to "black". + weight (str, optional): Font weight. Defaults to "bold". + """ + w, h = self.style.width, self.style.height + tx, ty = 0, 0 + if position == "top": ty = -h/2 + pad + elif position == "bottom": ty = h/2 - pad + elif position == "left": tx = -w/2 + pad + elif position == "right": tx = w/2 - pad + self.drawing.append(draw.Text( + text, font_size, tx, ty, fill=color, font_weight=weight, + font_family=self.style.font_family, text_anchor="middle", dominant_baseline="middle" + )) + + def add_scale_bar(self, length, label=None, x=None, y=None, stroke="black", stroke_width=2.0): + """ + Adds a physical scale bar representing a distance value. + + Args: + length (float): The distance value the bar represents. + label (str, optional): Text label. Defaults to str(length). + x (float, optional): X position. Defaults to auto. + y (float, optional): Y position. Defaults to auto. + stroke (str, optional): Color. Defaults to "black". + stroke_width (float, optional): Thickness. Defaults to 2.0. + """ + self._pre_flight_check() + px = float(length) * self.sf + label = label or str(length) + x = x if x is not None else -self.style.width/2 + 20 + y = y if y is not None else self.style.height/2 - 20 + self.drawing.append(draw.Line(x, y, x + px, y, stroke=stroke, stroke_width=stroke_width)) + self.drawing.append(draw.Text(label, self.style.font_size, x + px/2, y - 8, center=True)) diff --git a/src/phylustrator/utils.py b/src/phylustrator/utils.py index c596ce6..bae88b7 100644 --- a/src/phylustrator/utils.py +++ b/src/phylustrator/utils.py @@ -1,24 +1,53 @@ from __future__ import annotations - import ete3 +import math +import random +import string +def generate_id(prefix: str = "id", length: int = 6) -> str: + """Generates a unique ID for SVG elements like gradients.""" + suffix = ''.join(random.choices(string.ascii_lowercase + string.digits, k=length)) + return f"{prefix}_{suffix}" -def add_origin_if_root_has_dist(tree: ete3.Tree, origin_name: str = "Origin") -> ete3.Tree: - """If `tree.dist` is non-zero, interpret it as a stem and add an explicit origin node. +def to_rgb(color_str: str) -> tuple[int, int, int]: + """Parses hex, common names, or RGB tuples into a standard RGB tuple.""" + color_str = str(color_str).strip().lower() + if color_str.startswith("#"): + h = color_str.lstrip("#") + if len(h) == 3: h = "".join([c*2 for c in h]) + return tuple(int(h[i:i+2], 16) for i in (0, 2, 4)) + common_names = { + "white": (255, 255, 255), "black": (0, 0, 0), "red": (255, 0, 0), + "green": (0, 128, 0), "blue": (0, 0, 255), "orange": (255, 165, 0), + "purple": (128, 0, 128), "yellow": (255, 255, 0), "gray": (128, 128, 128) + } + return common_names.get(color_str, (0, 0, 0)) + +def to_hex(rgb: tuple[int, int, int]) -> str: + """Converts an RGB tuple to a hex string.""" + return "#{:02x}{:02x}{:02x}".format(*[int(max(0, min(255, x))) for x in rgb]) + +def lerp_color(low_hex: str, high_hex: str, t: float) -> str: + """Interpolates between two colors.""" + t = max(0.0, min(1.0, t)) + c1 = to_rgb(low_hex) + c2 = to_rgb(high_hex) + return to_hex(tuple(c1[i] + (c2[i] - c1[i]) * t for i in range(3))) - This avoids layout shifts when a rooted Newick encodes a stem length as the root's `dist`. +def polar_to_cartesian(degree: float, radius: float, rotation: float = 0) -> tuple[float, float]: + """Converts polar coordinates (degree, radius) to (x, y).""" + theta = math.radians(degree + rotation) + return radius * math.cos(theta), radius * math.sin(theta) - Returns the (possibly new) root tree. - """ - stem = float(getattr(tree, "dist", 0.0) or 0.0) +def add_origin_if_root_has_dist(tree: ete3.Tree, origin_name: str = "Origin") -> ete3.Tree: + """Standardizes trees by adding an explicit origin node if the root has a distance.""" + stem = float(tree.dist or 0.0) if stem <= 0.0: tree.dist = 0.0 return tree - origin = ete3.Tree() origin.name = origin_name origin.dist = 0.0 - tree.dist = stem origin.add_child(tree) return origin diff --git a/tests/tests_drawings.py b/tests/tests_drawings.py new file mode 100644 index 0000000..0938772 --- /dev/null +++ b/tests/tests_drawings.py @@ -0,0 +1,116 @@ +import pytest +import ete3 +from phylustrator.drawing import VerticalTreeDrawer, RadialTreeDrawer, TreeStyle + +@pytest.fixture +def simple_tree(): + # Simple tree: (A:1, B:1); + return ete3.Tree("(A:1, B:1);") + +@pytest.fixture +def transfer_data(): + return [{"from": "A", "to": "B", "freq": 1.0}] + +@pytest.fixture +def trait_data(): + return {"A": 1.0, "B": 2.0} + +def test_vertical_layout_init(simple_tree): + style = TreeStyle(width=500, height=500) + drawer = VerticalTreeDrawer(simple_tree, style=style) + + # Check if layout was calculated + for node in simple_tree.traverse(): + assert hasattr(node, "coordinates") + assert len(node.coordinates) == 2 + +def test_radial_layout_init(simple_tree): + style = TreeStyle(radius=200) + drawer = RadialTreeDrawer(simple_tree, style=style) + + # Check if layout was calculated + for node in simple_tree.traverse(): + assert hasattr(node, "rad") + assert hasattr(node, "angle") + + # Check bounds + root = simple_tree.get_tree_root() + assert root.rad == 0 + leaf_a = simple_tree.search_nodes(name="A")[0] + assert leaf_a.rad == 200 + +def test_pre_flight_check(simple_tree): + drawer = VerticalTreeDrawer(simple_tree) + # Reset flag manually + drawer._layout_calculated = False + # Calling a method that triggers check + drawer.add_title("Test") + assert drawer._layout_calculated is True + +@pytest.mark.parametrize("drawer_class", [VerticalTreeDrawer, RadialTreeDrawer]) +def test_method_existence(drawer_class, simple_tree): + """ + Comprehensive existence check for all public API methods to prevent + regressions during refactoring. + """ + drawer = drawer_class(simple_tree) + + required_methods = [ + "draw", + "highlight_clade", + "highlight_branch", + "gradient_branch", + "add_leaf_names", + "add_node_names", + "add_leaf_shapes", + "add_node_shapes", + "add_branch_shapes", + "plot_transfers", + "add_time_axis", + "add_heatmap", + "add_clade_labels", + "plot_continuous_variable", + "plot_categorical_trait", + "add_categorical_legend", + "add_transfer_legend", + "add_color_bar", + "add_leaf_images", + "add_ancestral_images", + "add_title", + "add_scale_bar", + "save_svg", + "save_png" + ] + + for method in required_methods: + assert hasattr(drawer, method), f"{drawer_class.__name__} is missing required method: {method}" + +@pytest.mark.parametrize("drawer_class", [VerticalTreeDrawer, RadialTreeDrawer]) +def test_smoke_execution(drawer_class, simple_tree, transfer_data, trait_data): + """ + Executes core methods with dummy data to ensure no internal crashes/SyntaxErrors. + """ + drawer = drawer_class(simple_tree) + + # Core Drawing + drawer.draw() + + # Overlays + drawer.highlight_clade(simple_tree, color="red") + drawer.add_leaf_names() + drawer.add_leaf_shapes(["A"], r=5) + drawer.add_branch_shapes([{"branch": "A", "where": 0.5, "shape": "circle"}]) + drawer.plot_transfers(transfer_data) + + # Labels & Legends + drawer.add_clade_labels({"A": "Label"}) + drawer.add_categorical_legend({"Trait": "blue"}) + drawer.add_color_bar("white", "blue", 0, 1) + + # Traits + drawer.plot_categorical_trait(trait_data, value_col="trait") + + # Title + drawer.add_title("Smoke Test") + + assert drawer.drawing is not None diff --git a/tests/tests_utils.py b/tests/tests_utils.py new file mode 100644 index 0000000..11d7ad6 --- /dev/null +++ b/tests/tests_utils.py @@ -0,0 +1,33 @@ +import pytest +from phylustrator.utils import to_rgb, to_hex, lerp_color, polar_to_cartesian +import math + +def test_to_rgb(): + assert to_rgb("#ffffff") == (255, 255, 255) + assert to_rgb("#000") == (0, 0, 0) + assert to_rgb("red") == (255, 0, 0) + assert to_rgb("invalid") == (0, 0, 0) + +def test_to_hex(): + assert to_hex((255, 255, 255)) == "#ffffff" + assert to_hex((0, 0, 0)) == "#000000" + # Test clipping + assert to_hex((300, -10, 100)) == "#ff0064" + +def test_lerp_color(): + # Midpoint between black and white + assert lerp_color("#000000", "#ffffff", 0.5) == "#7f7f7f" + # Bound checks + assert lerp_color("#000000", "#ffffff", -1) == "#000000" + assert lerp_color("#000000", "#ffffff", 2) == "#ffffff" + +def test_polar_to_cartesian(): + # 0 degrees, radius 100 should be (100, 0) + x, y = polar_to_cartesian(0, 100) + assert pytest.approx(x) == 100 + assert pytest.approx(y) == 0 + + # 90 degrees, radius 100 should be (0, 100) + x, y = polar_to_cartesian(90, 100) + assert pytest.approx(x) == 0 + assert pytest.approx(y) == 100