diff --git a/requirements.txt b/requirements.txt index c5ba0ac..7adccb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,5 +13,4 @@ seaborn>=0.13.2 ipywidgets>=8.1.5 ipython>=8.28.0 python-dateutil>=2.9.0.post0 -tabulate>=0.9.0 scikit-learn>=1.6.1 diff --git a/setup.py b/setup.py index bf526f3..704b279 100644 --- a/setup.py +++ b/setup.py @@ -28,19 +28,33 @@ "PySide6>=6.6.0", "scipy>=1.12.0", "openpyxl>=3.1.0", - "pytest>=8.1.0", "PyYAML>=6.0.0", "prince>=0.15.0", - "dash>=2.18.2", - "plotly>=5.24.1", - "matplotlib>=3.9.0", - "seaborn>=0.13.2", - "ipywidgets>=8.1.5", - "ipython>=8.28.0", + "pydantic>=2.7.0", "python-dateutil>=2.9.0.post0", - "tabulate>=0.9.0", "scikit-learn>=1.6.1", ], + optional_dependencies={ + "ipython": [ + "ipywidgets>=8.1.5", + "ipython>=8.28.0", + "matplotlib>=3.9.0", + "seaborn>=0.13.2", + "plotly>=5.24.1", + ], + "dash": [ + "dash>=2.18.2", + "dash-bootstrap-components>=1.4.1", + "plotly>=5.24.1", + ], + "dev": [ + "pytest>=8.1.0", + "black>=24.9.1", + "flake8>=6.1.0", + "mypy>=1.3.0", + ], + }, + python_requires=">=3.10", entry_points={ "console_scripts": [ diff --git a/src/Untitled.ipynb b/src/Untitled.ipynb index 7bf834d..3b02cf9 100644 --- a/src/Untitled.ipynb +++ b/src/Untitled.ipynb @@ -1561,7 +1561,7 @@ "import importlib\n", "importlib.reload(ExcelLayout)\n", "\n", - "midrc_data = ExcelLayout.DataSource('MIDRC')\n", + "midrc_data = ExcelLayout.DataSourceConfig('MIDRC')\n", "print( midrc_data.sheets['Race'].columns.values() )\n", "print( midrc_data.sheets['Race'].df.columns )\n", "cols_to_use = midrc_data.sheets['Race'].df.columns.intersection(midrc_data.sheets['Race'].columns.values())\n", @@ -1571,7 +1571,7 @@ "midrc_race_data = np.asarray(midrc_data.sheets['Race'].df[cols_to_use].iloc[-1].values,dtype=float)\n", "print(midrc_race_data)\n", "\n", - "cdc_data = ExcelLayout.DataSource('CDC')\n", + "cdc_data = ExcelLayout.DataSourceConfig('CDC')\n", "#print( cdc_data.sheets['Race'].columns )\n", "cols_to_use = cdc_data.sheets['Race'].df.columns.intersection(cdc_data.sheets['Race'].columns.values())\n", "#Remove date column\n", @@ -1579,7 +1579,7 @@ "cdc_race_data = np.asarray(cdc_data.sheets['Race'].df[cols_to_use].iloc[-1].values,dtype=float)\n", "print(cdc_race_data)\n", "\n", - "census_data = ExcelLayout.DataSource('Census')\n", + "census_data = ExcelLayout.DataSourceConfig('Census')\n", "#print( census_data.sheets['Race'].columns.values() )\n", "#print( census_data.sheets['Race'].df.columns )\n", "cols_to_use = census_data.sheets['Race'].df.columns.intersection(census_data.sheets['Race'].columns.values())\n", @@ -1607,17 +1607,17 @@ "\n", "sheet_name = 'Race'\n", "\n", - "midrc_data = ExcelLayout.DataSource('MIDRC')\n", + "midrc_data = ExcelLayout.DataSourceConfig('MIDRC')\n", "cols_to_use = midrc_data.sheets[sheet_name].df.columns.intersection(midrc_data.sheets[sheet_name].columns.values())\n", "cols_to_use = cols_to_use[1:]\n", "midrc_sheet_data = np.asarray(midrc_data.sheets[sheet_name].df[cols_to_use].iloc[-1].values,dtype=float)\n", "\n", - "cdc_data = ExcelLayout.DataSource('CDC')\n", + "cdc_data = ExcelLayout.DataSourceConfig('CDC')\n", "cols_to_use = cdc_data.sheets[sheet_name].df.columns.intersection(cdc_data.sheets[sheet_name].columns.values())\n", "cols_to_use = cols_to_use[1:]\n", "cdc_sheet_data = np.asarray(cdc_data.sheets[sheet_name].df[cols_to_use].iloc[-1].values,dtype=float)\n", "\n", - "census_data = ExcelLayout.DataSource('Census')\n", + "census_data = ExcelLayout.DataSourceConfig('Census')\n", "cols_to_use = census_data.sheets[sheet_name].df.columns.intersection(census_data.sheets[sheet_name].columns.values())\n", "cols_to_use = cols_to_use[1:]\n", "census_sheet_data = np.asarray(census_data.sheets[sheet_name].df[cols_to_use].iloc[-1].values,dtype=float)\n", diff --git a/src/midrc_react/core/aggregate_jsd_calc.py b/src/midrc_react/core/aggregate_jsd_calc.py index e8c1a9c..1b0a0f2 100644 --- a/src/midrc_react/core/aggregate_jsd_calc.py +++ b/src/midrc_react/core/aggregate_jsd_calc.py @@ -70,7 +70,7 @@ def calc_jsd_by_features_combined(combined_df: pd.DataFrame, cols_to_use: list[s # Convert dataset columns to string in case they are integers pivot_table.columns = pivot_table.columns.astype(str) - labels = combined_df[dataset_column].unique().astype(str) + labels = sorted(combined_df[dataset_column].unique().astype(str)) # Create a dictionary to hold counts for each dataset counts_dict = {dataset: pivot_table[dataset].values if dataset in pivot_table else np.zeros(len(pivot_table)) for diff --git a/src/midrc_react/core/excel_layout.py b/src/midrc_react/core/excel_layout.py index 69ac805..c5446f3 100644 --- a/src/midrc_react/core/excel_layout.py +++ b/src/midrc_react/core/excel_layout.py @@ -41,20 +41,20 @@ def __init__(self, data_source, custom_age_ranges=None): data_source (dict): The data source configuration. custom_age_ranges (dict, optional): A dictionary of custom age ranges. """ - self.name = data_source['name'] + self.name = data_source.name self.sheets = {} - self.datatype = data_source['data type'] - self.filename = data_source['filename'] + self.datatype = data_source.data_type + self.filename = data_source.filename self.data_source = data_source self.custom_age_ranges = custom_age_ranges - self._numeric_cols = data_source.get('numeric_cols', {}) # Extract numeric columns from config - self._columns = data_source.get('columns', []) + self._numeric_cols = data_source.numeric_cols # Extract numeric columns from config + self._columns = data_source.columns self.raw_data = None # Load preprocessing plugin if specified self.preprocessor = None - if 'plugin' in data_source and data_source['plugin']: - plugin_name = data_source['plugin'] + if data_source.plugin: + plugin_name = data_source.plugin plugin_path = os.path.join("plugins", f"{plugin_name}.py") self.preprocessor = DataSource.load_plugin(plugin_path) @@ -64,8 +64,8 @@ def __init__(self, data_source, custom_age_ranges=None): self.build_data_frames_from_csv(self.filename) else: self.build_data_frames_from_file(self.filename) - if self.datatype == 'content' and 'content' in data_source: - self.build_data_frames_from_content(data_source['content']) + if self.datatype == 'content' and hasattr(data_source, 'content') and data_source.content is not None: + self.build_data_frames_from_content(data_source.content) def raw_columns_to_use(self): """ @@ -126,9 +126,9 @@ def apply_numeric_column_adjustments(self, df: pd.DataFrame): pd.DataFrame: The DataFrame with numeric column adjustments. """ for str_col, col_dict in self._numeric_cols.items(): - num_col = col_dict['raw column'] if 'raw column' in col_dict else str_col - bins = col_dict['bins'] if 'bins' in col_dict else None - labels = col_dict['labels'] if 'labels' in col_dict else None + num_col = col_dict.raw_column if hasattr(col_dict, 'raw_column') else str_col + bins = col_dict.bins if hasattr(col_dict, 'bins') else None + labels = col_dict.labels if hasattr(col_dict, 'labels') else None if num_col in df.columns: df = bin_dataframe_column(df, num_col, str_col, bins=bins, labels=labels) @@ -139,7 +139,6 @@ def apply_numeric_column_adjustments(self, df: pd.DataFrame): # else: # # Default "N-N" format conversion # df[str_col] = df[num_col].apply(lambda x: f'{int(x)}-{int(x)}' if pd.notna(x) else x) - return df def build_data_frames_from_csv(self, filename: str): @@ -226,7 +225,7 @@ def create_sheets_from_df(self, df: pd.DataFrame): if col in df.columns: df_cumsum = self.calculate_cumulative_sums(df, col) if col in self._numeric_cols: - labels = self._numeric_cols[col].get('labels', None) + labels = self._numeric_cols[col].labels if hasattr(self._numeric_cols[col], 'labels') else None if labels: # The first column (e.g., date) remains at index 0. date_column = df_cumsum.columns[0] @@ -333,25 +332,28 @@ def _process_date_column(self, data_source: dict): """ # This assumes that the first column is either the date column or does not have useful data - if data_source.get('date'): + date_value = getattr(data_source, 'date', None) + if date_value: self._df.drop(self._df.columns[0], axis=1, inplace=True) - self._df.insert(0, 'date', data_source['date'], False) + self._df.insert(0, 'date', date_value, False) self._df['date'] = pd.to_datetime(self._df['date'], errors='coerce') self._columns['date'] = self._df.columns[0] - def _process_columns(self, data_source: dict): + def _process_columns(self, data_source): """ Process and rename columns according to the data source settings. Args: - data_source (dict): The data source object. + data_source (DataSource): The data source object. """ for col in self._df.columns[1:]: col_name = col - if 'remove column name text' in data_source: - for txt in data_source['remove column name text']: + # Access remove_column_name_text from pydantic model + remove_text = getattr(data_source, 'remove_column_name_text', None) + if remove_text: + for txt in remove_text: col_name = col.split(txt)[0] col_name = col_name.rstrip() self._columns[col_name] = col diff --git a/src/midrc_react/core/famd_calc.py b/src/midrc_react/core/famd_calc.py index 5889857..153a25e 100644 --- a/src/midrc_react/core/famd_calc.py +++ b/src/midrc_react/core/famd_calc.py @@ -22,7 +22,6 @@ import numpy as np import pandas as pd import prince -from tabulate import tabulate from midrc_react.core.data_preprocessing import combine_datasets_from_list from midrc_react.core.numeric_distances import calc_distances_via_df, scale_feature @@ -132,7 +131,11 @@ def calc_famd_df(raw_df, cols_to_use, numeric_cols, dataset_column='_dataset_', if len(outlier_df) > 0: outlier_df = outlier_df.sort_values(by=famd_column, ascending=False) print(f"Outliers in FAMD fitting: {outlier_df.shape[0]}") - print(tabulate(outlier_df, headers='keys', tablefmt='psql')) + try: + from tabulate import tabulate + print(tabulate(outlier_df, headers='keys', tablefmt='psql')) + except ImportError: + print(outlier_df) return c_df diff --git a/src/midrc_react/core/jsdconfig.py b/src/midrc_react/core/jsdconfig.py index a28fe49..5dba159 100644 --- a/src/midrc_react/core/jsdconfig.py +++ b/src/midrc_react/core/jsdconfig.py @@ -17,9 +17,11 @@ This module contains the JSDConfig class, which loads and stores data from a YAML file. """ -from dataclasses import dataclass, field import os +from typing import List, Optional, Dict, Union, Any +from pydantic import BaseModel, Field, ValidationError +from pydantic.dataclasses import dataclass from yaml import load try: from yaml import CLoader as Loader @@ -27,7 +29,52 @@ from yaml import Loader -@dataclass +class NumericColumnConfig(BaseModel): + """ + NumericColumnConfig model to represent numeric column configurations in the YAML configuration. + """ + raw_column: str = Field(..., alias='raw column') + bins: List[float] + labels: Optional[List[str]] = None + adjust_outliers: bool = Field(False, alias='adjust outliers') + +class DataSourceConfig(BaseModel): + """ + DataSource model to represent individual data sources in the YAML configuration. + """ + name: str + description: Optional[str] = None + data_type: str = Field(..., alias='data type') + filename: str + columns: Optional[List[str]] = None + numeric_cols: Optional[Dict[str, NumericColumnConfig]] = None + plugin: Optional[str] = None + date: Optional[str] = None + remove_column_name_text: Optional[List[str]] = Field(None, alias='remove column name text') + + content: Optional[Any] = None # Placeholder for loaded content + content_type: Optional[str] = None # Placeholder for content type after loading + + class Config: + validate_by_name = True + extra = 'allow' + +DataSourceConfigList = List[DataSourceConfig] + +class ConfigData(BaseModel): + """ + ConfigData model to represent the structure of the YAML configuration data. + """ + # Define fields based on expected YAML structure + data_sources: DataSourceConfigList = Field(..., alias='data sources') + custom_age_ranges: Optional[Dict[str, List[Union[int, float]]]] = Field(None, alias='custom_age_range') + + class Config: + validate_by_name = True + # accept extra fields in the YAML + extra = 'allow' + + class JSDConfig: """ The JSDConfig class loads and stores data from a YAML file. @@ -38,13 +85,16 @@ class JSDConfig: Methods: __init__(self, filename='jsdconfig.yaml'): Initializes a new instance of JSDConfig. - __post_init__(self): Loads the YAML data from the current filename. + _load_data(self): Loads the YAML data from the current filename. + set_filename(self, new_filename): Sets a new filename and reloads the data. """ - filename: str = 'jsdconfig.yaml' - data: dict = field(init=False) + filename: str + data: Optional[ConfigData] - def __post_init__(self): + def __init__(self, filename: str = 'jsdconfig.yaml'): """Load the YAML data from the current filename.""" + self.filename = filename + self.data = None # os.chdir(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) self._load_data() @@ -53,11 +103,16 @@ def _load_data(self): if not os.path.exists(self.filename): print(f"File {self.filename} does not exist. Skipping load.") print(f"Current working directory: {os.getcwd()}") - self.data = {} + self.data = None return with open(self.filename, 'r', encoding='utf-8') as stream: - self.data = load(stream, Loader=Loader) + raw = load(stream, Loader=Loader) + try: + self.data = ConfigData(**raw) + except ValidationError as e: + self.data = None + raise # print(dump(self.data)) def set_filename(self, new_filename: str): diff --git a/src/midrc_react/core/jsdcontroller.py b/src/midrc_react/core/jsdcontroller.py index 84c0998..5772cd3 100644 --- a/src/midrc_react/core/jsdcontroller.py +++ b/src/midrc_react/core/jsdcontroller.py @@ -179,14 +179,14 @@ def file_changed(self, _, new_category_index=None): return category_info = dataselectiongroupbox.get_category_info() - category_index = category_info['current_index'] + category_index = category_info.current_index if new_category_index is not None: category_index = new_category_index # Compute the intersection of category keys across all file_infos. category_set = None for cbox in file_infos: - ds = self.jsd_model.data_sources[cbox['source_id']] + ds = self.jsd_model.data_sources[cbox.source_id] if category_set is None: category_set = set(ds.sheets.keys()) else: @@ -196,7 +196,7 @@ def file_changed(self, _, new_category_index=None): # Compute has_raw_data using all() with a generator expression. has_raw_data = all( - self.jsd_model.data_sources[cbox['source_id']].raw_data is not None + self.jsd_model.data_sources[cbox.source_id].raw_data is not None for cbox in file_infos ) @@ -206,7 +206,7 @@ def file_changed(self, _, new_category_index=None): category_list.append('Aggregate') category_list.append('FAMD') for category in category_list: - if category in self.jsd_model.data_sources[file_infos[0]['source_id']].numeric_cols: + if category in self.jsd_model.data_sources[file_infos[0].source_id].numeric_cols: category_list.append(f'{category} (ks2)') dataselectiongroupbox.update_category_list(category_list, category_index) @@ -224,7 +224,7 @@ def get_file_sheets_from_index(self, index=0): dict: A dictionary containing the sheets from the selected file. """ try: - current_data = self.jsd_view.dataselectiongroupbox.get_file_infos()[index]['source_id'] + current_data = self.jsd_view.dataselectiongroupbox.get_file_infos()[index].source_id except IndexError as exc: raise IndexError("Index out of range") from exc @@ -252,7 +252,7 @@ def category_changed(self): """ dataselectiongroupbox = self.jsd_view.dataselectiongroupbox file_infos = dataselectiongroupbox.get_file_infos() - category = dataselectiongroupbox.get_category_info()['current_text'] + category = dataselectiongroupbox.get_category_info().current_text # Try to avoid a race condition where the category is changed before the file is changed if not category: @@ -267,8 +267,8 @@ def build_date_list(df1, df2): column_infos = [] for (i, cbox1), (j, cbox2) in itertools.combinations(enumerate(file_infos), 2): - file1 = cbox1['source_id'] - file2 = cbox2['source_id'] + file1 = cbox1.source_id + file2 = cbox2.source_id data_source_1 = self.jsd_model.data_sources[file1] data_source_2 = self.jsd_model.data_sources[file2] @@ -305,7 +305,7 @@ def build_date_list(df1, df2): numeric_cols = [] for str_col, col_info in data_source_1.numeric_cols.items(): cols_to_use.remove(str_col) - num_col = col_info['raw column'] + num_col = col_info.raw_column cols_to_use.append(num_col) numeric_cols.append(num_col) input_data = calc_famd_ks2_at_dates( @@ -321,7 +321,7 @@ def build_date_list(df1, df2): combined_df = combine_datasets_from_list([raw_df1, raw_df2]) date_list = build_date_list(raw_df1, raw_df2) str_col = category[:-6] - num_col = data_source_1.numeric_cols[str_col]['raw column'] + num_col = data_source_1.numeric_cols[str_col].raw_column input_data = [calc_ks2_samp_by_feature(combined_df[combined_df['date'] <= date], num_col)['Dataset 0 vs Dataset 1'] for date in date_list] @@ -394,7 +394,7 @@ def update_file_based_charts(self): sheet_dict = {} file_infos = self.jsd_view.dataselectiongroupbox.get_file_infos() for i, file_info in enumerate(file_infos): - if file_info['checked']: + if file_info.checked: sheet_dict[i] = self.get_file_sheets_from_index(i) spider_plot_values = self.get_spider_plot_values(spider_plot_date) @@ -422,7 +422,7 @@ def update_category_plots(self): sheet_dict = {} file_infos = self.jsd_view.dataselectiongroupbox.get_file_infos() for i, file_info in enumerate(file_infos): - if file_info['checked']: + if file_info.checked: sheet_dict[i] = self.get_file_sheets_from_index(i) try: @@ -448,7 +448,7 @@ def get_cols_to_use_for_jsd_calc(self, source_id, category): """ cols_to_use = self.jsd_model.data_sources[source_id].sheets[category].data_columns - custom_age_ranges = self._config.data.get('custom age ranges', None) + custom_age_ranges = self._config.data.custom_age_ranges if custom_age_ranges and category in custom_age_ranges: cols_to_use = [f'{age_range[0]}-{age_range[1]} Custom' for age_range in custom_age_ranges[category]] + [JSDController.NOT_REPORTED_COLUMN_NAME] @@ -469,11 +469,11 @@ def get_spider_plot_values(self, calc_date=None): calc_date = np.datetime64('today') dataselectiongroupbox = self.jsd_view.dataselectiongroupbox - categories = dataselectiongroupbox.get_category_info()['category_list'] + categories = dataselectiongroupbox.get_category_info().category_list # Determine indexes to use based on checked boxes or default to all file_infos = dataselectiongroupbox.get_file_infos() - indexes_to_use = [i for i, file_info in enumerate(file_infos) if file_info['checked']] + indexes_to_use = [i for i, file_info in enumerate(file_infos) if file_info.checked] if not indexes_to_use: indexes_to_use = list(range(len(file_infos))) if len(indexes_to_use) == 1: @@ -490,8 +490,8 @@ def get_spider_plot_values(self, calc_date=None): index2_candidates = [index2] for idx2 in index2_candidates: - source_id1 = file_infos[index1]['source_id'] - source_id2 = file_infos[idx2]['source_id'] + source_id1 = file_infos[index1].source_id + source_id2 = file_infos[idx2].source_id data_source_1 = self.jsd_model.data_sources[source_id1] data_source_2 = self.jsd_model.data_sources[source_id2] @@ -525,7 +525,7 @@ def get_spider_plot_values(self, calc_date=None): numeric_cols = [] for str_col, col_info in data_source_1.numeric_cols.items(): cols_to_use.remove(str_col) - num_col = col_info['raw column'] + num_col = col_info.raw_column cols_to_use.append(num_col) numeric_cols.append(num_col) jsd_dict[(index1, idx2)][category] = calc_famd_ks2_at_date( @@ -540,7 +540,7 @@ def get_spider_plot_values(self, calc_date=None): raw_df2 = data_source_2.raw_data combined_df = combine_datasets_from_list([raw_df1, raw_df2]) str_col = category[:-6] - num_col = data_source_1.numeric_cols[str_col]['raw column'] + num_col = data_source_1.numeric_cols[str_col].raw_column jsd_dict[(index1, idx2)][category] = calc_ks2_samp_by_feature( combined_df[combined_df['date'] <= calc_date], num_col, diff --git a/src/midrc_react/core/jsdmodel.py b/src/midrc_react/core/jsdmodel.py index 6a046a5..3bbefc7 100644 --- a/src/midrc_react/core/jsdmodel.py +++ b/src/midrc_react/core/jsdmodel.py @@ -24,6 +24,7 @@ from PySide6.QtGui import QColor from midrc_react.core.excel_layout import DataSource +from midrc_react.core.jsdconfig import DataSourceConfigList def convert_to_builtin(val): @@ -51,18 +52,18 @@ class JSDTableModel(QAbstractTableModel): ] data_source_added = Signal() - def __init__(self, data_source_list=None, custom_age_ranges=None): + def __init__(self, data_source_list: DataSourceConfigList=None, custom_age_ranges=None): """ Initialize the JSDTableModel. This method initializes the JSDTableModel by setting up the input data, mapping, and raw data sources. Args: - data_source_list (List[dict], optional): A list of data sources. Each data source is a dictionary with the\ - following keys: + data_source_list (List[DataSource], optional): A list of data sources. Each data source has the following + attributes (plus others): - 'name' (str): The name of the data source. - - 'data type' (str): The type of the data source. + - 'data_type' (str): The type of the data source. - 'filename' (str): The filename of the data source. custom_age_ranges (Any, optional): Custom age ranges for the data sources. @@ -79,9 +80,12 @@ def __init__(self, data_source_list=None, custom_age_ranges=None): self.max_row_count = 0 self.data_sources = {} + if data_source_list is not None: for data_source_dict in data_source_list: self.add_data_source(data_source_dict) + else: + print("No data sources provided to JSDTableModel") def add_data_source(self, data_source_dict): """ @@ -104,7 +108,7 @@ def add_data_source(self, data_source_dict): Returns: None """ - self.data_sources[data_source_dict['name']] = DataSource(data_source_dict, self.custom_age_ranges) + self.data_sources[data_source_dict.name] = DataSource(data_source_dict, self.custom_age_ranges) self.data_source_added.emit() def rowCount(self, parent: QModelIndex = None) -> int: diff --git a/src/midrc_react/core/numeric_distances.py b/src/midrc_react/core/numeric_distances.py index 2edaf04..83194ab 100644 --- a/src/midrc_react/core/numeric_distances.py +++ b/src/midrc_react/core/numeric_distances.py @@ -55,7 +55,7 @@ def calc_numerical_metric_by_feature(df, feature: str, dataset_column: str, metr Returns: dict: A dictionary containing metric results for each dataset combination. """ - dataset_names = df[dataset_column].unique() + dataset_names = sorted(df[dataset_column].unique()) metric_dict = {} # Compare each dataset combination @@ -330,10 +330,10 @@ def calc_distances_via_df(famd_df: pd.DataFrame, feature_column: str, dataset_co # Mapping of distance metrics to their respective functions distance_metric_functions = { 'jsd': {'func': lambda scaling=None: calc_jsd_from_counts_dict( - build_histogram_dict(famd_df, dataset_column, famd_df[dataset_column].unique(), + build_histogram_dict(famd_df, dataset_column, sorted(famd_df[dataset_column].unique()), feature_column, bin_width=jsd_scaled_bin_width, scaling_method=scaling), - famd_df[dataset_column].unique()), + sorted(famd_df[dataset_column].unique())), 'scaling': True}, 'wass': {'func': lambda scaling=None: calc_wasserstein_by_feature(famd_df, feature_column, dataset_column, scaling=scaling), diff --git a/src/midrc_react/gui/common/file_upload.py b/src/midrc_react/gui/common/file_upload.py index c69b671..af7126e 100644 --- a/src/midrc_react/gui/common/file_upload.py +++ b/src/midrc_react/gui/common/file_upload.py @@ -26,9 +26,9 @@ def process_file_upload(view, data_source_dict): Args: view: An instance that implements open_excel_file and holds a data_selection_group_box attribute. - data_source_dict (dict): Contains the uploaded file information. + data_source_dict (DataSource): Contains the uploaded file information. """ - print(f"handle_excel_file_uploaded() triggered with file: {data_source_dict['name']}") + print(f"handle_excel_file_uploaded() triggered with file: {data_source_dict.name}") view.open_excel_file(data_source_dict) print("Excel file loaded, try to update layout") if hasattr(view, 'data_selection_group_box'): diff --git a/src/midrc_react/gui/common/jsdview_base.py b/src/midrc_react/gui/common/jsdview_base.py index 79dcc3d..a3154f2 100644 --- a/src/midrc_react/gui/common/jsdview_base.py +++ b/src/midrc_react/gui/common/jsdview_base.py @@ -17,41 +17,40 @@ This module contains the JsdViewBase class, which serves as a base class for JSD views. """ -from dataclasses import dataclass +from typing import List, Optional, Union +from pydantic import BaseModel, Field from PySide6.QtCore import QObject, Signal from PySide6.QtWidgets import QMainWindow -@dataclass -class GroupBoxData: +class FileInfo(BaseModel): + description: Optional[str] = None + source_id: Optional[str] = None + index: Optional[int] = None + checked: bool = True + +FileInfoList = List[FileInfo] + +class CategoryInfo(BaseModel): + current_text: Optional[str] = None + current_index: Optional[int] = None + category_list: List[str] = Field(default_factory=list) + +class GroupBoxData(BaseModel): """ This class represents a group box widget for data selection. It provides functionality for creating labels and combo boxes for data files and a category combo box. The class has methods for setting up the layout, updating the category combo box, and initializing the widget. Attributes: - _file_infos (list): A list of file information dictionaries. - _category_info (dict): A dictionary containing information about the selected category. + file_infos (list): A list of file information dictionaries. + category_info (dict): A dictionary containing information about the selected category. """ - _file_infos = [] - _category_info = { - 'current_text': None, - 'current_index': None, - 'category_list': [], - } - - @property - def file_infos(self): - """ - Get the file information dictionaries. - - Returns: - list: A list of file information dictionaries. - """ - return self._file_infos + file_infos: FileInfoList = Field(default_factory=list) + category_info: CategoryInfo = Field(default_factory=CategoryInfo) - def get_file_infos(self): + def get_file_infos(self) -> FileInfoList: """ Get the file information dictionaries. @@ -60,37 +59,32 @@ def get_file_infos(self): """ return self._file_infos - def append_file_info(self, file_info: dict): + def append_file_info(self, file_info: Union[dict, FileInfo]): """ Appends a file information dictionary to the list of file information dictionaries. Args: file_info (dict): file information dictionary to append to the list """ - file_info['checked'] = file_info.get('checked', True) - self._file_infos.append(file_info) + if isinstance(file_info, dict): + file_info.setdefault('checked', True) + file_info = FileInfo(**file_info) + elif isinstance(file_info, FileInfo): + # ensure checked is set + file_info.checked = bool(file_info.checked) + self.file_infos.append(file_info) # TODO: Update the category list too? - @property - def category_info(self): - """ - Get the category information dictionary. - - Returns: - dict: A dictionary containing information about the selected category. - """ - return self._category_info - - def get_category_info(self): + def get_category_info(self) -> CategoryInfo: """ Get the category information dictionary. Returns: dict: A dictionary containing information about the selected category. """ - return self._category_info + return self.category_info - def update_category_list(self, categorylist, categoryindex): + def update_category_list(self, categorylist: List[str], categoryindex: int): """ Updates the category information dictionary with the given category list and index. @@ -101,13 +95,13 @@ def update_category_list(self, categorylist, categoryindex): Returns: None """ - self._category_info = { - 'current_text': categorylist[categoryindex], - 'current_index': categoryindex, - 'category_list': categorylist, - } + self.category_info = CategoryInfo( + current_text = categorylist[categoryindex], + current_index = categoryindex, + category_list = categorylist, + ) - def update_category_index(self, categoryindex): + def update_category_index(self, categoryindex: int): """ Updates the category information dictionary with the given category index. @@ -117,10 +111,13 @@ def update_category_index(self, categoryindex): Returns: None """ - self._category_info['current_index'] = categoryindex - self._category_info['current_text'] = self._category_info['category_list'][categoryindex] + self.category_info.current_index = categoryindex + if 0 <= categoryindex < len(self.category_info.category_list): + self.category_info.current_text = self.category_info.category_list[categoryindex] + else: + self.category_info.current_text = None - def update_category_text(self, categorytext): + def update_category_text(self, categorytext: str): """ Updates the category information dictionary with the given category text. @@ -130,10 +127,9 @@ def update_category_text(self, categorytext): Returns: None """ - category_list = self._category_info['category_list'] - if categorytext in category_list: - self._category_info['current_text'] = categorytext - self._category_info['current_index'] = category_list.index(categorytext) + if categorytext in self.category_info.category_list: + self.category_info.current_text = categorytext + self.category_info.current_index = self.category_info.category_list.index(categorytext) class JsdViewBase(QObject): @@ -176,14 +172,14 @@ def open_excel_file(self, data_source_dict): Opens an Excel file and adds it to the data selection group box. Args: - data_source_dict (dict): The data source dictionary. - """ - self._dataselectiongroupbox.append_file_info({ - 'description': data_source_dict['description'], - 'source_id': data_source_dict['name'], - 'index': len(self._dataselectiongroupbox.file_infos), - 'checked': True, - }) + data_source_dict (DataSource): The data source dictionary. + """ + self._dataselectiongroupbox.append_file_info(FileInfo( + description = data_source_dict.description, + source_id = data_source_dict.name, + index = len(self._dataselectiongroupbox.file_infos), + checked = True, + )) def update_pie_chart_dock(self, sheet_dict): """ diff --git a/src/midrc_react/gui/common/utils.py b/src/midrc_react/gui/common/utils.py index 1227af0..040d337 100644 --- a/src/midrc_react/gui/common/utils.py +++ b/src/midrc_react/gui/common/utils.py @@ -17,24 +17,27 @@ This module contains utility functions for file handling and data processing. """ +from jsdview_base import FileInfo +from midrc_react.core.jsdconfig import DataSourceConfig + def create_file_info(data_source, index): """ Create a file info dictionary from a data source. Args: - data_source (dict): Dictionary containing file info. + data_source (DataSourceConfig): Dictionary containing file info. index (int): The index to assign. Returns: dict: A dictionary with description, source_id, index and checked flag. """ - return { - 'description': data_source.get('description'), - 'source_id': data_source.get('name'), - 'index': index, - 'checked': True, - } + return FileInfo( + description = data_source.description, + source_id = data_source.name, + index = index, + checked = True, + ) def get_common_categories(file_infos, jsd_model): @@ -53,11 +56,11 @@ def get_common_categories(file_infos, jsd_model): # Get the initial set of categories from the first file info. cbox0 = file_infos[0] - common_categories = list(jsd_model.data_sources[cbox0['source_id']].sheets.keys()) + common_categories = list(jsd_model.data_sources[cbox0.source_id].sheets.keys()) # Intersect with the categories from subsequent file infos. for cbox in file_infos[1:]: - categorylist = jsd_model.data_sources[cbox['source_id']].sheets.keys() + categorylist = jsd_model.data_sources[cbox.source_id].sheets.keys() common_categories = [value for value in common_categories if value in categorylist] return common_categories @@ -76,10 +79,10 @@ def create_data_source_dict(filename, file_content, data_type='content', content Returns: dict: A dictionary with file details. """ - return { - 'description': filename, - 'name': filename, - 'content': file_content, - 'data type': data_type, - 'content type': content_type, - } + return DataSourceConfig( + description = filename, + name = filename, + content = file_content, + data_type = data_type, + content_type = content_type, + ) diff --git a/src/midrc_react/gui/dash/dataselectiongroupbox.py b/src/midrc_react/gui/dash/dataselectiongroupbox.py index e81769d..9ceff7f 100644 --- a/src/midrc_react/gui/dash/dataselectiongroupbox.py +++ b/src/midrc_react/gui/dash/dataselectiongroupbox.py @@ -277,7 +277,7 @@ def update_category_combobox(self): """ Update the category combobox based on the selected data sources. """ - previous_value = self.get_category_info()['current_text'] + previous_value = self.get_category_info().current_text file_infos = self.get_file_infos() diff --git a/src/midrc_react/gui/dash/jsdview_dash.py b/src/midrc_react/gui/dash/jsdview_dash.py index 1b1be0d..50470b5 100644 --- a/src/midrc_react/gui/dash/jsdview_dash.py +++ b/src/midrc_react/gui/dash/jsdview_dash.py @@ -206,13 +206,13 @@ def run(self): if __name__ == '__main__': # Example usage: my_config = JSDConfig() - my_data_source_list = my_config.data['data sources'] - my_jsd_model = JSDTableModel(my_data_source_list, my_config.data.get('custom age ranges', None)) + my_data_source_list = my_config.data.data_sources + my_jsd_model = JSDTableModel(my_data_source_list, my_config.data.custom_age_ranges) dash_view = JSDViewDash(my_jsd_model, my_config) # Load data sources for my_data_source in my_data_source_list: - print(f"Loading: {my_data_source['description']}...") + print(f"Loading: {my_data_source.description}...") dash_view.open_excel_file(my_data_source) print("Done Loading Files") diff --git a/src/midrc_react/gui/ipython/dataselectiongroupbox.py b/src/midrc_react/gui/ipython/dataselectiongroupbox.py index 3b52163..25f3cc6 100644 --- a/src/midrc_react/gui/ipython/dataselectiongroupbox.py +++ b/src/midrc_react/gui/ipython/dataselectiongroupbox.py @@ -191,7 +191,7 @@ def update_category_combobox(self): """ Update the category combobox based on the selected data sources. """ - previous_value = self.get_category_info()['current_text'] + previous_value = self.get_category_info().current_text file_infos = self.get_file_infos() diff --git a/src/midrc_react/gui/ipython/jsdview_ipython.py b/src/midrc_react/gui/ipython/jsdview_ipython.py index f1827ad..2538d3e 100644 --- a/src/midrc_react/gui/ipython/jsdview_ipython.py +++ b/src/midrc_react/gui/ipython/jsdview_ipython.py @@ -67,7 +67,7 @@ def open_excel_file(self, data_source_dict): Open an Excel file and add it as a data source. Args: - data_source_dict (dict): The data source information used for loading the data. + data_source_dict (DataSource): The data source information used for loading the data. """ super().open_excel_file(data_source_dict) self.add_data_source.emit(data_source_dict) @@ -297,7 +297,7 @@ def update_area_chart(self, category): if self.plot_method == 'interactive_plotly': return self.update_area_chart_interactive_plotly() - category = self.dataselectiongroupbox.get_category_info()['current_text'] + category = self.dataselectiongroupbox.get_category_info().current_text # Set up the figure with multiple subplots _fig, axes = plt.subplots(len(category), 1, figsize=(10, 6 * len(category)), sharex=True) @@ -335,7 +335,7 @@ def update_area_chart(self, category): # Final plot settings for each subplot ax.set_xlabel('Date') ax.set_ylabel(f'{category} Distribution Over Time') - source_id = self.dataselectiongroupbox.file_infos[index]['source_id'] + source_id = self.dataselectiongroupbox.file_infos[index].source_id ax.set_title(f"{source_id} {category} Distribution Over Time") ax.grid(True) ax.legend() @@ -352,7 +352,7 @@ def update_area_chart_interactive_plotly(self): """ Update the area chart using interactive plotting with Plotly. """ - category = self.dataselectiongroupbox.get_category_info()['current_text'] + category = self.dataselectiongroupbox.get_category_info().current_text # Find the global minimum and maximum date global_min_date = min(sheets[category].df['date'].min() for sheets in category.values()) @@ -376,7 +376,7 @@ def update_area_chart_interactive_plotly(self): # Update the layout for each individual figure fig.update_layout( - title=f"{self.dataselectiongroupbox.file_infos[index]['source_id']} {category} Distribution Over Time", + title=f"{self.dataselectiongroupbox.file_infos[index].source_id} {category} Distribution Over Time", xaxis_title="Date", yaxis_title="Percentage (%)", height=400, diff --git a/src/midrc_react/gui/pyside6/dataselectiongroupbox.py b/src/midrc_react/gui/pyside6/dataselectiongroupbox.py index 69f494b..011027c 100644 --- a/src/midrc_react/gui/pyside6/dataselectiongroupbox.py +++ b/src/midrc_react/gui/pyside6/dataselectiongroupbox.py @@ -21,10 +21,11 @@ from PySide6.QtCore import QSignalBlocker, Signal from PySide6.QtWidgets import QCheckBox, QComboBox, QFormLayout, QGroupBox, QHBoxLayout, QLabel -from midrc_react.gui.common.jsdview_base import GroupBoxData +from midrc_react.gui.common.jsdview_base import GroupBoxData, FileInfo +from midrc_react.core.jsdconfig import DataSourceConfigList -class JsdDataSelectionGroupBox(QGroupBox, GroupBoxData): +class JsdDataSelectionGroupBox(QGroupBox): """ This class represents a group box widget for data selection. It provides functionality for creating labels and combo boxes for data files and a category combo box. The class has methods for setting up the layout, @@ -49,10 +50,12 @@ def __init__(self, data_sources): category combo box. Parameters: - data_sources (list): A list of data sources. + data_sources (DataSourceList): A list of data sources. """ super().__init__() + self.data = GroupBoxData() + self.setTitle('Data Selection') self.form_layout = QFormLayout() @@ -62,7 +65,7 @@ def __init__(self, data_sources): self.category_combobox = QComboBox() self.set_layout(data_sources) - def set_layout(self, data_sources): + def set_layout(self, data_sources: DataSourceConfigList): """ Set the layout for the given data sources. @@ -85,7 +88,7 @@ def set_layout(self, data_sources): self.add_file_combobox_to_layout(auto_populate=False) # Add the file comboboxes and labels to the form layout - items = [(d['description'], d['name']) for d in data_sources] + items = [(d.description, d.name) for d in data_sources] for combobox_item in items: self.add_file_to_comboboxes(combobox_item[0], combobox_item[1]) self.file_comboboxes[0].setCurrentIndex(0) @@ -99,17 +102,17 @@ def get_file_infos(self): Get the file information for all files. Returns: - List[dict]: A list of dictionaries containing information about each file. + FileInfoList: A list of dictionaries containing information about each file. """ - self._file_infos = [] + self.data.file_infos: FileInfoList = [] for i, cbox in enumerate(self.file_comboboxes): - self._file_infos.append({ - 'description': cbox.currentText(), - 'source_id': cbox.currentData(), - 'index': i, - 'checked': self.file_checkboxes[i].isChecked(), - }) - return self._file_infos + self.data.append_file_info(FileInfo( + description = cbox.currentText(), + source_id = cbox.currentData(), + index = i, + checked = self.file_checkboxes[i].isChecked(), + )) + return self.data.file_infos def get_category_info(self): """ @@ -118,12 +121,9 @@ def get_category_info(self): Returns: dict: A dictionary containing the category information. """ - self._category_info = { - 'current_text': self.category_combobox.currentText(), - 'current_index': self.category_combobox.currentIndex(), - 'category_list': [self.category_combobox.itemText(i) for i in range(self.category_combobox.count())], - } - return self._category_info + category_list = [self.category_combobox.itemText(i) for i in range(self.category_combobox.count())] + category_index = self.category_combobox.currentIndex() + return self.data.category_info def add_file_combobox_to_layout(self, auto_populate: bool = True): """ @@ -233,3 +233,4 @@ def update_category_list(self, categorylist, categoryindex): self.category_combobox.clear() self.category_combobox.addItems(categorylist) self.category_combobox.setCurrentIndex(categoryindex) + self.data.update_category_list(categorylist, categoryindex) diff --git a/src/midrc_react/gui/pyside6/grabbablewidget.py b/src/midrc_react/gui/pyside6/grabbablewidget.py index 99a82d1..da93022 100644 --- a/src/midrc_react/gui/pyside6/grabbablewidget.py +++ b/src/midrc_react/gui/pyside6/grabbablewidget.py @@ -381,3 +381,26 @@ def save_chart_to_disk(self): """ self.grabbable_mixin.save_to_disk() + @property + def copyable_data(self) -> str: + """ + Get the copyable data for the chart view. + + Returns: + str: The data to be copied to the clipboard when requested. + """ + return self.grabbable_mixin.copyable_data + + @copyable_data.setter + def copyable_data(self, data: str): + """ + Set the copyable data for the chart view. + + Args: + data (str): The data to be copied to the clipboard when requested. + + Returns: + None + """ + self.grabbable_mixin.copyable_data = data + diff --git a/src/midrc_react/gui/pyside6/jsdview.py b/src/midrc_react/gui/pyside6/jsdview.py index cf4d36b..f59e7aa 100644 --- a/src/midrc_react/gui/pyside6/jsdview.py +++ b/src/midrc_react/gui/pyside6/jsdview.py @@ -37,6 +37,7 @@ ) from midrc_react.core.datetimetools import convert_date_to_milliseconds, numpy_datetime64_to_qdate +from midrc_react.core.jsdconfig import DataSourceConfigList from midrc_react.gui.common.jsdview_base import JsdViewBase from midrc_react.gui.pyside6.copyabletableview import CopyableTableView from midrc_react.gui.pyside6.dataselectiongroupbox import JsdDataSelectionGroupBox @@ -69,12 +70,12 @@ class JsdWindow(QMainWindow, JsdViewBase): "#fb9a99", "#e31a1c", "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a"] - def __init__(self, data_sources: Any) -> None: + def __init__(self, data_sources: DataSourceConfigList) -> None: """ Initialize the JsdWindow with provided data sources and set up the GUI. Args: - data_sources (Any): Data sources used to initialize the data selection group box. + data_sources (DataSourceList): Data sources used to initialize the data selection group box. """ super().__init__() # Set up the data selection group box @@ -418,7 +419,7 @@ def _set_spider_chart_copyable_data(self, spider_plot_values_dict: Dict[Any, Dic file2 = self._dataselectiongroupbox.file_comboboxes[series_key[1]].currentText() values = "\t".join(str(series[label]) for label in headers) formatted_text += f"{file1}\t{file2}\t{values}\n" - self.spider_chart_view.grabbable_mixin.copyable_data = formatted_text + self.spider_chart_view.copyable_data = formatted_text def update_spider_chart(self, spider_plot_values_dict: Dict[Any, Dict[str, float]]) -> bool: """ diff --git a/src/midrc_react/gui/pyside6/launch_react.py b/src/midrc_react/gui/pyside6/launch_react.py index af965f6..7ddd82a 100644 --- a/src/midrc_react/gui/pyside6/launch_react.py +++ b/src/midrc_react/gui/pyside6/launch_react.py @@ -108,13 +108,13 @@ def launch_react(): q_app.processEvents() config = JSDConfig() - if 'data sources' not in config.data: + if config.data is None or config.data.data_sources is None or len(config.data.data_sources) == 0: raise ValueError(f"No data sources found in the configuration file. \n" f" Check that the file < {config.filename} > exists and is in the correct format.") - data_source_list = config.data['data sources'] + data_source_list = config.data.data_sources w = JsdWindow(data_source_list) # Note: We should have the controller populate this once the tablemodel is loaded w.jsd_controller = JSDController(w, - JSDTableModel(data_source_list, config.data.get('custom age ranges', None)), + JSDTableModel(data_source_list, config.data.custom_age_ranges), config) # Set the default widget sizes, show the window, then reset the minimum sizes diff --git a/src/react_demo.ipynb b/src/react_demo.ipynb index bd0727d..353a8b1 100644 --- a/src/react_demo.ipynb +++ b/src/react_demo.ipynb @@ -37,11 +37,11 @@ "from midrc_react.gui.ipython.jsdview_ipython import JsdViewIPython\n", "\n", "config = JSDConfig()\n", - "jsd_model = JSDTableModel(None, config.data.get('custom age ranges', None))\n", + "jsd_model = JSDTableModel(None, config.data.custom_age_range)\n", "jsd_view = JsdViewIPython(jsd_model)\n", "jsd_controller = JSDController(jsd_view, jsd_model, config)\n", - "for data_source in config.data['data sources']:\n", - " print(f\"Loading: {data_source['description']}...\")\n", + "for data_source in config.data.data_sources:\n", + " print(f\"Loading: {data_source.description}...\")\n", " jsd_view.open_excel_file(data_source)\n", "\n", "print(\"Done Loading Files\")\n", @@ -210,7 +210,7 @@ "source": [ "#display(jsd_view_base.dataselectiongroupbox.get_file_infos())\n", "#display(jsd_view_base.dataselectiongroupbox.get_category_info())\n", - "#display(jsd_view_base.dataselectiongroupbox.get_category_info()['current_text'])\n", + "#display(jsd_view_base.dataselectiongroupbox.get_category_info().current_text)\n", "#display(jsd_view_base.dataselectiongroupbox.category_combobox.value)\n", "display(jsd_table_model.data_sources['MIDRC'].data_source)\n", "\n",