-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
254 lines (200 loc) · 8.88 KB
/
main.py
File metadata and controls
254 lines (200 loc) · 8.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
#!/usr/bin/env python
"""
Image Forgery Detection System
Main entry point for the forgery detection application.
Provides CLI interface for preprocessing, training, evaluation, and prediction.
"""
import argparse
import sys
import os
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.utils.config import config
def train_command(args):
"""Train the forgery detection model"""
from src.training.train import ForgeryDetectorTrainer
trainer = ForgeryDetectorTrainer(
model_type=args.model_type,
base_model_name=args.base_model,
mesonet_variant=args.mesonet_variant,
use_ela=not args.no_ela
)
trainer.train(epochs=args.epochs, model_name=args.model_name)
if args.evaluate:
trainer.evaluate_on_test()
def evaluate_command(args):
"""Evaluate a trained model"""
from src.training.evaluate import ModelEvaluator
evaluator = ModelEvaluator(
model_path=args.model_path,
use_ela=not args.no_ela
)
evaluator.evaluate(
save_results=not args.no_save,
output_dir=args.output_dir
)
def predict_command(args):
"""Run predictions on images"""
from src.inference.predict import ForgeryPredictor
import json
predictor = ForgeryPredictor(
model_path=args.model_path,
use_ela=not args.no_ela
)
if args.image:
# Single image
result = predictor.predict(args.image, args.threshold)
print("\n" + "=" * 50)
print("PREDICTION RESULT")
print("=" * 50)
predictor.print_result(result)
print(f"\nRaw score: {result['raw_score']:.4f}")
elif args.directory:
# Directory of images
results, summary = predictor.predict_directory(
args.directory,
threshold=args.threshold,
recursive=args.recursive
)
print("\nResults:")
print("-" * 50)
for result in results:
predictor.print_result(result)
predictor.print_summary(summary)
if args.output:
with open(args.output, 'w') as f:
json.dump({'results': results, 'summary': summary}, f, indent=2)
print(f"\nResults saved: {args.output}")
def info_command(args):
"""Display system information"""
from src.models.cnn_model import get_model_info
from src.models.transfer_learning import list_available_models
from src.models.mesonet import MESONET_VARIANTS
from src.utils.helpers import print_dataset_statistics
print("=" * 60)
print("IMAGE FORGERY / AI DETECTION SYSTEM")
print("=" * 60)
print("\nConfiguration:")
print(f" Image size: {config.IMG_SIZE}")
print(f" Input shape: {config.INPUT_SHAPE}")
print(f" Batch size: {config.BATCH_SIZE}")
print(f" Epochs: {config.EPOCHS}")
print(f" Learning rate: {config.LEARNING_RATE}")
print(f" Class names: {config.CLASS_NAMES}")
print("\nDirectories:")
print(f" Base: {config.BASE_DIR}")
print(f" Raw data: {config.RAW_DATA_DIR}")
print(f" Models: {config.SAVED_MODELS_DIR}")
print("\nAvailable models:")
print(" MesoNet variants:")
for variant in MESONET_VARIANTS:
print(f" - {variant}")
print(" Transfer learning models:")
for model in list_available_models():
print(f" - {model}")
print(" Custom CNN: custom")
if args.dataset_stats:
print_dataset_statistics(config.RAW_DATA_DIR)
def main():
"""Main entry point"""
parser = argparse.ArgumentParser(
description='Image Forgery Detection System',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Commands:
extract Extract CIFAKE dataset from archive.zip
train Train the forgery detection model (MesoNet-4 by default)
evaluate Evaluate a trained model
predict Run predictions on images
info Display system information
Examples:
python main.py extract
python main.py train --model_type mesonet --epochs 50
python main.py train --model_type transfer --base_model VGG16
python main.py evaluate --model_path models/saved_models/model.h5
python main.py predict --model_path model.h5 --image test.jpg
python main.py predict --model_path model.h5 --directory ./images
python main.py info
"""
)
subparsers = parser.add_subparsers(dest='command', help='Available commands')
# Extract command (for CIFAKE dataset)
extract_parser = subparsers.add_parser('extract',
help='Extract CIFAKE dataset from archive.zip')
extract_parser.add_argument('--archive', type=str, default='archive.zip',
help='Path to archive.zip file')
# Train command
train_parser = subparsers.add_parser('train', help='Train the model')
train_parser.add_argument('--model_type', type=str, default='mesonet',
choices=['mesonet', 'custom', 'transfer'],
help='Model type: mesonet (default), custom CNN, or transfer learning')
train_parser.add_argument('--mesonet_variant', type=str, default='meso4',
choices=['meso4', 'mesoinception4'],
help='MesoNet variant')
train_parser.add_argument('--base_model', type=str, default='VGG16',
choices=['VGG16', 'ResNet50', 'EfficientNetB0', 'MobileNetV2'],
help='Base model for transfer learning')
train_parser.add_argument('--epochs', type=int, default=config.EPOCHS,
help='Number of training epochs')
train_parser.add_argument('--model_name', type=str, default='ai_detector',
help='Name for the saved model')
train_parser.add_argument('--no_ela', action='store_true',
help='Use original images instead of ELA (default for CIFAKE)')
train_parser.add_argument('--evaluate', action='store_true',
help='Evaluate on test set after training')
# Evaluate command
eval_parser = subparsers.add_parser('evaluate', help='Evaluate trained model')
eval_parser.add_argument('--model_path', type=str, required=True,
help='Path to trained model')
eval_parser.add_argument('--no_ela', action='store_true',
help='Use original images instead of ELA')
eval_parser.add_argument('--output_dir', type=str, default=None,
help='Directory for saving results')
eval_parser.add_argument('--no_save', action='store_true',
help='Do not save results')
# Predict command
predict_parser = subparsers.add_parser('predict', help='Run predictions')
predict_parser.add_argument('--model_path', type=str, required=True,
help='Path to trained model')
predict_parser.add_argument('--image', type=str, default=None,
help='Path to single image')
predict_parser.add_argument('--directory', type=str, default=None,
help='Path to directory of images')
predict_parser.add_argument('--recursive', action='store_true',
help='Search directory recursively')
predict_parser.add_argument('--threshold', type=float,
default=config.PREDICTION_THRESHOLD,
help='Classification threshold')
predict_parser.add_argument('--no_ela', action='store_true',
help='Disable ELA preprocessing')
predict_parser.add_argument('--output', type=str, default=None,
help='Output JSON file for results')
# Info command
info_parser = subparsers.add_parser('info', help='Display system information')
info_parser.add_argument('--dataset_stats', action='store_true',
help='Show detailed dataset statistics')
args = parser.parse_args()
if args.command is None:
parser.print_help()
sys.exit(0)
# Ensure directories exist
config.ensure_directories()
# Handle extract command separately (runs external script)
if args.command == 'extract':
print("=" * 60)
print("EXTRACTING CIFAKE DATASET")
print("=" * 60)
import subprocess
result = subprocess.run([sys.executable, 'extract_dataset.py'],
cwd=os.path.dirname(os.path.abspath(__file__)))
sys.exit(result.returncode)
# Execute other commands
commands = {
'train': train_command,
'evaluate': evaluate_command,
'predict': predict_command,
'info': info_command
}
commands[args.command](args)
if __name__ == "__main__":
main()