-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_feature_utils.py
More file actions
164 lines (126 loc) · 6.29 KB
/
test_feature_utils.py
File metadata and controls
164 lines (126 loc) · 6.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Unit tests for ML feature utility functions.
"""
import sys
from pathlib import Path
import pytest
# Add project root to path to allow importing src modules
current_dir = Path(__file__).parent
project_root = current_dir.parent
sys.path.insert(0, str(project_root))
from src.ml.feature_utils import validate_feature_schema
# --- Tests for validate_feature_schema ---
def test_validate_feature_schema_match():
"""Test validation passes when schemas match."""
input_features = {"feat_a": 1.0, "feat_b": 2.0}
expected_features = ["feat_a", "feat_b"]
try:
validate_feature_schema(input_features, expected_features)
# No exception should be raised
except ValueError:
pytest.fail("validate_feature_schema raised ValueError unexpectedly.")
def test_validate_feature_schema_missing_input():
"""Test validation fails when input is missing expected features."""
input_features = {"feat_a": 1.0}
expected_features = ["feat_a", "feat_b"]
with pytest.raises(ValueError, match=r"Missing features in input:.*'feat_b'"):
validate_feature_schema(input_features, expected_features)
def test_validate_feature_schema_extra_input():
"""Test validation fails when input has extra features."""
input_features = {"feat_a": 1.0, "feat_b": 2.0, "feat_c": 3.0}
expected_features = ["feat_a", "feat_b"]
with pytest.raises(ValueError, match=r"Extra features in input:.*'feat_c'"):
validate_feature_schema(input_features, expected_features)
def test_validate_feature_schema_both_missing_and_extra():
"""Test validation fails with both missing and extra features."""
input_features = {"feat_a": 1.0, "feat_c": 3.0}
expected_features = ["feat_a", "feat_b"]
# Check that the error message contains both parts
with pytest.raises(ValueError) as excinfo:
validate_feature_schema(input_features, expected_features)
assert "Missing features in input:" in str(excinfo.value)
assert "'feat_b'" in str(excinfo.value)
assert "Extra features in input:" in str(excinfo.value)
assert "'feat_c'" in str(excinfo.value)
def test_validate_feature_schema_empty_input():
"""Test validation fails when input is empty but features are expected."""
input_features = {}
expected_features = ["feat_a", "feat_b"]
with pytest.raises(ValueError, match="Missing features in input:"):
validate_feature_schema(input_features, expected_features)
def test_validate_feature_schema_empty_expected():
"""Test validation fails when features are provided but none are expected."""
input_features = {"feat_a": 1.0}
expected_features = []
with pytest.raises(ValueError, match="Extra features in input:"):
validate_feature_schema(input_features, expected_features)
def test_validate_feature_schema_both_empty():
"""Test validation passes when both input and expected are empty."""
input_features = {}
expected_features = []
try:
validate_feature_schema(input_features, expected_features)
# No exception should be raised
except ValueError:
pytest.fail(
"validate_feature_schema raised ValueError unexpectedly for empty schemas."
)
# --- Add tests for extract_sentiment_features if needed ---
# (Example tests are already in the function's __main__ block)
# --- Integration Test for Schema Consistency ---
# Note: This test relies on MLflow logging locally and might be slow.
# It also assumes the .env file is configured correctly for SentimentAnalyzer loading.
@pytest.mark.integration # Mark as integration test
def test_real_world_schema_consistency():
"""Integration test ensuring ModelTrainer and SentimentAnalyzer use same features."""
# Need to import logger and os for this test
import logging
import os
logger = logging.getLogger(__name__) # Use local logger for test output
logger.info("Running schema consistency integration test...")
# 1. Train dummy model and get logged features/run_id
# Import dynamically to avoid circular dependencies if utils are imported elsewhere
from src.ml.model_trainer import _run_dummy_training_and_get_features
best_run_id, expected_features = _run_dummy_training_and_get_features(
asset_symbol="BTC_schema_test"
)
assert best_run_id is not None, "Dummy training failed to produce a best run ID."
assert (
expected_features is not None
), "Dummy training failed to log feature columns."
# 2. Configure SentimentAnalyzer to load from this specific run
# Temporarily set env var for this test (or pass URI directly if __init__ allows)
# Note: Modifying os.environ might affect other tests if run in parallel.
# A better approach might be dependency injection or direct URI passing.
original_uri = os.getenv("MLFLOW_SENTIMENT_MODEL_URI")
test_model_uri = f"runs:/{best_run_id}/model"
os.environ["MLFLOW_SENTIMENT_MODEL_URI"] = test_model_uri
logger.info(f"Set MLFLOW_SENTIMENT_MODEL_URI to: {test_model_uri}")
# 3. Load features via SentimentAnalyzer
from src.ai.sentiment_analyzer import SentimentAnalyzer
try:
# Need to re-import or ensure the env var change is picked up if SentimentAnalyzer was imported earlier
import importlib
import src.ai.sentiment_analyzer
importlib.reload(src.ai.sentiment_analyzer)
from src.ai.sentiment_analyzer import SentimentAnalyzer
analyzer = SentimentAnalyzer() # Should now load using the env var
loaded_features = analyzer.feature_columns
# 4. Assert match
assert (
loaded_features is not None
), "SentimentAnalyzer failed to load feature columns."
assert set(loaded_features) == set(
expected_features
), f"Feature mismatch! Trainer logged: {expected_features}, Analyzer loaded: {loaded_features}"
logger.info("Schema consistency test passed!")
finally:
# Restore original env var
if original_uri is None:
# If it wasn't set before, remove it
if "MLFLOW_SENTIMENT_MODEL_URI" in os.environ:
del os.environ["MLFLOW_SENTIMENT_MODEL_URI"]
else:
# Otherwise, restore the original value
os.environ["MLFLOW_SENTIMENT_MODEL_URI"] = original_uri
logger.info("Restored original MLFLOW_SENTIMENT_MODEL_URI.")