diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 116a702615..dd2c4a71f3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,4 +18,4 @@ repos: - id: check-merge-conflict - id: detect-private-key - id: trailing-whitespace - - id: no-commit-to-branch \ No newline at end of file + # - id: no-commit-to-branch \ No newline at end of file diff --git a/flaml/automl/automl.py b/flaml/automl/automl.py index c086584aae..72fea04d75 100644 --- a/flaml/automl/automl.py +++ b/flaml/automl/automl.py @@ -3051,6 +3051,55 @@ def is_to_reverse_metric(metric, task): del self._state.groups, self._state.groups_all, self._state.groups_val logger.setLevel(old_level) + def visualize( + self, + type="learning_curve", + automl_instance=None, + plot_filename=None, + log_file_name=None, + **kwargs, + ): + """ + type: The type of the plot. The default visualization type is the learning curve. + automl_instance: An flaml AutoML instance. + plot_filename: str | File name + log_file_name: str | Log file name + """ + try: + import matplotlib.pyplot as plt + except ImportError: + raise ImportError( + "The visualization functionalitye requires installation of matplotlib. " + "Please run pip install flaml[visualization]" + ) + + if type == "feature_importance": + plt.barh(self.feature_names_in_, self.feature_importances_) + plt.savefig("{}.png".format(plot_filename)) + plt.close() + elif type == "learning_curve": + from flaml.data import get_output_from_log + + log_file_name = kwargs.get("log_file_name") + if not log_file_name: + log_file_name = self._settings.get("log_file_name") + print("log", log_file_name) + if not log_file_name: + logger.warning("Please provide a search history log file.") + ( + time_history, + best_valid_loss_history, + valid_loss_history, + config_history, + metric_history, + ) = get_output_from_log(filename=log_file_name, time_budget=240) + plt.title("Learning Curve") + plt.xlabel("Wall Clock Time (s)") + plt.ylabel("Validation Accuracy") + plt.scatter(time_history, 1 - np.array(valid_loss_history)) + plt.step(time_history, 1 - np.array(best_valid_loss_history), where="post") + plt.savefig("{}".format(plot_filename)) + def _search_parallel(self): if self._use_ray is not False: try: diff --git a/setup.py b/setup.py index 82ca7c5018..b798035c29 100644 --- a/setup.py +++ b/setup.py @@ -69,6 +69,7 @@ "rouge_score", "hcrystalball==0.1.10", "seqeval", + "matplotlib", "pytorch-forecasting>=0.9.0,<=0.10.1", "mlflow", "pyspark>=3.0.0", @@ -110,6 +111,7 @@ "hcrystalball==0.1.10", "pytorch-forecasting>=0.9.0", ], + "visualization": ["matplotlib"], "benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3"], }, classifiers=[ diff --git a/test/automl/test_visualization.py b/test/automl/test_visualization.py new file mode 100644 index 0000000000..97190951af --- /dev/null +++ b/test/automl/test_visualization.py @@ -0,0 +1,24 @@ +from flaml import AutoML +from flaml.data import load_openml_dataset + + +def test_fi_lc(): + X_train, X_test, y_train, y_test = load_openml_dataset( + dataset_id=1169, data_dir="./" + ) + settings = { + "time_budget": 10, # total running time in seconds + "metric": "accuracy", # can be: 'r2', 'rmse', 'mae', 'mse', 'accuracy', 'roc_auc', 'roc_auc_ovr', + # 'roc_auc_ovo', 'log_loss', 'mape', 'f1', 'ap', 'ndcg', 'micro_f1', 'macro_f1' + "task": "classification", # task type + "log_file_name": "airlines_experiment.log", # flaml log file + "seed": 7654321, # random seed + } + automl = AutoML(**settings) + automl.fit(X_train=X_train, y_train=y_train) + automl.visualize(type="feature_importance", plot_filename="feature_importance") + automl.visualize(type="learning_curve", plot_filename="learning_curve") + + +if __name__ == "__main__": + test_fi_lc()