Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import TYPE_CHECKING, Optional
from threading import Thread
from .basic import Basic
Expand All @@ -8,6 +9,9 @@
from pieces._vendor.pieces_os_client.wrapper.client import PiecesClient
from pieces._vendor.pieces_os_client.models.user_profile import UserProfile


logger = logging.getLogger(__name__)

class BasicUser(Basic):
"""
A class to represent a basic user and manage their connection to the cloud.
Expand Down Expand Up @@ -55,26 +59,54 @@ def _on_login_connect(self):
"""
self.connect()

def login(self, connect_after_login=True, timeout=120):
def _finalize_login(self, user: Optional["UserProfile"], connect_after_login=True):
self.user_profile = user
if connect_after_login and user:
self._on_login_connect()
return user

def _complete_login(self, connect_after_login=True):
user = self.pieces_client.os_api.sign_into_os()
return self._finalize_login(user, connect_after_login)

def login(self, connect_after_login=True, timeout=120, async_req=False) -> Optional["UserProfile"] | Thread:
"""
Logs the user into the OS and optionally connects to the cloud.

Args:
connect_after_login: A flag indicating if the user should connect to the cloud after login (default is True).
timeout: The maximum time to wait for the login process (default is 120 seconds).
timeout: The maximum time to wait for the sign-in process (default is 120 seconds).
async_req: Start the login flow in the background without waiting for it to finish.

Returns:
The logged-in user profile for synchronous calls, or the background thread when async_req is True.
"""
result = {}
error = {}

def target():
result['user'] = self.pieces_client.os_api.sign_into_os()

thread = Thread(target=target)
try:
if async_req:
result['user'] = self._complete_login(connect_after_login)
else:
result['user'] = self.pieces_client.os_api.sign_into_os()
except Exception as exc:
error['exception'] = exc
if async_req:
logger.exception("PiecesOS login failed in background")

thread = Thread(target=target, daemon=True)
thread.start()
if async_req:
return thread

thread.join(timeout)
if thread.is_alive():
raise TimeoutError(f"Login did not complete within {timeout} seconds")
if 'exception' in error:
raise error['exception']

if connect_after_login:
self.user_profile = result.get('user')
self._on_login_connect()
return self._finalize_login(result.get('user'), connect_after_login)

