diff --git a/other/experiments/test_slakonet.ipynb b/other/experiments/test_slakonet.ipynb new file mode 100644 index 00000000..c83f5f8f --- /dev/null +++ b/other/experiments/test_slakonet.ipynb @@ -0,0 +1,176 @@ +{ + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "python", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8" + }, + "kernelspec": { + "name": "python", + "display_name": "Python (Pyodide)", + "language": "python" + } + }, + "nbformat_minor": 4, + "nbformat": 4, + "cells": [ + { + "cell_type": "code", + "source": "import micropip", + "metadata": { + "trusted": true + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "code", + "source": "for package in [\n \"emfs:/drive/packages/pymatgen-2024.4.13-py3-none-any.whl\",\n \"emfs:/drive/packages/spglib-2.0.2-py3-none-any.whl\",\n \"emfs:/drive/packages/ruamel.yaml-0.17.32-py3-none-any.whl\",\n \"emfs:/drive/packages/pydantic_core-2.18.2-py3-none-any.whl\",\n \"emfs:/drive/packages/pydantic-2.7.1-py3-none-any.whl\",\n \"emfs:/drive/packages/torch-2.1.0a0-cp311-cp311-emscripten_3_1_45_wasm32.whl\",\n \"numpy\",\n \"scipy\",\n \"matplotlib\",\n \"sympy\",\n \"pydantic_settings\",\n \"xmltodict\",\n \"requests\",\n \"urllib3\",\n \"idna\",\n \"certifi\",\n \"scikit-learn\",\n \"tqdm\",\n \"jarvis-tools\",\n \"ase\",\n \"slakonet\"\n]:\n print(package)\n await micropip.install(package, deps=False)", + "metadata": { + "trusted": true + }, + "execution_count": 2, + "outputs": [ + { + "name": "stdout", + "text": "emfs:/drive/packages/pymatgen-2024.4.13-py3-none-any.whl\nemfs:/drive/packages/spglib-2.0.2-py3-none-any.whl\nemfs:/drive/packages/ruamel.yaml-0.17.32-py3-none-any.whl\nemfs:/drive/packages/pydantic_core-2.18.2-py3-none-any.whl\nemfs:/drive/packages/pydantic-2.7.1-py3-none-any.whl\nemfs:/drive/packages/torch-2.1.0a0-cp311-cp311-emscripten_3_1_45_wasm32.whl\nnumpy\nscipy\nmatplotlib\nsympy\npydantic_settings\nxmltodict\nrequests\nurllib3\nidna\ncertifi\nscikit-learn\ntqdm\njarvis-tools\nase\nslakonet\n", + "output_type": "stream" + } + ] + }, + { + "cell_type": "code", + "source": "import torch\nimport numpy as np\nfrom collections import namedtuple\n\n# Define return types to mimic PyTorch's named tuples\nEigRet = namedtuple('linalg_eig', ['eigenvalues', 'eigenvectors'])\nEighRet = namedtuple('linalg_eigh', ['eigenvalues', 'eigenvectors'])\n\ndef _to_np(tensor):\n return tensor.detach().cpu().numpy()\n\ndef _to_torch(array, device, dtype=None):\n if dtype:\n return torch.tensor(array, device=device, dtype=dtype)\n return torch.tensor(array, device=device)\n\n# --- Define the Patches ---\n\ndef patch_solve(A, B, *args, **kwargs):\n return _to_torch(np.linalg.solve(_to_np(A), _to_np(B)), A.device)\n\ndef patch_inv(A, *args, **kwargs):\n return _to_torch(np.linalg.inv(_to_np(A)), A.device)\n\ndef patch_det(A, *args, **kwargs):\n return _to_torch(np.linalg.det(_to_np(A)), A.device)\n\ndef patch_cholesky(A, *args, **kwargs):\n # NumPy's cholesky operates on the lower triangle by default, same as PyTorch\n return _to_torch(np.linalg.cholesky(_to_np(A)), A.device)\n\ndef patch_eig(A, *args, **kwargs):\n vals, vecs = np.linalg.eig(_to_np(A))\n return EigRet(_to_torch(vals, A.device), _to_torch(vecs, A.device))\n\ndef patch_eigh(A, UPLO='L', *args, **kwargs):\n vals, vecs = np.linalg.eigh(_to_np(A), UPLO=UPLO)\n return EighRet(_to_torch(vals, A.device), _to_torch(vecs, A.device))\n\n# --- Apply the Patches to PyTorch ---\n\ntorch.linalg.solve = patch_solve\ntorch.linalg.inv = patch_inv\ntorch.inverse = patch_inv # Alias\ntorch.linalg.det = patch_det\ntorch.det = patch_det # Alias\ntorch.linalg.cholesky = patch_cholesky\ntorch.linalg.eig = patch_eig\ntorch.linalg.eigh = patch_eigh\n\n# Fix numpy\ndef _tensor_array_compat(self, dtype=None):\n \"\"\"Replacement for Tensor.__array__ in Pyodide where tensor.numpy() is unavailable.\"\"\"\n arr = np.array(self.tolist())\n return arr.astype(dtype) if dtype is not None else arr\ntorch.Tensor.__array__ = _tensor_array_compat\ntorch.Tensor.numpy = lambda self: np.array(self.detach().tolist())\n\n\n# Keep the SciPy LU patches we made earlier just in case\nimport scipy.linalg\nLUFactorReturn = namedtuple('LUFactorReturn', ['LU', 'pivots'])\n\ndef patch_lu_factor(A, *args, **kwargs):\n A_np = _to_np(A)\n if A_np.ndim > 2:\n orig_shape = A_np.shape\n A_reshaped = A_np.reshape(-1, orig_shape[-2], orig_shape[-1])\n lu_list, piv_list = [], []\n for mat in A_reshaped:\n lu, piv = scipy.linalg.lu_factor(mat)\n lu_list.append(lu)\n piv_list.append(piv)\n LU_np = np.stack(lu_list).reshape(orig_shape)\n piv_np = np.stack(piv_list).reshape(orig_shape[:-2] + (-1,))\n else:\n LU_np, piv_np = scipy.linalg.lu_factor(A_np)\n return LUFactorReturn(_to_torch(LU_np, A.device, A.dtype), _to_torch(piv_np, A.device, torch.int32))\n\ndef patch_lu_solve(LU, pivots, B, *args, **kwargs):\n LU_np, piv_np, B_np = _to_np(LU), _to_np(pivots), _to_np(B)\n if LU_np.ndim > 2:\n orig_shape_B = B_np.shape\n LU_reshaped = LU_np.reshape(-1, LU_np.shape[-2], LU_np.shape[-1])\n piv_reshaped = piv_np.reshape(-1, piv_np.shape[-1])\n is_vector = (B_np.ndim == LU_np.ndim - 1)\n B_reshaped = B_np.reshape(-1, B_np.shape[-1], 1) if is_vector else B_np.reshape(-1, B_np.shape[-2], B_np.shape[-1])\n X_list = [scipy.linalg.lu_solve((lu, piv), b) for lu, piv, b in zip(LU_reshaped, piv_reshaped, B_reshaped)]\n X_np = np.stack(X_list).reshape(orig_shape_B)\n else:\n X_np = scipy.linalg.lu_solve((LU_np, piv_np), B_np)\n return _to_torch(X_np, B.device, B.dtype)\n\ntorch.linalg.lu_factor = patch_lu_factor\ntorch.linalg.lu_solve = patch_lu_solve\n\nprint(\"All major torch.linalg functions successfully patched to use NumPy/SciPy!\")", + "metadata": { + "trusted": true + }, + "execution_count": 3, + "outputs": [ + { + "name": "stdout", + "text": "All major torch.linalg functions successfully patched to use NumPy/SciPy!\n", + "output_type": "stream" + } + ] + }, + { + "cell_type": "code", + "source": "from slakonet.optim import (\n MultiElementSkfParameterOptimizer,\n get_atoms,\n kpts_to_klines,\n default_model,\n)\nimport torch\nfrom slakonet.atoms import Geometry\nfrom slakonet.main import generate_shell_dict_upto_Z65\n\nmodel = default_model()\n\n# Get structure (example with JARVIS ID)\natoms, opt_gap, mbj_gap = get_atoms(\"JVASP-1002\") \ngeometry = Geometry.from_ase_atoms([atoms.ase_converter()])\nshell_dict = generate_shell_dict_upto_Z65()", + "metadata": { + "trusted": true + }, + "execution_count": 4, + "outputs": [ + { + "name": "stderr", + "text": "/lib/python3.11/site-packages/requests/__init__.py:86: RequestsDependencyWarning: Unable to find acceptable character detection dependency (chardet or charset_normalizer).\n warnings.warn(\n", + "output_type": "stream" + }, + { + "name": "stdout", + "text": "Downloading slakonet_v0 model from Figshare...\n", + "output_type": "stream" + }, + { + "name": "stderr", + "text": "/lib/python3.11/site-packages/slakonet/optim.py:3081: TqdmMonitorWarning: tqdm:disabling monitor support (monitor_interval = 0) due to:\ncan't start new thread\n progress_bar = tqdm(total=total_size_in_bytes, unit=\"iB\", unit_scale=True)\n100%|██████████| 183M/183M [00:01<00:00, 132MiB/s] \n", + "output_type": "stream" + }, + { + "name": "stdout", + "text": "Saved zip file: /home/pyodide/.cache/atomgptlab/slakonet/slakonet_v0/slakonet_v0.zip\nExtracting model file...\nExtracted model to: /home/pyodide/.cache/atomgptlab/slakonet/slakonet_v0/slakonet_v0.pt\n✅ Compact model loaded from: /home/pyodide/.cache/atomgptlab/slakonet/slakonet_v0/slakonet_v0.pt\nTotal time: 32.26s\nObtaining 3D dataset 76k ...\nReference:https://doi.org/10.1016/j.commatsci.2025.114063\nOther versions:https://doi.org/10.6084/m9.figshare.6815699\n", + "output_type": "stream" + }, + { + "name": "stderr", + "text": "100%|██████████| 40.8M/40.8M [00:00<00:00, 128MiB/s]\n", + "output_type": "stream" + }, + { + "name": "stdout", + "text": "Loading the zipfile...\nLoading completed.\n", + "output_type": "stream" + } + ] + }, + { + "cell_type": "code", + "source": "import torch\n\n# Your original waypoints (k_x, k_y, k_z, density)\nwaypoints = [\n [0.000, 0.000, 0.000, 10], # Gamma\n [0.500, 0.000, 0.500, 10], # X\n [0.500, 0.250, 0.750, 10], # W\n [0.375, 0.375, 0.750, 10], # K\n [0.000, 0.000, 0.000, 10], # Gamma\n [0.500, 0.500, 0.500, 10], # L\n [0.625, 0.250, 0.625, 10], # U\n [0.500, 0.250, 0.750, 10], # W\n [0.500, 0.500, 0.500, 10], # L\n [0.625, 0.250, 0.625, 10], # U\n [0.500, 0.000, 0.500, 10], # X\n]\n\n# Convert waypoints into 7-element segments: [start_k, end_k, N]\nformatted_segments = []\nfor i in range(len(waypoints) - 1):\n start = waypoints[i][:3] # Get x, y, z of current point\n end = waypoints[i+1][:3] # Get x, y, z of next point\n num_points = waypoints[i+1][3] # Get the N value from the destination point\n \n formatted_segments.append(start + end + [num_points])\n\n# slakonet expects (batch, num_segments, 7)\nklines_tensor = torch.tensor([formatted_segments], dtype=torch.float32)\n\n# Run the calculation\nwith torch.no_grad():\n properties, success = model.compute_multi_element_properties(\n geometry=geometry,\n shell_dict=shell_dict,\n klines=klines_tensor, \n get_fermi=True,\n )\n\nprint(f\"Calculation successful: {success}\")", + "metadata": { + "trusted": true + }, + "execution_count": 5, + "outputs": [ + { + "name": "stdout", + "text": "Cholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\nCholesky failed: expected m1 and m2 to have the same dtype, but got: c10::complex != c10::complex, falling back to eig\npotential_energy 0\nelectronic_energy tensor(-45.0753)\nCalculation successful: True\n", + "output_type": "stream" + } + ] + }, + { + "cell_type": "code", + "source": "# Access results using .item() to extract the float from the Tensor\nprint(f\"Band gap: {properties['bandgap'].item():.3f} eV\")", + "metadata": { + "trusted": true + }, + "execution_count": 6, + "outputs": [ + { + "name": "stdout", + "text": "Band gap: 1.084 eV\n", + "output_type": "stream" + } + ] + }, + { + "cell_type": "code", + "source": "print(f\"Fermi energy: {properties['fermi_energy'].item():.3f} eV\")", + "metadata": { + "trusted": true + }, + "execution_count": 7, + "outputs": [ + { + "name": "stdout", + "text": "Fermi energy: -2.494 eV\n", + "output_type": "stream" + } + ] + }, + { + "cell_type": "code", + "source": "# Plot band structure and DOS\neigenvalues = properties[\"eigenvalues\"]\ndos_values = properties['dos_values_tensor']\ndos_energies = properties['dos_energy_grid_tensor']", + "metadata": { + "trusted": true + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": "import matplotlib\nmatplotlib.use(\"Agg\") \nimport matplotlib.pyplot as plt\nimport numpy as np\nimport io\nimport base64\n\n# 1. Extract and Clean Data\n# Use .squeeze() but be careful with batch dimensions\nraw_eigenvals = properties[\"eigenvalues\"].detach().cpu().numpy()\nif raw_eigenvals.ndim == 3:\n raw_eigenvals = raw_eigenvals[0] # Take first batch\n\n# CRITICAL: Sort eigenvalues at each k-point to prevent \"zig-zags\" \n# between energy levels that swap indices\neigenvals = np.sort(raw_eigenvals, axis=1)\n\ndos_vals = properties['dos_values_tensor'].detach().cpu().numpy()\ndos_energies = properties['dos_energy_grid_tensor'].detach().cpu().numpy()\ne_fermi = properties['fermi_energy'].item()\n\n# Shift everything relative to Fermi Level (Standard convention)\neigenvals -= e_fermi\ndos_energies -= e_fermi\n\n# 2. Set up the figure\nfig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 6), \n gridspec_kw={'width_ratios': [2.5, 1]}, sharey=True)\n\n# ==========================================\n# Plot 1: Band Structure\n# ==========================================\nnum_kpoints = eigenvals.shape[0]\nk_axis = np.arange(num_kpoints)\n\n# Plot each band individually\nfor i in range(eigenvals.shape[1]):\n ax1.plot(k_axis, eigenvals[:, i], color='#1f77b4', linewidth=1.2, alpha=0.8)\n\n# Add a horizontal line at the NEW Fermi Level (0 eV)\nax1.axhline(0, color='r', linestyle='--', label='Fermi Level', linewidth=1.5)\n\nax1.set_ylabel('Energy - $E_f$ (eV)')\nax1.set_xlabel('k-point index')\nax1.set_title('Band Structure (Aligned)')\nax1.set_xlim(0, num_kpoints - 1)\nax1.grid(True, linestyle=':', alpha=0.5)\n\n# ==========================================\n# Plot 2: Density of States (DOS)\n# ==========================================\nax2.plot(dos_vals, dos_energies, color='k', linewidth=1.2)\nax2.axhline(0, color='r', linestyle='--')\n\n# Fill valence states (below Fermi Level)\nax2.fill_betweenx(dos_energies, 0, dos_vals, where=(dos_energies <= 0), color='skyblue', alpha=0.5)\n\nax2.set_xlabel('DOS (states/eV)')\nax2.set_title('DOS')\nax2.grid(True, linestyle=':', alpha=0.5)\n\n# Adjust y-limit to focus on the interesting region around the gap\nax1.set_ylim(-15, 15) \nax2.set_ylim(-15, 15) \n\nplt.tight_layout()\n\n# ==========================================\n# Pyodide Export\n# ==========================================\nbuf = io.BytesIO()\nfig.savefig(buf, format='png', dpi=150)\nbuf.seek(0)\nplt.close(fig) \n\nimg_base64 = base64.b64encode(buf.read()).decode('utf-8')\nimg_html = f''\n\nfrom IPython.display import display, HTML\ndisplay(HTML(img_html))", + "metadata": { + "trusted": true + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": "", + "text/html": "" + }, + "metadata": {} + } + ] + } + ] +} \ No newline at end of file