Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions libs/solvers/lib/operators/molecule/drivers/pyscf_driver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@

using namespace cudaqx;

namespace {
bool hasAvailableStatus(const nlohmann::json &metadata) {
return metadata.contains("status") && metadata["status"].is_string() &&
metadata["status"].get<std::string>() == "available";
}
} // namespace

namespace cudaq::solvers {

// Create a tear down service
Expand Down Expand Up @@ -50,13 +57,10 @@ class RESTPySCFDriver : public MoleculePackageDriver {
std::map<std::string, std::string> headers;
try {
auto res = client.get("localhost:8000/", "status", headers);
if (res.contains("status") &&
res["status"].get<std::string>() == "available")
return true;
} catch (std::exception &e) {
return hasAvailableStatus(res);
} catch (const std::exception &) {
return false;
}
return true;
}

std::unique_ptr<tear_down>
Expand All @@ -70,7 +74,6 @@ class RESTPySCFDriver : public MoleculePackageDriver {
auto argString = cudaqPySCFTool.string() + " --server-mode";
if (!python_path.empty())
argString = python_path + " " + argString;
int a0, a1;
auto [ret, msg] = cudaqx::launchProcess(argString.c_str());
if (ret == -1)
return nullptr;
Expand All @@ -86,28 +89,26 @@ class RESTPySCFDriver : public MoleculePackageDriver {

cudaq::RestClient client;
using namespace std::chrono_literals;
static constexpr std::size_t timeoutMs = 5000;
static constexpr std::size_t pollIntervalMs = 100;
std::size_t ticker = 0;
std::map<std::string, std::string> headers{
{"Content-Type", "application/json"}};
while (true) {
while (ticker < timeoutMs) {
std::this_thread::sleep_for(100ms);
ticker += pollIntervalMs;

nlohmann::json metadata;
try {
metadata = client.get("localhost:8000/", "status", headers);
if (metadata.count("status"))
break;
if (hasAvailableStatus(metadata))
return std::make_unique<PySCFTearDown>(ret);
} catch (...) {
continue;
// Keep polling until the startup timeout expires.
}

if (ticker > 5000)
return nullptr;

ticker += 100;
}

return std::make_unique<PySCFTearDown>(ret);
return nullptr;
}

/// @brief Create the molecular hamiltonian
Expand Down