def logout(self):
"""
Expand All @@ -95,11 +127,17 @@ def connect(self, async_req = False):
self.user_profile, True
) # Set the connecting to cloud bool to true
if async_req:
thread = Thread(
target=self.pieces_client.allocations_api.allocations_connect_new_cloud,
args=(self.user_profile,),
)
def target():
try:
self.pieces_client.allocations_api.allocations_connect_new_cloud(
self.user_profile
)
except Exception:
logger.exception("Pieces Cloud connection failed in background")

thread = Thread(target=target, daemon=True)
thread.start()
return thread
else:
self.pieces_client.allocations_api.allocations_connect_new_cloud(
self.user_profile
Expand Down
13 changes: 11 additions & 2 deletions src/pieces/command_interface/auth_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,19 @@ def execute(self, **kwargs) -> int:
)
if status == AllocationStatusEnum.DISCONNECTED:
Settings.logger.print("Connecting to the Pieces Cloud...")
Settings.pieces_client.user.connect()
if Settings.run_in_loop:
Settings.pieces_client.user.connect(async_req=True)
else:
Settings.pieces_client.user.connect()
return 0
try:
Settings.pieces_client.user.login(True)
if Settings.run_in_loop:
Settings.pieces_client.user.login(True, async_req=True)
Settings.logger.print(
"Sign-in opened in your browser. You can keep using `pieces run` while it completes."
)
else:
Settings.pieces_client.user.login(True)
except Exception as e:
Settings.logger.error(f"Sign in failed: {e}")
return 0
Expand Down
99 changes: 99 additions & 0 deletions tests/auth_commands_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import pytest
import threading
from unittest.mock import Mock, patch
from pieces.command_interface.auth_commands import LoginCommand, LogoutCommand
from pieces.settings import Settings
from pieces._vendor.pieces_os_client.models.allocation_status_enum import (
AllocationStatusEnum,
)
from pieces._vendor.pieces_os_client.models.user_profile import UserProfile
from pieces._vendor.pieces_os_client.wrapper.basic_identifier.user import BasicUser


class TestLoginCommand:
Expand Down Expand Up @@ -107,6 +109,48 @@ def test_execute_not_logged_in(
assert result == 0
mock_user.login.assert_called_once_with(True)

@patch.object(Settings, "run_in_loop", True)
@patch.object(Settings, "logger")
@patch.object(Settings, "pieces_client")
def test_execute_not_logged_in_in_run_loop_uses_async_login(
self, mock_pieces_client, mock_logger, login_command
):
"""Test login in run mode uses the non-blocking login path."""
mock_user = Mock()
mock_user.user_profile = None
mock_user.login = Mock()

mock_pieces_client.user = mock_user
mock_pieces_client.user_api.user_snapshot.return_value.user = None

result = login_command.execute()

assert result == 0
mock_user.login.assert_called_once_with(True, async_req=True)
mock_logger.print.assert_called_once()
assert "browser" in mock_logger.print.call_args[0][0].lower()

@patch.object(Settings, "run_in_loop", True)
@patch.object(Settings, "logger")
@patch.object(Settings, "pieces_client")
def test_execute_logged_in_but_disconnected_in_run_loop_uses_async_connect(
self, mock_pieces_client, mock_logger, login_command, mock_user_profile
):
mock_user = Mock()
mock_user.user_profile = mock_user_profile
mock_user.name = "Test User"
mock_user.email = "test@example.com"
mock_user.cloud_status = AllocationStatusEnum.DISCONNECTED

mock_pieces_client.user = mock_user
mock_pieces_client.user_api.user_snapshot.return_value.user = mock_user_profile

result = login_command.execute()

assert result == 0
mock_user.connect.assert_called_once_with(async_req=True)
assert mock_logger.print.call_count == 2

@patch.object(Settings, "logger")
@patch.object(Settings, "pieces_client")
def test_execute_login_exception(
Expand Down Expand Up @@ -317,6 +361,61 @@ def test_execute_multiple_logout_calls(
mock_logger.error.assert_not_called()


class TestBasicUserLogin:
def test_login_async_returns_background_thread_without_waiting(self):
pieces_client = Mock()
started = threading.Event()
release = threading.Event()

def delayed_login():
started.set()
release.wait(timeout=5)
return "user-profile"

pieces_client.os_api.sign_into_os.side_effect = delayed_login

user = BasicUser(pieces_client)

thread = user.login(connect_after_login=False, async_req=True)

assert started.wait(timeout=1)
assert thread.is_alive()

release.set()
thread.join(timeout=1)
assert not thread.is_alive()

def test_login_sync_propagates_connect_failures(self):
pieces_client = Mock()
user_profile = Mock(spec=UserProfile)
pieces_client.os_api.sign_into_os.return_value = user_profile
pieces_client.allocations_api.allocations_connect_new_cloud.side_effect = RuntimeError(
"Cloud connection failed"
)

user = BasicUser(pieces_client)

with pytest.raises(RuntimeError, match="Cloud connection failed"):
user.login()

def test_login_sync_raises_timeout_when_sign_in_does_not_finish(self):
pieces_client = Mock()
release = threading.Event()

def delayed_login():
release.wait(timeout=5)
return "user-profile"

pieces_client.os_api.sign_into_os.side_effect = delayed_login

user = BasicUser(pieces_client)

with pytest.raises(TimeoutError, match="Login did not complete within 0.01 seconds"):
user.login(connect_after_login=False, timeout=0.01)

release.set()


class TestLoginLogoutIntegration:
"""Integration tests for login and logout commands working together."""

Expand Down