Source code for TFilterPy.base_estimator

[docs] class BaseEstimator: """ Base class for all estimators in the TFilterPy package. Provides common functionality such as parameter handling and validation. """ def __init__(self, name=None): """ Initialize the BaseEstimator. Args: name (str): Optional name for the estimator. """ self.name = name or self.__class__.__name__
[docs] def get_params(self, deep=True): """ Get parameters of the estimator. Args: deep (bool): If True, retrieves parameters of nested objects. Returns: dict: A dictionary of parameter names mapped to their values. """ params = {} for key, value in self.__dict__.items(): if deep and hasattr(value, "get_params"): deep_items = value.get_params().items() params.update({f"{key}__{k}": v for k, v in deep_items}) else: params[key] = value return params
[docs] def set_params(self, **params): """ Set parameters of the estimator. Args: **params: Arbitrary keyword arguments of parameters to set. Returns: self: Returns the instance itself. """ for key, value in params.items(): if not hasattr(self, key): raise ValueError(f"Invalid parameter: {key}") setattr(self, key, value) return self
[docs] def validate_matrices(self, matrices): """ Validate that matrices have consistent shapes. Args: matrices (dict): A dictionary of matrix names and their values. Raises: ValueError: If the matrices are inconsistent. """ for name, matrix in matrices.items(): if not isinstance(matrix, (np.ndarray, da.Array)): raise ValueError(f"{name} must be a NumPy or Dask array.")
def __repr__(self): """ String representation of the estimator. Returns: str: A string representation of the estimator. """ return f"{self.name}({self.get_params(deep=False)})"