diff --git a/pyproject.toml b/pyproject.toml index 93a6166..0b7872f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,7 +94,7 @@ push = true [tool.ty.src] include = ["swvo"] -exclude = ["tests", "swvo/io/RBMDataSet"] +exclude = ["tests"] [[tool.ty.overrides]] include = ["swvo"] diff --git a/swvo/io/RBMDataSet/RBMDataSet.py b/swvo/io/RBMDataSet/RBMDataSet.py index 53381e1..91de5e7 100644 --- a/swvo/io/RBMDataSet/RBMDataSet.py +++ b/swvo/io/RBMDataSet/RBMDataSet.py @@ -236,7 +236,7 @@ def find_similar_variable(self, name: str) -> tuple[None | VariableEnum, dict[st sat_variable = var break else: - dist = distance.levenshtein(name, var.var_name) + dist = distance.levenshtein(name, var.var_name) # ty:ignore[possibly-missing-attribute] if name.lower() in var.var_name.lower(): dist = 1 @@ -406,7 +406,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: datetimes = typing.cast( NDArray[np.object_], np.asarray([matlab2python(t) for t in file_content["time"]]), - ) # type: ignore + ) file_content["datetime"] = datetimes # limit in time @@ -433,7 +433,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: else: joined_value = var_arr - loaded_var_arrs[key] = joined_value + loaded_var_arrs[key] = joined_value # ty:ignore[invalid-assignment] if key not in var_names_storred: var_names_storred.append(key) @@ -456,7 +456,7 @@ def get_loaded_variables(self) -> list[str]: loaded_vars.append(var.var_name) return loaded_vars - def __eq__(self, other: RBMDataSet) -> bool: + def __eq__(self, other: RBMDataSet) -> bool: # type :ignore[override] if ( self._file_loading_mode != other._file_loading_mode or self._satellite != other._satellite @@ -501,7 +501,7 @@ def get_different_variables(self, rbm_other: RBMDataSet) -> list[str]: return different_vars - from .bin_and_interpolate_to_model_grid import bin_and_interpolate_to_model_grid + from .bin_and_interpolate_to_model_grid import bin_and_interpolate_to_model_grid # noqa: I001 from .identify_orbits import identify_orbits from .interp_functions import interp_flux, interp_psd from .linearize_trajectories import linearize_trajectories diff --git a/swvo/io/RBMDataSet/RBMDataSetManager.py b/swvo/io/RBMDataSet/RBMDataSetManager.py index 7533a48..cdb1301 100644 --- a/swvo/io/RBMDataSet/RBMDataSetManager.py +++ b/swvo/io/RBMDataSet/RBMDataSetManager.py @@ -6,7 +6,7 @@ from datetime import datetime from pathlib import Path -from typing import Iterable, overload +from typing import Iterable, Literal, overload from swvo.io.RBMDataSet.custom_enums import ( FolderTypeEnum, @@ -94,7 +94,7 @@ def load( folder_type: FolderTypeEnum = FolderTypeEnum.DataServer, *, verbose: bool = True, - preferred_extension: str = "pickle", + preferred_extension: Literal["mat", "pickle"] = "pickle", ) -> RBMDataSet | list[RBMDataSet]: """Loads an RBMDataSet or a list of RBMDataSets based on the provided parameters. @@ -152,7 +152,7 @@ def load( return_list.append(cls._instance.data_set_dict[key_tuple]) else: cls._instance.data_set_dict[key_tuple] = RBMDataSet( - satellite=sat, + satellite=sat, # ty:ignore[invalid-argument-type] instrument=instrument, mfm=mfm, start_time=start_time, diff --git a/swvo/io/RBMDataSet/RBMNcDataSet.py b/swvo/io/RBMDataSet/RBMNcDataSet.py index f3aa22c..a172ab4 100644 --- a/swvo/io/RBMDataSet/RBMNcDataSet.py +++ b/swvo/io/RBMDataSet/RBMNcDataSet.py @@ -159,7 +159,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: np.asarray( [dt.datetime.fromtimestamp(t.astype(np.int64), tz=dt.timezone.utc) for t in datasets["time"]] ), - ) # type: ignore + ) datasets["datetime"] = datetimes # limit in time @@ -183,7 +183,7 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: else: joined_value = var_arr - loaded_var_arrs[key] = joined_value + loaded_var_arrs[key] = joined_value # ty:ignore[invalid-assignment] if key not in var_names_stored: var_names_stored.append(key) @@ -194,9 +194,9 @@ def _load_variable(self, var: Variable | VariableEnum) -> None: for var_name in var_names_stored: if var_name == "datetime": - loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # type: ignore + loaded_var_arrs[var_name] = list(loaded_var_arrs[var_name]) # ty:ignore[invalid-assignment] - rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) + rbm_var_name = RBMNcDataSet._get_rbm_name(var_name, self._mfm.mfm_name) # ty:ignore[invalid-argument-type] if rbm_var_name is not None: setattr(self, rbm_var_name, loaded_var_arrs[var_name]) diff --git a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py index 932ce54..1777a29 100644 --- a/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py +++ b/swvo/io/RBMDataSet/bin_and_interpolate_to_model_grid.py @@ -76,7 +76,7 @@ def bin_and_interpolate_to_model_grid( raise (ValueError(msg)) # 3. Bin in time - psd_binned_in_time = _bin_in_time(self.datetime, sim_time, psd_binned_in_space) + psd_binned_in_time = _bin_in_time(self.datetime, sim_time, psd_binned_in_space) # ty:ignore[invalid-argument-type] # sanity check if np.min(target_var_init) > np.min(psd_binned_in_time) or np.max(target_var_init) < np.max(psd_binned_in_time): msg = "Found inconsitency in time binning. Aborting..." @@ -87,7 +87,7 @@ def bin_and_interpolate_to_model_grid( plot_debug_figures( self, psd_binned_in_time, - sim_time, + sim_time, # ty:ignore[invalid-argument-type] grid_P, grid_R, grid_mu_V, @@ -99,7 +99,7 @@ def bin_and_interpolate_to_model_grid( plot_debug_figures_plasmasphere( self, psd_binned_in_time, - sim_time, + sim_time, # ty:ignore[invalid-argument-type] grid_P, grid_R, debug_plot_settings, @@ -257,12 +257,12 @@ def _interpolate_in_V_K( rs = p.map_async(func, range(psd_in.shape[0])) # display progress bar if verbose - total_elements = rs._number_left + total_elements = rs._number_left # ty:ignore[unresolved-attribute] with tqdm(total=total_elements) as t: while True: if rs.ready(): break - t.n = total_elements - rs._number_left + t.n = total_elements - rs._number_left # ty:ignore[unresolved-attribute] t.refresh() time.sleep(1) @@ -417,11 +417,11 @@ def plot_debug_figures_plasmasphere( ) ax0.set_ylim(1, 6.6) ax0.set_title("Orbit") - ax0.set_rlim([0, 6.6]) - ax0.set_theta_offset(np.pi) + ax0.set_rlim([0, 6.6]) # ty:ignore[unresolved-attribute] + ax0.set_theta_offset(np.pi) # ty:ignore[unresolved-attribute] - grid_X = grid_R[:, :, 0, 0] * np.cos(grid_P[:, :, 0, 0]) - grid_Y = grid_R[:, :, 0, 0] * np.sin(grid_P[:, :, 0, 0]) + grid_X = grid_R[:, :, 0, 0] * np.cos(grid_P[:, :, 0, 0]) # ty:ignore[non-subscriptable] # ty:ignore[ignore-comment-unknown-rule, not-subscriptable] + grid_Y = grid_R[:, :, 0, 0] * np.sin(grid_P[:, :, 0, 0]) # ty:ignore[non-subscriptable] # ty:ignore[ignore-comment-unknown-rule, not-subscriptable] pc = ax1.pcolormesh( grid_X, @@ -478,8 +478,8 @@ def plot_debug_figures( R_idx = np.argwhere(np.abs(grid_R[0, :, 0, 0] - R_or_Lstar_arr[sat_time_idx])) - K_idx = np.argmin(np.abs(grid_K[0, R_idx, 0, :] - debug_plot_settings.target_K)) - V_idx = np.argmin(np.abs(grid_V[0, R_idx, :, K_idx] - debug_plot_settings.target_V)) + K_idx = np.argmin(np.abs(grid_K[0, R_idx, 0, :] - debug_plot_settings.target_K)) # ty:ignore[unsupported-operator] + V_idx = np.argmin(np.abs(grid_V[0, R_idx, :, K_idx] - debug_plot_settings.target_V)) # ty:ignore[unsupported-operator] V_lim_min = np.log10(0.9 * np.min([np.nanmin(data_set_V_or_Mu), np.min(grid_V)])) V_lim_max = np.log10(1.1 * np.max([np.nanmax(data_set_V_or_Mu), np.max(grid_V)])) @@ -497,7 +497,7 @@ def plot_debug_figures( ax0.scatter(data_set.P[sat_time_idx], R_or_Lstar_arr[sat_time_idx], c="k", marker="D") ax0.set_ylim(1, 6.6) ax0.set_title("Orbit") - ax0.set_theta_offset(np.pi) + ax0.set_theta_offset(np.pi) # ty:ignore[unresolved-attribute] ax1.vlines( [np.log10(np.min(grid_V)), np.log10(np.max(grid_V))], diff --git a/swvo/io/RBMDataSet/identify_orbits.py b/swvo/io/RBMDataSet/identify_orbits.py index 7de79a0..c9db2e5 100644 --- a/swvo/io/RBMDataSet/identify_orbits.py +++ b/swvo/io/RBMDataSet/identify_orbits.py @@ -12,8 +12,8 @@ import numpy as np import pandas as pd from numpy.typing import NDArray -from scipy.interpolate import make_splrep # type: ignore[reportUnknownVariableType] -from scipy.signal import find_peaks # type: ignore[reportUnknownVariableType] +from scipy.interpolate import make_splrep +from scipy.signal import find_peaks if typing.TYPE_CHECKING: from swvo.io.RBMDataSet import RBMDataSet, RBMNcDataSet @@ -32,12 +32,12 @@ def _identify_orbits( if apply_smoothing: timestamps = [t.timestamp() for t in time] - distance_filled = make_splrep(timestamps, distance_filled, s=0)(timestamps) # type: ignore[reportUnknownVariableType] + distance_filled = make_splrep(timestamps, distance_filled, s=0)(timestamps) distance_filled = typing.cast("NDArray[np.floating]", distance_filled) - peaks, _ = find_peaks(distance_filled, distance=minimal_distance) # type: ignore[reportUnknownVariableType] - troughs, _ = find_peaks(-distance_filled, distance=minimal_distance) # type: ignore[reportUnknownVariableType] - extrema = np.sort(np.concatenate((peaks, troughs))) # type: ignore[reportUnknownVariableType] + peaks, _ = find_peaks(distance_filled, distance=minimal_distance) + troughs, _ = find_peaks(-distance_filled, distance=minimal_distance) + extrema = np.sort(np.concatenate((peaks, troughs))) extrema = typing.cast("NDArray[np.int32]", extrema) diffs = np.diff(distance_filled) diff --git a/swvo/io/RBMDataSet/interp_functions.py b/swvo/io/RBMDataSet/interp_functions.py index d3c467c..912a13b 100644 --- a/swvo/io/RBMDataSet/interp_functions.py +++ b/swvo/io/RBMDataSet/interp_functions.py @@ -10,6 +10,7 @@ from enum import Enum from functools import partial from multiprocessing import Pool +from typing import Literal import numpy as np from numpy.typing import NDArray @@ -119,14 +120,14 @@ def interp_flux( target_type = TargetType[target_type] if target_type == TargetType.TargetPairs: - assert len(target_en) == len( - target_al + assert len(target_en) == len( # ty:ignore[invalid-argument-type] + target_al # ty:ignore[invalid-argument-type] ), "For TargetType.Pairs, the target vectors must have the same size!" - result_arr = np.empty((len(self.time), len(target_en))) + result_arr = np.empty((len(self.time), len(target_en))) # ty:ignore[invalid-argument-type] targets = list(zip(target_en, target_al)) else: - result_arr = np.empty((len(self.time), len(target_en), len(target_al))) + result_arr = np.empty((len(self.time), len(target_en), len(target_al))) # ty:ignore[invalid-argument-type] targets = list(itertools.product(target_en, target_al)) func = partial( @@ -142,12 +143,12 @@ def interp_flux( # display progress bar if verbose if self._verbose: - total_elements = rs._number_left + total_elements = rs._number_left # ty:ignore[unresolved-attribute] with tqdm(total=total_elements) as t: while True: if rs.ready(): break - t.n = total_elements - rs._number_left + t.n = total_elements - rs._number_left # ty:ignore[unresolved-attribute] t.refresh() time.sleep(1) else: @@ -164,9 +165,9 @@ def interp_flux( result_arr[i, t] = parallel_results[i][t] else: for ie, ia in itertools.product( - range(len(target_en)), range(len(target_al)) + range(len(target_en)), range(len(target_al)) # ty:ignore[invalid-argument-type] ): - result_arr[i, ie, ia] = parallel_results[i][ie * len(target_al) + ia] + result_arr[i, ie, ia] = parallel_results[i][ie * len(target_al) + ia] # ty:ignore[invalid-argument-type] return result_arr @@ -296,11 +297,11 @@ def interp_psd(self: RBMDataSet, if target_type == TargetType.TargetPairs: assert len(target_mu) == len(target_K), \ - "For TargetType.Pairs, mu and K vectors must have the same size!" - result_arr = np.empty((len(self.time), len(target_mu))) + "For TargetType.Pairs, mu and K vectors must have the same size!" # ty:ignore[invalid-argument-type] + result_arr = np.empty((len(self.time), len(target_mu))) # ty:ignore[invalid-argument-type] targets = list(zip(target_mu, target_K)) else: - result_arr = np.empty((len(self.time), len(target_mu), len(target_K))) + result_arr = np.empty((len(self.time), len(target_mu), len(target_K))) # ty:ignore[invalid-argument-type] targets = list(itertools.product(target_mu, target_K)) # ensure needed fields are loaded (triggers lazy loader if any) @@ -313,11 +314,11 @@ def interp_psd(self: RBMDataSet, rs = p.map_async(func, range(len(self.time))) if self._verbose: - total_elements = rs._number_left + total_elements = rs._number_left # ty:ignore[unresolved-attribute] with tqdm(total=total_elements) as t: while True: if rs.ready(): break - t.n = (total_elements - rs._number_left) + t.n = (total_elements - rs._number_left) # ty:ignore[unresolved-attribute] t.refresh() time.sleep(1) else: @@ -333,7 +334,7 @@ def interp_psd(self: RBMDataSet, for t, _ in enumerate(targets): result_arr[i, t] = parallel_results[i][t] else: - n_mu, n_K = len(target_mu), len(target_K) + n_mu, n_K = len(target_mu), len(target_K) # ty:ignore[invalid-argument-type] for i in range(result_arr.shape[0]): for im, iK in itertools.product(range(n_mu), range(n_K)): result_arr[i, im, iK] = parallel_results[i][im * n_K + iK] diff --git a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py index 7914d77..c8b2081 100644 --- a/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py +++ b/swvo/io/RBMDataSet/scripts/create_RBSP_line_data.py @@ -97,7 +97,7 @@ def create_RBSP_line_data( target_type = TargetType[target_type] if target_type == TargetType.TargetPairs: - assert len(target_en) == len(target_al), "For TargetType.Pairs, the target vectors must have the same size!" + assert len(target_en) == len(target_al), "For TargetType.Pairs, the target vectors must have the same size!" # ty:ignore[invalid-argument-type] result_arr = [] list_instruments_used = [] @@ -115,18 +115,18 @@ def create_RBSP_line_data( instrument, mfm, verbose=verbose, - ) + ) # ty:ignore[no-matching-overload] ) # strip of time dimention if rbm_data[i].energy_channels.shape[0] == len(rbm_data[i].time): - rbm_data[i].energy_channels_no_time = np.nanmean(rbm_data[i].energy_channels, axis=0) + rbm_data[i].energy_channels_no_time = np.nanmean(rbm_data[i].energy_channels, axis=0) # ty:ignore[unresolved-attribute] else: - rbm_data[i].energy_channels_no_time = rbm_data[i].energy_channels + rbm_data[i].energy_channels_no_time = rbm_data[i].energy_channels # ty:ignore[unresolved-attribute] if rbm_data[i].alpha_local.shape[0] == len(rbm_data[i].time): - rbm_data[i].alpha_local_no_time = np.nanmean(rbm_data[i].alpha_local, axis=0) + rbm_data[i].alpha_local_no_time = np.nanmean(rbm_data[i].alpha_local, axis=0) # ty:ignore[unresolved-attribute] else: - rbm_data[i].alpha_local_no_time = rbm_data[i].alpha_local + rbm_data[i].alpha_local_no_time = rbm_data[i].alpha_local # ty:ignore[unresolved-attribute] for e, target_en_single in enumerate(target_en): if verbose: @@ -150,19 +150,19 @@ def create_RBSP_line_data( rbm_data_set_result = deepcopy(rbm_data[i]) if target_type == TargetType.TargetPairs: - rbm_data_set_result.line_data_flux = np.empty((len(rbm_data_set_result.time), len(target_en))) - rbm_data_set_result.line_data_energy = np.empty((len(target_en),)) - rbm_data_set_result.line_data_alpha_local = np.empty((len(target_al),)) + rbm_data_set_result.line_data_flux = np.empty((len(rbm_data_set_result.time), len(target_en))) # ty:ignore[invalid-argument-type, unresolved-attribute] + rbm_data_set_result.line_data_energy = np.empty((len(target_en),)) # ty:ignore[invalid-argument-type, unresolved-attribute] + rbm_data_set_result.line_data_alpha_local = np.empty((len(target_al),)) # ty:ignore[invalid-argument-type, unresolved-attribute] elif target_type == TargetType.TargetMeshGrid: - rbm_data_set_result.line_data_flux = np.empty( + rbm_data_set_result.line_data_flux = np.empty( # ty:ignore[unresolved-attribute] ( len(rbm_data_set_result.time), - len(target_en), - len(target_al), + len(target_en), # ty:ignore[invalid-argument-type] + len(target_al), # ty:ignore[invalid-argument-type] ) ) - rbm_data_set_result.line_data_energy = np.empty((len(target_en),)) - rbm_data_set_result.line_data_alpha_local = np.empty((len(target_al),)) + rbm_data_set_result.line_data_energy = np.empty((len(target_en),)) # ty:ignore[invalid-argument-type, unresolved-attribute] + rbm_data_set_result.line_data_alpha_local = np.empty((len(target_al),)) # ty:ignore[invalid-argument-type, unresolved-attribute] energy_offsets_relative = energy_offsets / target_en_single if np.all(np.abs(energy_offsets_relative) > energy_offset_threshold): @@ -187,7 +187,7 @@ def create_RBSP_line_data( if target_type == TargetType.TargetPairs: closest_al_idx = np.nanargmin( - np.abs(rbm_data[min_offset_instrument].alpha_local_no_time - target_al[e]) + np.abs(rbm_data[min_offset_instrument].alpha_local_no_time - target_al[e]) # ty:ignore[not-subscriptable] ) rbm_data_set_result.line_data_alpha_local[e] = rbm_data[min_offset_instrument].alpha_local_no_time[ closest_al_idx @@ -200,7 +200,9 @@ def create_RBSP_line_data( else: rbm_data_set_result.line_data_flux[:, e] = np.squeeze( rbm_data[min_offset_instrument].interp_flux( - target_en_single, target_al[e], TargetType.TargetPairs + target_en_single, + target_al[e], + TargetType.TargetPairs, # ty:ignore[not-subscriptable] ) ) diff --git a/swvo/io/RBMDataSet/utils.py b/swvo/io/RBMDataSet/utils.py index 06bbd61..251982b 100644 --- a/swvo/io/RBMDataSet/utils.py +++ b/swvo/io/RBMDataSet/utils.py @@ -109,13 +109,13 @@ def matlab2python(datenum: float | Iterable[float]) -> Iterable[datetime] | date warnings.filterwarnings("ignore", message="Discarding nonzero nanoseconds in conversion") datenum = np.asarray(datenum, dtype=float) - datenum = pd.to_datetime(datenum - 719529, unit="D", origin=pd.Timestamp("1970-01-01")).to_pydatetime() + datenum = pd.to_datetime(datenum - 719529, unit="D", origin=pd.Timestamp("1970-01-01")).to_pydatetime() # ty:ignore[unresolved-attribute] if isinstance(datenum, Iterable): - datenum = enforce_utc_timezone(list(datenum)) - datenum = [round_seconds(x) for x in datenum] + datenum = enforce_utc_timezone(list(datenum)) # ty:ignore[invalid-assignment] + datenum = [round_seconds(x) for x in datenum] # ty:ignore[invalid-argument-type, invalid-assignment, not-iterable] else: - datenum = round_seconds(enforce_utc_timezone(datenum)) + datenum = round_seconds(enforce_utc_timezone(datenum)) # ty:ignore[invalid-assignment] return datenum