diff --git a/relink.py b/relink.py index 4778ad7..6ff19df 100644 --- a/relink.py +++ b/relink.py @@ -52,21 +52,20 @@ def find_owned_files_scandir(directory, user_uid): with os.scandir(directory) as entries: for entry in entries: try: - # Check if it's a file (not following symlinks) - if entry.is_file(follow_symlinks=False): - # Get stat info (cached by scandir, very efficient) - stat_info = entry.stat(follow_symlinks=False) - - if stat_info.st_uid == user_uid: - yield entry.path - # Recursively process directories (not following symlinks) - elif entry.is_dir(follow_symlinks=False): + if entry.is_dir(follow_symlinks=False): yield from find_owned_files_scandir(entry.path, user_uid) - # Skip symlinks - elif entry.is_symlink(): - logger.info("Skipping symlink: %s", entry.path) + # Is this owned by the user? + elif entry.stat(follow_symlinks=False).st_uid == user_uid: + + # Return if it's a file (not following symlinks) + if entry.is_file(follow_symlinks=False): + yield entry.path + + # Skip symlinks + elif entry.is_symlink(): + logger.debug("Skipping symlink: %s", entry.path) except (OSError, PermissionError) as e: logger.debug("Error accessing %s: %s. Skipping.", entry.path, e) diff --git a/tests/relink/test_find_owned_files_scandir.py b/tests/relink/test_find_owned_files_scandir.py index d0718a3..f26fe68 100644 --- a/tests/relink/test_find_owned_files_scandir.py +++ b/tests/relink/test_find_owned_files_scandir.py @@ -6,6 +6,8 @@ import sys import tempfile import logging +from unittest.mock import patch +from contextlib import contextmanager # Add parent directory to path to import relink module sys.path.insert( @@ -15,6 +17,77 @@ import relink # noqa: E402 +class MockDirEntry: + """Wrapper for DirEntry that allows mocking stat() for specific files.""" + + # pylint: disable=missing-function-docstring + + def __init__(self, entry, uid_override=None): + """ + Initialize MockDirEntry. + + Args: + entry: The original DirEntry object. + uid_override: Dict mapping filename to UID to override in stat results. + """ + self._entry = entry + self._uid_override = uid_override or {} + + def __getattr__(self, name): + return getattr(self._entry, name) + + def stat(self, *args, **kwargs): + stat_result = self._entry.stat(*args, **kwargs) + if self._entry.name in self._uid_override: + # Create a modified stat result with different UID + modified_stat = os.stat_result( + ( + stat_result.st_mode, + stat_result.st_ino, + stat_result.st_dev, + stat_result.st_nlink, + self._uid_override[self._entry.name], # Override UID + stat_result.st_gid, + stat_result.st_size, + stat_result.st_atime, + stat_result.st_mtime, + stat_result.st_ctime, + ) + ) + return modified_stat + return stat_result + + def is_file(self, *args, **kwargs): + return self._entry.is_file(*args, **kwargs) + + def is_dir(self, *args, **kwargs): + return self._entry.is_dir(*args, **kwargs) + + def is_symlink(self): + return self._entry.is_symlink() + + +def create_mock_scandir(uid_override=None): + """ + Create a mock scandir function that wraps entries with MockDirEntry. + + Args: + uid_override: Dict mapping filename to UID to override in stat results. + + Returns: + A context manager function that can be used with patch. + """ + original_scandir = os.scandir + + @contextmanager + def mock_scandir(path): + with original_scandir(path) as entries: + wrapped_entries = [MockDirEntry(entry, uid_override) for entry in entries] + yield iter(wrapped_entries) + + return mock_scandir + + def test_find_owned_files_basic(temp_dirs): """Test basic functionality: find files owned by user.""" source_dir, _ = temp_dirs @@ -82,7 +155,7 @@ def test_skip_symlinks(temp_dirs, caplog): os.symlink(dummy_target, symlink_path) # Find owned files with logging - with caplog.at_level(logging.INFO): + with caplog.at_level(logging.DEBUG): found_files = list(relink.find_owned_files_scandir(source_dir, user_uid)) # Verify only regular file was found @@ -95,6 +168,48 @@ def test_skip_symlinks(temp_dirs, caplog): assert symlink_path in caplog.text +def test_skip_symlinks_owned_by_different_user(temp_dirs, caplog): + """Test that symlinks owned by different users are not logged. + + Since find_owned_files_scandir filters by UID first, symlinks owned + by other users should never reach the symlink check and thus should + not generate a "Skipping symlink" log message. + """ + source_dir, _ = temp_dirs + user_uid = os.stat(source_dir).st_uid + + # Use a different UID + different_uid = user_uid + 1000 + + # Create a regular file owned by current user + regular_file = os.path.join(source_dir, "regular.txt") + with open(regular_file, "w", encoding="utf-8") as f: + f.write("content") + + # Create a symlink + symlink_path = os.path.join(source_dir, "other_user_link.txt") + dummy_target = os.path.join(tempfile.gettempdir(), "somewhere") + os.symlink(dummy_target, symlink_path) + + # Mock DirEntry.stat to return different UID for the symlink + uid_override = {"other_user_link.txt": different_uid} + mock_scandir = create_mock_scandir(uid_override) + + with patch("os.scandir", side_effect=mock_scandir): + with caplog.at_level(logging.INFO): + found_files = list(relink.find_owned_files_scandir(source_dir, user_uid)) + + # Verify only regular file was found + assert len(found_files) == 1 + assert regular_file in found_files + assert symlink_path not in found_files + + # Check that "Skipping symlink" message was NOT logged for the other user's symlink + # (it should be filtered out by UID check before reaching symlink check) + if "Skipping symlink:" in caplog.text: + assert symlink_path not in caplog.text + + def test_empty_directory(temp_dirs): """Test with empty directory.""" source_dir, _ = temp_dirs diff --git a/tests/relink/test_replace_files_with_symlinks.py b/tests/relink/test_replace_files_with_symlinks.py index 3d670fe..5123c99 100644 --- a/tests/relink/test_replace_files_with_symlinks.py +++ b/tests/relink/test_replace_files_with_symlinks.py @@ -88,7 +88,7 @@ def test_skip_existing_symlinks(temp_dirs, current_user, caplog): stat_before = os.lstat(source_link) # Run the function - with caplog.at_level(logging.INFO): + with caplog.at_level(logging.DEBUG): relink.replace_files_with_symlinks(source_dir, target_dir, username) # Verify the symlink is unchanged (same inode means it wasn't deleted/recreated)