From d15143036a67a099aac40c203f27ed391ec6d387 Mon Sep 17 00:00:00 2001 From: jordantensor Date: Sun, 2 Nov 2025 17:46:21 +0000 Subject: [PATCH 1/2] Save tree structute over time [untested] --- pyproject.toml | 1 + .../evolve_and_branch_finite.py | 601 +++++++++++++++++- 2 files changed, 590 insertions(+), 12 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ff288e4..3fa1f52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "pydantic", "pytest", "ipykernel", + "jupyter", "torch", "quimb", "numpy", diff --git a/wavefunction_branching/evolve_and_branch_finite.py b/wavefunction_branching/evolve_and_branch_finite.py index 0272f2a..0e36ac0 100644 --- a/wavefunction_branching/evolve_and_branch_finite.py +++ b/wavefunction_branching/evolve_and_branch_finite.py @@ -119,6 +119,9 @@ def __init__(self): self.branch_values: list[dict] = [] # This is what gets updated when we add measurements self.df_branch_values: pd.DataFrame | None = None # A dataframe computed from branch_values self.df_combined_values: pd.DataFrame | None = None # prob-weigted average measurments + # Track branching events including unsampled branches + self.branching_events: list[dict] = [] # Records of each branching event + self.df_branching_events: pd.DataFrame | None = None # DataFrame of branching events def add_measurements_tebd(self, engine, extra_measurements=None, **kwargs): """Update self.branch_values with measurements from a TEBD engine (measures strictly more than add_measurements)""" @@ -175,9 +178,152 @@ def combine_measurements(self): self.df_combined_values = pd.DataFrame(combined) return self.df_combined_values + def add_branching_event( + self, + time, + n_candidate_branches, + n_sampled_branches, + candidate_probs, + sampled_indices, + parent_prob, + site, + trace_distances=None + ): + """ + Record a branching event with information about both sampled and unsampled branches. + + Parameters: + ----------- + time : float + Time at which branching occurred + n_candidate_branches : int + Total number of candidate branches before sampling + n_sampled_branches : int + Number of branches that survived sampling + candidate_probs : np.ndarray + Array of probabilities for all candidate branches + sampled_indices : np.ndarray + Indices of branches that were sampled + parent_prob : float + Probability of the parent branch before splitting + site : int + Site index where branching occurred + trace_distances : dict, optional + Dictionary of trace distances and other quality metrics + """ + event = { + 'time': time, + 'site': site, + 'n_candidate_branches': n_candidate_branches, + 'n_sampled_branches': n_sampled_branches, + 'n_discarded_branches': n_candidate_branches - n_sampled_branches, + 'parent_prob': parent_prob, + 'total_candidate_prob': np.sum(candidate_probs), + 'sampled_prob': np.sum(candidate_probs[sampled_indices]), + 'discarded_prob': np.sum(candidate_probs) - np.sum(candidate_probs[sampled_indices]), + 'candidate_probs_mean': np.mean(candidate_probs), + 'candidate_probs_std': np.std(candidate_probs), + 'candidate_probs_min': np.min(candidate_probs), + 'candidate_probs_max': np.max(candidate_probs), + } + + if trace_distances is not None: + event.update(trace_distances) + + self.branching_events.append(event) + + def branching_events_to_dataframe(self): + """Convert branching events list to DataFrame""" + if len(self.branching_events) > 0: + self.df_branching_events = pd.DataFrame.from_records(self.branching_events) + return self.df_branching_events + + def get_cumulative_branch_counts(self): + """ + Calculate cumulative number of branches over time. + + Returns: + -------- + pd.DataFrame with columns: + - time: measurement times + - n_sampled_branches_cumulative: total sampled branches existing at each time + - n_candidate_branches_cumulative: total candidate branches that existed at each time + """ + if self.df_branching_events is None: + self.branching_events_to_dataframe() + + if self.df_combined_values is None: + self.combine_measurements() + + # Ensure branch_values_to_dataframe has been called + if self.df_branch_values is None: + self.branch_values_to_dataframe() + + if self.df_branching_events is None or len(self.df_branching_events) == 0: + # If no branching events, return dataframe based on measurements + if self.df_combined_values is None or len(self.df_combined_values) == 0: + return None + if self.df_branch_values is None or len(self.df_branch_values) == 0: + return None + times = sorted(self.df_combined_values['time'].unique()) + cumulative_data = [] + for t in times: + n_actual = len(self.df_branch_values[self.df_branch_values['time'] == t]) + cumulative_data.append({ + 'time': t, + 'n_sampled_branches_actual': n_actual, + 'n_sampled_branches_cumulative': n_actual, + 'n_candidate_branches_cumulative': n_actual, + 'n_discarded_branches_cumulative': 0, + }) + return pd.DataFrame(cumulative_data) + + # Get all unique times from measurements + if self.df_branch_values is None or len(self.df_branch_values) == 0: + return None + times = sorted(self.df_combined_values['time'].unique()) + + cumulative_data = [] + + for t in times: + # Count actual branches (from measurements) at this time + n_actual = len(self.df_branch_values[self.df_branch_values['time'] == t]) + + # Count total candidate branches up to this time + events_before = self.df_branching_events[self.df_branching_events['time'] <= t] + + if len(events_before) > 0: + # Calculate cumulative branches + # Start with 1 (initial branch) + # Each branching event: creates n_candidate_branches new branches, replaces 1 parent + # So net change per event: -1 + n_sampled_branches (only counted if we sample) + # But for candidates: -1 + n_candidate_branches + n_candidates_total = 1 # Start with 1 branch + n_sampled_total = 1 + for _, event in events_before.iterrows(): + n_candidates_total = n_candidates_total - 1 + event['n_candidate_branches'] + n_sampled_total = n_sampled_total - 1 + event['n_sampled_branches'] + n_discarded_total = events_before['n_discarded_branches'].sum() + else: + n_candidates_total = 1 # Start with 1 branch + n_sampled_total = 1 + n_discarded_total = 0 + + cumulative_data.append({ + 'time': t, + 'n_sampled_branches_actual': n_actual, + 'n_sampled_branches_cumulative': n_sampled_total, + 'n_candidate_branches_cumulative': n_candidates_total, + 'n_discarded_branches_cumulative': n_discarded_total, + }) + + return pd.DataFrame(cumulative_data) + def merge_with_other(self, other): self.branch_values += other.branch_values + self.branching_events += other.branching_events # NEW self.branch_values_to_dataframe() + self.branching_events_to_dataframe() # NEW self.combine_measurements() return self @@ -190,6 +336,203 @@ def _repr_html_(self): return self.df_branch_values._repr_html_() # type: ignore +def plot_branch_counts_over_time(branch_values, name="", outfolder=None, save=True): + """ + Plot the number of branches over time, including both sampled and unsampled. + + Parameters: + ----------- + branch_values : BranchValues + The BranchValues object containing branching event data + name : str + Name for the plot title and filename + outfolder : Path or str + Directory to save the plot + save : bool + Whether to save the plot to file + """ + # Get cumulative branch counts + df_cumulative = branch_values.get_cumulative_branch_counts() + + if df_cumulative is None or len(df_cumulative) == 0: + print("No branching event data available to plot") + return + + # Create figure + plt.figure(figsize=(12, 6), dpi=150) + + # Plot cumulative candidate branches (total that ever existed) + plt.plot( + df_cumulative['time'], + df_cumulative['n_candidate_branches_cumulative'], + label='Total candidate branches (including unsampled)', + color='#d62728', + linewidth=2, + linestyle='--', + alpha=0.8 + ) + + # Plot cumulative sampled branches + plt.plot( + df_cumulative['time'], + df_cumulative['n_sampled_branches_cumulative'], + label='Sampled branches (kept)', + color='#2ca02c', + linewidth=2.5 + ) + + # Plot actual active branches (from measurements) + plt.plot( + df_cumulative['time'], + df_cumulative['n_sampled_branches_actual'], + label='Active branches at time t', + color='#1f77b4', + linewidth=2, + marker='o', + markersize=3, + alpha=0.7 + ) + + # Plot discarded branches (as shaded region) + plt.fill_between( + df_cumulative['time'], + df_cumulative['n_sampled_branches_cumulative'], + df_cumulative['n_candidate_branches_cumulative'], + alpha=0.3, + color='#ff7f0e', + label='Discarded branches (unsampled)' + ) + + plt.xlabel('Time', fontsize=12) + plt.ylabel('Number of Branches', fontsize=12) + plt.title(f'Branch Counts Over Time: {name}', fontsize=14) + plt.legend(loc='best', fontsize=10) + plt.grid(True, alpha=0.3) + plt.tight_layout() + + if save and outfolder is not None: + from pathlib import Path + outfolder = Path(outfolder) + outfolder.mkdir(exist_ok=True, parents=True) + + plt.savefig(outfolder / f"{NOW}_{name}_branch_counts.pdf") + plt.savefig(outfolder / f"{NOW}_{name}_branch_counts.png") + print(f"Saved branch counts plot to {outfolder}") + + plt.show() + + +def plot_branching_events_detail(branch_values, name="", outfolder=None, save=True): + """ + Plot detailed information about each branching event. + + Shows when branching occurred, how many branches were created vs sampled. + """ + branch_values.branching_events_to_dataframe() + df_events = branch_values.df_branching_events + + if df_events is None or len(df_events) == 0: + print("No branching event data available") + return + + fig, axes = plt.subplots(2, 1, figsize=(12, 8), dpi=150, sharex=True) + + # Top plot: Number of branches at each event + ax1 = axes[0] + width = 0.35 + x = np.arange(len(df_events)) + + ax1.bar(x - width/2, df_events['n_candidate_branches'], width, + label='Candidate branches', color='#d62728', alpha=0.7) + ax1.bar(x + width/2, df_events['n_sampled_branches'], width, + label='Sampled branches', color='#2ca02c', alpha=0.7) + + ax1.set_ylabel('Number of Branches', fontsize=11) + ax1.set_title(f'Branching Events Detail: {name}', fontsize=13) + ax1.legend() + ax1.grid(True, alpha=0.3, axis='y') + + # Bottom plot: Probability distribution + ax2 = axes[1] + ax2.bar(x, df_events['total_candidate_prob'], width*1.5, + label='Total candidate prob', color='#ff7f0e', alpha=0.5) + ax2.bar(x, df_events['sampled_prob'], width*1.5, + label='Sampled prob', color='#1f77b4', alpha=0.7) + + ax2.set_xlabel('Branching Event', fontsize=11) + ax2.set_ylabel('Probability', fontsize=11) + ax2.legend() + ax2.grid(True, alpha=0.3, axis='y') + + # Set x-axis labels with times + ax2.set_xticks(x) + ax2.set_xticklabels([f"{t:.2f}" for t in df_events['time']], rotation=45) + + plt.tight_layout() + + if save and outfolder is not None: + from pathlib import Path + outfolder = Path(outfolder) + outfolder.mkdir(exist_ok=True, parents=True) + + plt.savefig(outfolder / f"{NOW}_{name}_branching_events.pdf") + plt.savefig(outfolder / f"{NOW}_{name}_branching_events.png") + print(f"Saved branching events plot to {outfolder}") + + plt.show() + + +def plot_sampling_efficiency(branch_values, name="", outfolder=None, save=True): + """ + Plot the sampling efficiency over time. + Shows what fraction of candidate branches are being kept. + """ + branch_values.branching_events_to_dataframe() + df_events = branch_values.df_branching_events + + if df_events is None or len(df_events) == 0: + print("No branching event data available") + return + + # Calculate sampling efficiency + df_events['sampling_efficiency'] = ( + df_events['n_sampled_branches'] / df_events['n_candidate_branches'] + ) + df_events['prob_efficiency'] = ( + df_events['sampled_prob'] / df_events['total_candidate_prob'] + ) + + plt.figure(figsize=(12, 6), dpi=150) + + plt.plot(df_events['time'], df_events['sampling_efficiency'], + marker='o', label='Branch count efficiency', + linewidth=2, markersize=8, color='#1f77b4') + plt.plot(df_events['time'], df_events['prob_efficiency'], + marker='s', label='Probability efficiency', + linewidth=2, markersize=8, color='#2ca02c') + + plt.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='No sampling (100%)') + + plt.xlabel('Time', fontsize=12) + plt.ylabel('Sampling Efficiency (sampled / candidates)', fontsize=12) + plt.title(f'Sampling Efficiency Over Time: {name}', fontsize=14) + plt.legend(loc='best', fontsize=10) + plt.grid(True, alpha=0.3) + plt.ylim([0, 1.1]) + plt.tight_layout() + + if save and outfolder is not None: + from pathlib import Path + outfolder = Path(outfolder) + outfolder.mkdir(exist_ok=True, parents=True) + + plt.savefig(outfolder / f"{NOW}_{name}_sampling_efficiency.pdf") + plt.savefig(outfolder / f"{NOW}_{name}_sampling_efficiency.png") + print(f"Saved sampling efficiency plot to {outfolder}") + + plt.show() + + def bring_into_theta_form(theta, formL, formR, sL, sR): return einsum((sL) ** (1.0 - formL), theta, (sR) ** (1.0 - formR), "l, b p l r, r -> b p l r") @@ -264,8 +607,9 @@ class BranchingMPS: def __init__( self, tebd_engine: tenpy.TEBDEngine - | ExpMPOEvolution, # The TEBD engine to use for time evolution - cfg: BranchingMPSConfig, # The configuration for splitting the wavefunction into branches + | ExpMPOEvolution + | None = None, # The TEBD engine to use for time evolution (None for unsampled branches) + cfg: BranchingMPSConfig | None = None, # The configuration for splitting the wavefunction into branches branch_values: BranchValues | None = None, # The structure for storing the measurements of all the branches over time branch_function: Callable @@ -281,10 +625,34 @@ def __init__( name="", wandb_project: str | None = None, info={}, + # New parameters for unsampled branches + sampled: bool = True, # Whether this branch was sampled (False for discarded branches) + created_time: float | None = None, # Time when branch was created (evolved_time at creation) + branching_site: int | None = None, # Site where branching occurred (for child nodes) + prob_override: float | None = None, # Probability override (for unsampled branches without tebd_engine) + norm_override: float | None = None, # Norm override (for unsampled branches without tebd_engine) ): - self.tebd_engine = tebd_engine # The TEBD engine to use for time evolution - self.norm = self.tebd_engine.psi.norm - self.prob = abs(self.norm**2) + self.sampled = sampled # Track whether this branch was sampled or discarded + self.created_time = created_time # Time when branch was created + self.branching_site = branching_site # Site where branching occurred + + # For unsampled branches, tebd_engine may be None + if tebd_engine is None: + # This is an unsampled branch - use override values + self.tebd_engine = None + if prob_override is not None: + self.prob = prob_override + self.norm = np.sqrt(prob_override) if norm_override is None else norm_override + elif norm_override is not None: + self.norm = norm_override + self.prob = abs(norm_override**2) + else: + self.prob = 0.0 + self.norm = 0.0 + else: + self.tebd_engine = tebd_engine # The TEBD engine to use for time evolution + self.norm = self.tebd_engine.psi.norm + self.prob = abs(self.norm**2) self.cfg = cfg # The configuration for splitting the wavefunction into branches self.pickle_file = pickle_file self.outfolder = outfolder @@ -296,6 +664,14 @@ def __init__( self.branch_values = BranchValues() else: self.branch_values = branch_values + + # Track root creation walltime for relative time calculations + # This needs to be set before we use it, so handle it after parent is assigned + # (We'll set it at the end of __init__) + pass + + # Track final walltime (updated when branch finishes or at save time) + self.final_walltime: datetime | None = None # a BrancingMPS from which we split (or None) self.parent = parent @@ -306,8 +682,14 @@ def __init__( else: self.children = children + # List of unsampled (discarded) branches from accepted decompositions + self.unsampled_children: list[BranchingMPS] = [] + if max_children is None: - self.max_children = int(self.cfg.max_branches) + if cfg is not None: + self.max_children = int(cfg.max_branches) + else: + self.max_children = 8 # Default else: self.max_children = max_children @@ -316,19 +698,30 @@ def __init__( print(f"Starting name = {self.name}, ID = {self.ID}") self.trunc_err = tenpy.linalg.truncation.TruncationError(eps=0.0, ov=1.0) - self.dt = tebd_engine.options["dt"] if "dt" in tebd_engine.options else 1 + if self.tebd_engine is not None: + self.dt = tebd_engine.options["dt"] if "dt" in tebd_engine.options else 1 + else: + self.dt = 1 # Default for unsampled branches if self.parent is None: self.costFun_LM_MR_trace_distance = 0.0 self.global_reconstruction_error_trace_distance = 0.0 - self.t_last_attempted_branching_sites = np.zeros(len(self.tebd_engine.psi.chi)) + if self.tebd_engine is not None: + self.t_last_attempted_branching_sites = np.zeros(len(self.tebd_engine.psi.chi)) + else: + # For unsampled root branch, create empty array (will be initialized later if needed) + self.t_last_attempted_branching_sites = np.array([]) self.site_last_attempted_branching: int | None = None + if self.tebd_engine is not None: + num_sites = len(self.tebd_engine.psi.chi) + else: + num_sites = 0 # For unsampled branches, will be set if needed self.last_attempted_branching_trunc_bond_dims_sites: list[ None | tuple[int, int, int] - ] = [None] * len(self.tebd_engine.psi.chi) + ] = [None] * num_sites self.last_attempted_branching_trunc_trace_distance_sites: list[None | float] = [ None - ] * len(self.tebd_engine.psi.chi) + ] * num_sites self.trace_distances = {} self.trace_distances["estimated_interference_error"] = 0.0 self.trace_distances["global_reconstruction_error_trace_distance"] = ( @@ -394,9 +787,26 @@ def __init__( self.depth = self.parent.depth + 1 self.branching_attempts = self.parent.branching_attempts - self.evolved_time = float(abs(self.tebd_engine.evolved_time)) + if self.tebd_engine is not None: + self.evolved_time = float(abs(self.tebd_engine.evolved_time)) + else: + # For unsampled branches, use created_time if available + self.evolved_time = self.created_time if self.created_time is not None else 0.0 + self.t_last_attempted_branching = self.evolved_time self.finished = False + + # Set created_time if not provided (for sampled branches) + if self.created_time is None: + self.created_time = self.evolved_time + + # Set root_created_walltime (after parent is assigned) + if self.parent is None: + # This is the root branch - store its creation time as reference + self.root_created_walltime = self.created_walltime + else: + # Inherit root creation time from parent + self.root_created_walltime = self.parent.root_created_walltime def branch_and_sample( self, @@ -721,6 +1131,8 @@ def branch_and_sample( print(f"{self.ID}No further branch_indices were selected: len(branch_indices) = 0") print(f"{self.ID}TERMINATING.") self.finished = True + # Update final walltime when branch terminates + self.final_walltime = datetime.now() self.branch_values.add_measurements_tebd( self.tebd_engine, extra_measurements=self.trace_distances ) @@ -746,6 +1158,23 @@ def branch_and_sample( self.depth += 1 elif len(branch_indices) > 1: print(f"{self.ID}Creating {num_kept_branches} children nodes.") + + # Record the branching event including unsampled branches + if hasattr(self, 'branch_values') and self.branch_values is not None: + self.branch_values.add_branching_event( + time=self.evolved_time, + n_candidate_branches=num_candidates, + n_sampled_branches=len(survivor_indices), + candidate_probs=branch_probs, + sampled_indices=survivor_indices, + parent_prob=abs(self.norm ** 2), + site=coarsegrain_from, + trace_distances={ + 'costFun_LM_MR_trace_distance': costFun_LM_MR_trace_distance, + 'global_reconstruction_error_trace_distance': global_reconstruction_error_trace_distance, + } + ) + print(f"{self.ID}Branching summary: Total candidates={num_candidates}, Sampled={len(survivor_indices)}, Discarded={num_candidates - len(survivor_indices)}") if np.isclose(total_prob_survived, 0.0): print( @@ -792,12 +1221,45 @@ def branch_and_sample( ID=self.ID + f"{branch_indices[i]}|", n_times_saved=self.n_times_saved, name=self.name, + sampled=True, # Explicitly mark as sampled + created_time=self.evolved_time, + branching_site=coarsegrain_from, ) ) print( f"{self.ID} Child {i} (orig index {survivor_indices[i]}): prob={prob:.4f}, prob={child_prob:.6f}, max_children={child_max_children}" ) + # Create unsampled branch nodes for discarded branches + discarded_indices = np.setdiff1d(candidate_indices, survivor_indices) + for discarded_idx in discarded_indices: + discarded_prob = branch_probs[discarded_idx] + # Calculate probability similar to sampled branches + if sampling_occurred: + discarded_child_prob = discarded_prob / total_prob_survived + else: + discarded_child_prob = discarded_prob + + unsampled_child = BranchingMPS( + tebd_engine=None, # No engine for unsampled branches + cfg=self.cfg, + branch_values=self.branch_values, + branch_function=self.branch_function, + parent=self, + max_children=0, # Unsampled branches don't get budget + ID=self.ID + f"{discarded_idx}|", + n_times_saved=self.n_times_saved, + name=self.name, + sampled=False, # Mark as unsampled + created_time=self.evolved_time, + branching_site=coarsegrain_from, + prob_override=discarded_child_prob * abs(self.norm ** 2), # Parent prob * child prob + ) + self.unsampled_children.append(unsampled_child) + print( + f"{self.ID} Unsampled child (orig index {discarded_idx}): prob={discarded_prob:.4f}, child_prob={discarded_child_prob:.6f}" + ) + # Verify prob conservation child_prob_sum = sum(abs(c.norm**2) for c in self.children) print( @@ -923,11 +1385,79 @@ def count_leaves(self): else: return sum([child.count_leaves() for child in self.children]) + def to_tree_dict(self): + """ + Convert BranchingMPS tree structure to a dictionary, stripping out tensors and engines. + This preserves the tree structure with all metadata for visualization and analysis. + + Returns: + -------- + dict: Dictionary representation of the tree node + """ + # Calculate relative times in seconds from root creation + root_created = self.root_created_walltime if hasattr(self, 'root_created_walltime') else self.created_walltime + + created_walltime_rel = (self.created_walltime - root_created).total_seconds() if self.created_walltime else None + + if self.final_walltime is not None: + final_walltime_rel = (self.final_walltime - root_created).total_seconds() + else: + final_walltime_rel = None + + tree_dict = { + 'ID': self.ID, + 'sampled': self.sampled, + 'created_time': self.created_time, + 'evolved_time': self.evolved_time, + 'created_walltime': self.created_walltime.isoformat() if self.created_walltime else None, + 'created_walltime_rel_seconds': created_walltime_rel, + 'final_walltime': self.final_walltime.isoformat() if self.final_walltime else None, + 'final_walltime_rel_seconds': final_walltime_rel, + 'prob': float(self.prob), + 'norm': float(self.norm), + 'depth': self.depth, + 'max_children': self.max_children, + 'finished': self.finished, + 'synchronized': self.synchronized if hasattr(self, 'synchronized') else False, + 'branching_site': self.branching_site, + 'costFun_LM_MR_trace_distance': float(self.costFun_LM_MR_trace_distance) if hasattr(self, 'costFun_LM_MR_trace_distance') else 0.0, + 'global_reconstruction_error_trace_distance': float(self.global_reconstruction_error_trace_distance) if hasattr(self, 'global_reconstruction_error_trace_distance') else 0.0, + 'has_tebd_engine': self.tebd_engine is not None, + 'n_children': len(self.children), + 'n_unsampled_children': len(self.unsampled_children) if hasattr(self, 'unsampled_children') else 0, + 'children': [child.to_tree_dict() for child in self.children], + 'unsampled_children': [child.to_tree_dict() for child in (self.unsampled_children if hasattr(self, 'unsampled_children') else [])], + } + return tree_dict + + def _update_final_walltime_recursive(self, current_time: datetime): + """ + Recursively update final_walltime for all branches in the tree. + If a branch is finished, keep its final_walltime. Otherwise, update to current_time. + """ + if self.finished and self.final_walltime is not None: + # Keep existing final_walltime if branch already finished + pass + else: + # Update to current time (at save time) + self.final_walltime = current_time + + # Update children and unsampled children + for child in self.children: + child._update_final_walltime_recursive(current_time) + if hasattr(self, 'unsampled_children'): + for child in self.unsampled_children: + child._update_final_walltime_recursive(current_time) + def save(self, final=False): if self.parent is not None: self.parent.save() else: t0 = time.time() + # Update final_walltime for all branches at save time + current_walltime = datetime.now() + self._update_final_walltime_recursive(current_walltime) + if self.pickle_file is not None: branchvals_file = str(self.pickle_file).split(".pkl")[0] + "_branch_values.pkl" branchvals_file = ( @@ -938,6 +1468,24 @@ def save(self, final=False): with open(branchvals_file, "wb") as f: pickle.dump(self.branch_values, f) f.close() + + # Save tree structure (without tensors) to JSON + tree_file = str(self.pickle_file).split(".pkl")[0] + "_tree.json" + tree_file = ( + tree_file + if (self.n_times_saved % 2 == 0 or final) + else tree_file + "tmp" + ) + try: + tree_dict = self.to_tree_dict() + with open(tree_file, "w") as f: + json.dump(tree_dict, f, indent=2, default=str) + print(f"{self.ID}Saved tree structure to {tree_file}") + except Exception as e: + print(f"{self.ID}Warning: Failed to save tree structure: {e}") + import traceback + traceback.print_exc() + if self.cfg.save_full_state: pickle_file = ( self.pickle_file @@ -951,7 +1499,7 @@ def save(self, final=False): if final: print(f"{self.ID}Removing temp files as this is the final save.") - for tmp_file in [branchvals_file + "tmp", str(self.pickle_file) + "tmp"]: + for tmp_file in [branchvals_file + "tmp", str(self.pickle_file) + "tmp", tree_file + "tmp"]: Path(tmp_file).unlink(missing_ok=True) t1 = time.time() print(f"{self.ID}Saved in {t1 - t0} seconds to {self.pickle_file}") @@ -1178,6 +1726,33 @@ def trace_distance_colors(key): plt.clf() plt.cla() + # NEW PLOTS: Branch counts including unsampled branches + try: + plot_branch_counts_over_time( + self.branch_values, + name=self.name, + outfolder=plots_dir, + save=True + ) + + plot_branching_events_detail( + self.branch_values, + name=self.name, + outfolder=plots_dir, + save=True + ) + + plot_sampling_efficiency( + self.branch_values, + name=self.name, + outfolder=plots_dir, + save=True + ) + except Exception as e: + print(f"Error plotting branch counts: {e}") + import traceback + traceback.print_exc() + # Log to wandb if self.wandb_project is not None: # Select just the most recent combined values @@ -1519,6 +2094,8 @@ def evolve_and_branch_leaf(self, stop_before_branching=False, t_evo=None, **kwar if self.evolved_time >= self.cfg.t_evo: print(f"{self.ID}Finished.") self.finished = True + # Update final walltime when branch finishes + self.final_walltime = datetime.now() # Measure self.branch_values.add_measurements_tebd( self.tebd_engine, extra_measurements=self.trace_distances From fa6394f741214f0891b666cbaab6881428f39a44 Mon Sep 17 00:00:00 2001 From: jordantensor Date: Sun, 2 Nov 2025 17:47:06 +0000 Subject: [PATCH 2/2] ruff format --- .../evolve_and_branch_finite.py | 487 ++++++++++-------- 1 file changed, 274 insertions(+), 213 deletions(-) diff --git a/wavefunction_branching/evolve_and_branch_finite.py b/wavefunction_branching/evolve_and_branch_finite.py index 0e36ac0..87c1ef4 100644 --- a/wavefunction_branching/evolve_and_branch_finite.py +++ b/wavefunction_branching/evolve_and_branch_finite.py @@ -179,7 +179,7 @@ def combine_measurements(self): return self.df_combined_values def add_branching_event( - self, + self, time, n_candidate_branches, n_sampled_branches, @@ -187,11 +187,11 @@ def add_branching_event( sampled_indices, parent_prob, site, - trace_distances=None + trace_distances=None, ): """ Record a branching event with information about both sampled and unsampled branches. - + Parameters: ----------- time : float @@ -212,36 +212,36 @@ def add_branching_event( Dictionary of trace distances and other quality metrics """ event = { - 'time': time, - 'site': site, - 'n_candidate_branches': n_candidate_branches, - 'n_sampled_branches': n_sampled_branches, - 'n_discarded_branches': n_candidate_branches - n_sampled_branches, - 'parent_prob': parent_prob, - 'total_candidate_prob': np.sum(candidate_probs), - 'sampled_prob': np.sum(candidate_probs[sampled_indices]), - 'discarded_prob': np.sum(candidate_probs) - np.sum(candidate_probs[sampled_indices]), - 'candidate_probs_mean': np.mean(candidate_probs), - 'candidate_probs_std': np.std(candidate_probs), - 'candidate_probs_min': np.min(candidate_probs), - 'candidate_probs_max': np.max(candidate_probs), + "time": time, + "site": site, + "n_candidate_branches": n_candidate_branches, + "n_sampled_branches": n_sampled_branches, + "n_discarded_branches": n_candidate_branches - n_sampled_branches, + "parent_prob": parent_prob, + "total_candidate_prob": np.sum(candidate_probs), + "sampled_prob": np.sum(candidate_probs[sampled_indices]), + "discarded_prob": np.sum(candidate_probs) - np.sum(candidate_probs[sampled_indices]), + "candidate_probs_mean": np.mean(candidate_probs), + "candidate_probs_std": np.std(candidate_probs), + "candidate_probs_min": np.min(candidate_probs), + "candidate_probs_max": np.max(candidate_probs), } - + if trace_distances is not None: event.update(trace_distances) - + self.branching_events.append(event) - + def branching_events_to_dataframe(self): """Convert branching events list to DataFrame""" if len(self.branching_events) > 0: self.df_branching_events = pd.DataFrame.from_records(self.branching_events) return self.df_branching_events - + def get_cumulative_branch_counts(self): """ Calculate cumulative number of branches over time. - + Returns: -------- pd.DataFrame with columns: @@ -251,47 +251,49 @@ def get_cumulative_branch_counts(self): """ if self.df_branching_events is None: self.branching_events_to_dataframe() - + if self.df_combined_values is None: self.combine_measurements() - + # Ensure branch_values_to_dataframe has been called if self.df_branch_values is None: self.branch_values_to_dataframe() - + if self.df_branching_events is None or len(self.df_branching_events) == 0: # If no branching events, return dataframe based on measurements if self.df_combined_values is None or len(self.df_combined_values) == 0: return None if self.df_branch_values is None or len(self.df_branch_values) == 0: return None - times = sorted(self.df_combined_values['time'].unique()) + times = sorted(self.df_combined_values["time"].unique()) cumulative_data = [] for t in times: - n_actual = len(self.df_branch_values[self.df_branch_values['time'] == t]) - cumulative_data.append({ - 'time': t, - 'n_sampled_branches_actual': n_actual, - 'n_sampled_branches_cumulative': n_actual, - 'n_candidate_branches_cumulative': n_actual, - 'n_discarded_branches_cumulative': 0, - }) + n_actual = len(self.df_branch_values[self.df_branch_values["time"] == t]) + cumulative_data.append( + { + "time": t, + "n_sampled_branches_actual": n_actual, + "n_sampled_branches_cumulative": n_actual, + "n_candidate_branches_cumulative": n_actual, + "n_discarded_branches_cumulative": 0, + } + ) return pd.DataFrame(cumulative_data) - + # Get all unique times from measurements if self.df_branch_values is None or len(self.df_branch_values) == 0: return None - times = sorted(self.df_combined_values['time'].unique()) - + times = sorted(self.df_combined_values["time"].unique()) + cumulative_data = [] - + for t in times: # Count actual branches (from measurements) at this time - n_actual = len(self.df_branch_values[self.df_branch_values['time'] == t]) - + n_actual = len(self.df_branch_values[self.df_branch_values["time"] == t]) + # Count total candidate branches up to this time - events_before = self.df_branching_events[self.df_branching_events['time'] <= t] - + events_before = self.df_branching_events[self.df_branching_events["time"] <= t] + if len(events_before) > 0: # Calculate cumulative branches # Start with 1 (initial branch) @@ -301,22 +303,24 @@ def get_cumulative_branch_counts(self): n_candidates_total = 1 # Start with 1 branch n_sampled_total = 1 for _, event in events_before.iterrows(): - n_candidates_total = n_candidates_total - 1 + event['n_candidate_branches'] - n_sampled_total = n_sampled_total - 1 + event['n_sampled_branches'] - n_discarded_total = events_before['n_discarded_branches'].sum() + n_candidates_total = n_candidates_total - 1 + event["n_candidate_branches"] + n_sampled_total = n_sampled_total - 1 + event["n_sampled_branches"] + n_discarded_total = events_before["n_discarded_branches"].sum() else: n_candidates_total = 1 # Start with 1 branch n_sampled_total = 1 n_discarded_total = 0 - - cumulative_data.append({ - 'time': t, - 'n_sampled_branches_actual': n_actual, - 'n_sampled_branches_cumulative': n_sampled_total, - 'n_candidate_branches_cumulative': n_candidates_total, - 'n_discarded_branches_cumulative': n_discarded_total, - }) - + + cumulative_data.append( + { + "time": t, + "n_sampled_branches_actual": n_actual, + "n_sampled_branches_cumulative": n_sampled_total, + "n_candidate_branches_cumulative": n_candidates_total, + "n_discarded_branches_cumulative": n_discarded_total, + } + ) + return pd.DataFrame(cumulative_data) def merge_with_other(self, other): @@ -339,7 +343,7 @@ def _repr_html_(self): def plot_branch_counts_over_time(branch_values, name="", outfolder=None, save=True): """ Plot the number of branches over time, including both sampled and unsampled. - + Parameters: ----------- branch_values : BranchValues @@ -353,132 +357,153 @@ def plot_branch_counts_over_time(branch_values, name="", outfolder=None, save=Tr """ # Get cumulative branch counts df_cumulative = branch_values.get_cumulative_branch_counts() - + if df_cumulative is None or len(df_cumulative) == 0: print("No branching event data available to plot") return - + # Create figure plt.figure(figsize=(12, 6), dpi=150) - + # Plot cumulative candidate branches (total that ever existed) plt.plot( - df_cumulative['time'], - df_cumulative['n_candidate_branches_cumulative'], - label='Total candidate branches (including unsampled)', - color='#d62728', + df_cumulative["time"], + df_cumulative["n_candidate_branches_cumulative"], + label="Total candidate branches (including unsampled)", + color="#d62728", linewidth=2, - linestyle='--', - alpha=0.8 + linestyle="--", + alpha=0.8, ) - + # Plot cumulative sampled branches plt.plot( - df_cumulative['time'], - df_cumulative['n_sampled_branches_cumulative'], - label='Sampled branches (kept)', - color='#2ca02c', - linewidth=2.5 + df_cumulative["time"], + df_cumulative["n_sampled_branches_cumulative"], + label="Sampled branches (kept)", + color="#2ca02c", + linewidth=2.5, ) - + # Plot actual active branches (from measurements) plt.plot( - df_cumulative['time'], - df_cumulative['n_sampled_branches_actual'], - label='Active branches at time t', - color='#1f77b4', + df_cumulative["time"], + df_cumulative["n_sampled_branches_actual"], + label="Active branches at time t", + color="#1f77b4", linewidth=2, - marker='o', + marker="o", markersize=3, - alpha=0.7 + alpha=0.7, ) - + # Plot discarded branches (as shaded region) plt.fill_between( - df_cumulative['time'], - df_cumulative['n_sampled_branches_cumulative'], - df_cumulative['n_candidate_branches_cumulative'], + df_cumulative["time"], + df_cumulative["n_sampled_branches_cumulative"], + df_cumulative["n_candidate_branches_cumulative"], alpha=0.3, - color='#ff7f0e', - label='Discarded branches (unsampled)' + color="#ff7f0e", + label="Discarded branches (unsampled)", ) - - plt.xlabel('Time', fontsize=12) - plt.ylabel('Number of Branches', fontsize=12) - plt.title(f'Branch Counts Over Time: {name}', fontsize=14) - plt.legend(loc='best', fontsize=10) + + plt.xlabel("Time", fontsize=12) + plt.ylabel("Number of Branches", fontsize=12) + plt.title(f"Branch Counts Over Time: {name}", fontsize=14) + plt.legend(loc="best", fontsize=10) plt.grid(True, alpha=0.3) plt.tight_layout() - + if save and outfolder is not None: from pathlib import Path + outfolder = Path(outfolder) outfolder.mkdir(exist_ok=True, parents=True) - + plt.savefig(outfolder / f"{NOW}_{name}_branch_counts.pdf") plt.savefig(outfolder / f"{NOW}_{name}_branch_counts.png") print(f"Saved branch counts plot to {outfolder}") - + plt.show() def plot_branching_events_detail(branch_values, name="", outfolder=None, save=True): """ Plot detailed information about each branching event. - + Shows when branching occurred, how many branches were created vs sampled. """ branch_values.branching_events_to_dataframe() df_events = branch_values.df_branching_events - + if df_events is None or len(df_events) == 0: print("No branching event data available") return - + fig, axes = plt.subplots(2, 1, figsize=(12, 8), dpi=150, sharex=True) - + # Top plot: Number of branches at each event ax1 = axes[0] width = 0.35 x = np.arange(len(df_events)) - - ax1.bar(x - width/2, df_events['n_candidate_branches'], width, - label='Candidate branches', color='#d62728', alpha=0.7) - ax1.bar(x + width/2, df_events['n_sampled_branches'], width, - label='Sampled branches', color='#2ca02c', alpha=0.7) - - ax1.set_ylabel('Number of Branches', fontsize=11) - ax1.set_title(f'Branching Events Detail: {name}', fontsize=13) + + ax1.bar( + x - width / 2, + df_events["n_candidate_branches"], + width, + label="Candidate branches", + color="#d62728", + alpha=0.7, + ) + ax1.bar( + x + width / 2, + df_events["n_sampled_branches"], + width, + label="Sampled branches", + color="#2ca02c", + alpha=0.7, + ) + + ax1.set_ylabel("Number of Branches", fontsize=11) + ax1.set_title(f"Branching Events Detail: {name}", fontsize=13) ax1.legend() - ax1.grid(True, alpha=0.3, axis='y') - + ax1.grid(True, alpha=0.3, axis="y") + # Bottom plot: Probability distribution ax2 = axes[1] - ax2.bar(x, df_events['total_candidate_prob'], width*1.5, - label='Total candidate prob', color='#ff7f0e', alpha=0.5) - ax2.bar(x, df_events['sampled_prob'], width*1.5, - label='Sampled prob', color='#1f77b4', alpha=0.7) - - ax2.set_xlabel('Branching Event', fontsize=11) - ax2.set_ylabel('Probability', fontsize=11) + ax2.bar( + x, + df_events["total_candidate_prob"], + width * 1.5, + label="Total candidate prob", + color="#ff7f0e", + alpha=0.5, + ) + ax2.bar( + x, df_events["sampled_prob"], width * 1.5, label="Sampled prob", color="#1f77b4", alpha=0.7 + ) + + ax2.set_xlabel("Branching Event", fontsize=11) + ax2.set_ylabel("Probability", fontsize=11) ax2.legend() - ax2.grid(True, alpha=0.3, axis='y') - + ax2.grid(True, alpha=0.3, axis="y") + # Set x-axis labels with times ax2.set_xticks(x) - ax2.set_xticklabels([f"{t:.2f}" for t in df_events['time']], rotation=45) - + ax2.set_xticklabels([f"{t:.2f}" for t in df_events["time"]], rotation=45) + plt.tight_layout() - + if save and outfolder is not None: from pathlib import Path + outfolder = Path(outfolder) outfolder.mkdir(exist_ok=True, parents=True) - + plt.savefig(outfolder / f"{NOW}_{name}_branching_events.pdf") plt.savefig(outfolder / f"{NOW}_{name}_branching_events.png") print(f"Saved branching events plot to {outfolder}") - + plt.show() @@ -489,47 +514,58 @@ def plot_sampling_efficiency(branch_values, name="", outfolder=None, save=True): """ branch_values.branching_events_to_dataframe() df_events = branch_values.df_branching_events - + if df_events is None or len(df_events) == 0: print("No branching event data available") return - + # Calculate sampling efficiency - df_events['sampling_efficiency'] = ( - df_events['n_sampled_branches'] / df_events['n_candidate_branches'] - ) - df_events['prob_efficiency'] = ( - df_events['sampled_prob'] / df_events['total_candidate_prob'] + df_events["sampling_efficiency"] = ( + df_events["n_sampled_branches"] / df_events["n_candidate_branches"] ) - + df_events["prob_efficiency"] = df_events["sampled_prob"] / df_events["total_candidate_prob"] + plt.figure(figsize=(12, 6), dpi=150) - - plt.plot(df_events['time'], df_events['sampling_efficiency'], - marker='o', label='Branch count efficiency', - linewidth=2, markersize=8, color='#1f77b4') - plt.plot(df_events['time'], df_events['prob_efficiency'], - marker='s', label='Probability efficiency', - linewidth=2, markersize=8, color='#2ca02c') - - plt.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, label='No sampling (100%)') - - plt.xlabel('Time', fontsize=12) - plt.ylabel('Sampling Efficiency (sampled / candidates)', fontsize=12) - plt.title(f'Sampling Efficiency Over Time: {name}', fontsize=14) - plt.legend(loc='best', fontsize=10) + + plt.plot( + df_events["time"], + df_events["sampling_efficiency"], + marker="o", + label="Branch count efficiency", + linewidth=2, + markersize=8, + color="#1f77b4", + ) + plt.plot( + df_events["time"], + df_events["prob_efficiency"], + marker="s", + label="Probability efficiency", + linewidth=2, + markersize=8, + color="#2ca02c", + ) + + plt.axhline(y=1.0, color="gray", linestyle="--", alpha=0.5, label="No sampling (100%)") + + plt.xlabel("Time", fontsize=12) + plt.ylabel("Sampling Efficiency (sampled / candidates)", fontsize=12) + plt.title(f"Sampling Efficiency Over Time: {name}", fontsize=14) + plt.legend(loc="best", fontsize=10) plt.grid(True, alpha=0.3) plt.ylim([0, 1.1]) plt.tight_layout() - + if save and outfolder is not None: from pathlib import Path + outfolder = Path(outfolder) outfolder.mkdir(exist_ok=True, parents=True) - + plt.savefig(outfolder / f"{NOW}_{name}_sampling_efficiency.pdf") plt.savefig(outfolder / f"{NOW}_{name}_sampling_efficiency.png") print(f"Saved sampling efficiency plot to {outfolder}") - + plt.show() @@ -609,7 +645,8 @@ def __init__( tebd_engine: tenpy.TEBDEngine | ExpMPOEvolution | None = None, # The TEBD engine to use for time evolution (None for unsampled branches) - cfg: BranchingMPSConfig | None = None, # The configuration for splitting the wavefunction into branches + cfg: BranchingMPSConfig + | None = None, # The configuration for splitting the wavefunction into branches branch_values: BranchValues | None = None, # The structure for storing the measurements of all the branches over time branch_function: Callable @@ -627,15 +664,18 @@ def __init__( info={}, # New parameters for unsampled branches sampled: bool = True, # Whether this branch was sampled (False for discarded branches) - created_time: float | None = None, # Time when branch was created (evolved_time at creation) + created_time: float + | None = None, # Time when branch was created (evolved_time at creation) branching_site: int | None = None, # Site where branching occurred (for child nodes) - prob_override: float | None = None, # Probability override (for unsampled branches without tebd_engine) - norm_override: float | None = None, # Norm override (for unsampled branches without tebd_engine) + prob_override: float + | None = None, # Probability override (for unsampled branches without tebd_engine) + norm_override: float + | None = None, # Norm override (for unsampled branches without tebd_engine) ): self.sampled = sampled # Track whether this branch was sampled or discarded self.created_time = created_time # Time when branch was created self.branching_site = branching_site # Site where branching occurred - + # For unsampled branches, tebd_engine may be None if tebd_engine is None: # This is an unsampled branch - use override values @@ -664,12 +704,12 @@ def __init__( self.branch_values = BranchValues() else: self.branch_values = branch_values - + # Track root creation walltime for relative time calculations # This needs to be set before we use it, so handle it after parent is assigned # (We'll set it at the end of __init__) pass - + # Track final walltime (updated when branch finishes or at save time) self.final_walltime: datetime | None = None @@ -792,14 +832,14 @@ def __init__( else: # For unsampled branches, use created_time if available self.evolved_time = self.created_time if self.created_time is not None else 0.0 - + self.t_last_attempted_branching = self.evolved_time self.finished = False - + # Set created_time if not provided (for sampled branches) if self.created_time is None: self.created_time = self.evolved_time - + # Set root_created_walltime (after parent is assigned) if self.parent is None: # This is the root branch - store its creation time as reference @@ -1158,23 +1198,25 @@ def branch_and_sample( self.depth += 1 elif len(branch_indices) > 1: print(f"{self.ID}Creating {num_kept_branches} children nodes.") - + # Record the branching event including unsampled branches - if hasattr(self, 'branch_values') and self.branch_values is not None: + if hasattr(self, "branch_values") and self.branch_values is not None: self.branch_values.add_branching_event( time=self.evolved_time, n_candidate_branches=num_candidates, n_sampled_branches=len(survivor_indices), candidate_probs=branch_probs, sampled_indices=survivor_indices, - parent_prob=abs(self.norm ** 2), + parent_prob=abs(self.norm**2), site=coarsegrain_from, trace_distances={ - 'costFun_LM_MR_trace_distance': costFun_LM_MR_trace_distance, - 'global_reconstruction_error_trace_distance': global_reconstruction_error_trace_distance, - } + "costFun_LM_MR_trace_distance": costFun_LM_MR_trace_distance, + "global_reconstruction_error_trace_distance": global_reconstruction_error_trace_distance, + }, + ) + print( + f"{self.ID}Branching summary: Total candidates={num_candidates}, Sampled={len(survivor_indices)}, Discarded={num_candidates - len(survivor_indices)}" ) - print(f"{self.ID}Branching summary: Total candidates={num_candidates}, Sampled={len(survivor_indices)}, Discarded={num_candidates - len(survivor_indices)}") if np.isclose(total_prob_survived, 0.0): print( @@ -1239,7 +1281,7 @@ def branch_and_sample( discarded_child_prob = discarded_prob / total_prob_survived else: discarded_child_prob = discarded_prob - + unsampled_child = BranchingMPS( tebd_engine=None, # No engine for unsampled branches cfg=self.cfg, @@ -1253,7 +1295,8 @@ def branch_and_sample( sampled=False, # Mark as unsampled created_time=self.evolved_time, branching_site=coarsegrain_from, - prob_override=discarded_child_prob * abs(self.norm ** 2), # Parent prob * child prob + prob_override=discarded_child_prob + * abs(self.norm**2), # Parent prob * child prob ) self.unsampled_children.append(unsampled_child) print( @@ -1389,44 +1432,67 @@ def to_tree_dict(self): """ Convert BranchingMPS tree structure to a dictionary, stripping out tensors and engines. This preserves the tree structure with all metadata for visualization and analysis. - + Returns: -------- dict: Dictionary representation of the tree node """ # Calculate relative times in seconds from root creation - root_created = self.root_created_walltime if hasattr(self, 'root_created_walltime') else self.created_walltime - - created_walltime_rel = (self.created_walltime - root_created).total_seconds() if self.created_walltime else None - + root_created = ( + self.root_created_walltime + if hasattr(self, "root_created_walltime") + else self.created_walltime + ) + + created_walltime_rel = ( + (self.created_walltime - root_created).total_seconds() + if self.created_walltime + else None + ) + if self.final_walltime is not None: final_walltime_rel = (self.final_walltime - root_created).total_seconds() else: final_walltime_rel = None - + tree_dict = { - 'ID': self.ID, - 'sampled': self.sampled, - 'created_time': self.created_time, - 'evolved_time': self.evolved_time, - 'created_walltime': self.created_walltime.isoformat() if self.created_walltime else None, - 'created_walltime_rel_seconds': created_walltime_rel, - 'final_walltime': self.final_walltime.isoformat() if self.final_walltime else None, - 'final_walltime_rel_seconds': final_walltime_rel, - 'prob': float(self.prob), - 'norm': float(self.norm), - 'depth': self.depth, - 'max_children': self.max_children, - 'finished': self.finished, - 'synchronized': self.synchronized if hasattr(self, 'synchronized') else False, - 'branching_site': self.branching_site, - 'costFun_LM_MR_trace_distance': float(self.costFun_LM_MR_trace_distance) if hasattr(self, 'costFun_LM_MR_trace_distance') else 0.0, - 'global_reconstruction_error_trace_distance': float(self.global_reconstruction_error_trace_distance) if hasattr(self, 'global_reconstruction_error_trace_distance') else 0.0, - 'has_tebd_engine': self.tebd_engine is not None, - 'n_children': len(self.children), - 'n_unsampled_children': len(self.unsampled_children) if hasattr(self, 'unsampled_children') else 0, - 'children': [child.to_tree_dict() for child in self.children], - 'unsampled_children': [child.to_tree_dict() for child in (self.unsampled_children if hasattr(self, 'unsampled_children') else [])], + "ID": self.ID, + "sampled": self.sampled, + "created_time": self.created_time, + "evolved_time": self.evolved_time, + "created_walltime": self.created_walltime.isoformat() + if self.created_walltime + else None, + "created_walltime_rel_seconds": created_walltime_rel, + "final_walltime": self.final_walltime.isoformat() if self.final_walltime else None, + "final_walltime_rel_seconds": final_walltime_rel, + "prob": float(self.prob), + "norm": float(self.norm), + "depth": self.depth, + "max_children": self.max_children, + "finished": self.finished, + "synchronized": self.synchronized if hasattr(self, "synchronized") else False, + "branching_site": self.branching_site, + "costFun_LM_MR_trace_distance": float(self.costFun_LM_MR_trace_distance) + if hasattr(self, "costFun_LM_MR_trace_distance") + else 0.0, + "global_reconstruction_error_trace_distance": float( + self.global_reconstruction_error_trace_distance + ) + if hasattr(self, "global_reconstruction_error_trace_distance") + else 0.0, + "has_tebd_engine": self.tebd_engine is not None, + "n_children": len(self.children), + "n_unsampled_children": len(self.unsampled_children) + if hasattr(self, "unsampled_children") + else 0, + "children": [child.to_tree_dict() for child in self.children], + "unsampled_children": [ + child.to_tree_dict() + for child in ( + self.unsampled_children if hasattr(self, "unsampled_children") else [] + ) + ], } return tree_dict @@ -1441,11 +1507,11 @@ def _update_final_walltime_recursive(self, current_time: datetime): else: # Update to current time (at save time) self.final_walltime = current_time - + # Update children and unsampled children for child in self.children: child._update_final_walltime_recursive(current_time) - if hasattr(self, 'unsampled_children'): + if hasattr(self, "unsampled_children"): for child in self.unsampled_children: child._update_final_walltime_recursive(current_time) @@ -1457,7 +1523,7 @@ def save(self, final=False): # Update final_walltime for all branches at save time current_walltime = datetime.now() self._update_final_walltime_recursive(current_walltime) - + if self.pickle_file is not None: branchvals_file = str(self.pickle_file).split(".pkl")[0] + "_branch_values.pkl" branchvals_file = ( @@ -1468,13 +1534,11 @@ def save(self, final=False): with open(branchvals_file, "wb") as f: pickle.dump(self.branch_values, f) f.close() - + # Save tree structure (without tensors) to JSON tree_file = str(self.pickle_file).split(".pkl")[0] + "_tree.json" tree_file = ( - tree_file - if (self.n_times_saved % 2 == 0 or final) - else tree_file + "tmp" + tree_file if (self.n_times_saved % 2 == 0 or final) else tree_file + "tmp" ) try: tree_dict = self.to_tree_dict() @@ -1484,8 +1548,9 @@ def save(self, final=False): except Exception as e: print(f"{self.ID}Warning: Failed to save tree structure: {e}") import traceback + traceback.print_exc() - + if self.cfg.save_full_state: pickle_file = ( self.pickle_file @@ -1499,7 +1564,11 @@ def save(self, final=False): if final: print(f"{self.ID}Removing temp files as this is the final save.") - for tmp_file in [branchvals_file + "tmp", str(self.pickle_file) + "tmp", tree_file + "tmp"]: + for tmp_file in [ + branchvals_file + "tmp", + str(self.pickle_file) + "tmp", + tree_file + "tmp", + ]: Path(tmp_file).unlink(missing_ok=True) t1 = time.time() print(f"{self.ID}Saved in {t1 - t0} seconds to {self.pickle_file}") @@ -1729,28 +1798,20 @@ def trace_distance_colors(key): # NEW PLOTS: Branch counts including unsampled branches try: plot_branch_counts_over_time( - self.branch_values, - name=self.name, - outfolder=plots_dir, - save=True + self.branch_values, name=self.name, outfolder=plots_dir, save=True ) - + plot_branching_events_detail( - self.branch_values, - name=self.name, - outfolder=plots_dir, - save=True + self.branch_values, name=self.name, outfolder=plots_dir, save=True ) - + plot_sampling_efficiency( - self.branch_values, - name=self.name, - outfolder=plots_dir, - save=True + self.branch_values, name=self.name, outfolder=plots_dir, save=True ) except Exception as e: print(f"Error plotting branch counts: {e}") import traceback + traceback.print_exc() # Log to wandb