diff --git a/buffalogs/impossible_travel/management/commands/setup_config.py b/buffalogs/impossible_travel/management/commands/setup_config.py index 3ef2b6f3..8edbd8b5 100644 --- a/buffalogs/impossible_travel/management/commands/setup_config.py +++ b/buffalogs/impossible_travel/management/commands/setup_config.py @@ -6,6 +6,7 @@ from django.core.exceptions import ValidationError from django.core.management.base import CommandError from django.db.models.fields import Field +from impossible_travel.constants import AlertDetectionType from impossible_travel.management.commands.base_command import TaskLoggingCommand from impossible_travel.models import Config @@ -31,7 +32,18 @@ def _cast_value(val: str) -> Any: def parse_field_value(item: str) -> Tuple[str, Any]: - """Parse a string of the form FIELD=VALUE or FIELD=[val1,val2]""" + """Parse a string of the form FIELD=VALUE or FIELD=[val1,val2] + + Supports multiple formats: + - FIELD=value (single value) + - FIELD=[val1,val2,val3] (list without spaces in values) + - FIELD=['val 1','val 2'] (list with quoted values containing spaces) + - FIELD=["val 1","val 2"] (list with double-quoted values) + - FIELD=[val 1, val 2] (list with spaces, no quotes - legacy support) + + IMPORTANT: When brackets [...] are present, ALWAYS returns a list, + even for single elements. This is required for ArrayField validation. + """ if "=" not in item: raise CommandError(f"Invalid syntax '{item}': must be FIELD=VALUE") @@ -39,12 +51,52 @@ def parse_field_value(item: str) -> Tuple[str, Any]: value = value.strip() if value.startswith("[") and value.endswith("]"): + # This is a list - must ALWAYS return a list type inner = value[1:-1].strip() - parsed = [_cast_value(v) for v in inner.split(",") if v.strip()] + if not inner: + # Empty list case: [] + parsed = [] + else: + # Check if the input has any quotes + has_quotes = ("'" in inner) or ('"' in inner) + + if has_quotes: + # Manual parsing for quoted values (handles all cases reliably) + parsed_values = [] + current = "" + in_quotes = False + quote_char = None + + for char in inner: + if char in ('"', "'") and (not in_quotes or char == quote_char): + in_quotes = not in_quotes + quote_char = char if in_quotes else None + current += char + elif char == "," and not in_quotes: + # Found a comma outside quotes - this is a separator + if current.strip(): + parsed_values.append(current.strip()) + current = "" + else: + current += char + + # Don't forget the last value + if current.strip(): + parsed_values.append(current.strip()) + + # Now cast each value (this also strips quotes) + parsed = [_cast_value(v) for v in parsed_values if v.strip()] + else: + # No quotes - split by comma (handles both "a,b,c" and "a, b, c") + parsed = [_cast_value(v.strip()) for v in inner.split(",") if v.strip()] + + # CRITICAL: Always return a list when brackets are present + # This ensures ArrayField validators receive the correct type + return field.strip(), parsed else: + # Single value without brackets - can be a non-list type parsed = _cast_value(value) - - return field.strip(), parsed + return field.strip(), parsed class Command(TaskLoggingCommand): @@ -148,20 +200,22 @@ def handle(self, *args, **options): if is_list and not isinstance(value, list): value = [value] - # Validate values - values_to_validate = value if is_list else [value] - for val in values_to_validate: - for validator in getattr(field_obj, "validators", []): - try: - validator(val) - except ValidationError as e: - raise CommandError(f"Validation error on field '{field}' with value '{val}': {e}") + # Validate the value (validators expect the full value, not individual elements) + for validator in getattr(field_obj, "validators", []): + try: + validator(value) + except ValidationError as e: + # Extract detailed error messages from ValidationError + error_details = "; ".join(e.messages) if hasattr(e, "messages") else str(e) + raise CommandError(f"Validation error on field '{field}' with value '{value}': {error_details}") # Apply changes if is_list: current = current or [] if mode == "append": - current += value + # Only append values that don't already exist (make it idempotent) + new_values = [v for v in value if v not in current] + current += new_values elif mode == "override": current = value elif mode == "remove": @@ -173,5 +227,13 @@ def handle(self, *args, **options): setattr(config, field, current) + # Validate filtered_alerts_types before saving + if hasattr(config, "filtered_alerts_types") and config.filtered_alerts_types: + valid_choices = [choice[0] for choice in AlertDetectionType.choices] + invalid_values = [val for val in config.filtered_alerts_types if val not in valid_choices] + + if invalid_values: + raise CommandError(f"Invalid values in 'filtered_alerts_types': {invalid_values}. " f"Valid choices are: {valid_choices}") + config.save() self.stdout.write(self.style.SUCCESS("Config updated successfully.")) diff --git a/buffalogs/impossible_travel/tests/task/test_management_commands.py b/buffalogs/impossible_travel/tests/task/test_management_commands.py index 1fc23832..c74c7111 100644 --- a/buffalogs/impossible_travel/tests/task/test_management_commands.py +++ b/buffalogs/impossible_travel/tests/task/test_management_commands.py @@ -260,6 +260,82 @@ def test_parse_field_value_numeric(self): self.assertEqual(field_float, "vel_accepted") self.assertEqual(value_float, 55.7) + def test_parse_field_value_list_with_single_quoted_values_with_spaces(self): + # Testing the parse_field_value function with single-quoted values containing spaces + field, value = parse_field_value("filtered_alerts_types=['New Device','User Risk Threshold','Anonymous IP Login']") + self.assertEqual(field, "filtered_alerts_types") + self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"]) + + def test_parse_field_value_list_with_double_quoted_values_with_spaces(self): + # Testing the parse_field_value function with double-quoted values containing spaces + field, value = parse_field_value('filtered_alerts_types=["New Device","User Risk Threshold","Anonymous IP Login"]') + self.assertEqual(field, "filtered_alerts_types") + self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"]) + + def test_parse_field_value_list_with_mixed_quotes(self): + # Testing the parse_field_value function with mixed single and double quotes + field, value = parse_field_value("filtered_alerts_types=['New Device',\"User Risk Threshold\",'Anonymous IP Login']") + self.assertEqual(field, "filtered_alerts_types") + self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"]) + + def test_parse_field_value_list_with_quoted_and_unquoted_mixed(self): + # Testing the parse_field_value function with a mix of quoted and unquoted values + field, value = parse_field_value("allowed_countries=['United States',Italy,'United Kingdom',France]") + self.assertEqual(field, "allowed_countries") + self.assertListEqual(value, ["United States", "Italy", "United Kingdom", "France"]) + + def test_parse_field_value_list_with_spaces_around_quoted_values(self): + # Testing the parse_field_value function with spaces around quoted values + field, value = parse_field_value("filtered_alerts_types=[ 'New Device' , 'User Risk Threshold' , 'Anonymous IP Login' ]") + self.assertEqual(field, "filtered_alerts_types") + self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"]) + + def test_setup_config_command_with_quoted_list_values(self): + # Integration test: verify the entire command works with quoted values containing spaces + Config.objects.all().delete() + config = Config.objects.create(id=1, filtered_alerts_types=[]) + + # Test append mode with single-quoted values + call_command("setup_config", "-a", "filtered_alerts_types=['New Device','User Risk Threshold']") + config.refresh_from_db() + self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold"]) + + # Test append mode adding more values + call_command("setup_config", "-a", "filtered_alerts_types=['Anonymous IP Login']") + config.refresh_from_db() + self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold", "Anonymous IP Login"]) + + # Test override mode with double-quoted values + call_command("setup_config", "-o", 'filtered_alerts_types=["Imp Travel","New Country"]') + config.refresh_from_db() + self.assertListEqual(config.filtered_alerts_types, ["Imp Travel", "New Country"]) + + # Test remove mode with quoted values + call_command("setup_config", "-r", "filtered_alerts_types=['New Country']") + config.refresh_from_db() + self.assertListEqual(config.filtered_alerts_types, ["Imp Travel"]) + + def test_setup_config_append_idempotent(self): + """Test that running append multiple times doesn't create duplicates (idempotent behavior)""" + Config.objects.all().delete() + config = Config.objects.create(id=1, filtered_alerts_types=[]) + + # Run append command first time + call_command("setup_config", "-a", "filtered_alerts_types=['New Device','User Risk Threshold']") + config.refresh_from_db() + self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold"]) + + # Run the SAME command again - should NOT create duplicates + call_command("setup_config", "-a", "filtered_alerts_types=['New Device','User Risk Threshold']") + config.refresh_from_db() + self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold"]) + + # Append with mix of existing and new values + call_command("setup_config", "-a", "filtered_alerts_types=['User Risk Threshold','Anonymous IP Login']") + config.refresh_from_db() + # Should only add 'Anonymous IP Login', not duplicate 'User Risk Threshold' + self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold", "Anonymous IP Login"]) + class ResetUserRiskScoreCommandTests(TestCase): def setUp(self):