Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions notebooks/automl_pipeline_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@
" inner_max_epochs=10, # increase at least to 128\n",
" inner_max_time=\"00:00:01:00\", # increase at least to \"00:00:10:00\"\n",
" automl_overwrite_fit=True,\n",
" accelerator=\"cpu\",\n",
" **pipeline_args\n",
")"
],
Expand Down
1 change: 1 addition & 0 deletions notebooks/ligthning_pipeline_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@
" max_epochs=10,\n",
" max_time=\"00:00:01:00\", # DD:HH:MM:SS\n",
" overwrite_fit=True,\n",
" accelerator=\"cpu\",\n",
" verbose=True,\n",
" **model_args,\n",
")"
Expand Down
30 changes: 16 additions & 14 deletions notebooks/nif_deep_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"import torch\n",
"from ax import RangeParameterConfig\n",
"from matchcake import NonInteractingFermionicDevice\n",
"from matchcake.operations import SptmAngleEmbedding, SptmfRxRx, SptmFHH\n",
"from matchcake.operations import SptmAngleEmbedding, CompRxRx, CompHH\n",
"\n",
"from matchcake_opt.datamodules.datamodule import DataModule\n",
"from matchcake_opt.modules.classification_model import ClassificationModel\n",
Expand Down Expand Up @@ -114,8 +114,8 @@
" SptmAngleEmbedding(inputs, wires=range(self.n_qubits))\n",
" for i in range(self.n_layers):\n",
" for j in range(self.n_qubits - 1):\n",
" SptmfRxRx(weights[i, j*2 : j*2+2], wires=[j, j+1])\n",
" SptmFHH(wires=[j, j+1])\n",
" CompRxRx(weights[i, j*2 : j*2+2], wires=[j, j+1])\n",
" CompHH(wires=[j, j+1])\n",
" return [qml.expval(qml.PauliZ(wires=i)) for i in range(self.n_qubits)]\n",
"\n",
" def forward(self, x) -> Any:\n",
Expand Down Expand Up @@ -198,8 +198,6 @@
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"model_args = dict(\n",
" n_qubits=16,\n",
Expand All @@ -214,16 +212,17 @@
" max_time=\"00:00:01:00\", # DD:HH:MM:SS\n",
" overwrite_fit=True,\n",
" verbose=True,\n",
" accelerator=\"cpu\",\n",
" **model_args,\n",
")"
],
"id": "8832b1d06113ac48"
"id": "8832b1d06113ac48",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"start_time = time.perf_counter()\n",
"metrics = lightning_pipeline.run()\n",
Expand All @@ -234,7 +233,9 @@
"test_metrics = lightning_pipeline.run_test()\n",
"print(\"⚡\" * 20, \"\\nTest Metrics:\\n\", test_metrics, \"\\n\", \"⚡\" * 20)"
],
"id": "7d8a499f05426a0d"
"id": "7d8a499f05426a0d",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
Expand All @@ -245,21 +246,22 @@
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"automl_pipeline = AutoMLPipeline(\n",
" model_cls=model_cls,\n",
" datamodule=datamodule,\n",
" checkpoint_folder=checkpoint_folder,\n",
" automl_iterations=5, # increase at least to 32\n",
" automl_iterations=2, # increase at least to 32\n",
" inner_max_epochs=10, # increase at least to 128\n",
" inner_max_time=\"00:00:01:00\", # increase at least to \"00:00:10:00\"\n",
" inner_max_time=\"00:00:00:10\", # increase at least to \"00:00:10:00\"\n",
" automl_overwrite_fit=True,\n",
" accelerator=\"cpu\",\n",
" **pipeline_args\n",
")"
],
"id": "16001dd8071bc8b0"
"id": "16001dd8071bc8b0",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"tensorboard (>=2.19.0,<3.0.0)",
"ax-platform[mysql] (>=1.0.0,<2.0.0)",
"torcheval (>=0.0.7,<0.0.8)",
"matchcake (>=0.0.4)",
"matchcake>=0.1.2",
"autoray (<=0.7.2)",
"medmnist (>=3.0.2,<4.0.0)",
"torch-geometric>=2.7.0",
Expand All @@ -52,6 +52,8 @@ dev = [
"isort>=6.0.1,<7",
"types-networkx>=3.5.0.20251001",
"pip>=25.3",
"jupyter>=1.1.1",
"notebook>=7.5.0",
]
docs = [
"sphinx>=6.2.1,<6.3.0",
Expand Down
Loading