diff --git a/src/pieces/_vendor/pieces_os_client/wrapper/basic_identifier/user.py b/src/pieces/_vendor/pieces_os_client/wrapper/basic_identifier/user.py index 310af380..2f8cffef 100644 --- a/src/pieces/_vendor/pieces_os_client/wrapper/basic_identifier/user.py +++ b/src/pieces/_vendor/pieces_os_client/wrapper/basic_identifier/user.py @@ -1,3 +1,4 @@ +import logging from typing import TYPE_CHECKING, Optional from threading import Thread from .basic import Basic @@ -8,6 +9,9 @@ from pieces._vendor.pieces_os_client.wrapper.client import PiecesClient from pieces._vendor.pieces_os_client.models.user_profile import UserProfile + +logger = logging.getLogger(__name__) + class BasicUser(Basic): """ A class to represent a basic user and manage their connection to the cloud. @@ -55,26 +59,54 @@ def _on_login_connect(self): """ self.connect() - def login(self, connect_after_login=True, timeout=120): + def _finalize_login(self, user: Optional["UserProfile"], connect_after_login=True): + self.user_profile = user + if connect_after_login and user: + self._on_login_connect() + return user + + def _complete_login(self, connect_after_login=True): + user = self.pieces_client.os_api.sign_into_os() + return self._finalize_login(user, connect_after_login) + + def login(self, connect_after_login=True, timeout=120, async_req=False) -> Optional["UserProfile"] | Thread: """ Logs the user into the OS and optionally connects to the cloud. Args: connect_after_login: A flag indicating if the user should connect to the cloud after login (default is True). - timeout: The maximum time to wait for the login process (default is 120 seconds). + timeout: The maximum time to wait for the sign-in process (default is 120 seconds). + async_req: Start the login flow in the background without waiting for it to finish. + + Returns: + The logged-in user profile for synchronous calls, or the background thread when async_req is True. """ result = {} + error = {} def target(): - result['user'] = self.pieces_client.os_api.sign_into_os() - - thread = Thread(target=target) + try: + if async_req: + result['user'] = self._complete_login(connect_after_login) + else: + result['user'] = self.pieces_client.os_api.sign_into_os() + except Exception as exc: + error['exception'] = exc + if async_req: + logger.exception("PiecesOS login failed in background") + + thread = Thread(target=target, daemon=True) thread.start() + if async_req: + return thread + thread.join(timeout) + if thread.is_alive(): + raise TimeoutError(f"Login did not complete within {timeout} seconds") + if 'exception' in error: + raise error['exception'] - if connect_after_login: - self.user_profile = result.get('user') - self._on_login_connect() + return self._finalize_login(result.get('user'), connect_after_login) def logout(self): """ @@ -95,11 +127,17 @@ def connect(self, async_req = False): self.user_profile, True ) # Set the connecting to cloud bool to true if async_req: - thread = Thread( - target=self.pieces_client.allocations_api.allocations_connect_new_cloud, - args=(self.user_profile,), - ) + def target(): + try: + self.pieces_client.allocations_api.allocations_connect_new_cloud( + self.user_profile + ) + except Exception: + logger.exception("Pieces Cloud connection failed in background") + + thread = Thread(target=target, daemon=True) thread.start() + return thread else: self.pieces_client.allocations_api.allocations_connect_new_cloud( self.user_profile diff --git a/src/pieces/command_interface/auth_commands.py b/src/pieces/command_interface/auth_commands.py index d8d3ae1e..f485fd4b 100644 --- a/src/pieces/command_interface/auth_commands.py +++ b/src/pieces/command_interface/auth_commands.py @@ -49,10 +49,19 @@ def execute(self, **kwargs) -> int: ) if status == AllocationStatusEnum.DISCONNECTED: Settings.logger.print("Connecting to the Pieces Cloud...") - Settings.pieces_client.user.connect() + if Settings.run_in_loop: + Settings.pieces_client.user.connect(async_req=True) + else: + Settings.pieces_client.user.connect() return 0 try: - Settings.pieces_client.user.login(True) + if Settings.run_in_loop: + Settings.pieces_client.user.login(True, async_req=True) + Settings.logger.print( + "Sign-in opened in your browser. You can keep using `pieces run` while it completes." + ) + else: + Settings.pieces_client.user.login(True) except Exception as e: Settings.logger.error(f"Sign in failed: {e}") return 0 diff --git a/tests/auth_commands_test.py b/tests/auth_commands_test.py index 2d150806..1482dc9b 100644 --- a/tests/auth_commands_test.py +++ b/tests/auth_commands_test.py @@ -1,4 +1,5 @@ import pytest +import threading from unittest.mock import Mock, patch from pieces.command_interface.auth_commands import LoginCommand, LogoutCommand from pieces.settings import Settings @@ -6,6 +7,7 @@ AllocationStatusEnum, ) from pieces._vendor.pieces_os_client.models.user_profile import UserProfile +from pieces._vendor.pieces_os_client.wrapper.basic_identifier.user import BasicUser class TestLoginCommand: @@ -107,6 +109,48 @@ def test_execute_not_logged_in( assert result == 0 mock_user.login.assert_called_once_with(True) + @patch.object(Settings, "run_in_loop", True) + @patch.object(Settings, "logger") + @patch.object(Settings, "pieces_client") + def test_execute_not_logged_in_in_run_loop_uses_async_login( + self, mock_pieces_client, mock_logger, login_command + ): + """Test login in run mode uses the non-blocking login path.""" + mock_user = Mock() + mock_user.user_profile = None + mock_user.login = Mock() + + mock_pieces_client.user = mock_user + mock_pieces_client.user_api.user_snapshot.return_value.user = None + + result = login_command.execute() + + assert result == 0 + mock_user.login.assert_called_once_with(True, async_req=True) + mock_logger.print.assert_called_once() + assert "browser" in mock_logger.print.call_args[0][0].lower() + + @patch.object(Settings, "run_in_loop", True) + @patch.object(Settings, "logger") + @patch.object(Settings, "pieces_client") + def test_execute_logged_in_but_disconnected_in_run_loop_uses_async_connect( + self, mock_pieces_client, mock_logger, login_command, mock_user_profile + ): + mock_user = Mock() + mock_user.user_profile = mock_user_profile + mock_user.name = "Test User" + mock_user.email = "test@example.com" + mock_user.cloud_status = AllocationStatusEnum.DISCONNECTED + + mock_pieces_client.user = mock_user + mock_pieces_client.user_api.user_snapshot.return_value.user = mock_user_profile + + result = login_command.execute() + + assert result == 0 + mock_user.connect.assert_called_once_with(async_req=True) + assert mock_logger.print.call_count == 2 + @patch.object(Settings, "logger") @patch.object(Settings, "pieces_client") def test_execute_login_exception( @@ -317,6 +361,61 @@ def test_execute_multiple_logout_calls( mock_logger.error.assert_not_called() +class TestBasicUserLogin: + def test_login_async_returns_background_thread_without_waiting(self): + pieces_client = Mock() + started = threading.Event() + release = threading.Event() + + def delayed_login(): + started.set() + release.wait(timeout=5) + return "user-profile" + + pieces_client.os_api.sign_into_os.side_effect = delayed_login + + user = BasicUser(pieces_client) + + thread = user.login(connect_after_login=False, async_req=True) + + assert started.wait(timeout=1) + assert thread.is_alive() + + release.set() + thread.join(timeout=1) + assert not thread.is_alive() + + def test_login_sync_propagates_connect_failures(self): + pieces_client = Mock() + user_profile = Mock(spec=UserProfile) + pieces_client.os_api.sign_into_os.return_value = user_profile + pieces_client.allocations_api.allocations_connect_new_cloud.side_effect = RuntimeError( + "Cloud connection failed" + ) + + user = BasicUser(pieces_client) + + with pytest.raises(RuntimeError, match="Cloud connection failed"): + user.login() + + def test_login_sync_raises_timeout_when_sign_in_does_not_finish(self): + pieces_client = Mock() + release = threading.Event() + + def delayed_login(): + release.wait(timeout=5) + return "user-profile" + + pieces_client.os_api.sign_into_os.side_effect = delayed_login + + user = BasicUser(pieces_client) + + with pytest.raises(TimeoutError, match="Login did not complete within 0.01 seconds"): + user.login(connect_after_login=False, timeout=0.01) + + release.set() + + class TestLoginLogoutIntegration: """Integration tests for login and logout commands working together."""