Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions buffalogs/impossible_travel/management/commands/setup_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from django.core.exceptions import ValidationError
from django.core.management.base import CommandError
from django.db.models.fields import Field
from impossible_travel.constants import AlertDetectionType
from impossible_travel.management.commands.base_command import TaskLoggingCommand
from impossible_travel.models import Config

Expand All @@ -31,20 +32,71 @@ def _cast_value(val: str) -> Any:


def parse_field_value(item: str) -> Tuple[str, Any]:
"""Parse a string of the form FIELD=VALUE or FIELD=[val1,val2]"""
"""Parse a string of the form FIELD=VALUE or FIELD=[val1,val2]

Supports multiple formats:
- FIELD=value (single value)
- FIELD=[val1,val2,val3] (list without spaces in values)
- FIELD=['val 1','val 2'] (list with quoted values containing spaces)
- FIELD=["val 1","val 2"] (list with double-quoted values)
- FIELD=[val 1, val 2] (list with spaces, no quotes - legacy support)

IMPORTANT: When brackets [...] are present, ALWAYS returns a list,
even for single elements. This is required for ArrayField validation.
"""
if "=" not in item:
raise CommandError(f"Invalid syntax '{item}': must be FIELD=VALUE")

field, value = item.split("=", 1)
value = value.strip()

if value.startswith("[") and value.endswith("]"):
# This is a list - must ALWAYS return a list type
inner = value[1:-1].strip()
parsed = [_cast_value(v) for v in inner.split(",") if v.strip()]
if not inner:
# Empty list case: []
parsed = []
else:
# Check if the input has any quotes
has_quotes = ("'" in inner) or ('"' in inner)

if has_quotes:
# Manual parsing for quoted values (handles all cases reliably)
parsed_values = []
current = ""
in_quotes = False
quote_char = None

for char in inner:
if char in ('"', "'") and (not in_quotes or char == quote_char):
in_quotes = not in_quotes
quote_char = char if in_quotes else None
current += char
elif char == "," and not in_quotes:
# Found a comma outside quotes - this is a separator
if current.strip():
parsed_values.append(current.strip())
current = ""
else:
current += char

# Don't forget the last value
if current.strip():
parsed_values.append(current.strip())

# Now cast each value (this also strips quotes)
parsed = [_cast_value(v) for v in parsed_values if v.strip()]
else:
# No quotes - split by comma (handles both "a,b,c" and "a, b, c")
parsed = [_cast_value(v.strip()) for v in inner.split(",") if v.strip()]

# CRITICAL: Always return a list when brackets are present
# This ensures ArrayField validators receive the correct type
return field.strip(), parsed
else:
# Single value without brackets - can be a non-list type
parsed = _cast_value(value)

return field.strip(), parsed
return field.strip(), parsed


class Command(TaskLoggingCommand):
Expand Down Expand Up @@ -148,20 +200,22 @@ def handle(self, *args, **options):
if is_list and not isinstance(value, list):
value = [value]

# Validate values
values_to_validate = value if is_list else [value]
for val in values_to_validate:
for validator in getattr(field_obj, "validators", []):
try:
validator(val)
except ValidationError as e:
raise CommandError(f"Validation error on field '{field}' with value '{val}': {e}")
# Validate the value (validators expect the full value, not individual elements)
for validator in getattr(field_obj, "validators", []):
try:
validator(value)
except ValidationError as e:
# Extract detailed error messages from ValidationError
error_details = "; ".join(e.messages) if hasattr(e, "messages") else str(e)
raise CommandError(f"Validation error on field '{field}' with value '{value}': {error_details}")

# Apply changes
if is_list:
current = current or []
if mode == "append":
current += value
# Only append values that don't already exist (make it idempotent)
new_values = [v for v in value if v not in current]
current += new_values
elif mode == "override":
current = value
elif mode == "remove":
Expand All @@ -173,5 +227,13 @@ def handle(self, *args, **options):

setattr(config, field, current)

# Validate filtered_alerts_types before saving
if hasattr(config, "filtered_alerts_types") and config.filtered_alerts_types:
valid_choices = [choice[0] for choice in AlertDetectionType.choices]
invalid_values = [val for val in config.filtered_alerts_types if val not in valid_choices]

if invalid_values:
raise CommandError(f"Invalid values in 'filtered_alerts_types': {invalid_values}. " f"Valid choices are: {valid_choices}")

config.save()
self.stdout.write(self.style.SUCCESS("Config updated successfully."))
76 changes: 76 additions & 0 deletions buffalogs/impossible_travel/tests/task/test_management_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,82 @@ def test_parse_field_value_numeric(self):
self.assertEqual(field_float, "vel_accepted")
self.assertEqual(value_float, 55.7)

def test_parse_field_value_list_with_single_quoted_values_with_spaces(self):
# Testing the parse_field_value function with single-quoted values containing spaces
field, value = parse_field_value("filtered_alerts_types=['New Device','User Risk Threshold','Anonymous IP Login']")
self.assertEqual(field, "filtered_alerts_types")
self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_parse_field_value_list_with_double_quoted_values_with_spaces(self):
# Testing the parse_field_value function with double-quoted values containing spaces
field, value = parse_field_value('filtered_alerts_types=["New Device","User Risk Threshold","Anonymous IP Login"]')
self.assertEqual(field, "filtered_alerts_types")
self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_parse_field_value_list_with_mixed_quotes(self):
# Testing the parse_field_value function with mixed single and double quotes
field, value = parse_field_value("filtered_alerts_types=['New Device',\"User Risk Threshold\",'Anonymous IP Login']")
self.assertEqual(field, "filtered_alerts_types")
self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_parse_field_value_list_with_quoted_and_unquoted_mixed(self):
# Testing the parse_field_value function with a mix of quoted and unquoted values
field, value = parse_field_value("allowed_countries=['United States',Italy,'United Kingdom',France]")
self.assertEqual(field, "allowed_countries")
self.assertListEqual(value, ["United States", "Italy", "United Kingdom", "France"])

def test_parse_field_value_list_with_spaces_around_quoted_values(self):
# Testing the parse_field_value function with spaces around quoted values
field, value = parse_field_value("filtered_alerts_types=[ 'New Device' , 'User Risk Threshold' , 'Anonymous IP Login' ]")
self.assertEqual(field, "filtered_alerts_types")
self.assertListEqual(value, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

def test_setup_config_command_with_quoted_list_values(self):
# Integration test: verify the entire command works with quoted values containing spaces
Config.objects.all().delete()
config = Config.objects.create(id=1, filtered_alerts_types=[])

# Test append mode with single-quoted values
call_command("setup_config", "-a", "filtered_alerts_types=['New Device','User Risk Threshold']")
config.refresh_from_db()
self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold"])

# Test append mode adding more values
call_command("setup_config", "-a", "filtered_alerts_types=['Anonymous IP Login']")
config.refresh_from_db()
self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold", "Anonymous IP Login"])

# Test override mode with double-quoted values
call_command("setup_config", "-o", 'filtered_alerts_types=["Imp Travel","New Country"]')
config.refresh_from_db()
self.assertListEqual(config.filtered_alerts_types, ["Imp Travel", "New Country"])

# Test remove mode with quoted values
call_command("setup_config", "-r", "filtered_alerts_types=['New Country']")
config.refresh_from_db()
self.assertListEqual(config.filtered_alerts_types, ["Imp Travel"])

def test_setup_config_append_idempotent(self):
"""Test that running append multiple times doesn't create duplicates (idempotent behavior)"""
Config.objects.all().delete()
config = Config.objects.create(id=1, filtered_alerts_types=[])

# Run append command first time
call_command("setup_config", "-a", "filtered_alerts_types=['New Device','User Risk Threshold']")
config.refresh_from_db()
self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold"])

# Run the SAME command again - should NOT create duplicates
call_command("setup_config", "-a", "filtered_alerts_types=['New Device','User Risk Threshold']")
config.refresh_from_db()
self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold"])

# Append with mix of existing and new values
call_command("setup_config", "-a", "filtered_alerts_types=['User Risk Threshold','Anonymous IP Login']")
config.refresh_from_db()
# Should only add 'Anonymous IP Login', not duplicate 'User Risk Threshold'
self.assertListEqual(config.filtered_alerts_types, ["New Device", "User Risk Threshold", "Anonymous IP Login"])


class ResetUserRiskScoreCommandTests(TestCase):
def setUp(self):
Expand Down
Loading