Skip to content

API Reference

This page provides auto-generated API documentation from Python docstrings. All classes and modules are documented using mkdocstrings, which extracts comprehensive information directly from the source code including method signatures, parameters, return types, and detailed descriptions.

Synthesizers

The synthesizer module provides multiple approaches for generating synthetic omics data, all implementing a unified interface through the BaseSynthesizer class.

BaseSynthesizer

Base class providing a consistent pipeline: preprocess -> fit -> sample -> postprocess -> save.

Subclasses should override
  • preprocess (if they need data transformations before fitting, e.g., Gaussian Copula)
  • fit (required)
  • sample (required)
  • postprocess (optional to extend/override base behavior)
Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
class BaseSynthesizer:
    """
    Base class providing a consistent pipeline: preprocess -> fit -> sample -> postprocess -> save.

    Subclasses should override:
      - preprocess (if they need data transformations before fitting, e.g., Gaussian Copula)
      - fit (required)
      - sample (required)
      - postprocess (optional to extend/override base behavior)
    """

    def __init__(self, output_path: str, metadata: Optional[Dict[str, str]] = None) -> None:
        """
        Initialize the BaseSynthesizer.

        Args:
            output_path (str): Path to save outputs and logs.
            metadata (dict, optional): Metadata dictionary containing column type information.
        """
        self.output_path = output_path
        os.makedirs(self.output_path, exist_ok=True)
        logger_name = f"{self.__class__.__name__}_{id(self)}"
        log_file_name = f"{self.__class__.__name__}_{id(self)}.log"
        self.emissions_file_name = f"{self.__class__.__name__}_{id(self)}.csv"
        self.logger = set_logger(logger_name, self.output_path, log_file_name)
        self.model: Any = None
        self.metadata: Optional[Dict[str, str]] = metadata

        # Run-start info
        self.logger.info("========== Synthesizer Initialized ==========")
        self.logger.info(f"Class: {self.__class__.__name__}")
        self.logger.info(f"Output path: {self.output_path}")
        self.logger.info(f"Log file: {os.path.join(self.output_path, log_file_name)}")
        if self.metadata is not None:
            self.logger.info(f"Metadata provided with {len(self.metadata)} columns.")
        else:
            self.logger.info("No metadata dictionary provided at initialization.")
        self.logger.info("============================================")

    def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess input data for synthesizer.

        Args:
            data (pd.DataFrame): Input data.

        Returns:
            pd.DataFrame: Preprocessed data (default: returns input unchanged).
        """
        return data  # Override in subclass if needed (e.g., Gaussian Copula)

    def set_metadata(self, metadata: Dict[str, str]) -> None:
        """
        Set or update the metadata dictionary.

        Args:
            metadata (dict): Metadata dictionary containing column type information.

        Returns:
            None
        """
        self.metadata = metadata
        self.logger.info(f"Metadata updated with {len(metadata)} columns.")

    # Column detection utils use self.metadata directly
    def detect_discrete_columns(self, data: pd.DataFrame) -> List[str]:
        """
        Detect discrete columns from data and self.metadata.

        Args:
            data (pd.DataFrame): Input data.

        Returns:
            list: List of discrete column names.

        Raises:
            ValueError: If metadata is not set.
        """
        if self.metadata is None:
            raise ValueError("Metadata dictionary must be set before detecting discrete columns.")
        return _detect_discrete_columns(data, self.metadata)

    def detect_numerical_columns(self, data: pd.DataFrame) -> List[str]:
        """
        Detect numerical columns from data and self.metadata.

        Args:
            data (pd.DataFrame): Input data.

        Returns:
            list: List of numerical column names.

        Raises:
            ValueError: If metadata is not set.
        """
        if self.metadata is None:
            raise ValueError("Metadata dictionary must be set before detecting numerical columns.")
        return _detect_numerical_columns(data, self.metadata)

    # Thin wrappers to centralize calls (keeps subclasses tidy)
    def apply_rounding(self, data: pd.DataFrame, numerical_columns: Sequence[str], digits: Dict[str, int]) -> pd.DataFrame:
        return apply_rounding(data, numerical_columns, digits)

    def apply_min_max(
        self, data: pd.DataFrame, numerical_columns: Sequence[str], min_vals: Dict[str, float], max_vals: Dict[str, float]
    ) -> pd.DataFrame:
        return apply_min_max(data, numerical_columns, min_vals, max_vals)

    def detect_min_max_values(self, data: pd.DataFrame, numerical_columns: Sequence[str]):
        return _detect_min_max_values(data, numerical_columns)

    def detect_rounding_digits(self, data: pd.DataFrame, numerical_columns: Sequence[str]) -> Dict[str, int]:
        return _detect_rounding_digits(data, numerical_columns)

    def anonymize_ids(self, ids: Sequence[Any], synthetic_data: pd.DataFrame) -> pd.DataFrame:
        return anonymize_ids(ids, synthetic_data, self.output_path)

    def save_synthetic_data(self, synthetic_data: Union[pd.DataFrame, List[pd.DataFrame]], filename: str, index: bool = False) -> None:
        """
        Save synthetic data to CSV. Supports a single DataFrame or a list (e.g., Synthpop m>1).

        Args:
            synthetic_data (DataFrame or list[DataFrame]): Data to save.
            filename (str): Base filename for saving.
            index (bool): Whether to write DataFrame index.
        """
        if isinstance(synthetic_data, list):
            for i, df in enumerate(synthetic_data, 1):
                out = filename if filename.lower().endswith(".csv") else f"{filename}.csv"
                stem, ext = os.path.splitext(out)
                path = os.path.join(self.output_path, f"{stem}_{i}{ext or '.csv'}")
                df.to_csv(path, index=index)
                self.logger.info(f"Saved synthetic dataset {i} to {path}")
        else:
            out = filename if filename.lower().endswith(".csv") else f"{filename}.csv"
            path = os.path.join(self.output_path, out)
            synthetic_data.to_csv(path, index=index)
            self.logger.info(f"Saved synthetic data to {path}")

    # To be implemented by subclasses
    def fit(self, *args, **kwargs) -> None:
        raise NotImplementedError("fit() must be implemented in subclass.")

    def sample(self, *args, **kwargs) -> Union[pd.DataFrame, List[pd.DataFrame]]:
        raise NotImplementedError("sample() must be implemented in subclass.")

    def postprocess(
        self,
        synthetic_data: pd.DataFrame,
        original_data: pd.DataFrame,
        data_ids: Optional[Sequence[Any]] = None,
        enforce_rounding: bool = True,
        enforce_min_max: bool = True,
        masking: bool = False,
    ) -> pd.DataFrame:
        """
        Postprocess synthetic data, including anonymization, rounding, and min-max scaling.

        Args:
            synthetic_data (pd.DataFrame): Generated synthetic data.
            original_data (pd.DataFrame): Original data for reference.
            data_ids (list, optional): IDs to anonymize.
            enforce_rounding (bool, optional): Apply rounding with digits inferred from original_data.
            enforce_min_max (bool, optional): Clip to min/max observed in original_data.
            masking (bool, optional): If True and mask_func is provided, apply it to introduce missingness.

        Returns:
            pd.DataFrame: Postprocessed synthetic data.

        Raises:
            ValueError: If metadata is not set.
        """
        if self.metadata is None:
            raise ValueError("Metadata dictionary must be set before postprocessing.")

        sdf = synthetic_data.copy()

        # Anonymize IDs
        if data_ids:
            self.logger.info("Anonymizing IDs in synthetic data")
            sdf = self.anonymize_ids(data_ids[: len(sdf)], sdf)
        else:
            self.logger.info("No data IDs provided for anonymization. Skipping ID anonymization.")

        # Numeric constraints
        if enforce_rounding or enforce_min_max:
            numerical_columns = self.detect_numerical_columns(original_data)
            if numerical_columns:
                if enforce_rounding:
                    self.logger.info("Applying rounding to synthetic data")
                    rounding_digits = self.detect_rounding_digits(original_data, numerical_columns)
                    sdf = self.apply_rounding(sdf, numerical_columns, rounding_digits)
                if enforce_min_max:
                    self.logger.info("Enforcing min-max on synthetic data")
                    min_values, max_values = self.detect_min_max_values(original_data, numerical_columns)
                    sdf = self.apply_min_max(sdf, numerical_columns, min_values, max_values)

        # Optional masking hook
        if masking:
            self.logger.info("Applying masking function to synthetic data")
            try:
                sdf = post_masking(sdf)
            except Exception as e:
                self.logger.warning(f"Masking function failed: {e}")

        return sdf

    @_monitor_resources
    # @track_emissions(output_dir = self.output_path, output_file = self.emissions_file_name)
    # @profile
    def generate(
        self,
        data: pd.DataFrame,
        data_ids: Optional[Sequence[Any]] = None,
        enforce_rounding: bool = True,
        enforce_min_max: bool = True,
        masking: bool = False,
        seed: int = 42,
        n_samples: int = 10,
        fit_params: Optional[Dict[str, Any]] = None,
        sample_params: Optional[Dict[str, Any]] = None,
        output_filename: str = "synthetic_data.csv",
        save_index: bool = False) -> Union[pd.DataFrame, List[pd.DataFrame]]:

        fit_params = fit_params or {}
        sample_params = sample_params or {}
        try:
            self.logger.info(
                f"Pipeline started for {self.__class__.__name__} "
                f"(n_samples={n_samples}, fit={fit_params}, sample={sample_params}, "
                f"rounding={enforce_rounding}, minmax={enforce_min_max}, masking={masking})"
            )

            processed_data = self.preprocess(data)
            self.logger.info(f"Preprocessed data shape: {processed_data.shape}")

            # Subclass fit() should consume already-preprocessed data (or ignore if not needed)
            self.fit(processed_data, seed, **fit_params)
            self.logger.info("Model fit complete.")

            # sample() may return DataFrame or list[DataFrame] (e.g., Synthpop for m>1)
            sampled = self.sample(n_samples, seed, **sample_params)
            if isinstance(sampled, list):
                self.logger.info(f"Sampled {len(sampled)} dataset(s).")
                # Postprocess each dataset if a list is returned
                post_list = []
                for i, sdf in enumerate(sampled, 1):
                    self.logger.info(f"Postprocessing dataset {i}")
                    post_list.append(
                        self.postprocess(
                            sdf,
                            original_data=data,  # original (unprocessed) data for constraints
                            data_ids=data_ids,
                            enforce_rounding=enforce_rounding,
                            enforce_min_max=enforce_min_max,
                            masking=masking,
                        )
                    )
                self.save_synthetic_data(post_list, output_filename, index=save_index)
                self.logger.info("Pipeline complete.")
                return post_list
            else:
                self.logger.info(f"Sampled {len(sampled)} rows.")
                processed_synthetic = self.postprocess(
                    sampled,
                    original_data=data,
                    data_ids=data_ids,
                    enforce_rounding=enforce_rounding,
                    enforce_min_max=enforce_min_max,
                    masking=masking,
                )
                self.logger.info(f"Postprocessed synthetic data. Final shape: {processed_synthetic.shape}")
                self.save_synthetic_data(processed_synthetic, output_filename, index=save_index)
                self.logger.info("Pipeline complete.")
                return processed_synthetic

        except Exception as e:
            self.logger.error(f"Error in full pipeline: {e}", exc_info=True)
            raise ValueError(f"Error in generating synthetic data: {e}")

__init__(output_path, metadata=None)

Initialize the BaseSynthesizer.

Parameters:

Name Type Description Default
output_path str

Path to save outputs and logs.

required
metadata dict

Metadata dictionary containing column type information.

None
Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
def __init__(self, output_path: str, metadata: Optional[Dict[str, str]] = None) -> None:
    """
    Initialize the BaseSynthesizer.

    Args:
        output_path (str): Path to save outputs and logs.
        metadata (dict, optional): Metadata dictionary containing column type information.
    """
    self.output_path = output_path
    os.makedirs(self.output_path, exist_ok=True)
    logger_name = f"{self.__class__.__name__}_{id(self)}"
    log_file_name = f"{self.__class__.__name__}_{id(self)}.log"
    self.emissions_file_name = f"{self.__class__.__name__}_{id(self)}.csv"
    self.logger = set_logger(logger_name, self.output_path, log_file_name)
    self.model: Any = None
    self.metadata: Optional[Dict[str, str]] = metadata

    # Run-start info
    self.logger.info("========== Synthesizer Initialized ==========")
    self.logger.info(f"Class: {self.__class__.__name__}")
    self.logger.info(f"Output path: {self.output_path}")
    self.logger.info(f"Log file: {os.path.join(self.output_path, log_file_name)}")
    if self.metadata is not None:
        self.logger.info(f"Metadata provided with {len(self.metadata)} columns.")
    else:
        self.logger.info("No metadata dictionary provided at initialization.")
    self.logger.info("============================================")

detect_discrete_columns(data)

Detect discrete columns from data and self.metadata.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required

Returns:

Name Type Description
list List[str]

List of discrete column names.

Raises:

Type Description
ValueError

If metadata is not set.

Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def detect_discrete_columns(self, data: pd.DataFrame) -> List[str]:
    """
    Detect discrete columns from data and self.metadata.

    Args:
        data (pd.DataFrame): Input data.

    Returns:
        list: List of discrete column names.

    Raises:
        ValueError: If metadata is not set.
    """
    if self.metadata is None:
        raise ValueError("Metadata dictionary must be set before detecting discrete columns.")
    return _detect_discrete_columns(data, self.metadata)

detect_numerical_columns(data)

Detect numerical columns from data and self.metadata.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required

Returns:

Name Type Description
list List[str]

List of numerical column names.

Raises:

Type Description
ValueError

If metadata is not set.

Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def detect_numerical_columns(self, data: pd.DataFrame) -> List[str]:
    """
    Detect numerical columns from data and self.metadata.

    Args:
        data (pd.DataFrame): Input data.

    Returns:
        list: List of numerical column names.

    Raises:
        ValueError: If metadata is not set.
    """
    if self.metadata is None:
        raise ValueError("Metadata dictionary must be set before detecting numerical columns.")
    return _detect_numerical_columns(data, self.metadata)

postprocess(synthetic_data, original_data, data_ids=None, enforce_rounding=True, enforce_min_max=True, masking=False)

Postprocess synthetic data, including anonymization, rounding, and min-max scaling.

Parameters:

Name Type Description Default
synthetic_data DataFrame

Generated synthetic data.

required
original_data DataFrame

Original data for reference.

required
data_ids list

IDs to anonymize.

None
enforce_rounding bool

Apply rounding with digits inferred from original_data.

True
enforce_min_max bool

Clip to min/max observed in original_data.

True
masking bool

If True and mask_func is provided, apply it to introduce missingness.

False

Returns:

Type Description
DataFrame

pd.DataFrame: Postprocessed synthetic data.

Raises:

Type Description
ValueError

If metadata is not set.

Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
def postprocess(
    self,
    synthetic_data: pd.DataFrame,
    original_data: pd.DataFrame,
    data_ids: Optional[Sequence[Any]] = None,
    enforce_rounding: bool = True,
    enforce_min_max: bool = True,
    masking: bool = False,
) -> pd.DataFrame:
    """
    Postprocess synthetic data, including anonymization, rounding, and min-max scaling.

    Args:
        synthetic_data (pd.DataFrame): Generated synthetic data.
        original_data (pd.DataFrame): Original data for reference.
        data_ids (list, optional): IDs to anonymize.
        enforce_rounding (bool, optional): Apply rounding with digits inferred from original_data.
        enforce_min_max (bool, optional): Clip to min/max observed in original_data.
        masking (bool, optional): If True and mask_func is provided, apply it to introduce missingness.

    Returns:
        pd.DataFrame: Postprocessed synthetic data.

    Raises:
        ValueError: If metadata is not set.
    """
    if self.metadata is None:
        raise ValueError("Metadata dictionary must be set before postprocessing.")

    sdf = synthetic_data.copy()

    # Anonymize IDs
    if data_ids:
        self.logger.info("Anonymizing IDs in synthetic data")
        sdf = self.anonymize_ids(data_ids[: len(sdf)], sdf)
    else:
        self.logger.info("No data IDs provided for anonymization. Skipping ID anonymization.")

    # Numeric constraints
    if enforce_rounding or enforce_min_max:
        numerical_columns = self.detect_numerical_columns(original_data)
        if numerical_columns:
            if enforce_rounding:
                self.logger.info("Applying rounding to synthetic data")
                rounding_digits = self.detect_rounding_digits(original_data, numerical_columns)
                sdf = self.apply_rounding(sdf, numerical_columns, rounding_digits)
            if enforce_min_max:
                self.logger.info("Enforcing min-max on synthetic data")
                min_values, max_values = self.detect_min_max_values(original_data, numerical_columns)
                sdf = self.apply_min_max(sdf, numerical_columns, min_values, max_values)

    # Optional masking hook
    if masking:
        self.logger.info("Applying masking function to synthetic data")
        try:
            sdf = post_masking(sdf)
        except Exception as e:
            self.logger.warning(f"Masking function failed: {e}")

    return sdf

preprocess(data)

Preprocess input data for synthesizer.

Parameters:

Name Type Description Default
data DataFrame

Input data.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Preprocessed data (default: returns input unchanged).

Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
87
88
89
90
91
92
93
94
95
96
97
def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
    """
    Preprocess input data for synthesizer.

    Args:
        data (pd.DataFrame): Input data.

    Returns:
        pd.DataFrame: Preprocessed data (default: returns input unchanged).
    """
    return data  # Override in subclass if needed (e.g., Gaussian Copula)

save_synthetic_data(synthetic_data, filename, index=False)

Save synthetic data to CSV. Supports a single DataFrame or a list (e.g., Synthpop m>1).

Parameters:

Name Type Description Default
synthetic_data DataFrame or list[DataFrame]

Data to save.

required
filename str

Base filename for saving.

required
index bool

Whether to write DataFrame index.

False
Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def save_synthetic_data(self, synthetic_data: Union[pd.DataFrame, List[pd.DataFrame]], filename: str, index: bool = False) -> None:
    """
    Save synthetic data to CSV. Supports a single DataFrame or a list (e.g., Synthpop m>1).

    Args:
        synthetic_data (DataFrame or list[DataFrame]): Data to save.
        filename (str): Base filename for saving.
        index (bool): Whether to write DataFrame index.
    """
    if isinstance(synthetic_data, list):
        for i, df in enumerate(synthetic_data, 1):
            out = filename if filename.lower().endswith(".csv") else f"{filename}.csv"
            stem, ext = os.path.splitext(out)
            path = os.path.join(self.output_path, f"{stem}_{i}{ext or '.csv'}")
            df.to_csv(path, index=index)
            self.logger.info(f"Saved synthetic dataset {i} to {path}")
    else:
        out = filename if filename.lower().endswith(".csv") else f"{filename}.csv"
        path = os.path.join(self.output_path, out)
        synthetic_data.to_csv(path, index=index)
        self.logger.info(f"Saved synthetic data to {path}")

set_metadata(metadata)

Set or update the metadata dictionary.

Parameters:

Name Type Description Default
metadata dict

Metadata dictionary containing column type information.

required

Returns:

Type Description
None

None

Source code in src/synomicsbench/synthesizer/BaseSynthesizer.py
 99
100
101
102
103
104
105
106
107
108
109
110
def set_metadata(self, metadata: Dict[str, str]) -> None:
    """
    Set or update the metadata dictionary.

    Args:
        metadata (dict): Metadata dictionary containing column type information.

    Returns:
        None
    """
    self.metadata = metadata
    self.logger.info(f"Metadata updated with {len(metadata)} columns.")

CTGANsynthesizer

Bases: BaseSynthesizer

CTGANsynthesizer for generating synthetic data using CTGAN.

Parameters:

Name Type Description Default
output_path str

Path to save outputs.

required
metadata dict

Metadata dictionary containing column type information.

None

Methods:

Name Description
fit

Train the CTGAN model.

sample

Generate synthetic samples.

postprocess

Apply postprocessing (anonymize IDs, rounding, scaling).

generate

Orchestrate data synthesis pipeline.

Source code in src/synomicsbench/synthesizer/CTGANsynthesizer.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class CTGANsynthesizer(BaseSynthesizer):
    """
    CTGANsynthesizer for generating synthetic data using CTGAN.

    Args:
        output_path (str): Path to save outputs.
        metadata (dict, optional): Metadata dictionary containing column type information.

    Methods:
        fit: Train the CTGAN model.
        sample: Generate synthetic samples.
        postprocess: Apply postprocessing (anonymize IDs, rounding, scaling).
        generate: Orchestrate data synthesis pipeline.
    """
    def _seed_all(self, seed: Optional[int] = None, use_cuda: bool = False) -> None:
        """Seed Python, NumPy, and Torch (if available) to improve reproducibility."""
        if seed is None:
            return
        random.seed(seed)
        np.random.seed(seed)
        try:
            import torch
            torch.manual_seed(seed)
            if use_cuda and torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        except Exception:
            # Torch not installed or not available; ignore
            pass
    def fit(self, 
            data: pd.DataFrame,
            seed:  Optional[int] = None,
            *,
            epochs: int=100, 
            verbose: bool=True, 
            cuda: bool = True, 
            **kwargs):
        """
        Train the CTGAN model on provided data.

        Args:
            data (pd.DataFrame): DataFrame for training.
            seed (int, optional): Random seed used for reproducible training.
            epochs (int): Number of training epochs.
            verbose (bool): Verbosity flag.
            cuda (bool): Use GPU if True.
            **kwargs: Extra CTGAN parameters.

        Returns:
            None

        Raises:
            ValueError: If metadata is not set, or fitting fails.
        """
        if self.metadata is None:
            raise ValueError("Metadata dictionary must be set before fitting the model.")
        if not isinstance(data, pd.DataFrame) or data.empty:
            raise ValueError("Input 'data' must be a non-empty pandas DataFrame.")
        self.logger.info(f"Starting CTGAN training (epochs={epochs}, cuda={cuda}, verbose={verbose})")    
        discrete_columns = self.detect_discrete_columns(data)
        self.logger.info(f"Detected discrete columns: {discrete_columns}")

        nested = kwargs.pop("kwargs", None)
        if nested is not None:
            if not isinstance(nested, dict):
                self.logger.warning("The 'kwargs' entry in fit_params is not a dict and will be ignored.")
            else:
                # Merge nested kwargs with top-level kwargs. Nested keys override top-level ones.
                kwargs = {**kwargs, **nested}

        self._seed_all(seed, use_cuda=cuda)
        self.model = CTGAN(
            epochs=epochs,
            verbose=verbose,
            **kwargs
        )
        try:
            if seed is not None and hasattr(self.model, "set_random_state"):
                self.model.set_random_state(seed)
        except Exception:
            pass

        self.model.fit(data, discrete_columns=discrete_columns)
        self.logger.info("CTGAN training completed")

        model_path = os.path.join(self.output_path, "CTGAN_model.pkl")
        try:
            self.model.save(model_path)
            self.logger.info(f"Saved CTGAN model to {model_path}")
        except Exception as e:
            self.logger.warning(f"Could not save CTGAN model: {e}")

    def sample(self, n_samples: int = 10, seed:  Optional[int] = None, **kwargs) -> pd.DataFrame:
        """
        Generate synthetic samples from fitted CTGAN model.

        Args:
            n_samples (int): Number of samples to generate.
            seed (int, optional): Random seed used for sampling.
            **kwargs: Extra parameters for model.sample().

        Returns:
            pd.DataFrame: Synthetic samples.

        Raises:
            ValueError: If model not fitted or sampling fails.
        """
        if not hasattr(self, "model") or self.model is None:
            self.logger.error("Model is not fitted. Call fit() first.")
            raise ValueError("Model is not fitted. Call fit() first.")
        # Seed all relevant RNGs
        self._seed_all(seed, use_cuda=getattr(self.model, "cuda", False))

        # Also try to set model RNG if supported
        try:
            if hasattr(self.model, "set_random_state"):
                self.model.set_random_state(seed)
        except Exception:
            pass
        # random.seed(seed)
        # np.random.seed(seed)
        # self.model.set_random_state(seed)
        synthetic_data = self.model.sample(n_samples, **kwargs)
        self.logger.info(f"Generated {n_samples} synthetic samples")
        return synthetic_data

fit(data, seed=None, *, epochs=100, verbose=True, cuda=True, **kwargs)

Train the CTGAN model on provided data.

Parameters:

Name Type Description Default
data DataFrame

DataFrame for training.

required
seed int

Random seed used for reproducible training.

None
epochs int

Number of training epochs.

100
verbose bool

Verbosity flag.

True
cuda bool

Use GPU if True.

True
**kwargs

Extra CTGAN parameters.

{}

Returns:

Type Description

None

Raises:

Type Description
ValueError

If metadata is not set, or fitting fails.

Source code in src/synomicsbench/synthesizer/CTGANsynthesizer.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def fit(self, 
        data: pd.DataFrame,
        seed:  Optional[int] = None,
        *,
        epochs: int=100, 
        verbose: bool=True, 
        cuda: bool = True, 
        **kwargs):
    """
    Train the CTGAN model on provided data.

    Args:
        data (pd.DataFrame): DataFrame for training.
        seed (int, optional): Random seed used for reproducible training.
        epochs (int): Number of training epochs.
        verbose (bool): Verbosity flag.
        cuda (bool): Use GPU if True.
        **kwargs: Extra CTGAN parameters.

    Returns:
        None

    Raises:
        ValueError: If metadata is not set, or fitting fails.
    """
    if self.metadata is None:
        raise ValueError("Metadata dictionary must be set before fitting the model.")
    if not isinstance(data, pd.DataFrame) or data.empty:
        raise ValueError("Input 'data' must be a non-empty pandas DataFrame.")
    self.logger.info(f"Starting CTGAN training (epochs={epochs}, cuda={cuda}, verbose={verbose})")    
    discrete_columns = self.detect_discrete_columns(data)
    self.logger.info(f"Detected discrete columns: {discrete_columns}")

    nested = kwargs.pop("kwargs", None)
    if nested is not None:
        if not isinstance(nested, dict):
            self.logger.warning("The 'kwargs' entry in fit_params is not a dict and will be ignored.")
        else:
            # Merge nested kwargs with top-level kwargs. Nested keys override top-level ones.
            kwargs = {**kwargs, **nested}

    self._seed_all(seed, use_cuda=cuda)
    self.model = CTGAN(
        epochs=epochs,
        verbose=verbose,
        **kwargs
    )
    try:
        if seed is not None and hasattr(self.model, "set_random_state"):
            self.model.set_random_state(seed)
    except Exception:
        pass

    self.model.fit(data, discrete_columns=discrete_columns)
    self.logger.info("CTGAN training completed")

    model_path = os.path.join(self.output_path, "CTGAN_model.pkl")
    try:
        self.model.save(model_path)
        self.logger.info(f"Saved CTGAN model to {model_path}")
    except Exception as e:
        self.logger.warning(f"Could not save CTGAN model: {e}")

sample(n_samples=10, seed=None, **kwargs)

Generate synthetic samples from fitted CTGAN model.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate.

10
seed int

Random seed used for sampling.

None
**kwargs

Extra parameters for model.sample().

{}

Returns:

Type Description
DataFrame

pd.DataFrame: Synthetic samples.

Raises:

Type Description
ValueError

If model not fitted or sampling fails.

Source code in src/synomicsbench/synthesizer/CTGANsynthesizer.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def sample(self, n_samples: int = 10, seed:  Optional[int] = None, **kwargs) -> pd.DataFrame:
    """
    Generate synthetic samples from fitted CTGAN model.

    Args:
        n_samples (int): Number of samples to generate.
        seed (int, optional): Random seed used for sampling.
        **kwargs: Extra parameters for model.sample().

    Returns:
        pd.DataFrame: Synthetic samples.

    Raises:
        ValueError: If model not fitted or sampling fails.
    """
    if not hasattr(self, "model") or self.model is None:
        self.logger.error("Model is not fitted. Call fit() first.")
        raise ValueError("Model is not fitted. Call fit() first.")
    # Seed all relevant RNGs
    self._seed_all(seed, use_cuda=getattr(self.model, "cuda", False))

    # Also try to set model RNG if supported
    try:
        if hasattr(self.model, "set_random_state"):
            self.model.set_random_state(seed)
    except Exception:
        pass
    # random.seed(seed)
    # np.random.seed(seed)
    # self.model.set_random_state(seed)
    synthetic_data = self.model.sample(n_samples, **kwargs)
    self.logger.info(f"Generated {n_samples} synthetic samples")
    return synthetic_data

TVAEsynthesizer

Bases: BaseSynthesizer

TVAEsynthesizer for generating synthetic data using TVAE.

Parameters:

Name Type Description Default
output_path str

Path to save outputs.

required
metadata dict

Metadata dictionary containing column type information.

None

Methods:

Name Description
fit

Train the TVAE model.

sample

Generate synthetic samples.

postprocess

Apply postprocessing (anonymize IDs, rounding, scaling).

generate

Orchestrate data synthesis pipeline.

Source code in src/synomicsbench/synthesizer/TVAEsynthesizer.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
class TVAEsynthesizer(BaseSynthesizer):
    """
    TVAEsynthesizer for generating synthetic data using TVAE.

    Args:
        output_path (str): Path to save outputs.
        metadata (dict, optional): Metadata dictionary containing column type information.

    Methods:
        fit: Train the TVAE model.
        sample: Generate synthetic samples.
        postprocess: Apply postprocessing (anonymize IDs, rounding, scaling).
        generate: Orchestrate data synthesis pipeline.
    """

    def _seed_all(self, seed: Optional[int] = None, use_cuda: bool = False) -> None:
        """Seed Python, NumPy, and Torch (if available) to improve reproducibility."""
        if seed is None:
            return
        random.seed(seed)
        np.random.seed(seed)
        try:
            import torch
            torch.manual_seed(seed)
            if use_cuda and torch.cuda.is_available():
                torch.cuda.manual_seed_all(seed)
        except Exception:
            # Torch not installed or not available; ignore
            pass

    def fit(self, 
            data: pd.DataFrame, 
            seed:  Optional[int] = None,
            *,
            epochs: int = 100, 
            verbose: bool = True, 
            cuda: bool = True, 
            **kwargs):
        """
        Train the TVAE model on provided data.

        Args:
            data (pd.DataFrame): DataFrame for training.
            epochs (int): Number of training epochs.
            verbose (bool): Verbosity flag.
            cuda (bool): Use GPU if True.
            seed (int, optional): Random seed for reproducibility.
            **kwargs: Extra TVAE parameters.

        Returns:
            None

        Raises:
            ValueError: If metadata is not set, input is invalid, or fitting fails.
        """
        if self.metadata is None:
            raise ValueError("Metadata dictionary must be set before fitting the model.")
        if not isinstance(data, pd.DataFrame) or data.empty:
            raise ValueError("Input 'data' must be a non-empty pandas DataFrame.")

        discrete_columns = self.detect_discrete_columns(data)
        self.logger.info(f"Detected discrete columns: {discrete_columns}")
        self.logger.info(f"Starting TVAE training (epochs={epochs}, cuda={cuda}, verbose={verbose}, seed={seed})")

        nested = kwargs.pop("kwargs", None)
        if nested is not None:
            if not isinstance(nested, dict):
                self.logger.warning("The 'kwargs' entry in fit_params is not a dict and will be ignored.")
            else:
                # Merge nested kwargs with top-level kwargs. Nested keys override top-level ones.
                kwargs = {**kwargs, **nested}
        # Seed everything we can
        self._seed_all(seed, use_cuda=cuda)

        self.model = TVAE(
            epochs=epochs,
            verbose=verbose,
            cuda=cuda,
            **kwargs
        )

        # Some TVAE builds support setting RNG state before training
        try:
            if seed is not None and hasattr(self.model, "set_random_state"):
                self.model.set_random_state(seed)
        except Exception:
            pass

        self.model.fit(data, discrete_columns=discrete_columns)
        self.logger.info("TVAE training completed")

        model_path = os.path.join(self.output_path, "TVAE_model.pkl")
        try:
            self.model.save(model_path)
            self.logger.info(f"Saved TVAE model to {model_path}")
        except Exception as e:
            self.logger.warning(f"Could not save TVAE model: {e}")

    def sample(self, n_samples: int = 10, 
               seed: Optional[int] = None, 
               **kwargs) -> pd.DataFrame:
        """
        Generate synthetic samples from fitted TVAE model.

        Args:
            n_samples (int): Number of samples to generate.
            seed (int): Random seed for reproducibility during sampling.
            **kwargs: Extra parameters for model.sample().

        Returns:
            pd.DataFrame: Synthetic samples.

        Raises:
            ValueError: If model not fitted or sampling fails.
        """
        if not hasattr(self, "model") or self.model is None:
            self.logger.error("Model is not fitted. Call fit() first.")
            raise ValueError("Model is not fitted. Call fit() first.")

        # Seed all relevant RNGs
        self._seed_all(seed, use_cuda=getattr(self.model, "cuda", False))

        # Also try to set model RNG if supported
        try:
            if hasattr(self.model, "set_random_state"):
                self.model.set_random_state(seed)
        except Exception:
            pass

        synthetic_data = self.model.sample(n_samples, **kwargs)
        self.logger.info(f"Generated {n_samples} synthetic samples")
        return synthetic_data

fit(data, seed=None, *, epochs=100, verbose=True, cuda=True, **kwargs)

Train the TVAE model on provided data.

Parameters:

Name Type Description Default
data DataFrame

DataFrame for training.

required
epochs int

Number of training epochs.

100
verbose bool

Verbosity flag.

True
cuda bool

Use GPU if True.

True
seed int

Random seed for reproducibility.

None
**kwargs

Extra TVAE parameters.

{}

Returns:

Type Description

None

Raises:

Type Description
ValueError

If metadata is not set, input is invalid, or fitting fails.

Source code in src/synomicsbench/synthesizer/TVAEsynthesizer.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def fit(self, 
        data: pd.DataFrame, 
        seed:  Optional[int] = None,
        *,
        epochs: int = 100, 
        verbose: bool = True, 
        cuda: bool = True, 
        **kwargs):
    """
    Train the TVAE model on provided data.

    Args:
        data (pd.DataFrame): DataFrame for training.
        epochs (int): Number of training epochs.
        verbose (bool): Verbosity flag.
        cuda (bool): Use GPU if True.
        seed (int, optional): Random seed for reproducibility.
        **kwargs: Extra TVAE parameters.

    Returns:
        None

    Raises:
        ValueError: If metadata is not set, input is invalid, or fitting fails.
    """
    if self.metadata is None:
        raise ValueError("Metadata dictionary must be set before fitting the model.")
    if not isinstance(data, pd.DataFrame) or data.empty:
        raise ValueError("Input 'data' must be a non-empty pandas DataFrame.")

    discrete_columns = self.detect_discrete_columns(data)
    self.logger.info(f"Detected discrete columns: {discrete_columns}")
    self.logger.info(f"Starting TVAE training (epochs={epochs}, cuda={cuda}, verbose={verbose}, seed={seed})")

    nested = kwargs.pop("kwargs", None)
    if nested is not None:
        if not isinstance(nested, dict):
            self.logger.warning("The 'kwargs' entry in fit_params is not a dict and will be ignored.")
        else:
            # Merge nested kwargs with top-level kwargs. Nested keys override top-level ones.
            kwargs = {**kwargs, **nested}
    # Seed everything we can
    self._seed_all(seed, use_cuda=cuda)

    self.model = TVAE(
        epochs=epochs,
        verbose=verbose,
        cuda=cuda,
        **kwargs
    )

    # Some TVAE builds support setting RNG state before training
    try:
        if seed is not None and hasattr(self.model, "set_random_state"):
            self.model.set_random_state(seed)
    except Exception:
        pass

    self.model.fit(data, discrete_columns=discrete_columns)
    self.logger.info("TVAE training completed")

    model_path = os.path.join(self.output_path, "TVAE_model.pkl")
    try:
        self.model.save(model_path)
        self.logger.info(f"Saved TVAE model to {model_path}")
    except Exception as e:
        self.logger.warning(f"Could not save TVAE model: {e}")

sample(n_samples=10, seed=None, **kwargs)

Generate synthetic samples from fitted TVAE model.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate.

10
seed int

Random seed for reproducibility during sampling.

None
**kwargs

Extra parameters for model.sample().

{}

Returns:

Type Description
DataFrame

pd.DataFrame: Synthetic samples.

Raises:

Type Description
ValueError

If model not fitted or sampling fails.

Source code in src/synomicsbench/synthesizer/TVAEsynthesizer.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
def sample(self, n_samples: int = 10, 
           seed: Optional[int] = None, 
           **kwargs) -> pd.DataFrame:
    """
    Generate synthetic samples from fitted TVAE model.

    Args:
        n_samples (int): Number of samples to generate.
        seed (int): Random seed for reproducibility during sampling.
        **kwargs: Extra parameters for model.sample().

    Returns:
        pd.DataFrame: Synthetic samples.

    Raises:
        ValueError: If model not fitted or sampling fails.
    """
    if not hasattr(self, "model") or self.model is None:
        self.logger.error("Model is not fitted. Call fit() first.")
        raise ValueError("Model is not fitted. Call fit() first.")

    # Seed all relevant RNGs
    self._seed_all(seed, use_cuda=getattr(self.model, "cuda", False))

    # Also try to set model RNG if supported
    try:
        if hasattr(self.model, "set_random_state"):
            self.model.set_random_state(seed)
    except Exception:
        pass

    synthetic_data = self.model.sample(n_samples, **kwargs)
    self.logger.info(f"Generated {n_samples} synthetic samples")
    return synthetic_data

GaussianCopulasynthesizer

Bases: BaseSynthesizer

GaussianCopulasynthesizer for generating synthetic data using a Gaussian Copula model.

This refactor aligns the class with BaseSynthesizer
  • preprocess: encodes categorical variables (dummy + ordinal) and keeps numerical/missing indicators
  • fit: trains GaussianMultivariate_Parallel on preprocessed data
  • sample: draws synthetic samples
  • postprocess: decodes categorical variables, enforces constraints, anonymizes IDs
Notes
  • Metadata must be provided as a dictionary to the constructor or via set_metadata(): { "col_name": "ordinal_categorical" | "dummy_categorical" | "missing_categorical" | }
  • Ordinal variables are encoded using sklearn's OrdinalEncoder (0..n_levels-1). During postprocess, synthetic ordinal columns are rounded and clipped to valid ranges derived from encoder categories, then inverse-transformed.
Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
class GaussianCopulasynthesizer(BaseSynthesizer):
    """
    GaussianCopulasynthesizer for generating synthetic data using a Gaussian Copula model.

    This refactor aligns the class with BaseSynthesizer:
      - preprocess: encodes categorical variables (dummy + ordinal) and keeps numerical/missing indicators
      - fit: trains GaussianMultivariate_Parallel on preprocessed data
      - sample: draws synthetic samples
      - postprocess: decodes categorical variables, enforces constraints, anonymizes IDs

    Notes:
      - Metadata must be provided as a dictionary to the constructor or via set_metadata():
          {
              "col_name": "ordinal_categorical" | "dummy_categorical" | "missing_categorical" | <other for numerical>
          }
      - Ordinal variables are encoded using sklearn's OrdinalEncoder (0..n_levels-1).
        During postprocess, synthetic ordinal columns are rounded and clipped to valid ranges
        derived from encoder categories, then inverse-transformed.
    """

    def encode_dummy_cat_features(self,data: pd.DataFrame) -> pd.DataFrame:
        """
        One-hot encode the provided categorical columns (no drop-first, no dummy for NaN).
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrecame")
        try:
            # self.logger.info(f"Dummy encoding {data.shape[1]} features")
            cat_data_encoded = pd.get_dummies(
                data, dtype="int64", dummy_na=False, columns=data.columns
            )
            return cat_data_encoded
        except Exception as e:
            raise ValueError(f"Error encoding categorical features: {e}")

    def encode_ordinal_cat_features(self, data: pd.DataFrame) -> pd.DataFrame:
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        try:
            # self.logger.info(f"Ordinal encoding {data.shape[1]} features")
            self._ordinal_encoder = OrdinalEncoder()
            encoded_values = self._ordinal_encoder.fit_transform(data)
            df_ordinal_encoded = pd.DataFrame(encoded_values, columns=data.columns, index = data.index)
            self._ordinal_valid_max = {
                col: len(cats) - 1 for col, cats in zip(data.columns, self._ordinal_encoder.categories_)
            }
            return df_ordinal_encoded
        except Exception as e:
            raise ValueError(f"Error encoding ordinal categorical features: {e}")


    def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess input data for Gaussian Copula synthesizer.

        - Splits columns by type using self.metadata
        - One-hot encodes dummy categorical columns
        - Ordinal-encodes ordinal categorical columns
        - Keeps numerical and missing indicator columns unchanged

        Args:
            data (pd.DataFrame): The input DataFrame.

        Returns:
            pd.DataFrame: Preprocessed DataFrame.

        Raises:
            ValueError: If metadata is not set.
        """
        if self.metadata is None:
            raise ValueError("Metadata dictionary must be set before preprocessing.")

        # Identify columns by type from metadata and present in data
        self._dummy_cat_cols: List[str] = [
            col for col, typ in self.metadata.items() if typ == "dummy_categorical" and col in data.columns
        ]
        self._ordinal_cat_cols: List[str] = [
            col for col, typ in self.metadata.items() if typ == "ordinal_categorical" and col in data.columns
        ]
        self._missing_indicator_cols: List[str] = [
            col for col, typ in self.metadata.items() if typ == "missing_categorical" and col in data.columns
        ]
        # Everything else present in data is considered numerical
        self._num_cols: List[str] = [
            col for col in data.columns
            if col not in set(self._dummy_cat_cols + self._ordinal_cat_cols + self._missing_indicator_cols)
        ]

        # Encode dummy categorical features
        if self._dummy_cat_cols:
            dummy_cat_encoded_df = self.encode_dummy_cat_features(data[self._dummy_cat_cols])
        else:
            dummy_cat_encoded_df = pd.DataFrame(index=data.index)

        # Encode ordinal categorical features
        if self._ordinal_cat_cols:
            ordinal_cat_cols_encoded_df = self.encode_ordinal_cat_features(data[self._ordinal_cat_cols])
        else:
            ordinal_cat_cols_encoded_df = pd.DataFrame(index=data.index)
            self._ordinal_encoder = None
            self._ordinal_valid_max = {}

        # Keep other columns (numerical and missing indicators)
        drop_cols = self._dummy_cat_cols + self._ordinal_cat_cols
        num_missing_indicators_df = data.drop(columns=drop_cols) if drop_cols else data.copy()

        # Concatenate all processed columns
        processed_data = pd.concat(
            [num_missing_indicators_df, dummy_cat_encoded_df, ordinal_cat_cols_encoded_df], axis=1
        )

        self.logger.info(
            f"Preprocess summary: num={len(self._num_cols)}, dummy={len(self._dummy_cat_cols)}, "
            f"ordinal={len(self._ordinal_cat_cols)}, missing_ind={len(self._missing_indicator_cols)}. "
            f"Processed shape: {processed_data.shape}"
        )
        return processed_data

    def fit(self, 
            data: pd.DataFrame,
            seed: Optional[int] = None,
            n_jobs: int = 8, 
            chunk_size: int = 20, 
            **kwargs) -> None:
        """
        Train the Gaussian Copula model on preprocessed data.

        Args:
            data (pd.DataFrame): Preprocessed DataFrame (output of preprocess()).
            n_jobs (int): Number of parallel worker threads.
            chunk_size (int): Chunk size for parallel fitting.
            **kwargs: Extra GaussianMultivariate_Parallel parameters.

        Returns:
            None
        """
        self.logger.info(
            f"Starting fitting Gaussian Copula Synthesizer by {n_jobs} threads with chunk size {chunk_size}"
        )
        nested = kwargs.pop("kwargs", None)
        if nested is not None:
            if not isinstance(nested, dict):
                self.logger.warning("The 'kwargs' entry in fit_params is not a dict and will be ignored.")
            else:
                # Merge nested kwargs with top-level kwargs. Nested keys override top-level ones.
                kwargs = {**kwargs, **nested}

        self.model = GaussianMultivariate_Parallel(n_jobs=n_jobs, chunk_size=chunk_size, **kwargs)
        self.model.fit(data)

        # Save model to disk
        model_path = os.path.join(self.output_path, "GaussianCopula_model.pkl")
        try:
            with open(model_path, "wb") as f:
                pickle.dump(self.model, f)
            self.logger.info(f"Saved Gaussian Copula model to {model_path}")
        except Exception as e:
            self.logger.warning(f"Could not save Gaussian Copula model: {e}")

    def sample(self, 
               n_samples: int = 10, 
               seed: Optional[int] = None, 
               **kwargs) -> pd.DataFrame:
        """
        Generate synthetic samples from fitted Gaussian Copula model.

        Args:
            n_samples (int): Number of samples to generate.
            seed (int): Random seed for reproducibility.
            **kwargs: Extra parameters for model.sample().

        Returns:
            pd.DataFrame: Synthetic samples.

        Raises:
            ValueError: If model is not fitted.
        """
        if not hasattr(self, "model") or self.model is None:
            self.logger.error("Model is not fitted. Call fit() first.")
            raise ValueError("Model is not fitted. Call fit() first.")
        try:
            self.model.set_random_state(seed)
        except Exception:
            pass
        synthetic_data = self.model.sample(num_rows=n_samples, **kwargs)
        self.logger.info(f"Generated {n_samples} synthetic samples with Gaussian Copula")
        return synthetic_data

    def inverse_dummy_cat_features(self, data: pd.DataFrame, dummy_cat_columns: list) -> pd.DataFrame:
        """
        Inverse-transform one-hot encoded dummy categorical columns back to single categorical columns.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input data must be a pandas DataFrame")
        if not isinstance(dummy_cat_columns, list) or not dummy_cat_columns:
            raise TypeError("dummy_cat_columns must be a non-empty list of column names")
        try:
            # self.logger.debug("Inverse encoding dummy categorical features")
            out = data.copy()
            cat_features_dict = {}
            for cat_feature in dummy_cat_columns:
                for column in data.columns:
                    if column.startswith(cat_feature):
                        cat_features_dict[cat_feature] = cat_features_dict.get(
                            cat_feature, []
                        ) + [column]
            for feature, cols in cat_features_dict.items():
                out[feature] = np.array(cols)[out[cols].to_numpy().argmax(axis=1)]
                out[feature] = out[feature].str.replace(f"{feature}_", "")
                out.drop(columns=cols, inplace=True)
            out = out[dummy_cat_columns].astype("category")
            return out
        except Exception as e:
            raise ValueError(f"Error in inversing encoding categorical features: {e}")

    def inverse_ordinal_cat_features(self, data: pd.DataFrame, ordinal_cat_columns: list) -> pd.DataFrame:
        """
        Perform inverse ordinal encoding on categorical features.
        """
        if not hasattr(self, "_ordinal_encoder") or self._ordinal_encoder is None:
            raise AttributeError("Ordinal encoder not initialized. Run encode_ordinal_cat_features first")
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input data must be a pandas DataFrame")
        try:
            # self.logger.debug("Inverse encoding ordinal categorical features")
            df = data.copy()
            for col in ordinal_cat_columns:
                df[col] = np.round(df[col]).astype(int)
                # Clip to valid range [0, n_levels-1] based on fitted encoder
                vmax = self._ordinal_valid_max.get(col, None)
                if vmax is not None:
                    df[col] = df[col].clip(lower=0, upper=vmax)

            inv = pd.DataFrame(
            self._ordinal_encoder.inverse_transform(df[ordinal_cat_columns]),
            columns=ordinal_cat_columns,
            index=df.index,
        )  
            return inv
        except Exception as e:
            raise ValueError(f"Error in inversing ordinal encoding: {e}")

    def inverse_missing_indicators(self, data: pd.DataFrame, missing_indicators: List[str]) -> pd.DataFrame:
        """
        Threshold missing indicator columns at 0.5 and cast to float.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input data must be a pandas DataFrame")
        if not isinstance(missing_indicators, list):
            raise TypeError("missing_indicators must be a list of column names")
        try:
            out = data.copy()
            for col in missing_indicators:
                if col in out.columns:
                    out[col] = (out[col] > 0.5).astype(float)
            return out
        except Exception as e:
            raise ValueError(f"Error in inversing missing indicators: {e}")

    def postprocess(
        self,
        synthetic_data: pd.DataFrame,
        original_data: pd.DataFrame,
        data_ids: Optional[Sequence[Any]] = None,
        enforce_rounding: bool = True,
        enforce_min_max: bool = True,
        masking: bool = False,
    ) -> pd.DataFrame:
        """
        Inverse the encodings back to original space, then delegate anonymization, rounding,
        min-max clipping, and optional masking to BaseSynthesizer.postprocess.
        """
        try:
            sdf = synthetic_data.copy()

            # 1) Inverse dummy categorical features
            if getattr(self, "_dummy_cat_cols", []):
                dummy_prefixes = [f"{c}_" for c in self._dummy_cat_cols]
                dummy_cols_present = [c for c in sdf.columns if any(c.startswith(p) for p in dummy_prefixes)]
                inv_dummy_df = (
                    self.inverse_dummy_cat_features(sdf[dummy_cols_present], self._dummy_cat_cols)
                    if dummy_cols_present
                    else pd.DataFrame(index=sdf.index)
                )
            else:
                inv_dummy_df = pd.DataFrame(index=sdf.index)

            # 2) Inverse ordinal categorical features
            if getattr(self, "_ordinal_cat_cols", []) and getattr(self, "_ordinal_encoder", None) is not None:
                ord_slice = (
                    sdf[self._ordinal_cat_cols].copy()
                    if set(self._ordinal_cat_cols).issubset(sdf.columns)
                    else pd.DataFrame(index=sdf.index)
                )
                ord_slice = ord_slice.reindex(columns=self._ordinal_cat_cols, fill_value=0)
                inv_ord_df = self.inverse_ordinal_cat_features(ord_slice, self._ordinal_cat_cols)
            else:
                inv_ord_df = pd.DataFrame(index=sdf.index)

            # 3) Inverse missing indicators
            if getattr(self, "_missing_indicator_cols", []):
                miss_slice = (
                    sdf[self._missing_indicator_cols].copy()
                    if set(self._missing_indicator_cols).issubset(sdf.columns)
                    else pd.DataFrame(index=sdf.index)
                )
                miss_slice = miss_slice.reindex(columns=self._missing_indicator_cols, fill_value=0.0)
                inv_miss_df = self.inverse_missing_indicators(miss_slice, self._missing_indicator_cols)
            else:
                inv_miss_df = pd.DataFrame(index=sdf.index)

            # 4) Numeric slice
            num_slice_cols = [c for c in getattr(self, "_num_cols", []) if c in sdf.columns]
            num_slice = sdf[num_slice_cols].copy() if num_slice_cols else pd.DataFrame(index=sdf.index)

            # 5) Concatenate back to "original space"
            inverse_full = pd.concat([num_slice, inv_miss_df, inv_ord_df, inv_dummy_df], axis=1)

            # Reorder columns to match original_data first, then append any extras
            common_cols = [col for col in original_data.columns if col in inverse_full.columns]
            inverse_full = inverse_full.reindex(
                columns=common_cols + [c for c in inverse_full.columns if c not in common_cols]
            )

            # Delegate rounding/min-max/anonymization/masking to the base class
            processed = super().postprocess(
                synthetic_data=inverse_full,
                original_data=original_data,
                data_ids=data_ids,
                enforce_rounding=enforce_rounding,
                enforce_min_max=enforce_min_max,
                masking=masking,
            )
            return processed
        except Exception as e:
            raise ValueError(f"Error in post-processing synthetic data: {e}")

encode_dummy_cat_features(data)

One-hot encode the provided categorical columns (no drop-first, no dummy for NaN).

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def encode_dummy_cat_features(self,data: pd.DataFrame) -> pd.DataFrame:
    """
    One-hot encode the provided categorical columns (no drop-first, no dummy for NaN).
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrecame")
    try:
        # self.logger.info(f"Dummy encoding {data.shape[1]} features")
        cat_data_encoded = pd.get_dummies(
            data, dtype="int64", dummy_na=False, columns=data.columns
        )
        return cat_data_encoded
    except Exception as e:
        raise ValueError(f"Error encoding categorical features: {e}")

fit(data, seed=None, n_jobs=8, chunk_size=20, **kwargs)

Train the Gaussian Copula model on preprocessed data.

Parameters:

Name Type Description Default
data DataFrame

Preprocessed DataFrame (output of preprocess()).

required
n_jobs int

Number of parallel worker threads.

8
chunk_size int

Chunk size for parallel fitting.

20
**kwargs

Extra GaussianMultivariate_Parallel parameters.

{}

Returns:

Type Description
None

None

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def fit(self, 
        data: pd.DataFrame,
        seed: Optional[int] = None,
        n_jobs: int = 8, 
        chunk_size: int = 20, 
        **kwargs) -> None:
    """
    Train the Gaussian Copula model on preprocessed data.

    Args:
        data (pd.DataFrame): Preprocessed DataFrame (output of preprocess()).
        n_jobs (int): Number of parallel worker threads.
        chunk_size (int): Chunk size for parallel fitting.
        **kwargs: Extra GaussianMultivariate_Parallel parameters.

    Returns:
        None
    """
    self.logger.info(
        f"Starting fitting Gaussian Copula Synthesizer by {n_jobs} threads with chunk size {chunk_size}"
    )
    nested = kwargs.pop("kwargs", None)
    if nested is not None:
        if not isinstance(nested, dict):
            self.logger.warning("The 'kwargs' entry in fit_params is not a dict and will be ignored.")
        else:
            # Merge nested kwargs with top-level kwargs. Nested keys override top-level ones.
            kwargs = {**kwargs, **nested}

    self.model = GaussianMultivariate_Parallel(n_jobs=n_jobs, chunk_size=chunk_size, **kwargs)
    self.model.fit(data)

    # Save model to disk
    model_path = os.path.join(self.output_path, "GaussianCopula_model.pkl")
    try:
        with open(model_path, "wb") as f:
            pickle.dump(self.model, f)
        self.logger.info(f"Saved Gaussian Copula model to {model_path}")
    except Exception as e:
        self.logger.warning(f"Could not save Gaussian Copula model: {e}")

inverse_dummy_cat_features(data, dummy_cat_columns)

Inverse-transform one-hot encoded dummy categorical columns back to single categorical columns.

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
def inverse_dummy_cat_features(self, data: pd.DataFrame, dummy_cat_columns: list) -> pd.DataFrame:
    """
    Inverse-transform one-hot encoded dummy categorical columns back to single categorical columns.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input data must be a pandas DataFrame")
    if not isinstance(dummy_cat_columns, list) or not dummy_cat_columns:
        raise TypeError("dummy_cat_columns must be a non-empty list of column names")
    try:
        # self.logger.debug("Inverse encoding dummy categorical features")
        out = data.copy()
        cat_features_dict = {}
        for cat_feature in dummy_cat_columns:
            for column in data.columns:
                if column.startswith(cat_feature):
                    cat_features_dict[cat_feature] = cat_features_dict.get(
                        cat_feature, []
                    ) + [column]
        for feature, cols in cat_features_dict.items():
            out[feature] = np.array(cols)[out[cols].to_numpy().argmax(axis=1)]
            out[feature] = out[feature].str.replace(f"{feature}_", "")
            out.drop(columns=cols, inplace=True)
        out = out[dummy_cat_columns].astype("category")
        return out
    except Exception as e:
        raise ValueError(f"Error in inversing encoding categorical features: {e}")

inverse_missing_indicators(data, missing_indicators)

Threshold missing indicator columns at 0.5 and cast to float.

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def inverse_missing_indicators(self, data: pd.DataFrame, missing_indicators: List[str]) -> pd.DataFrame:
    """
    Threshold missing indicator columns at 0.5 and cast to float.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input data must be a pandas DataFrame")
    if not isinstance(missing_indicators, list):
        raise TypeError("missing_indicators must be a list of column names")
    try:
        out = data.copy()
        for col in missing_indicators:
            if col in out.columns:
                out[col] = (out[col] > 0.5).astype(float)
        return out
    except Exception as e:
        raise ValueError(f"Error in inversing missing indicators: {e}")

inverse_ordinal_cat_features(data, ordinal_cat_columns)

Perform inverse ordinal encoding on categorical features.

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def inverse_ordinal_cat_features(self, data: pd.DataFrame, ordinal_cat_columns: list) -> pd.DataFrame:
    """
    Perform inverse ordinal encoding on categorical features.
    """
    if not hasattr(self, "_ordinal_encoder") or self._ordinal_encoder is None:
        raise AttributeError("Ordinal encoder not initialized. Run encode_ordinal_cat_features first")
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input data must be a pandas DataFrame")
    try:
        # self.logger.debug("Inverse encoding ordinal categorical features")
        df = data.copy()
        for col in ordinal_cat_columns:
            df[col] = np.round(df[col]).astype(int)
            # Clip to valid range [0, n_levels-1] based on fitted encoder
            vmax = self._ordinal_valid_max.get(col, None)
            if vmax is not None:
                df[col] = df[col].clip(lower=0, upper=vmax)

        inv = pd.DataFrame(
        self._ordinal_encoder.inverse_transform(df[ordinal_cat_columns]),
        columns=ordinal_cat_columns,
        index=df.index,
    )  
        return inv
    except Exception as e:
        raise ValueError(f"Error in inversing ordinal encoding: {e}")

postprocess(synthetic_data, original_data, data_ids=None, enforce_rounding=True, enforce_min_max=True, masking=False)

Inverse the encodings back to original space, then delegate anonymization, rounding, min-max clipping, and optional masking to BaseSynthesizer.postprocess.

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def postprocess(
    self,
    synthetic_data: pd.DataFrame,
    original_data: pd.DataFrame,
    data_ids: Optional[Sequence[Any]] = None,
    enforce_rounding: bool = True,
    enforce_min_max: bool = True,
    masking: bool = False,
) -> pd.DataFrame:
    """
    Inverse the encodings back to original space, then delegate anonymization, rounding,
    min-max clipping, and optional masking to BaseSynthesizer.postprocess.
    """
    try:
        sdf = synthetic_data.copy()

        # 1) Inverse dummy categorical features
        if getattr(self, "_dummy_cat_cols", []):
            dummy_prefixes = [f"{c}_" for c in self._dummy_cat_cols]
            dummy_cols_present = [c for c in sdf.columns if any(c.startswith(p) for p in dummy_prefixes)]
            inv_dummy_df = (
                self.inverse_dummy_cat_features(sdf[dummy_cols_present], self._dummy_cat_cols)
                if dummy_cols_present
                else pd.DataFrame(index=sdf.index)
            )
        else:
            inv_dummy_df = pd.DataFrame(index=sdf.index)

        # 2) Inverse ordinal categorical features
        if getattr(self, "_ordinal_cat_cols", []) and getattr(self, "_ordinal_encoder", None) is not None:
            ord_slice = (
                sdf[self._ordinal_cat_cols].copy()
                if set(self._ordinal_cat_cols).issubset(sdf.columns)
                else pd.DataFrame(index=sdf.index)
            )
            ord_slice = ord_slice.reindex(columns=self._ordinal_cat_cols, fill_value=0)
            inv_ord_df = self.inverse_ordinal_cat_features(ord_slice, self._ordinal_cat_cols)
        else:
            inv_ord_df = pd.DataFrame(index=sdf.index)

        # 3) Inverse missing indicators
        if getattr(self, "_missing_indicator_cols", []):
            miss_slice = (
                sdf[self._missing_indicator_cols].copy()
                if set(self._missing_indicator_cols).issubset(sdf.columns)
                else pd.DataFrame(index=sdf.index)
            )
            miss_slice = miss_slice.reindex(columns=self._missing_indicator_cols, fill_value=0.0)
            inv_miss_df = self.inverse_missing_indicators(miss_slice, self._missing_indicator_cols)
        else:
            inv_miss_df = pd.DataFrame(index=sdf.index)

        # 4) Numeric slice
        num_slice_cols = [c for c in getattr(self, "_num_cols", []) if c in sdf.columns]
        num_slice = sdf[num_slice_cols].copy() if num_slice_cols else pd.DataFrame(index=sdf.index)

        # 5) Concatenate back to "original space"
        inverse_full = pd.concat([num_slice, inv_miss_df, inv_ord_df, inv_dummy_df], axis=1)

        # Reorder columns to match original_data first, then append any extras
        common_cols = [col for col in original_data.columns if col in inverse_full.columns]
        inverse_full = inverse_full.reindex(
            columns=common_cols + [c for c in inverse_full.columns if c not in common_cols]
        )

        # Delegate rounding/min-max/anonymization/masking to the base class
        processed = super().postprocess(
            synthetic_data=inverse_full,
            original_data=original_data,
            data_ids=data_ids,
            enforce_rounding=enforce_rounding,
            enforce_min_max=enforce_min_max,
            masking=masking,
        )
        return processed
    except Exception as e:
        raise ValueError(f"Error in post-processing synthetic data: {e}")

preprocess(data)

Preprocess input data for Gaussian Copula synthesizer.

  • Splits columns by type using self.metadata
  • One-hot encodes dummy categorical columns
  • Ordinal-encodes ordinal categorical columns
  • Keeps numerical and missing indicator columns unchanged

Parameters:

Name Type Description Default
data DataFrame

The input DataFrame.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Preprocessed DataFrame.

Raises:

Type Description
ValueError

If metadata is not set.

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
    """
    Preprocess input data for Gaussian Copula synthesizer.

    - Splits columns by type using self.metadata
    - One-hot encodes dummy categorical columns
    - Ordinal-encodes ordinal categorical columns
    - Keeps numerical and missing indicator columns unchanged

    Args:
        data (pd.DataFrame): The input DataFrame.

    Returns:
        pd.DataFrame: Preprocessed DataFrame.

    Raises:
        ValueError: If metadata is not set.
    """
    if self.metadata is None:
        raise ValueError("Metadata dictionary must be set before preprocessing.")

    # Identify columns by type from metadata and present in data
    self._dummy_cat_cols: List[str] = [
        col for col, typ in self.metadata.items() if typ == "dummy_categorical" and col in data.columns
    ]
    self._ordinal_cat_cols: List[str] = [
        col for col, typ in self.metadata.items() if typ == "ordinal_categorical" and col in data.columns
    ]
    self._missing_indicator_cols: List[str] = [
        col for col, typ in self.metadata.items() if typ == "missing_categorical" and col in data.columns
    ]
    # Everything else present in data is considered numerical
    self._num_cols: List[str] = [
        col for col in data.columns
        if col not in set(self._dummy_cat_cols + self._ordinal_cat_cols + self._missing_indicator_cols)
    ]

    # Encode dummy categorical features
    if self._dummy_cat_cols:
        dummy_cat_encoded_df = self.encode_dummy_cat_features(data[self._dummy_cat_cols])
    else:
        dummy_cat_encoded_df = pd.DataFrame(index=data.index)

    # Encode ordinal categorical features
    if self._ordinal_cat_cols:
        ordinal_cat_cols_encoded_df = self.encode_ordinal_cat_features(data[self._ordinal_cat_cols])
    else:
        ordinal_cat_cols_encoded_df = pd.DataFrame(index=data.index)
        self._ordinal_encoder = None
        self._ordinal_valid_max = {}

    # Keep other columns (numerical and missing indicators)
    drop_cols = self._dummy_cat_cols + self._ordinal_cat_cols
    num_missing_indicators_df = data.drop(columns=drop_cols) if drop_cols else data.copy()

    # Concatenate all processed columns
    processed_data = pd.concat(
        [num_missing_indicators_df, dummy_cat_encoded_df, ordinal_cat_cols_encoded_df], axis=1
    )

    self.logger.info(
        f"Preprocess summary: num={len(self._num_cols)}, dummy={len(self._dummy_cat_cols)}, "
        f"ordinal={len(self._ordinal_cat_cols)}, missing_ind={len(self._missing_indicator_cols)}. "
        f"Processed shape: {processed_data.shape}"
    )
    return processed_data

sample(n_samples=10, seed=None, **kwargs)

Generate synthetic samples from fitted Gaussian Copula model.

Parameters:

Name Type Description Default
n_samples int

Number of samples to generate.

10
seed int

Random seed for reproducibility.

None
**kwargs

Extra parameters for model.sample().

{}

Returns:

Type Description
DataFrame

pd.DataFrame: Synthetic samples.

Raises:

Type Description
ValueError

If model is not fitted.

Source code in src/synomicsbench/synthesizer/GaussianCopulasynthesizer.py
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def sample(self, 
           n_samples: int = 10, 
           seed: Optional[int] = None, 
           **kwargs) -> pd.DataFrame:
    """
    Generate synthetic samples from fitted Gaussian Copula model.

    Args:
        n_samples (int): Number of samples to generate.
        seed (int): Random seed for reproducibility.
        **kwargs: Extra parameters for model.sample().

    Returns:
        pd.DataFrame: Synthetic samples.

    Raises:
        ValueError: If model is not fitted.
    """
    if not hasattr(self, "model") or self.model is None:
        self.logger.error("Model is not fitted. Call fit() first.")
        raise ValueError("Model is not fitted. Call fit() first.")
    try:
        self.model.set_random_state(seed)
    except Exception:
        pass
    synthetic_data = self.model.sample(num_rows=n_samples, **kwargs)
    self.logger.info(f"Generated {n_samples} synthetic samples with Gaussian Copula")
    return synthetic_data

SynthpopSynthesizer

Bases: BaseSynthesizer

SynthpopSynthesizer for generating synthetic data using Synthpop (R) via rpy2.

This implementation aligns with BaseSynthesizer
  • preprocess: pass-through by default
  • fit: stores the training DataFrame for the R call
  • sample: calls R's synthpop::syn, returns a DataFrame or list of DataFrames (for m > 1)
  • postprocess: anonymize IDs, rounding, min-max; works for single or multiple datasets via BaseSynthesizer.generate

Parameters:

Name Type Description Default
output_path str

Directory where outputs are saved.

required
metadata dict

Column-type metadata. Keys are column names, values are one of 'dummy_categorical', 'ordinal_categorical', 'missing_categorical', or a numeric type string.

None
r_home str

Path to R_HOME. If provided, sets os.environ['R_HOME'] at init.

None
r_terminal str

R executable name or path used to configure library search paths. Default 'R'.

'R'
Source code in src/synomicsbench/synthesizer/Synthpopsynthesizer.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
class SynthpopSynthesizer(BaseSynthesizer):
    """
    SynthpopSynthesizer for generating synthetic data using Synthpop (R) via rpy2.

    This implementation aligns with BaseSynthesizer:
      - preprocess: pass-through by default
      - fit: stores the training DataFrame for the R call
      - sample: calls R's synthpop::syn, returns a DataFrame or list of DataFrames (for m > 1)
      - postprocess: anonymize IDs, rounding, min-max; works for single or multiple datasets via BaseSynthesizer.generate

    Args:
        output_path (str): Directory where outputs are saved.
        metadata (dict, optional): Column-type metadata. Keys are column names, values are
            one of 'dummy_categorical', 'ordinal_categorical', 'missing_categorical', or a numeric type string.
        r_home (str, optional): Path to R_HOME. If provided, sets os.environ['R_HOME'] at init.
        r_terminal (str, optional): R executable name or path used to configure library search paths. Default 'R'.
    """

    def __init__(
        self,
        output_path: str,
        metadata: Optional[dict] = None,
        r_home: Optional[str] = None,
        r_terminal: str = "R",
    ) -> None:
        super().__init__(output_path=output_path, metadata=metadata)
        if r_home:
            os.environ["R_HOME"] = r_home
            self.logger.info(f"R_HOME set to: {r_home}")
        self.r_terminal = r_terminal
        self._train_df: Optional[pd.DataFrame] = None

    def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess input data for Synthpop (default: pass-through).

        Args:
            data (pd.DataFrame): Input data.

        Returns:
            pd.DataFrame: Preprocessed data (unchanged).
        """
        return data

    def fit(self, data: pd.DataFrame, seed: int = 42, **kwargs) -> None:
        """
        Store the training data for later use in sample().

        Args:
            data (pd.DataFrame): Preprocessed training data.
            **kwargs: Unused; for API compatibility.

        Returns:
            None
        """
        if not isinstance(data, pd.DataFrame) or data.empty:
            raise ValueError("Input 'data' must be a non-empty pandas DataFrame.")
        seed = seed
        self._train_df = data.copy()
        self.logger.info(f"Stored training data for Synthpop with shape {self._train_df.shape}")

    def _ensure_r_libpath(self) -> None:
        """
        Ensure R library path is available in the current session by forwarding .libPaths().
        """
        # If user provides a custom lib path (via R), we can augment here by querying from R.
        # Many environments work without overriding .libPaths(), so we keep this minimal.
        try:
            robjects.r(".libPaths()")  # no-op, ensures R is alive
        except Exception as e:
            self.logger.warning(f"Could not query R library paths: {e}")

    def sample(
        self,
        n_samples: Union[int, str] = "auto",
        seed: int = 42,
        *,
        discrete_columns: Optional[List[str]] = None,
        method: str = "cart",
        minimumlevels: int = 3,
        proper: bool = False,
        n_datasets: int = 1,
        visit_sequence: Optional[List[str]] = None,
        cont_na: Optional[dict] = None,
        verbose: bool = True,
        predictor_matrix: Optional[pd.DataFrame] = None,
        **kwargs,
    ) -> Union[pd.DataFrame, List[pd.DataFrame]]:
        """
        Call R's synthpop::syn to generate synthetic data.

        Args:
            n_samples (int or 'auto'): Number of rows to synthesize; 'auto' uses training size.
            discrete_columns (list[str], optional): categorical columns.
            method (str): Synthpop synthesis method (e.g., 'cart', 'parametric', ...).
            minimumlevels (int): Minimum levels for categorical variables.
            proper (bool): Proper synthesis flag.
            n_datasets (int): Number of synthetic datasets (m).
            visit_sequence (list[str], optional): Variable visit sequence.
            cont_na (dict, optional): Settings for NA handling of continuous variables.
            seed (int): Random seed for reproducibility on the R side.
            verbose (bool): Verbosity for R synthpop.
            predictor_matrix (pd.DataFrame, optional): Square 0/1 matrix restricting predictors.
                Must have the same index and columns as training data columns.


        Returns:
            pd.DataFrame or list[pd.DataFrame]: Synthetic dataset(s).

        Raises:
            RuntimeError: If R synthesis fails.
            ValueError: If fit() has not been called or inputs are invalid.
        """
        if self._train_df is None:
            raise ValueError("Call fit(data) before sample().")
        df = self._train_df.reset_index(drop=True)

        # Resolve n_samples
        if n_samples == "auto":
            n_samples = df.shape[0]
        if not isinstance(n_samples, int) or n_samples <= 0:
            raise ValueError("n_samples must be a positive integer or 'auto'.")

        # Resolve discrete columns
        # disc_cols = self._build_discrete_columns(df, override=discrete_columns)

        # Validate predictor_matrix
        r_predictor_matrix = robjects.NULL
        if predictor_matrix is not None:
            if not isinstance(predictor_matrix, pd.DataFrame):
                raise ValueError("predictor_matrix must be a pandas DataFrame if provided.")
            if list(predictor_matrix.columns) != list(df.columns) or list(predictor_matrix.index) != list(df.columns):
                raise ValueError("predictor_matrix must have the same columns and index as the training DataFrame.")
                # Extract row and column names
            row_names = list(predictor_matrix.index)
            col_names = list(predictor_matrix.columns)

            # Extract matrix values and convert to R-compatible format
            pm_np = predictor_matrix.values.astype(int)
            nr, nc = pm_np.shape
            r_values = robjects.IntVector(pm_np.flatten(order='F'))

            # Create the R matrix
            r_matrix = robjects.r['matrix'](r_values, nrow=nr, ncol=nc)

            # Assign dimnames (row and column names) to the matrix
            r_matrix.do_slot_assign(
                "dimnames",
                robjects.ListVector([
                    robjects.StrVector(row_names),  # First element: Row names
                    robjects.StrVector(col_names)   # Second element: Column names
                ])
            )

            r_predictor_matrix = r_matrix
        else:
            r_predictor_matrix = robjects.NULL

        # R helper for synthesis
        r_code = """
        suppressPackageStartupMessages(library(synthpop))
        library(dplyr)
        create_synthetic_data <- function(data, seed, discrete_columns, minnumlevels, 
                                       method, proper, m, k, visit_sequence, cont_na, verbose, predictorMatrix = NULL) {
                cat_features <- discrete_columns
                cat_features <- trimws(cat_features)
                data[cat_features] <- lapply(data[cat_features], factor)

                if (is.null(visit_sequence)) {
                    visit_seq <- colnames(data)
                } else {
                    visit_seq <- trimws(visit_sequence)
                }

                if (is.null(cont_na)) {
                    cont_na <- NULL
                }

                set.seed(seed)
                synthpop_clinical_res <- syn(
                    data, 
                    minnumlevels = minnumlevels,
                    method = method,
                    proper = proper,
                    m = m,
                    k = if (is.null(k)) nrow(data) else k,
                    visit.sequence = visit_seq,
                    cont.na = cont_na,
                    print.flag = verbose,
                    predictor.matrix = predictorMatrix
                )

                if (m > 1) {
                    synth_data <- synthpop_clinical_res$syn
                } else {
                    synth_data <- synthpop_clinical_res$syn
                }

                # Convert factors to characters before returning to Python
                if (m > 1) {
                    synth_data <- lapply(synth_data, function(df) {
                        df[] <- lapply(df, function(x) {
                            if (is.factor(x)) as.character(x) else x
                        })
                        return(df)
                    })
                } else {
                    synth_data[] <- lapply(synth_data, function(x) {
                        if (is.factor(x)) as.character(x) else x
                    })
                }
                return(synth_data)
            }
            """

        try:
            # Ensure R lib path callable (no-op if not needed)
            self._ensure_r_libpath()

            # Load helper into R
            robjects.r(r_code)
            r_func = robjects.globalenv["create_synthetic_data"]

            # Prepare R args
            r_discrete_columns = robjects.StrVector(discrete_columns)
            r_visit_sequence = robjects.StrVector(visit_sequence) if visit_sequence else robjects.NULL
            r_cont_na = robjects.ListVector(cont_na) if cont_na else robjects.NULL

            with conversion.localconverter(default_converter + pandas2ri_converter):
                r_data = conversion.py2rpy(df)

            r_method = robjects.StrVector([method])
            r_proper = robjects.BoolVector([proper])
            r_verbose = robjects.BoolVector([verbose])

            self.logger.info(
                f"Running Synthpop (method={method}, proper={proper}, m={n_datasets}, k={n_samples}, "
                f"minnumlevels={minimumlevels}, seed={seed})"
            )

            r_out = r_func(
                r_data,
                seed=seed,
                discrete_columns=r_discrete_columns,
                minnumlevels=minimumlevels,
                method=r_method,
                proper=r_proper,
                m=n_datasets,
                k=n_samples,
                visit_sequence=r_visit_sequence,
                cont_na=r_cont_na,
                verbose=r_verbose,
                predictorMatrix=r_predictor_matrix,
            )

            if n_datasets == 1:
                with conversion.localconverter(default_converter + pandas2ri_converter):
                    synthetic_df = conversion.rpy2py(r_out)
                self.logger.info(f"Synthpop generated 1 dataset with shape {synthetic_df.shape}")
                return synthetic_df
            else:
                out_list: List[pd.DataFrame] = []
                for i, r_df in enumerate(r_out, start=1):
                    with conversion.localconverter(default_converter + pandas2ri_converter):
                        pdf = conversion.rpy2py(r_df)
                    self.logger.info(f"Synthpop dataset {i}/{n_datasets} shape: {pdf.shape}")
                    out_list.append(pdf)
                return out_list

        except Exception as e:
            raise RuntimeError(f"Synthpop synthesis failed: {e}")

fit(data, seed=42, **kwargs)

Store the training data for later use in sample().

Parameters:

Name Type Description Default
data DataFrame

Preprocessed training data.

required
**kwargs

Unused; for API compatibility.

{}

Returns:

Type Description
None

None

Source code in src/synomicsbench/synthesizer/Synthpopsynthesizer.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def fit(self, data: pd.DataFrame, seed: int = 42, **kwargs) -> None:
    """
    Store the training data for later use in sample().

    Args:
        data (pd.DataFrame): Preprocessed training data.
        **kwargs: Unused; for API compatibility.

    Returns:
        None
    """
    if not isinstance(data, pd.DataFrame) or data.empty:
        raise ValueError("Input 'data' must be a non-empty pandas DataFrame.")
    seed = seed
    self._train_df = data.copy()
    self.logger.info(f"Stored training data for Synthpop with shape {self._train_df.shape}")

preprocess(data)

Preprocess input data for Synthpop (default: pass-through).

Parameters:

Name Type Description Default
data DataFrame

Input data.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Preprocessed data (unchanged).

Source code in src/synomicsbench/synthesizer/Synthpopsynthesizer.py
47
48
49
50
51
52
53
54
55
56
57
def preprocess(self, data: pd.DataFrame) -> pd.DataFrame:
    """
    Preprocess input data for Synthpop (default: pass-through).

    Args:
        data (pd.DataFrame): Input data.

    Returns:
        pd.DataFrame: Preprocessed data (unchanged).
    """
    return data

sample(n_samples='auto', seed=42, *, discrete_columns=None, method='cart', minimumlevels=3, proper=False, n_datasets=1, visit_sequence=None, cont_na=None, verbose=True, predictor_matrix=None, **kwargs)

Call R's synthpop::syn to generate synthetic data.

Parameters:

Name Type Description Default
n_samples int or auto

Number of rows to synthesize; 'auto' uses training size.

'auto'
discrete_columns list[str]

categorical columns.

None
method str

Synthpop synthesis method (e.g., 'cart', 'parametric', ...).

'cart'
minimumlevels int

Minimum levels for categorical variables.

3
proper bool

Proper synthesis flag.

False
n_datasets int

Number of synthetic datasets (m).

1
visit_sequence list[str]

Variable visit sequence.

None
cont_na dict

Settings for NA handling of continuous variables.

None
seed int

Random seed for reproducibility on the R side.

42
verbose bool

Verbosity for R synthpop.

True
predictor_matrix DataFrame

Square 0/1 matrix restricting predictors. Must have the same index and columns as training data columns.

None

Returns:

Type Description
Union[DataFrame, List[DataFrame]]

pd.DataFrame or list[pd.DataFrame]: Synthetic dataset(s).

Raises:

Type Description
RuntimeError

If R synthesis fails.

ValueError

If fit() has not been called or inputs are invalid.

Source code in src/synomicsbench/synthesizer/Synthpopsynthesizer.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def sample(
    self,
    n_samples: Union[int, str] = "auto",
    seed: int = 42,
    *,
    discrete_columns: Optional[List[str]] = None,
    method: str = "cart",
    minimumlevels: int = 3,
    proper: bool = False,
    n_datasets: int = 1,
    visit_sequence: Optional[List[str]] = None,
    cont_na: Optional[dict] = None,
    verbose: bool = True,
    predictor_matrix: Optional[pd.DataFrame] = None,
    **kwargs,
) -> Union[pd.DataFrame, List[pd.DataFrame]]:
    """
    Call R's synthpop::syn to generate synthetic data.

    Args:
        n_samples (int or 'auto'): Number of rows to synthesize; 'auto' uses training size.
        discrete_columns (list[str], optional): categorical columns.
        method (str): Synthpop synthesis method (e.g., 'cart', 'parametric', ...).
        minimumlevels (int): Minimum levels for categorical variables.
        proper (bool): Proper synthesis flag.
        n_datasets (int): Number of synthetic datasets (m).
        visit_sequence (list[str], optional): Variable visit sequence.
        cont_na (dict, optional): Settings for NA handling of continuous variables.
        seed (int): Random seed for reproducibility on the R side.
        verbose (bool): Verbosity for R synthpop.
        predictor_matrix (pd.DataFrame, optional): Square 0/1 matrix restricting predictors.
            Must have the same index and columns as training data columns.


    Returns:
        pd.DataFrame or list[pd.DataFrame]: Synthetic dataset(s).

    Raises:
        RuntimeError: If R synthesis fails.
        ValueError: If fit() has not been called or inputs are invalid.
    """
    if self._train_df is None:
        raise ValueError("Call fit(data) before sample().")
    df = self._train_df.reset_index(drop=True)

    # Resolve n_samples
    if n_samples == "auto":
        n_samples = df.shape[0]
    if not isinstance(n_samples, int) or n_samples <= 0:
        raise ValueError("n_samples must be a positive integer or 'auto'.")

    # Resolve discrete columns
    # disc_cols = self._build_discrete_columns(df, override=discrete_columns)

    # Validate predictor_matrix
    r_predictor_matrix = robjects.NULL
    if predictor_matrix is not None:
        if not isinstance(predictor_matrix, pd.DataFrame):
            raise ValueError("predictor_matrix must be a pandas DataFrame if provided.")
        if list(predictor_matrix.columns) != list(df.columns) or list(predictor_matrix.index) != list(df.columns):
            raise ValueError("predictor_matrix must have the same columns and index as the training DataFrame.")
            # Extract row and column names
        row_names = list(predictor_matrix.index)
        col_names = list(predictor_matrix.columns)

        # Extract matrix values and convert to R-compatible format
        pm_np = predictor_matrix.values.astype(int)
        nr, nc = pm_np.shape
        r_values = robjects.IntVector(pm_np.flatten(order='F'))

        # Create the R matrix
        r_matrix = robjects.r['matrix'](r_values, nrow=nr, ncol=nc)

        # Assign dimnames (row and column names) to the matrix
        r_matrix.do_slot_assign(
            "dimnames",
            robjects.ListVector([
                robjects.StrVector(row_names),  # First element: Row names
                robjects.StrVector(col_names)   # Second element: Column names
            ])
        )

        r_predictor_matrix = r_matrix
    else:
        r_predictor_matrix = robjects.NULL

    # R helper for synthesis
    r_code = """
    suppressPackageStartupMessages(library(synthpop))
    library(dplyr)
    create_synthetic_data <- function(data, seed, discrete_columns, minnumlevels, 
                                   method, proper, m, k, visit_sequence, cont_na, verbose, predictorMatrix = NULL) {
            cat_features <- discrete_columns
            cat_features <- trimws(cat_features)
            data[cat_features] <- lapply(data[cat_features], factor)

            if (is.null(visit_sequence)) {
                visit_seq <- colnames(data)
            } else {
                visit_seq <- trimws(visit_sequence)
            }

            if (is.null(cont_na)) {
                cont_na <- NULL
            }

            set.seed(seed)
            synthpop_clinical_res <- syn(
                data, 
                minnumlevels = minnumlevels,
                method = method,
                proper = proper,
                m = m,
                k = if (is.null(k)) nrow(data) else k,
                visit.sequence = visit_seq,
                cont.na = cont_na,
                print.flag = verbose,
                predictor.matrix = predictorMatrix
            )

            if (m > 1) {
                synth_data <- synthpop_clinical_res$syn
            } else {
                synth_data <- synthpop_clinical_res$syn
            }

            # Convert factors to characters before returning to Python
            if (m > 1) {
                synth_data <- lapply(synth_data, function(df) {
                    df[] <- lapply(df, function(x) {
                        if (is.factor(x)) as.character(x) else x
                    })
                    return(df)
                })
            } else {
                synth_data[] <- lapply(synth_data, function(x) {
                    if (is.factor(x)) as.character(x) else x
                })
            }
            return(synth_data)
        }
        """

    try:
        # Ensure R lib path callable (no-op if not needed)
        self._ensure_r_libpath()

        # Load helper into R
        robjects.r(r_code)
        r_func = robjects.globalenv["create_synthetic_data"]

        # Prepare R args
        r_discrete_columns = robjects.StrVector(discrete_columns)
        r_visit_sequence = robjects.StrVector(visit_sequence) if visit_sequence else robjects.NULL
        r_cont_na = robjects.ListVector(cont_na) if cont_na else robjects.NULL

        with conversion.localconverter(default_converter + pandas2ri_converter):
            r_data = conversion.py2rpy(df)

        r_method = robjects.StrVector([method])
        r_proper = robjects.BoolVector([proper])
        r_verbose = robjects.BoolVector([verbose])

        self.logger.info(
            f"Running Synthpop (method={method}, proper={proper}, m={n_datasets}, k={n_samples}, "
            f"minnumlevels={minimumlevels}, seed={seed})"
        )

        r_out = r_func(
            r_data,
            seed=seed,
            discrete_columns=r_discrete_columns,
            minnumlevels=minimumlevels,
            method=r_method,
            proper=r_proper,
            m=n_datasets,
            k=n_samples,
            visit_sequence=r_visit_sequence,
            cont_na=r_cont_na,
            verbose=r_verbose,
            predictorMatrix=r_predictor_matrix,
        )

        if n_datasets == 1:
            with conversion.localconverter(default_converter + pandas2ri_converter):
                synthetic_df = conversion.rpy2py(r_out)
            self.logger.info(f"Synthpop generated 1 dataset with shape {synthetic_df.shape}")
            return synthetic_df
        else:
            out_list: List[pd.DataFrame] = []
            for i, r_df in enumerate(r_out, start=1):
                with conversion.localconverter(default_converter + pandas2ri_converter):
                    pdf = conversion.rpy2py(r_df)
                self.logger.info(f"Synthpop dataset {i}/{n_datasets} shape: {pdf.shape}")
                out_list.append(pdf)
            return out_list

    except Exception as e:
        raise RuntimeError(f"Synthpop synthesis failed: {e}")

Processing

The processing module handles data integration, preprocessing, metadata management, and gene-level queries for multi-omics datasets.

DataIntegrationPipeline

End-to-end pipeline for processing and integrating clinical and transcriptomics data.

This pipeline: - Cleans data (removes undefined IDs, deduplicates, filters over-missing samples) - Performs feature engineering (over-missing feature filtering, type classification, imputation) - Optionally maps Ensembl gene IDs to HUGO symbols - Integrates processed clinical and transcriptomics data on a common ID - Exports a feature metadata JSON and logs detailed progress for easy monitoring

Parameters:

Name Type Description Default
output_dir str

Directory where logs and outputs (e.g., feature_metadata.json) are saved.

required
logger str

Suffix for the log filename (e.g., Preprocess_{logger}.log).

''

Returns:

Type Description

None

Raises:

Type Description
OSError

If the output directory cannot be created.

Source code in src/synomicsbench/processing/pipeline.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
class DataIntegrationPipeline:
    """
    End-to-end pipeline for processing and integrating clinical and transcriptomics data.

    This pipeline:
    - Cleans data (removes undefined IDs, deduplicates, filters over-missing samples)
    - Performs feature engineering (over-missing feature filtering, type classification, imputation)
    - Optionally maps Ensembl gene IDs to HUGO symbols
    - Integrates processed clinical and transcriptomics data on a common ID
    - Exports a feature metadata JSON and logs detailed progress for easy monitoring

    Args:
        output_dir (str): Directory where logs and outputs (e.g., feature_metadata.json) are saved.
        logger (str): Suffix for the log filename (e.g., Preprocess_{logger}.log).

    Returns:
        None

    Raises:
        OSError: If the output directory cannot be created.
    """

    def __init__(self, output_dir: str, logger: str = ""):
        """
        Initialize the data integration pipeline and configure logging.

        Args:
            output_dir (str): Directory where logs and outputs are stored.
            logger (str): Suffix for the log filename.

        Returns:
            None

        Raises:
            OSError: If the output directory cannot be created.
        """
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)

        logger_name = f"{self.__class__.__name__}_{id(self)}"
        log_file_name = f"Preprocess_{logger}.log"
        self.logger = set_logger(logger_name, self.output_dir, log_file_name)

        # Placeholders set during processing
        self.clinical_num_columns: List[str] = []
        self.clinical_dummy_columns: List[str] = []
        self.transcriptomics_num_columns: List[str] = []

    def process_transcriptomics_data(
        self,
        transcriptomics_data: pd.DataFrame,
        transcriptomics_id_column: str,
        steps_config: Dict,
        overmissing_samples_threshold: float = 50.0,
        overmissing_features_threshold: float = 50.0,
        unique_threshold: int = 10,
        scaler: str = "minmax",
        imputer: str = "mice",
        imputer_params: Optional[Dict[str, Any]] = None,
        low_expression_variance_threshold: float = 0.0005,
        add_indicators: bool = True,
        verbose: bool = True,
    ) -> pd.DataFrame:
        """
        Process raw transcriptomics data with configurable steps.

        Args:
            transcriptomics_data (pd.DataFrame): Raw transcriptomics data.
            transcriptomics_id_column (str): Column name for sample/patient IDs.
            steps_config (dict): Flags controlling which steps to run.
            overmissing_samples_threshold (float): Remove rows with missingness > threshold (0-100).
            overmissing_features_threshold (float): Remove columns with missingness > threshold (0-100).
            unique_threshold (int): Threshold to classify categorical features (not used for transcriptomics).
            scaler (str): Scaler type for numerical features ('minmax', 'standard', 'robust').
            imputer (str): Imputation method to use ('knn' or 'mice').
            imputer_params (dict, optional): Method-specific params. For 'knn': e.g., {'n_neighbors': 5}.
                                            For 'mice': e.g., {'iterations': 20, 'n_estimators': 300}.
            low_expression_variance_threshold (float): Variance cutoff for near-constant genes when
                `remove_low_expression_genes` is enabled.
            add_indicators (bool): Whether to add missingness indicator columns during imputation.
            verbose (bool): Print progress from lower-level utilities.

        Returns:
            pd.DataFrame: Processed transcriptomics data with imputed features and optional indicators.

        Raises:
            KeyError: If transcriptomics_id_column is not found in the input DataFrame.
            TypeError: If transcriptomics_data is not a pandas DataFrame.
            ValueError: On invalid thresholds or processing errors in underlying steps.
        """
        start_total = time.perf_counter()
        if not isinstance(transcriptomics_data, pd.DataFrame):
            raise TypeError("transcriptomics_data must be a pandas DataFrame")
        if transcriptomics_id_column not in transcriptomics_data.columns:
            raise KeyError(
                f"ID column '{transcriptomics_id_column}' not found in transcriptomics data"
            )

        # Initialize gene query helper
        self.gene_querier = GeneQuery(
            fields=["symbol"],
            scopes=["ensemblgene"],
            species=["human"],
            output_dir=self.output_dir,
        )

        data = transcriptomics_data.copy()
        self.logger.info(
            f"[TRANSCRIPTOMICS] Initial shape: rows={data.shape[0]}, cols={data.shape[1]}"
        )

        # Step 1: Remove undefined data (rows with missing IDs)
        if steps_config.get("remove_undefined", True):
            t0 = time.perf_counter()
            before = data.shape[0]
            self.logger.info(
                f"[T1] Removing undefined samples missing ID in '{transcriptomics_id_column}'"
            )
            data = DataProcessor.remove_unknown_entities(
                data=data, id_column=transcriptomics_id_column
            )
            data = data.reset_index(drop=True)
            after = data.shape[0]
            self.logger.info(
                f"[T1] Removed {before - after} rows; new shape: rows={data.shape[0]}, cols={data.shape[1]} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[T1] Skipped remove_undefined")

        # Step 2: Remove duplicates (columns then rows)
        if steps_config.get("remove_duplicates", True):
            # Columns
            t0 = time.perf_counter()
            before_cols = data.shape[1]
            self.logger.info("[T2.1] Removing duplicate columns")
            data = DataProcessor.remove_duplications(data=data, axis=1)
            after_cols = data.shape[1]
            self.logger.info(
                f"[T2.1] Removed {before_cols - after_cols} duplicate columns; cols={after_cols} (took {time.perf_counter() - t0:.2f}s)"
            )

            # Rows
            t0 = time.perf_counter()
            before_rows = data.shape[0]
            self.logger.info("[T2.2] Removing duplicate rows")
            data = DataProcessor.remove_duplications(data=data, axis=0)
            data = data.reset_index(drop=True)
            after_rows = data.shape[0]
            self.logger.info(
                f"[T2.2] Removed {before_rows - after_rows} duplicate rows; rows={after_rows} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[T2] Skipped remove_duplicates")

        # Step 3: Remove over-missing samples
        if steps_config.get("remove_overmissing_samples", True):
            t0 = time.perf_counter()
            before = data.shape[0]
            self.logger.info(
                f"[T3] Removing samples with missingness > {overmissing_samples_threshold}%"
            )
            data = DataProcessor.remove_overmissing_entities(
                data=data, threshold=overmissing_samples_threshold
            )
            data = data.reset_index(drop=True)
            after = data.shape[0]
            self.logger.info(
                f"[T3] Removed {before - after} rows; new rows={after} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[T3] Skipped remove_overmissing_samples")

        # Split into ID and features
        id_data = data[[transcriptomics_id_column]].copy()
        features_data = data.drop(columns=[transcriptomics_id_column])

        # Step 4.0: Remove low-expressed genes (transcriptomics only)
        remove_low_expr = steps_config.get(
            "remove_low_expression_genes",
            steps_config.get("remove_low_expressed_gene", False),
        )
        if remove_low_expr:
            t0 = time.perf_counter()
            before_cols = features_data.shape[1]
            self.logger.info(
                "[T4.0] Removing low-expressed genes (zero-sum or near-zero variance)"
            )
            features_data = DataProcessor.remove_low_expression_genes(
                data=features_data,
                variance_threshold=low_expression_variance_threshold,
            )
            after_cols = features_data.shape[1]
            self.logger.info(
                f"[T4.0] Removed {before_cols - after_cols} gene(s); cols={after_cols} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[T4.0] Skipped remove_low_expression_genes")

        # Step 4.1: Check duplicate genes
        if steps_config.get("check_duplicate_genes", True):
            t0 = time.perf_counter()
            self.logger.info(
                "[T4.1] Checking for genes with identical expression profiles"
            )
            self.dup_genes, self.dup_mapped_genes = self.gene_querier.check_duplicates(
                data=features_data
            )
            n_dup = len(self.dup_genes) if self.dup_genes is not None else 0
            self.logger.info(
                f"[T4.1] Found {n_dup} duplicate gene(s) (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[T4.1] Skipped check_duplicate_genes")

        # Step 4.2: Map Ensembl IDs to HUGO symbols
        if steps_config.get("mapping_genes", True):
            t0 = time.perf_counter()
            self.logger.info("[T4.2] Mapping Ensembl IDs to HUGO symbols")
            self.gene_info_df = self.gene_querier.convert_genes(data=features_data)
            self.logger.info(
                f"[T4.2] Gene mapping produced shape={getattr(self.gene_info_df, 'shape', None)} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[T4.2] Skipped mapping_genes")

        # Step 5: Feature engineering for transcriptomics (all numerical)
        if steps_config.get("feature_engineering", True):
            t0 = time.perf_counter()
            self.logger.info(
                f"[T5] Feature engineering (remove over-missing features > {overmissing_features_threshold}%, impute via {imputer.upper()})"
            )
            features_data, self.transcriptomics_num_columns, _ = (
                DataProcessor.feature_engineering(
                    data=features_data,
                    data_type="transcriptomics",
                    overmissing_threshold=overmissing_features_threshold,
                    ordinal_cat_columns=None,
                    unique_threshold=unique_threshold,
                    scaler=scaler,
                    imputer=imputer,
                    imputer_params=imputer_params or {},
                    add_indicators=add_indicators,
                    verbose=verbose,
                )
            )
            self.logger.info(
                f"[T5] Feature engineering complete; features shape={features_data.shape} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[T5] Skipped feature_engineering")

        processed_data = pd.concat([id_data, features_data], axis=1)
        self.logger.info(
            f"[TRANSCRIPTOMICS] Final shape: rows={processed_data.shape[0]}, cols={processed_data.shape[1]} (total {time.perf_counter() - start_total:.2f}s)"
        )
        return processed_data

    def process_clinical_data(
        self,
        clinical_data: pd.DataFrame,
        clinical_id_column: str,
        steps_config: Dict,
        overmissing_samples_threshold: float = 50.0,
        overmissing_features_threshold: float = 50.0,
        unique_threshold: int = 10,
        scaler: str = "minmax",
        imputer: str = "knn",
        imputer_params: Optional[Dict[str, Any]] = None,
        ordinal_cat_columns: Optional[List[str]] = None,
        add_indicators: bool = True,
        verbose: bool = True,
    ) -> pd.DataFrame:
        """
        Process raw clinical data with configurable steps.

        Args:
            clinical_data (pd.DataFrame): Raw clinical data.
            clinical_id_column (str): Column name for sample/patient IDs.
            steps_config (dict): Flags controlling which steps to run.
            overmissing_samples_threshold (float): Remove rows with missingness > threshold (0-100).
            overmissing_features_threshold (float): Remove columns with missingness > threshold (0-100).
            unique_threshold (int): Unique value threshold to classify categorical features.
            scaler (str): Scaler type for numerical features ('minmax', 'standard', 'robust').
            imputer (str): Imputation method to use ('knn' or 'mice').
            imputer_params (dict, optional): Method-specific params. For 'knn': e.g., {'n_neighbors': 5}.
                                            For 'mice': e.g., {'iterations': 20, 'n_estimators': 300}.
            ordinal_cat_columns (list, optional): Known ordinal categorical columns.
            add_indicators (bool): Whether to add missingness indicator columns during imputation.
            verbose (bool): Print progress from lower-level utilities.

        Returns:
            pd.DataFrame: Processed clinical data with imputed features and indicators.

        Raises:
            KeyError: If clinical_id_column is not found in clinical_data.
            TypeError: If clinical_data is not a pandas DataFrame.
            ValueError: On invalid thresholds or processing errors in underlying steps.
        """
        start_total = time.perf_counter()
        ordinal_cat_columns = ordinal_cat_columns or []

        if not isinstance(clinical_data, pd.DataFrame):
            raise TypeError("clinical_data must be a pandas DataFrame")
        if clinical_id_column not in clinical_data.columns:
            raise KeyError(
                f"ID column '{clinical_id_column}' not found in clinical data"
            )

        data = clinical_data.copy()
        self.logger.info(
            f"[CLINICAL] Initial shape: rows={data.shape[0]}, cols={data.shape[1]}"
        )

        # Step 1: Remove undefined data (rows with missing IDs)
        if steps_config.get("remove_undefined", True):
            t0 = time.perf_counter()
            before = data.shape[0]
            self.logger.info(
                f"[C1] Removing undefined samples missing ID in '{clinical_id_column}'"
            )
            data = DataProcessor.remove_unknown_entities(
                data=data, id_column=clinical_id_column
            )
            data = data.reset_index(drop=True)
            after = data.shape[0]
            self.logger.info(
                f"[C1] Removed {before - after} rows; new shape: rows={after}, cols={data.shape[1]} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[C1] Skipped remove_undefined")

        # Step 2: Remove duplicates (columns then rows)
        if steps_config.get("remove_duplicates", True):
            # Columns
            t0 = time.perf_counter()
            before_cols = data.shape[1]
            self.logger.info("[C2.1] Removing duplicate columns")
            data = DataProcessor.remove_duplications(data=data, axis=1)
            after_cols = data.shape[1]
            self.logger.info(
                f"[C2.1] Removed {before_cols - after_cols} duplicate columns; cols={after_cols} (took {time.perf_counter() - t0:.2f}s)"
            )

            # Rows
            t0 = time.perf_counter()
            before_rows = data.shape[0]
            self.logger.info("[C2.2] Removing duplicate rows")
            data = DataProcessor.remove_duplications(data=data, axis=0)
            data = data.reset_index(drop=True)
            after_rows = data.shape[0]
            self.logger.info(
                f"[C2.2] Removed {before_rows - after_rows} duplicate rows; rows={after_rows} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[C2] Skipped remove_duplicates")

        # Step 3: Remove over-missing samples
        if steps_config.get("remove_overmissing_samples", True):
            t0 = time.perf_counter()
            before = data.shape[0]
            self.logger.info(
                f"[C3] Removing clinical samples with missingness > {overmissing_samples_threshold}%"
            )
            data = DataProcessor.remove_overmissing_entities(
                data=data, threshold=overmissing_samples_threshold
            )
            data = data.reset_index(drop=True)
            after = data.shape[0]
            self.logger.info(
                f"[C3] Removed {before - after} rows; new rows={after} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[C3] Skipped remove_overmissing_samples")

        # Split into ID and features
        id_data = data[[clinical_id_column]].copy()
        features_data = data.drop(columns=[clinical_id_column])

        # Step 4: Feature engineering (classify, remove over-missing features, impute)
        if steps_config.get("feature_engineering", True):
            t0 = time.perf_counter()
            self.logger.info(
                f"[C4] Feature engineering (remove over-missing features > {overmissing_features_threshold}%, classify types, impute via {imputer.upper()})"
            )
            features_data, self.clinical_num_columns, self.clinical_dummy_columns = (
                DataProcessor.feature_engineering(
                    data=features_data,
                    data_type="clinical",
                    overmissing_threshold=overmissing_features_threshold,
                    ordinal_cat_columns=ordinal_cat_columns,
                    unique_threshold=unique_threshold,
                    scaler=scaler,
                    imputer=imputer,
                    imputer_params=imputer_params or {},
                    add_indicators=add_indicators,
                    verbose=verbose,
                )
            )
            self.logger.info(
                f"[C4] Feature engineering complete; features shape={features_data.shape} (took {time.perf_counter() - t0:.2f}s)"
            )
        else:
            self.logger.info("[C4] Skipped feature_engineering")

        processed_data = pd.concat([id_data, features_data], axis=1)
        self.logger.info(
            f"[CLINICAL] Final shape: rows={processed_data.shape[0]}, cols={processed_data.shape[1]} (total {time.perf_counter() - start_total:.2f}s)"
        )
        return processed_data

    def integrate_data(
        self,
        processed_clinical: pd.DataFrame,
        processed_transcriptomics: pd.DataFrame,
        clinical_id_column: str,
        transcriptomics_id_column: str,
        integration_id_column: str,
        steps_config: Dict,
    ) -> Optional[pd.DataFrame]:
        """
        Integrate processed clinical and transcriptomics data on a common ID.

        Args:
            processed_clinical (pd.DataFrame): Processed clinical dataset.
            processed_transcriptomics (pd.DataFrame): Processed transcriptomics dataset.
            clinical_id_column (str): ID column in processed_clinical.
            transcriptomics_id_column (str): ID column in processed_transcriptomics.
            integration_id_column (str): Name for the common ID column after renaming.
            steps_config (dict): Flags controlling whether to integrate.

        Returns:
            pd.DataFrame or None: Integrated dataset if integration is enabled; otherwise None.

        Raises:
            KeyError: If required ID columns are missing in the provided DataFrames.
            TypeError: If inputs are not pandas DataFrames.
        """
        if not isinstance(processed_clinical, pd.DataFrame) or not isinstance(
            processed_transcriptomics, pd.DataFrame
        ):
            raise TypeError(
                "processed_clinical and processed_transcriptomics must be pandas DataFrames"
            )
        if clinical_id_column not in processed_clinical.columns:
            raise KeyError(
                f"ID column '{clinical_id_column}' not found in processed_clinical"
            )
        if transcriptomics_id_column not in processed_transcriptomics.columns:
            raise KeyError(
                f"ID column '{transcriptomics_id_column}' not found in processed_transcriptomics"
            )

        if not steps_config.get("integrate_data", True):
            self.logger.info("[MERGE] Skipping data integration")
            return None

        self.logger.info(
            "[MERGE] Integrating clinical and transcriptomics data by common ID"
        )

        clinical_data = processed_clinical.rename(
            columns={clinical_id_column: integration_id_column}
        )
        transcriptomics_data = processed_transcriptomics.rename(
            columns={transcriptomics_id_column: integration_id_column}
        )

        # Ensure compatible types for join key
        clinical_data[integration_id_column] = clinical_data[
            integration_id_column
        ].astype(str)
        transcriptomics_data[integration_id_column] = transcriptomics_data[
            integration_id_column
        ].astype(str)

        left_count = clinical_data.shape[0]
        right_count = transcriptomics_data.shape[0]
        self.logger.info(f"[MERGE] Left rows={left_count}, Right rows={right_count}")

        integrated_data = pd.merge(
            clinical_data, transcriptomics_data, on=integration_id_column, how="inner"
        )

        matched = integrated_data.shape[0]
        self.logger.info(
            f"[MERGE] Inner join matched rows={matched}; final shape: rows={integrated_data.shape[0]}, cols={integrated_data.shape[1]}"
        )
        return integrated_data

    def run_pipeline(
        self,
        clinical_data: pd.DataFrame,
        transcriptomics_data: pd.DataFrame,
        clinical_id_column: str,
        transcriptomics_id_column: str,
        integration_id_column: str = "PATIENT_ID",
        steps_config: Optional[Dict] = None,
        overmissing_samples_threshold: float = 50.0,
        overmissing_features_threshold: float = 50.0,
        unique_threshold: int = 10,
        scaler: str = "minmax",
        ordinal_cat_columns: Optional[List[str]] = None,
        imputer: str = "mice",
        imputer_params: Optional[Dict[str, Any]] = None,
        low_expression_variance_threshold: float = 0.0005,
        add_indicators: bool = True,
        verbose: bool = True,
    ) -> Dict[str, Optional[pd.DataFrame]]:
        """
        Run the full pipeline across transcriptomics and clinical data and optionally integrate them.

        Args:
            clinical_data (pd.DataFrame): Raw clinical data.
            transcriptomics_data (pd.DataFrame): Raw transcriptomics data.
            clinical_id_column (str): ID column name in clinical_data.
            transcriptomics_id_column (str): ID column name in transcriptomics_data.
            integration_id_column (str): Common ID column name for integration.
            steps_config (dict, optional): Dict specifying which steps to run; defaults enable all steps.
            overmissing_samples_threshold (float): Remove rows with missingness > threshold (0-100).
            overmissing_features_threshold (float): Remove columns with missingness > threshold (0-100).
            unique_threshold (int): Unique value threshold to classify categorical features (clinical).
            scaler (str): Scaler type for numerical features ('minmax', 'standard', 'robust').
            ordinal_cat_columns (list, optional): Ordinal categorical columns in clinical data.
            imputer (str): Imputation method to use ('knn' or 'mice').
            imputer_params (dict, optional): Method-specific params.
            low_expression_variance_threshold (float): Variance cutoff for near-constant genes when
                `remove_low_expression_genes` is enabled.
            add_indicators (bool): Whether to add missingness indicator columns during imputation.
            verbose (bool): Print progress from lower-level utilities.

        Returns:
            dict: Dictionary containing:
                - 'processed_clinical' (pd.DataFrame): Processed clinical data
                - 'processed_transcriptomics' (pd.DataFrame): Processed transcriptomics data
                - 'integrated_data' (pd.DataFrame or None): Integrated dataset if integration enabled

        Raises:
            TypeError: If inputs are not pandas DataFrames.
            KeyError: If required ID columns are missing.
            ValueError: On processing errors in underlying steps.
        """
        self.logger.info("=== STARTING DATA INTEGRATION PIPELINE ===")
        pipeline_start = time.perf_counter()

        default_steps = {
            "remove_undefined": True,
            "remove_duplicates": True,
            "remove_overmissing_samples": True,
            "remove_low_expression_genes": False,
            "check_duplicate_genes": True,
            "mapping_genes": True,
            "feature_engineering": True,
            "integrate_data": True,
        }
        steps_config = {**default_steps, **(steps_config or {})}
        self.logger.info(f"[CONFIG] Steps: {steps_config}")

        # Process transcriptomics
        processed_transcriptomics = self.process_transcriptomics_data(
            transcriptomics_data=transcriptomics_data,
            transcriptomics_id_column=transcriptomics_id_column,
            steps_config=steps_config,
            overmissing_samples_threshold=overmissing_samples_threshold,
            overmissing_features_threshold=overmissing_features_threshold,
            unique_threshold=unique_threshold,
            scaler=scaler,
            imputer=imputer,
            imputer_params=imputer_params,
            low_expression_variance_threshold=low_expression_variance_threshold,
            add_indicators=add_indicators,
            verbose=verbose,
        )

        # Process clinical
        processed_clinical = self.process_clinical_data(
            clinical_data=clinical_data,
            clinical_id_column=clinical_id_column,
            steps_config=steps_config,
            overmissing_samples_threshold=overmissing_samples_threshold,
            overmissing_features_threshold=overmissing_features_threshold,
            unique_threshold=unique_threshold,
            scaler=scaler,
            imputer=imputer,
            imputer_params=imputer_params,
            ordinal_cat_columns=ordinal_cat_columns,
            add_indicators=add_indicators,
            verbose=verbose,
        )

        # Integrate
        integrated_data = self.integrate_data(
            processed_clinical=processed_clinical,
            processed_transcriptomics=processed_transcriptomics,
            clinical_id_column=clinical_id_column,
            transcriptomics_id_column=transcriptomics_id_column,
            integration_id_column=integration_id_column,
            steps_config=steps_config,
        )

        # Build feature metadata regardless of integration outcome
        self.logger.info("[METADATA] Building feature type metadata")
        clinical_features = processed_clinical.drop(
            columns=[clinical_id_column]
        ).columns.tolist()
        transcriptomics_features = processed_transcriptomics.drop(
            columns=[transcriptomics_id_column]
        ).columns.tolist()

        if integrated_data is not None:
            integrated_features = integrated_data.drop(
                columns=[integration_id_column]
            ).columns.tolist()
        else:
            # If not integrated, use union of features for metadata
            integrated_features = sorted(
                set(clinical_features).union(set(transcriptomics_features))
            )

        ordinal_cat_columns = ordinal_cat_columns or []
        clinical_dummy = getattr(self, "clinical_dummy_columns", []) or []
        clinical_num = getattr(self, "clinical_num_columns", []) or []
        transcript_num = getattr(self, "transcriptomics_num_columns", []) or []

        feature_type_meta: Dict[str, str] = {}
        for col in integrated_features:
            if col in clinical_dummy:
                feature_type_meta[col] = "dummy_categorical"
            elif col in clinical_num or col in transcript_num:
                feature_type_meta[col] = "numerical"
            elif col in ordinal_cat_columns:
                feature_type_meta[col] = "ordinal_categorical"
            elif col.startswith("missingindicator_"):
                feature_type_meta[col] = "missing_categorical"
            else:
                feature_type_meta[col] = "unclassified"

        # Log summary
        counts = {
            "numerical": sum(v == "numerical" for v in feature_type_meta.values()),
            "ordinal_categorical": sum(
                v == "ordinal_categorical" for v in feature_type_meta.values()
            ),
            "dummy_categorical": sum(
                v == "dummy_categorical" for v in feature_type_meta.values()
            ),
            "missing_categorical": sum(
                v == "missing_categorical" for v in feature_type_meta.values()
            ),
            "unclassified": sum(
                v == "unclassified" for v in feature_type_meta.values()
            ),
        }
        self.logger.info(f"[METADATA] Counts: {counts}")
        if counts["unclassified"] > 0:
            self.logger.warning(
                f"[METADATA] {counts['unclassified']} unclassified feature(s). Check your metadata or processing configuration."
            )

        # Save metadata JSON
        json_path = os.path.join(self.output_dir, "feature_metadata.json")
        with open(json_path, "w") as f:
            json.dump(feature_type_meta, f, indent=4)
        self.logger.info(f"[METADATA] Saved feature metadata to: {json_path}")

        self.logger.info(
            f"=== PIPELINE COMPLETED in {time.perf_counter() - pipeline_start:.2f}s ==="
        )
        return {
            "processed_clinical": processed_clinical,
            "processed_transcriptomics": processed_transcriptomics,
            "integrated_data": integrated_data,
        }

__init__(output_dir, logger='')

Initialize the data integration pipeline and configure logging.

Parameters:

Name Type Description Default
output_dir str

Directory where logs and outputs are stored.

required
logger str

Suffix for the log filename.

''

Returns:

Type Description

None

Raises:

Type Description
OSError

If the output directory cannot be created.

Source code in src/synomicsbench/processing/pipeline.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(self, output_dir: str, logger: str = ""):
    """
    Initialize the data integration pipeline and configure logging.

    Args:
        output_dir (str): Directory where logs and outputs are stored.
        logger (str): Suffix for the log filename.

    Returns:
        None

    Raises:
        OSError: If the output directory cannot be created.
    """
    self.output_dir = output_dir
    os.makedirs(output_dir, exist_ok=True)

    logger_name = f"{self.__class__.__name__}_{id(self)}"
    log_file_name = f"Preprocess_{logger}.log"
    self.logger = set_logger(logger_name, self.output_dir, log_file_name)

    # Placeholders set during processing
    self.clinical_num_columns: List[str] = []
    self.clinical_dummy_columns: List[str] = []
    self.transcriptomics_num_columns: List[str] = []

integrate_data(processed_clinical, processed_transcriptomics, clinical_id_column, transcriptomics_id_column, integration_id_column, steps_config)

Integrate processed clinical and transcriptomics data on a common ID.

Parameters:

Name Type Description Default
processed_clinical DataFrame

Processed clinical dataset.

required
processed_transcriptomics DataFrame

Processed transcriptomics dataset.

required
clinical_id_column str

ID column in processed_clinical.

required
transcriptomics_id_column str

ID column in processed_transcriptomics.

required
integration_id_column str

Name for the common ID column after renaming.

required
steps_config dict

Flags controlling whether to integrate.

required

Returns:

Type Description
Optional[DataFrame]

pd.DataFrame or None: Integrated dataset if integration is enabled; otherwise None.

Raises:

Type Description
KeyError

If required ID columns are missing in the provided DataFrames.

TypeError

If inputs are not pandas DataFrames.

Source code in src/synomicsbench/processing/pipeline.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
def integrate_data(
    self,
    processed_clinical: pd.DataFrame,
    processed_transcriptomics: pd.DataFrame,
    clinical_id_column: str,
    transcriptomics_id_column: str,
    integration_id_column: str,
    steps_config: Dict,
) -> Optional[pd.DataFrame]:
    """
    Integrate processed clinical and transcriptomics data on a common ID.

    Args:
        processed_clinical (pd.DataFrame): Processed clinical dataset.
        processed_transcriptomics (pd.DataFrame): Processed transcriptomics dataset.
        clinical_id_column (str): ID column in processed_clinical.
        transcriptomics_id_column (str): ID column in processed_transcriptomics.
        integration_id_column (str): Name for the common ID column after renaming.
        steps_config (dict): Flags controlling whether to integrate.

    Returns:
        pd.DataFrame or None: Integrated dataset if integration is enabled; otherwise None.

    Raises:
        KeyError: If required ID columns are missing in the provided DataFrames.
        TypeError: If inputs are not pandas DataFrames.
    """
    if not isinstance(processed_clinical, pd.DataFrame) or not isinstance(
        processed_transcriptomics, pd.DataFrame
    ):
        raise TypeError(
            "processed_clinical and processed_transcriptomics must be pandas DataFrames"
        )
    if clinical_id_column not in processed_clinical.columns:
        raise KeyError(
            f"ID column '{clinical_id_column}' not found in processed_clinical"
        )
    if transcriptomics_id_column not in processed_transcriptomics.columns:
        raise KeyError(
            f"ID column '{transcriptomics_id_column}' not found in processed_transcriptomics"
        )

    if not steps_config.get("integrate_data", True):
        self.logger.info("[MERGE] Skipping data integration")
        return None

    self.logger.info(
        "[MERGE] Integrating clinical and transcriptomics data by common ID"
    )

    clinical_data = processed_clinical.rename(
        columns={clinical_id_column: integration_id_column}
    )
    transcriptomics_data = processed_transcriptomics.rename(
        columns={transcriptomics_id_column: integration_id_column}
    )

    # Ensure compatible types for join key
    clinical_data[integration_id_column] = clinical_data[
        integration_id_column
    ].astype(str)
    transcriptomics_data[integration_id_column] = transcriptomics_data[
        integration_id_column
    ].astype(str)

    left_count = clinical_data.shape[0]
    right_count = transcriptomics_data.shape[0]
    self.logger.info(f"[MERGE] Left rows={left_count}, Right rows={right_count}")

    integrated_data = pd.merge(
        clinical_data, transcriptomics_data, on=integration_id_column, how="inner"
    )

    matched = integrated_data.shape[0]
    self.logger.info(
        f"[MERGE] Inner join matched rows={matched}; final shape: rows={integrated_data.shape[0]}, cols={integrated_data.shape[1]}"
    )
    return integrated_data

process_clinical_data(clinical_data, clinical_id_column, steps_config, overmissing_samples_threshold=50.0, overmissing_features_threshold=50.0, unique_threshold=10, scaler='minmax', imputer='knn', imputer_params=None, ordinal_cat_columns=None, add_indicators=True, verbose=True)

Process raw clinical data with configurable steps.

Parameters:

Name Type Description Default
clinical_data DataFrame

Raw clinical data.

required
clinical_id_column str

Column name for sample/patient IDs.

required
steps_config dict

Flags controlling which steps to run.

required
overmissing_samples_threshold float

Remove rows with missingness > threshold (0-100).

50.0
overmissing_features_threshold float

Remove columns with missingness > threshold (0-100).

50.0
unique_threshold int

Unique value threshold to classify categorical features.

10
scaler str

Scaler type for numerical features ('minmax', 'standard', 'robust').

'minmax'
imputer str

Imputation method to use ('knn' or 'mice').

'knn'
imputer_params dict

Method-specific params. For 'knn': e.g., {'n_neighbors': 5}. For 'mice': e.g., {'iterations': 20, 'n_estimators': 300}.

None
ordinal_cat_columns list

Known ordinal categorical columns.

None
add_indicators bool

Whether to add missingness indicator columns during imputation.

True
verbose bool

Print progress from lower-level utilities.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Processed clinical data with imputed features and indicators.

Raises:

Type Description
KeyError

If clinical_id_column is not found in clinical_data.

TypeError

If clinical_data is not a pandas DataFrame.

ValueError

On invalid thresholds or processing errors in underlying steps.

Source code in src/synomicsbench/processing/pipeline.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def process_clinical_data(
    self,
    clinical_data: pd.DataFrame,
    clinical_id_column: str,
    steps_config: Dict,
    overmissing_samples_threshold: float = 50.0,
    overmissing_features_threshold: float = 50.0,
    unique_threshold: int = 10,
    scaler: str = "minmax",
    imputer: str = "knn",
    imputer_params: Optional[Dict[str, Any]] = None,
    ordinal_cat_columns: Optional[List[str]] = None,
    add_indicators: bool = True,
    verbose: bool = True,
) -> pd.DataFrame:
    """
    Process raw clinical data with configurable steps.

    Args:
        clinical_data (pd.DataFrame): Raw clinical data.
        clinical_id_column (str): Column name for sample/patient IDs.
        steps_config (dict): Flags controlling which steps to run.
        overmissing_samples_threshold (float): Remove rows with missingness > threshold (0-100).
        overmissing_features_threshold (float): Remove columns with missingness > threshold (0-100).
        unique_threshold (int): Unique value threshold to classify categorical features.
        scaler (str): Scaler type for numerical features ('minmax', 'standard', 'robust').
        imputer (str): Imputation method to use ('knn' or 'mice').
        imputer_params (dict, optional): Method-specific params. For 'knn': e.g., {'n_neighbors': 5}.
                                        For 'mice': e.g., {'iterations': 20, 'n_estimators': 300}.
        ordinal_cat_columns (list, optional): Known ordinal categorical columns.
        add_indicators (bool): Whether to add missingness indicator columns during imputation.
        verbose (bool): Print progress from lower-level utilities.

    Returns:
        pd.DataFrame: Processed clinical data with imputed features and indicators.

    Raises:
        KeyError: If clinical_id_column is not found in clinical_data.
        TypeError: If clinical_data is not a pandas DataFrame.
        ValueError: On invalid thresholds or processing errors in underlying steps.
    """
    start_total = time.perf_counter()
    ordinal_cat_columns = ordinal_cat_columns or []

    if not isinstance(clinical_data, pd.DataFrame):
        raise TypeError("clinical_data must be a pandas DataFrame")
    if clinical_id_column not in clinical_data.columns:
        raise KeyError(
            f"ID column '{clinical_id_column}' not found in clinical data"
        )

    data = clinical_data.copy()
    self.logger.info(
        f"[CLINICAL] Initial shape: rows={data.shape[0]}, cols={data.shape[1]}"
    )

    # Step 1: Remove undefined data (rows with missing IDs)
    if steps_config.get("remove_undefined", True):
        t0 = time.perf_counter()
        before = data.shape[0]
        self.logger.info(
            f"[C1] Removing undefined samples missing ID in '{clinical_id_column}'"
        )
        data = DataProcessor.remove_unknown_entities(
            data=data, id_column=clinical_id_column
        )
        data = data.reset_index(drop=True)
        after = data.shape[0]
        self.logger.info(
            f"[C1] Removed {before - after} rows; new shape: rows={after}, cols={data.shape[1]} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[C1] Skipped remove_undefined")

    # Step 2: Remove duplicates (columns then rows)
    if steps_config.get("remove_duplicates", True):
        # Columns
        t0 = time.perf_counter()
        before_cols = data.shape[1]
        self.logger.info("[C2.1] Removing duplicate columns")
        data = DataProcessor.remove_duplications(data=data, axis=1)
        after_cols = data.shape[1]
        self.logger.info(
            f"[C2.1] Removed {before_cols - after_cols} duplicate columns; cols={after_cols} (took {time.perf_counter() - t0:.2f}s)"
        )

        # Rows
        t0 = time.perf_counter()
        before_rows = data.shape[0]
        self.logger.info("[C2.2] Removing duplicate rows")
        data = DataProcessor.remove_duplications(data=data, axis=0)
        data = data.reset_index(drop=True)
        after_rows = data.shape[0]
        self.logger.info(
            f"[C2.2] Removed {before_rows - after_rows} duplicate rows; rows={after_rows} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[C2] Skipped remove_duplicates")

    # Step 3: Remove over-missing samples
    if steps_config.get("remove_overmissing_samples", True):
        t0 = time.perf_counter()
        before = data.shape[0]
        self.logger.info(
            f"[C3] Removing clinical samples with missingness > {overmissing_samples_threshold}%"
        )
        data = DataProcessor.remove_overmissing_entities(
            data=data, threshold=overmissing_samples_threshold
        )
        data = data.reset_index(drop=True)
        after = data.shape[0]
        self.logger.info(
            f"[C3] Removed {before - after} rows; new rows={after} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[C3] Skipped remove_overmissing_samples")

    # Split into ID and features
    id_data = data[[clinical_id_column]].copy()
    features_data = data.drop(columns=[clinical_id_column])

    # Step 4: Feature engineering (classify, remove over-missing features, impute)
    if steps_config.get("feature_engineering", True):
        t0 = time.perf_counter()
        self.logger.info(
            f"[C4] Feature engineering (remove over-missing features > {overmissing_features_threshold}%, classify types, impute via {imputer.upper()})"
        )
        features_data, self.clinical_num_columns, self.clinical_dummy_columns = (
            DataProcessor.feature_engineering(
                data=features_data,
                data_type="clinical",
                overmissing_threshold=overmissing_features_threshold,
                ordinal_cat_columns=ordinal_cat_columns,
                unique_threshold=unique_threshold,
                scaler=scaler,
                imputer=imputer,
                imputer_params=imputer_params or {},
                add_indicators=add_indicators,
                verbose=verbose,
            )
        )
        self.logger.info(
            f"[C4] Feature engineering complete; features shape={features_data.shape} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[C4] Skipped feature_engineering")

    processed_data = pd.concat([id_data, features_data], axis=1)
    self.logger.info(
        f"[CLINICAL] Final shape: rows={processed_data.shape[0]}, cols={processed_data.shape[1]} (total {time.perf_counter() - start_total:.2f}s)"
    )
    return processed_data

process_transcriptomics_data(transcriptomics_data, transcriptomics_id_column, steps_config, overmissing_samples_threshold=50.0, overmissing_features_threshold=50.0, unique_threshold=10, scaler='minmax', imputer='mice', imputer_params=None, low_expression_variance_threshold=0.0005, add_indicators=True, verbose=True)

Process raw transcriptomics data with configurable steps.

Parameters:

Name Type Description Default
transcriptomics_data DataFrame

Raw transcriptomics data.

required
transcriptomics_id_column str

Column name for sample/patient IDs.

required
steps_config dict

Flags controlling which steps to run.

required
overmissing_samples_threshold float

Remove rows with missingness > threshold (0-100).

50.0
overmissing_features_threshold float

Remove columns with missingness > threshold (0-100).

50.0
unique_threshold int

Threshold to classify categorical features (not used for transcriptomics).

10
scaler str

Scaler type for numerical features ('minmax', 'standard', 'robust').

'minmax'
imputer str

Imputation method to use ('knn' or 'mice').

'mice'
imputer_params dict

Method-specific params. For 'knn': e.g., {'n_neighbors': 5}. For 'mice': e.g., {'iterations': 20, 'n_estimators': 300}.

None
low_expression_variance_threshold float

Variance cutoff for near-constant genes when remove_low_expression_genes is enabled.

0.0005
add_indicators bool

Whether to add missingness indicator columns during imputation.

True
verbose bool

Print progress from lower-level utilities.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Processed transcriptomics data with imputed features and optional indicators.

Raises:

Type Description
KeyError

If transcriptomics_id_column is not found in the input DataFrame.

TypeError

If transcriptomics_data is not a pandas DataFrame.

ValueError

On invalid thresholds or processing errors in underlying steps.

Source code in src/synomicsbench/processing/pipeline.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
def process_transcriptomics_data(
    self,
    transcriptomics_data: pd.DataFrame,
    transcriptomics_id_column: str,
    steps_config: Dict,
    overmissing_samples_threshold: float = 50.0,
    overmissing_features_threshold: float = 50.0,
    unique_threshold: int = 10,
    scaler: str = "minmax",
    imputer: str = "mice",
    imputer_params: Optional[Dict[str, Any]] = None,
    low_expression_variance_threshold: float = 0.0005,
    add_indicators: bool = True,
    verbose: bool = True,
) -> pd.DataFrame:
    """
    Process raw transcriptomics data with configurable steps.

    Args:
        transcriptomics_data (pd.DataFrame): Raw transcriptomics data.
        transcriptomics_id_column (str): Column name for sample/patient IDs.
        steps_config (dict): Flags controlling which steps to run.
        overmissing_samples_threshold (float): Remove rows with missingness > threshold (0-100).
        overmissing_features_threshold (float): Remove columns with missingness > threshold (0-100).
        unique_threshold (int): Threshold to classify categorical features (not used for transcriptomics).
        scaler (str): Scaler type for numerical features ('minmax', 'standard', 'robust').
        imputer (str): Imputation method to use ('knn' or 'mice').
        imputer_params (dict, optional): Method-specific params. For 'knn': e.g., {'n_neighbors': 5}.
                                        For 'mice': e.g., {'iterations': 20, 'n_estimators': 300}.
        low_expression_variance_threshold (float): Variance cutoff for near-constant genes when
            `remove_low_expression_genes` is enabled.
        add_indicators (bool): Whether to add missingness indicator columns during imputation.
        verbose (bool): Print progress from lower-level utilities.

    Returns:
        pd.DataFrame: Processed transcriptomics data with imputed features and optional indicators.

    Raises:
        KeyError: If transcriptomics_id_column is not found in the input DataFrame.
        TypeError: If transcriptomics_data is not a pandas DataFrame.
        ValueError: On invalid thresholds or processing errors in underlying steps.
    """
    start_total = time.perf_counter()
    if not isinstance(transcriptomics_data, pd.DataFrame):
        raise TypeError("transcriptomics_data must be a pandas DataFrame")
    if transcriptomics_id_column not in transcriptomics_data.columns:
        raise KeyError(
            f"ID column '{transcriptomics_id_column}' not found in transcriptomics data"
        )

    # Initialize gene query helper
    self.gene_querier = GeneQuery(
        fields=["symbol"],
        scopes=["ensemblgene"],
        species=["human"],
        output_dir=self.output_dir,
    )

    data = transcriptomics_data.copy()
    self.logger.info(
        f"[TRANSCRIPTOMICS] Initial shape: rows={data.shape[0]}, cols={data.shape[1]}"
    )

    # Step 1: Remove undefined data (rows with missing IDs)
    if steps_config.get("remove_undefined", True):
        t0 = time.perf_counter()
        before = data.shape[0]
        self.logger.info(
            f"[T1] Removing undefined samples missing ID in '{transcriptomics_id_column}'"
        )
        data = DataProcessor.remove_unknown_entities(
            data=data, id_column=transcriptomics_id_column
        )
        data = data.reset_index(drop=True)
        after = data.shape[0]
        self.logger.info(
            f"[T1] Removed {before - after} rows; new shape: rows={data.shape[0]}, cols={data.shape[1]} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[T1] Skipped remove_undefined")

    # Step 2: Remove duplicates (columns then rows)
    if steps_config.get("remove_duplicates", True):
        # Columns
        t0 = time.perf_counter()
        before_cols = data.shape[1]
        self.logger.info("[T2.1] Removing duplicate columns")
        data = DataProcessor.remove_duplications(data=data, axis=1)
        after_cols = data.shape[1]
        self.logger.info(
            f"[T2.1] Removed {before_cols - after_cols} duplicate columns; cols={after_cols} (took {time.perf_counter() - t0:.2f}s)"
        )

        # Rows
        t0 = time.perf_counter()
        before_rows = data.shape[0]
        self.logger.info("[T2.2] Removing duplicate rows")
        data = DataProcessor.remove_duplications(data=data, axis=0)
        data = data.reset_index(drop=True)
        after_rows = data.shape[0]
        self.logger.info(
            f"[T2.2] Removed {before_rows - after_rows} duplicate rows; rows={after_rows} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[T2] Skipped remove_duplicates")

    # Step 3: Remove over-missing samples
    if steps_config.get("remove_overmissing_samples", True):
        t0 = time.perf_counter()
        before = data.shape[0]
        self.logger.info(
            f"[T3] Removing samples with missingness > {overmissing_samples_threshold}%"
        )
        data = DataProcessor.remove_overmissing_entities(
            data=data, threshold=overmissing_samples_threshold
        )
        data = data.reset_index(drop=True)
        after = data.shape[0]
        self.logger.info(
            f"[T3] Removed {before - after} rows; new rows={after} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[T3] Skipped remove_overmissing_samples")

    # Split into ID and features
    id_data = data[[transcriptomics_id_column]].copy()
    features_data = data.drop(columns=[transcriptomics_id_column])

    # Step 4.0: Remove low-expressed genes (transcriptomics only)
    remove_low_expr = steps_config.get(
        "remove_low_expression_genes",
        steps_config.get("remove_low_expressed_gene", False),
    )
    if remove_low_expr:
        t0 = time.perf_counter()
        before_cols = features_data.shape[1]
        self.logger.info(
            "[T4.0] Removing low-expressed genes (zero-sum or near-zero variance)"
        )
        features_data = DataProcessor.remove_low_expression_genes(
            data=features_data,
            variance_threshold=low_expression_variance_threshold,
        )
        after_cols = features_data.shape[1]
        self.logger.info(
            f"[T4.0] Removed {before_cols - after_cols} gene(s); cols={after_cols} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[T4.0] Skipped remove_low_expression_genes")

    # Step 4.1: Check duplicate genes
    if steps_config.get("check_duplicate_genes", True):
        t0 = time.perf_counter()
        self.logger.info(
            "[T4.1] Checking for genes with identical expression profiles"
        )
        self.dup_genes, self.dup_mapped_genes = self.gene_querier.check_duplicates(
            data=features_data
        )
        n_dup = len(self.dup_genes) if self.dup_genes is not None else 0
        self.logger.info(
            f"[T4.1] Found {n_dup} duplicate gene(s) (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[T4.1] Skipped check_duplicate_genes")

    # Step 4.2: Map Ensembl IDs to HUGO symbols
    if steps_config.get("mapping_genes", True):
        t0 = time.perf_counter()
        self.logger.info("[T4.2] Mapping Ensembl IDs to HUGO symbols")
        self.gene_info_df = self.gene_querier.convert_genes(data=features_data)
        self.logger.info(
            f"[T4.2] Gene mapping produced shape={getattr(self.gene_info_df, 'shape', None)} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[T4.2] Skipped mapping_genes")

    # Step 5: Feature engineering for transcriptomics (all numerical)
    if steps_config.get("feature_engineering", True):
        t0 = time.perf_counter()
        self.logger.info(
            f"[T5] Feature engineering (remove over-missing features > {overmissing_features_threshold}%, impute via {imputer.upper()})"
        )
        features_data, self.transcriptomics_num_columns, _ = (
            DataProcessor.feature_engineering(
                data=features_data,
                data_type="transcriptomics",
                overmissing_threshold=overmissing_features_threshold,
                ordinal_cat_columns=None,
                unique_threshold=unique_threshold,
                scaler=scaler,
                imputer=imputer,
                imputer_params=imputer_params or {},
                add_indicators=add_indicators,
                verbose=verbose,
            )
        )
        self.logger.info(
            f"[T5] Feature engineering complete; features shape={features_data.shape} (took {time.perf_counter() - t0:.2f}s)"
        )
    else:
        self.logger.info("[T5] Skipped feature_engineering")

    processed_data = pd.concat([id_data, features_data], axis=1)
    self.logger.info(
        f"[TRANSCRIPTOMICS] Final shape: rows={processed_data.shape[0]}, cols={processed_data.shape[1]} (total {time.perf_counter() - start_total:.2f}s)"
    )
    return processed_data

run_pipeline(clinical_data, transcriptomics_data, clinical_id_column, transcriptomics_id_column, integration_id_column='PATIENT_ID', steps_config=None, overmissing_samples_threshold=50.0, overmissing_features_threshold=50.0, unique_threshold=10, scaler='minmax', ordinal_cat_columns=None, imputer='mice', imputer_params=None, low_expression_variance_threshold=0.0005, add_indicators=True, verbose=True)

Run the full pipeline across transcriptomics and clinical data and optionally integrate them.

Parameters:

Name Type Description Default
clinical_data DataFrame

Raw clinical data.

required
transcriptomics_data DataFrame

Raw transcriptomics data.

required
clinical_id_column str

ID column name in clinical_data.

required
transcriptomics_id_column str

ID column name in transcriptomics_data.

required
integration_id_column str

Common ID column name for integration.

'PATIENT_ID'
steps_config dict

Dict specifying which steps to run; defaults enable all steps.

None
overmissing_samples_threshold float

Remove rows with missingness > threshold (0-100).

50.0
overmissing_features_threshold float

Remove columns with missingness > threshold (0-100).

50.0
unique_threshold int

Unique value threshold to classify categorical features (clinical).

10
scaler str

Scaler type for numerical features ('minmax', 'standard', 'robust').

'minmax'
ordinal_cat_columns list

Ordinal categorical columns in clinical data.

None
imputer str

Imputation method to use ('knn' or 'mice').

'mice'
imputer_params dict

Method-specific params.

None
low_expression_variance_threshold float

Variance cutoff for near-constant genes when remove_low_expression_genes is enabled.

0.0005
add_indicators bool

Whether to add missingness indicator columns during imputation.

True
verbose bool

Print progress from lower-level utilities.

True

Returns:

Name Type Description
dict Dict[str, Optional[DataFrame]]

Dictionary containing: - 'processed_clinical' (pd.DataFrame): Processed clinical data - 'processed_transcriptomics' (pd.DataFrame): Processed transcriptomics data - 'integrated_data' (pd.DataFrame or None): Integrated dataset if integration enabled

Raises:

Type Description
TypeError

If inputs are not pandas DataFrames.

KeyError

If required ID columns are missing.

ValueError

On processing errors in underlying steps.

Source code in src/synomicsbench/processing/pipeline.py
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
def run_pipeline(
    self,
    clinical_data: pd.DataFrame,
    transcriptomics_data: pd.DataFrame,
    clinical_id_column: str,
    transcriptomics_id_column: str,
    integration_id_column: str = "PATIENT_ID",
    steps_config: Optional[Dict] = None,
    overmissing_samples_threshold: float = 50.0,
    overmissing_features_threshold: float = 50.0,
    unique_threshold: int = 10,
    scaler: str = "minmax",
    ordinal_cat_columns: Optional[List[str]] = None,
    imputer: str = "mice",
    imputer_params: Optional[Dict[str, Any]] = None,
    low_expression_variance_threshold: float = 0.0005,
    add_indicators: bool = True,
    verbose: bool = True,
) -> Dict[str, Optional[pd.DataFrame]]:
    """
    Run the full pipeline across transcriptomics and clinical data and optionally integrate them.

    Args:
        clinical_data (pd.DataFrame): Raw clinical data.
        transcriptomics_data (pd.DataFrame): Raw transcriptomics data.
        clinical_id_column (str): ID column name in clinical_data.
        transcriptomics_id_column (str): ID column name in transcriptomics_data.
        integration_id_column (str): Common ID column name for integration.
        steps_config (dict, optional): Dict specifying which steps to run; defaults enable all steps.
        overmissing_samples_threshold (float): Remove rows with missingness > threshold (0-100).
        overmissing_features_threshold (float): Remove columns with missingness > threshold (0-100).
        unique_threshold (int): Unique value threshold to classify categorical features (clinical).
        scaler (str): Scaler type for numerical features ('minmax', 'standard', 'robust').
        ordinal_cat_columns (list, optional): Ordinal categorical columns in clinical data.
        imputer (str): Imputation method to use ('knn' or 'mice').
        imputer_params (dict, optional): Method-specific params.
        low_expression_variance_threshold (float): Variance cutoff for near-constant genes when
            `remove_low_expression_genes` is enabled.
        add_indicators (bool): Whether to add missingness indicator columns during imputation.
        verbose (bool): Print progress from lower-level utilities.

    Returns:
        dict: Dictionary containing:
            - 'processed_clinical' (pd.DataFrame): Processed clinical data
            - 'processed_transcriptomics' (pd.DataFrame): Processed transcriptomics data
            - 'integrated_data' (pd.DataFrame or None): Integrated dataset if integration enabled

    Raises:
        TypeError: If inputs are not pandas DataFrames.
        KeyError: If required ID columns are missing.
        ValueError: On processing errors in underlying steps.
    """
    self.logger.info("=== STARTING DATA INTEGRATION PIPELINE ===")
    pipeline_start = time.perf_counter()

    default_steps = {
        "remove_undefined": True,
        "remove_duplicates": True,
        "remove_overmissing_samples": True,
        "remove_low_expression_genes": False,
        "check_duplicate_genes": True,
        "mapping_genes": True,
        "feature_engineering": True,
        "integrate_data": True,
    }
    steps_config = {**default_steps, **(steps_config or {})}
    self.logger.info(f"[CONFIG] Steps: {steps_config}")

    # Process transcriptomics
    processed_transcriptomics = self.process_transcriptomics_data(
        transcriptomics_data=transcriptomics_data,
        transcriptomics_id_column=transcriptomics_id_column,
        steps_config=steps_config,
        overmissing_samples_threshold=overmissing_samples_threshold,
        overmissing_features_threshold=overmissing_features_threshold,
        unique_threshold=unique_threshold,
        scaler=scaler,
        imputer=imputer,
        imputer_params=imputer_params,
        low_expression_variance_threshold=low_expression_variance_threshold,
        add_indicators=add_indicators,
        verbose=verbose,
    )

    # Process clinical
    processed_clinical = self.process_clinical_data(
        clinical_data=clinical_data,
        clinical_id_column=clinical_id_column,
        steps_config=steps_config,
        overmissing_samples_threshold=overmissing_samples_threshold,
        overmissing_features_threshold=overmissing_features_threshold,
        unique_threshold=unique_threshold,
        scaler=scaler,
        imputer=imputer,
        imputer_params=imputer_params,
        ordinal_cat_columns=ordinal_cat_columns,
        add_indicators=add_indicators,
        verbose=verbose,
    )

    # Integrate
    integrated_data = self.integrate_data(
        processed_clinical=processed_clinical,
        processed_transcriptomics=processed_transcriptomics,
        clinical_id_column=clinical_id_column,
        transcriptomics_id_column=transcriptomics_id_column,
        integration_id_column=integration_id_column,
        steps_config=steps_config,
    )

    # Build feature metadata regardless of integration outcome
    self.logger.info("[METADATA] Building feature type metadata")
    clinical_features = processed_clinical.drop(
        columns=[clinical_id_column]
    ).columns.tolist()
    transcriptomics_features = processed_transcriptomics.drop(
        columns=[transcriptomics_id_column]
    ).columns.tolist()

    if integrated_data is not None:
        integrated_features = integrated_data.drop(
            columns=[integration_id_column]
        ).columns.tolist()
    else:
        # If not integrated, use union of features for metadata
        integrated_features = sorted(
            set(clinical_features).union(set(transcriptomics_features))
        )

    ordinal_cat_columns = ordinal_cat_columns or []
    clinical_dummy = getattr(self, "clinical_dummy_columns", []) or []
    clinical_num = getattr(self, "clinical_num_columns", []) or []
    transcript_num = getattr(self, "transcriptomics_num_columns", []) or []

    feature_type_meta: Dict[str, str] = {}
    for col in integrated_features:
        if col in clinical_dummy:
            feature_type_meta[col] = "dummy_categorical"
        elif col in clinical_num or col in transcript_num:
            feature_type_meta[col] = "numerical"
        elif col in ordinal_cat_columns:
            feature_type_meta[col] = "ordinal_categorical"
        elif col.startswith("missingindicator_"):
            feature_type_meta[col] = "missing_categorical"
        else:
            feature_type_meta[col] = "unclassified"

    # Log summary
    counts = {
        "numerical": sum(v == "numerical" for v in feature_type_meta.values()),
        "ordinal_categorical": sum(
            v == "ordinal_categorical" for v in feature_type_meta.values()
        ),
        "dummy_categorical": sum(
            v == "dummy_categorical" for v in feature_type_meta.values()
        ),
        "missing_categorical": sum(
            v == "missing_categorical" for v in feature_type_meta.values()
        ),
        "unclassified": sum(
            v == "unclassified" for v in feature_type_meta.values()
        ),
    }
    self.logger.info(f"[METADATA] Counts: {counts}")
    if counts["unclassified"] > 0:
        self.logger.warning(
            f"[METADATA] {counts['unclassified']} unclassified feature(s). Check your metadata or processing configuration."
        )

    # Save metadata JSON
    json_path = os.path.join(self.output_dir, "feature_metadata.json")
    with open(json_path, "w") as f:
        json.dump(feature_type_meta, f, indent=4)
    self.logger.info(f"[METADATA] Saved feature metadata to: {json_path}")

    self.logger.info(
        f"=== PIPELINE COMPLETED in {time.perf_counter() - pipeline_start:.2f}s ==="
    )
    return {
        "processed_clinical": processed_clinical,
        "processed_transcriptomics": processed_transcriptomics,
        "integrated_data": integrated_data,
    }

DataProcessor

DataProcessor provides static methods for preprocessing, encoding, imputing, and engineering features in omics or clinical datasets.

This class enables duplicate removal, missing value filtering, encoding of categorical features, normalization, KNN imputation, and feature engineering classification.

Methods are tailored for pandas DataFrame workflows in bioinformatics and genomics.

Source code in src/synomicsbench/processing/preprocessing.py
  21
  22
  23
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
class DataProcessor:
    """
    DataProcessor provides static methods for preprocessing, encoding, imputing, and engineering features
    in omics or clinical datasets.

    This class enables duplicate removal, missing value filtering, encoding of categorical features,
    normalization, KNN imputation, and feature engineering classification.

    Methods are tailored for pandas DataFrame workflows in bioinformatics and genomics.

    """

    def __init__(self):
        pass

    @staticmethod
    def remove_duplications(data: pd.DataFrame, axis: int) -> pd.DataFrame:
        """
        Remove duplicate rows or columns from the DataFrame.

        Args:
            data (pd.DataFrame): Input DataFrame.
            axis (int): 0 to remove duplicate rows, 1 to remove duplicate columns.

        Returns:
            pd.DataFrame: DataFrame with duplicates removed.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            ValueError: If axis is not 0 or 1.
            RuntimeError: On error during duplicate removal.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        if axis not in [0, 1]:
            raise ValueError("Axis must be 0 (rows) or 1 (columns)")
        try:
            if axis == 0:
                return data.drop_duplicates()
            elif axis == 1:
                return data.loc[:, ~data.columns.duplicated()]
        except Exception as e:
            raise RuntimeError(f"Error removing duplicates: {e}")

    @staticmethod
    def remove_unknown_entities(data: pd.DataFrame, id_column: str) -> pd.DataFrame:
        """
        Remove rows where the identifier column contains missing values.

        Args:
            data (pd.DataFrame): Input DataFrame.
            id_column (str): Name of the identifier column.

        Returns:
            pd.DataFrame: Filtered DataFrame.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            KeyError: If id_column is not in DataFrame.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Data input must be a pandas DataFrame")
        if id_column not in data.columns:
            raise KeyError(f"ID column '{id_column}' not in DataFrame")
        return data.dropna(subset=[id_column])

    @staticmethod
    def remove_overmissing_entities(
        data: pd.DataFrame, threshold: float
    ) -> pd.DataFrame:
        """
        Remove rows (entities) with missing value percentage above a threshold.

        Args:
            data (pd.DataFrame): Input DataFrame.
            threshold (float): Percentage threshold (0-100).

        Returns:
            pd.DataFrame: DataFrame with over-missing entities removed.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            ValueError: If threshold is not between 0 and 100.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        if not 0 <= threshold <= 100:
            raise ValueError("Threshold must be between 0 and 100")
        try:
            # Calculate the percentage of missing values per row
            missing_percentage = data.isnull().mean(axis=1) * 100
            # Keep rows where missing percentage is <= threshold
            keep_rows = missing_percentage <= threshold
            return data[keep_rows].copy()
        except Exception as e:
            raise ValueError(f"Error removing over-missing samples: {e}")

    @staticmethod
    def find_missing_percent(data: pd.DataFrame) -> pd.DataFrame:
        """
        Calculate the percentage of missing values for each column.

        Args:
            data (pd.DataFrame): Input DataFrame.

        Returns:
            pd.DataFrame: DataFrame with columns 'ColumnName' and 'PercentMissing'.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            RuntimeError: On error during calculation.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Data input must be a pandas DataFrame")
        try:
            return (
                (data.isnull().sum() / len(data) * 100)
                .reset_index()
                .rename(columns={0: "PercentMissing", "index": "ColumnName"})
            )
        except Exception as e:
            raise RuntimeError(f"Error calculating missing percentages: {e}")

    @staticmethod
    def remove_overmissing_features(
        data: pd.DataFrame, threshold: float
    ) -> pd.DataFrame:
        """
        Remove columns (features) with missing value percentage above a threshold.

        Args:
            data (pd.DataFrame): Input DataFrame.
            threshold (float): Percentage threshold (0-100).

        Returns:
            pd.DataFrame: DataFrame with over-missing features removed.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            ValueError: If threshold is not between 0 and 100.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        if not 0 <= threshold <= 100:
            raise ValueError("Threshold must be between 0 and 100")
        try:
            miss_df = (
                (data.isnull().sum() / len(data) * 100)
                .reset_index()
                .rename(columns={0: "PercentMissing", "index": "ColumnName"})
            )
            drop_cols = miss_df[miss_df["PercentMissing"] > threshold][
                "ColumnName"
            ].tolist()
            return data.drop(drop_cols, axis=1)
        except Exception as e:
            raise ValueError(f"Error removing over-missing columns: {e}")

    @staticmethod
    def remove_low_expression_genes(
        data: pd.DataFrame,
        gene_id_column: str = "gene_id",
        variance_threshold: float = 0.0005,
    ) -> pd.DataFrame:
        """
        Remove low-expressed genes from transcriptomics expression data.

        This utility supports two common transcriptomics orientations:
        - Genes as rows and samples as columns (optionally with a `gene_id` column).
        - Samples as rows and genes as columns (pipeline convention in synomicsbench).

        A gene is removed if either:
        - Total expression equals 0 across all samples, OR
        - Variance is <= `variance_threshold` across all samples.

        Args:
            data (pd.DataFrame): Expression DataFrame.
            gene_id_column (str): Column containing gene IDs when genes are rows.
            variance_threshold (float): Near-zero variance threshold.

        Returns:
            pd.DataFrame: Filtered DataFrame in the same orientation as the input.

        Raises:
            TypeError: If data is not a pandas DataFrame.
            ValueError: If variance_threshold is negative or data contains non-numeric columns.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        if variance_threshold < 0:
            raise ValueError("variance_threshold must be >= 0")

        # Orientation: genes are rows (optionally with a gene_id column).
        if gene_id_column in data.columns:
            expression_matrix = data.set_index(gene_id_column)
            numeric = expression_matrix.apply(pd.to_numeric, errors="coerce")
            non_numeric_cols = numeric.columns[
                numeric.notna().sum(axis=0) == 0
            ].tolist()
            if non_numeric_cols:
                raise ValueError(
                    "Non-numeric sample columns found in expression matrix: "
                    f"{non_numeric_cols}"
                )

            sum_expression = numeric.sum(axis=1)
            zero_genes = sum_expression[sum_expression == 0].index.tolist()

            variances_expression = numeric.var(axis=1)
            near_zero_var = variances_expression[
                variances_expression <= variance_threshold
            ].index.tolist()

            remove_genes = set(zero_genes) | set(near_zero_var)
            filtered = expression_matrix.drop(index=list(remove_genes), errors="ignore")
            return filtered.reset_index()

        # Orientation: samples are rows, genes are columns.
        numeric = data.apply(pd.to_numeric, errors="coerce")
        non_numeric_cols = numeric.columns[numeric.notna().sum(axis=0) == 0].tolist()
        if non_numeric_cols:
            raise ValueError(
                "Non-numeric gene columns found in expression matrix: "
                f"{non_numeric_cols}"
            )

        sum_expression = numeric.sum(axis=0)
        zero_genes = sum_expression[sum_expression == 0].index.tolist()

        variances_expression = numeric.var(axis=0)
        near_zero_var = variances_expression[
            variances_expression <= variance_threshold
        ].index.tolist()

        remove_genes = set(zero_genes) | set(near_zero_var)
        return data.drop(columns=list(remove_genes), errors="ignore")

    ### Prepare for Imputation ###
    @staticmethod
    def encode_dummy_features(data: pd.DataFrame) -> pd.DataFrame:
        """
        Encode categorical features into dummy/one-hot columns.

        Args:
            data (pd.DataFrame): DataFrame with categorical features.

        Returns:
            pd.DataFrame: Dummy-encoded DataFrame.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            ValueError: If encoding fails.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        try:
            cat_data_encoded = pd.get_dummies(
                data, dtype="int64", dummy_na=True, columns=data.columns
            )
            return cat_data_encoded
        except Exception as e:
            raise ValueError(f"Error encoding categorical features: {e}")

    @staticmethod
    def encode_ordinal_features(
        data: pd.DataFrame,
    ) -> Tuple[pd.DataFrame, OrdinalEncoder]:
        """
        Encode ordinal categorical features using OrdinalEncoder.

        Args:
            data (pd.DataFrame): DataFrame with ordinal categorical features.

        Returns:
            Tuple[pd.DataFrame, OrdinalEncoder]: Encoded DataFrame and fitted encoder.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            ValueError: If encoding fails.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        try:
            oe = OrdinalEncoder()
            encoded_values = oe.fit_transform(data)
            df_ordinal_encoded = pd.DataFrame(
                encoded_values, columns=data.columns, index=data.index
            )
            return df_ordinal_encoded, oe
        except Exception as e:
            raise ValueError(f"Error encoding ordinal categorical features: {e}")

    @staticmethod
    def standardization(
        data: pd.DataFrame, scaler: str
    ) -> Tuple[pd.DataFrame, BaseEstimator]:
        """
        Standardize numerical features using specified scaler.

        Args:
            data (pd.DataFrame): DataFrame with numerical features.
            scaler (str): Scaler type ('standard', 'minmax', or 'robust').

        Returns:
            Tuple[pd.DataFrame, BaseEstimator]: Scaled DataFrame and scaler object.

        Raises:
            ValueError: If scaler type is invalid.
            TypeError: If input is not a pandas DataFrame.
            RuntimeError: If scaling fails.
        """
        scalers = {
            "standard": StandardScaler(),
            "minmax": MinMaxScaler(),
            "robust": RobustScaler(),
        }
        if scaler not in scalers:
            raise ValueError("Scaler must be 'standard', 'minmax', or 'robust'")
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        try:
            scaler_obj = scalers[scaler]
            scaled_data = pd.DataFrame(
                scaler_obj.fit_transform(data), columns=data.columns, index=data.index
            )
            return scaled_data, scaler_obj
        except Exception as e:
            raise RuntimeError(f"Error in standardization: {e}")

    @staticmethod
    def _add_indicators(data: pd.DataFrame) -> pd.DataFrame:
        """
        Append missingness indicators only for columns that have missing values.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        mi = MissingIndicator(features="missing-only", error_on_new=False)
        mask = mi.fit_transform(data)  # shape: (n_samples, n_missing_cols)
        missing_cols = data.columns[mi.features_]  # Only columns with missingness
        ind_cols = [f"missingindicator_{c}" for c in missing_cols]
        indicators_df = pd.DataFrame(
            mask.astype(np.int8), columns=ind_cols, index=data.index
        )
        return indicators_df

    @staticmethod
    def mice_imputation(
        data: pd.DataFrame,
        *,
        random_state: int = 42,
        iterations: int = 20,
        n_estimators: int = 300,
        add_indicators: bool = True,
        verbose: bool = True,
        **mice_kwargs: Any,
    ) -> pd.DataFrame:
        """
        Impute with miceforest and optionally append missing indicators.
        Additional miceforest arguments can be passed via **mice_kwargs (e.g., variable_schema).
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")

        if mf is None:
            raise ImportError(
                "miceforest is required for MICE imputation. Install it or use imputer='knn'."
            )

        out = data.copy()

        # Cast object to category so miceforest can treat them as categorical if present
        for c in out.select_dtypes(include=["object"]).columns:
            out[c] = out[c].astype("category")

        kernel = mf.ImputationKernel(
            data=out,
            num_datasets=1,  # miceforest uses `datasets`
            random_state=random_state,
            **mice_kwargs,
        )

        kernel.mice(iterations=iterations, n_estimators=n_estimators, verbose=verbose)

        df_imputed = kernel.complete_data(dataset=0)

        if add_indicators:
            df_ind = DataProcessor._add_indicators(
                data
            )  # build indicators from original
            df_imputed = pd.concat([df_imputed, df_ind], axis=1)

        return df_imputed

    @staticmethod
    def impute_router(
        data: pd.DataFrame,
        method: str = "mice",
        *,
        add_indicators: bool = True,
        knn_params: Optional[Dict[str, Any]] = None,
        mice_params: Optional[Dict[str, Any]] = None,
    ) -> pd.DataFrame:
        """
        Dispatch imputation to KNN or MICE with method-specific parameters.

        Args:
            data (pd.DataFrame): Input frame (already preprocessed if needed).
            method (str): 'knn' or 'mice'.
            add_indicators (bool): Whether to append missing indicators.
            knn_params (dict): Passed to knn_imputer (n_neighbors, etc.).
            mice_params (dict): Passed to mice_imputation (iterations, n_estimators, etc.).

        Returns:
            pd.DataFrame: Imputed dataframe (with indicators if requested).
        """
        method = method.lower()
        knn_params = knn_params or {}
        mice_params = mice_params or {}

        if method == "knn":
            # Expecting this low-level utility to receive an already-encoded/normalized frame.
            # If you want the full pipeline, keep using knn_imputation (the full pipeline) elsewhere.
            return DataProcessor.knn_imputer(
                data=data,
                dummy_cat_columns=knn_params.pop("dummy_cat_columns", []),
                ordinal_cat_columns=knn_params.pop("ordinal_cat_columns", []),
                n_neighbors=knn_params.pop("n_neighbors", 5),
                add_indicators=add_indicators,
                **knn_params,
            )
        elif method == "mice":
            return DataProcessor.mice_imputation(
                data=data, add_indicators=add_indicators, **mice_params
            )
        else:
            raise ValueError("method must be 'knn' or 'mice'")

    @staticmethod
    def knn_imputer(
        data: pd.DataFrame,
        dummy_cat_columns: list,
        ordinal_cat_columns: Optional[List] = None,
        n_neighbors: int = 5,
        add_indicators: bool = True,
        **kwargs,
    ) -> pd.DataFrame:
        """
        Impute missing values using KNNImputer, with special handling for dummy and ordinal features.

        Args:
            data (pd.DataFrame): Preprocessed DataFrame.
            dummy_cat_columns (list): List of dummy categorical feature names.
            ordinal_cat_columns (List, optional): List of ordinal categorical feature names.
            n_neighbors (int): Number of neighbors for KNN.
            add_indicators (bool, optional): Whether to add missing indicators.
            **kwargs: Additional KNNImputer arguments.

        Returns:
            pd.DataFrame: DataFrame with imputed values and indicators.

        Raises:
            TypeError: On invalid input types.
            ValueError: On imputation errors.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        if not isinstance(dummy_cat_columns, list):
            raise TypeError("dummy_cat_columns must be a list")
        ordinal_cat_columns = ordinal_cat_columns or []
        if not isinstance(ordinal_cat_columns, list):
            raise TypeError("ordinal_cat_columns must be a list")
        try:
            dummy_cat_features_dict = {}
            nan_dummy_features = []
            # Create a dictionary to hold dummy categorical features
            for dummy_cat_feature in dummy_cat_columns:
                for column in data.columns:
                    if column.startswith(dummy_cat_feature):
                        dummy_cat_features_dict[dummy_cat_feature] = (
                            dummy_cat_features_dict.get(dummy_cat_feature, [])
                            + [column]
                        )
            # Check for NaN dummy columns and handle them
            NAN_SUFFIX = "_nan"
            for dummy_cat_feature, dummy_columns in dummy_cat_features_dict.items():
                for column in dummy_columns:
                    if column.endswith(NAN_SUFFIX):
                        nan_dummy_features.append(column)
            # Refill np.nan to dummy categorical features
            index_dict = {}
            for key, value in dummy_cat_features_dict.items():
                num_type_cat_feature = len(value) - 1
                index_dict[key] = []
                for row_i in data.index:
                    if np.array_equal(
                        data.loc[row_i, dummy_cat_features_dict[key]].values,
                        np.array([0] * num_type_cat_feature + [1]),
                    ):
                        data.loc[row_i, dummy_cat_features_dict[key]] = [
                            np.nan
                        ] * num_type_cat_feature + [1]
                        index_dict[key].append(row_i)
            # Remove NaN dummy columns from the DataFrame
            if len(nan_dummy_features) > 0:
                data = data.drop(nan_dummy_features, axis=1)
            # Use KNN imputer to fill missing values
            imputer = KNNImputer(
                n_neighbors=n_neighbors,
                weights="distance",
                add_indicator=add_indicators,
                **kwargs,
            )
            data_imputed = pd.DataFrame(
                imputer.fit_transform(data),
                columns=imputer.get_feature_names_out(),
                index=data.index,
            )
            # Post-process the imputed data
            data_imputed[ordinal_cat_columns] = data_imputed[
                ordinal_cat_columns
            ].round()
            for key, value in index_dict.items():
                features = dummy_cat_features_dict[key][0:-1]
                argmax_index = np.argmax(
                    data_imputed.loc[value, features].values, axis=1
                )
                for i in range(len(value)):
                    data_imputed.loc[value[i], features] = [0] * len(features)
                    data_imputed.loc[value[i], features[argmax_index[i]]] = 1

            # Combine duplicate missing indicators for dummy variables
            if add_indicators:
                indicator_cols = [
                    col
                    for col in data_imputed.columns
                    if col.startswith("missingindicator_")
                ]
                # Map: dummy feature -> list of indicator columns
                indicator_map = {}
                for col in indicator_cols:
                    for dummy_cat_feature in dummy_cat_columns:
                        if col.startswith(f"missingindicator_{dummy_cat_feature}"):
                            indicator_map.setdefault(dummy_cat_feature, []).append(col)
                # For each dummy feature, keep only one indicator column (since all are identical)
                for dummy_cat_feature, cols in indicator_map.items():
                    if len(cols) > 1:
                        # Keep the first column, drop the rest, and rename to standard name
                        combined_col = f"missingindicator_{dummy_cat_feature}"
                        data_imputed[combined_col] = data_imputed[cols[0]]
                        data_imputed.drop(columns=cols, inplace=True)
                    elif len(cols) == 1:
                        # Rename to standard name if needed
                        col = cols[0]
                        combined_col = f"missingindicator_{dummy_cat_feature}"
                        if col != combined_col:
                            data_imputed.rename(
                                columns={col: combined_col}, inplace=True
                            )
            return data_imputed
        except Exception as e:
            raise ValueError(f"Error in KNN imputation: {e}")

    @staticmethod
    def extract_missingindicator_columns(data: pd.DataFrame) -> pd.DataFrame:
        """
        Extract columns indicating missingness from imputed DataFrame.

        Args:
            data (pd.DataFrame): DataFrame after imputation.

        Returns:
            pd.DataFrame: DataFrame with only missing indicator columns.

        Raises:
            TypeError: If input is not a pandas DataFrame.
            ValueError: On extraction errors.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        try:
            missing_indicator_cols = [
                col for col in data.columns if col.startswith("missingindicator_")
            ]
            return data[missing_indicator_cols]
        except Exception as e:
            raise ValueError(f"Error extracting missing indicator columns: {e}")

    ### Processing after imputation ###
    @staticmethod
    def inverse_dummy_features(
        data: pd.DataFrame, dummy_cat_columns: list
    ) -> pd.DataFrame:
        """
        Decode dummy-encoded categorical features back to original categories.

        Args:
            data (pd.DataFrame): DataFrame with dummy-encoded features.
            dummy_cat_columns (list): List of dummy categorical feature names.

        Returns:
            pd.DataFrame: DataFrame with decoded categorical columns.

        Raises:
            TypeError: On invalid input types or empty list.
            ValueError: On decoding errors.
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input data must be a pandas DataFrame")
        if not isinstance(dummy_cat_columns, list) or not dummy_cat_columns:
            raise TypeError(
                "dummy_cat_columns must be a non-empty list of column names"
            )
        try:
            cat_features_dict = {}
            for cat_feature in dummy_cat_columns:
                for column in data.columns:
                    if column.startswith(cat_feature):
                        cat_features_dict[cat_feature] = cat_features_dict.get(
                            cat_feature, []
                        ) + [column]
            for feature, cols in cat_features_dict.items():
                data[feature] = np.array(cols)[data[cols].to_numpy().argmax(axis=1)]
                data[feature] = data[feature].str.replace(f"{feature}_", "")
                data.drop(columns=cols, inplace=True)
            data = data[dummy_cat_columns].astype("category")
            return data
        except Exception as e:
            raise ValueError(f"Error in inverse encoding categorical features: {e}")

    @staticmethod
    def inverse_ordinal_features(
        data: pd.DataFrame, encoder: OrdinalEncoder
    ) -> pd.DataFrame:
        """
        Decode ordinal categorical features from encoded values back to original categories.

        Args:
            data (pd.DataFrame): DataFrame with encoded ordinal features.
            encoder (OrdinalEncoder): Fitted ordinal encoder.

        Returns:
            pd.DataFrame: Decoded ordinal categorical features.

        Raises:
            AttributeError: If encoder is not initialized.
            TypeError: If input is not a DataFrame.
            ValueError: On decoding errors.
        """
        if not isinstance(encoder, OrdinalEncoder):
            raise AttributeError(
                "Ordinal encoder not initialized. Run encode_ordinal_features first"
            )
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input data must be a pandas DataFrame")
        try:
            data_round = data.round()
            return pd.DataFrame(
                encoder.inverse_transform(data_round),
                columns=data.columns,
                index=data.index,
            )
        except Exception as e:
            raise ValueError(f"Error in inverse ordinal encoding: {e}")

    @staticmethod
    def inverse_standardization(
        data: pd.DataFrame, scaler: BaseEstimator
    ) -> pd.DataFrame:
        """
        Undo normalization/standardization on numerical features.

        Args:
            data (pd.DataFrame): DataFrame with normalized features.
            scaler (BaseEstimator): Fitted scaler.

        Returns:
            pd.DataFrame: DataFrame with original scale restored.

        Raises:
            AttributeError: If scaler is not initialized.
            TypeError: If input is not a DataFrame.
            ValueError: On inverse normalization errors.
        """
        if not isinstance(scaler, BaseEstimator):
            raise AttributeError("Scaler not initialized. Run standardization first")
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        try:
            data_inverse_normalized = pd.DataFrame(
                scaler.inverse_transform(data),
                columns=data.columns,
                index=data.index,
            )
            return data_inverse_normalized.astype("float64")
        except Exception as e:
            raise ValueError(f"Error in inverse normalization: {e}")

    @staticmethod
    def knn_imputation(
        data: pd.DataFrame,
        dummy_cat_columns: Optional[List] = None,
        ordinal_cat_columns: Optional[List] = None,
        numerical_columns: Optional[List] = None,
        scaler: str = "minmax",
        n_neighbors: int = 5,
        add_indicators: bool = True,
        verbose: bool = True,
        **kwargs,
    ) -> pd.DataFrame:
        """
        Perform full preprocessing and KNN imputation, including encoding, normalization, and postprocessing.

        Args:
            data (pd.DataFrame): Input DataFrame.
            dummy_cat_columns (List, optional): Dummy categorical columns.
            ordinal_cat_columns (List, optional): Ordinal categorical columns.
            numerical_columns (List, optional): Numerical columns.
            scaler (str, optional): Scaler type ('minmax', 'standard', 'robust').
            n_neighbors (int, optional): KNN neighbors.
            add_indicators (bool, optional): Add missing indicators.
            verbose (bool, optional): Print progress.
            **kwargs: Additional KNNImputer arguments.

        Returns:
            pd.DataFrame: DataFrame after imputation and decoding.

        Raises:
            TypeError: On invalid input.
            ValueError: On missing columns or errors.
        """
        dummy_cat_columns = dummy_cat_columns or []
        ordinal_cat_columns = ordinal_cat_columns or []
        numerical_columns = numerical_columns or []

        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")

        all_cols = dummy_cat_columns + ordinal_cat_columns + numerical_columns
        if not all(col in data.columns for col in all_cols):
            missing_cols = [col for col in all_cols if col not in data.columns]
            raise ValueError(f"Some specified columns not in dataset: {missing_cols}")

        try:
            if verbose:
                print("IMPUTATION: Starting preprocessing")
            preprocessed_parts = []

            or_encoder: Optional[OrdinalEncoder] = None
            scaler_obj: Optional[BaseEstimator] = None

            if ordinal_cat_columns:
                if verbose:
                    print(
                        "IMPUTATION-PREPROCESSING: Encoding ordinal categorical features"
                    )
                ordinal_data = data[ordinal_cat_columns]
                ordinal_encoded, or_encoder = DataProcessor.encode_ordinal_features(
                    ordinal_data
                )
                preprocessed_parts.append(ordinal_encoded)

            if dummy_cat_columns:
                if verbose:
                    print(
                        "IMPUTATION-PREPROCESSING: Encoding dummy categorical features"
                    )
                dummy_data = data[dummy_cat_columns]
                dummy_encoded = DataProcessor.encode_dummy_features(dummy_data)
                preprocessed_parts.append(dummy_encoded)

            if numerical_columns:
                if verbose:
                    print("IMPUTATION-PREPROCESSING: Standardizing numerical features")
                numerical_data = data[numerical_columns]
                numerical_normalized, scaler_obj = DataProcessor.standardization(
                    numerical_data, scaler=scaler
                )
                preprocessed_parts.append(numerical_normalized)

            if not preprocessed_parts:
                raise ValueError("No columns specified for processing")

            data_preprocessed = pd.concat(preprocessed_parts, axis=1)

            if verbose:
                print("IMPUTATION: Initializing KNN imputer")
            data_imputed = DataProcessor.knn_imputer(
                data_preprocessed,
                dummy_cat_columns=dummy_cat_columns,
                ordinal_cat_columns=ordinal_cat_columns,
                n_neighbors=n_neighbors,
                add_indicators=add_indicators,
                **kwargs,
            )

            missing_indicator_cols = [
                col
                for col in data_imputed.columns
                if col.startswith("missingindicator_")
            ]
            postprocessed_parts = []
            if verbose:
                print("IMPUTATION: Starting postprocessing")

            if ordinal_cat_columns:
                if verbose:
                    print(
                        "IMPUTATION-POSTPROCESSING: Decoding ordinal categorical features"
                    )
                ordinal_imputed = data_imputed[ordinal_cat_columns]
                if or_encoder is None:
                    raise RuntimeError("Ordinal encoder is not initialized")
                ordinal_decoded = DataProcessor.inverse_ordinal_features(
                    ordinal_imputed, or_encoder
                )
                postprocessed_parts.append(ordinal_decoded)

            if dummy_cat_columns:
                if verbose:
                    print(
                        "IMPUTATION-POSTPROCESSING: Decoding dummy categorical features"
                    )
                dummy_cols_after_encoding = [
                    col
                    for col in data_imputed.columns
                    if any(col.startswith(cat) for cat in dummy_cat_columns)
                ]
                dummy_imputed = data_imputed[dummy_cols_after_encoding]
                dummy_decoded = DataProcessor.inverse_dummy_features(
                    dummy_imputed, dummy_cat_columns
                )
                postprocessed_parts.append(dummy_decoded)

            if numerical_columns:
                if verbose:
                    print(
                        "IMPUTATION-POSTPROCESSING: Inverse normalizing numerical features"
                    )
                numerical_imputed = data_imputed[numerical_columns]
                if scaler_obj is None:
                    raise RuntimeError("Scaler is not initialized")
                numerical_denormalized = DataProcessor.inverse_standardization(
                    numerical_imputed, scaler_obj
                )
                postprocessed_parts.append(numerical_denormalized)
            # Extract missing indicator columns
            if verbose:
                print("IMPUTATION-POSTPROCESSING: Extracting missing indicator columns")
            missing_indicator_cols = DataProcessor.extract_missingindicator_columns(
                data_imputed
            )
            postprocessed_parts.append(missing_indicator_cols)

            data_final = pd.concat(postprocessed_parts, axis=1)
            if verbose:
                print("IMPUTATION: Data imputation completed")
            return data_final
        except Exception as e:
            raise ValueError(f"Error in data imputation: {e}")

    @staticmethod
    def feature_engineering(
        data: pd.DataFrame,
        data_type: str,
        overmissing_threshold: float = 50,
        imputer: str = "mice",
        ordinal_cat_columns: Optional[List[str]] = None,
        unique_threshold: int = 10,
        scaler: str = "minmax",
        imputer_params: Optional[Dict[str, Any]] = None,
        add_indicators: bool = True,
        verbose: bool = True,
    ) -> tuple:
        """
        Perform feature engineering and imputation using a chosen method with its specific params.

        Args:
            data (pd.DataFrame): Input DataFrame.
            data_type (str): 'transcriptomics' or 'clinical'.
            overmissing_threshold (float): Drop features with missingness > threshold.
            imputer (str): 'mice' or 'knn'.
            ordinal_cat_columns (List[str], optional): Known ordinal categorical columns.
            unique_threshold (int): Threshold for classifying dummy features (clinical).
            scaler (str): 'minmax' | 'standard' | 'robust'.
            imputer_params (dict, optional): Method-specific params.
                - For 'knn': pass keys as in knn_imputer (e.g., n_neighbors, dummy_cat_columns, ordinal_cat_columns)
                - For 'mice': pass keys as in mice_imputation (e.g., iterations, n_estimators, random_state, verbose, …)
            add_indicators (bool): Append missing indicators.
            verbose (bool): Print progress.

        Returns:
            tuple: (processed_df, numerical_features, dummy_categorical_features)
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        if data_type not in ["transcriptomics", "clinical"]:
            raise ValueError("Data type must be 'transcriptomics' or 'clinical'")
        imputer_params = imputer_params or {}
        ordinal_cat_columns = ordinal_cat_columns or []

        if verbose:
            print(
                f"FEATURE ENGINEERING - OVERMISSING: Removing features with missing > {overmissing_threshold}%"
            )
        x = DataProcessor.remove_overmissing_features(
            data=data, threshold=overmissing_threshold
        )

        if data_type == "clinical":
            if verbose:
                print(
                    "FEATURE ENGINEERING - METADATA: Classifying feature types for clinical data"
                )
            dummy_cat_features, num_features = MetaData.classify_features_types(
                data=x,
                threshold_unique_values=unique_threshold,
                ordinal_features=ordinal_cat_columns,
            )
            if verbose:
                print(
                    f"  Dummy categorical: {len(dummy_cat_features)} | Numerical: {len(num_features)} | Ordinal: {len(ordinal_cat_columns)}"
                )
        else:
            if verbose:
                print("FEATURE ENGINEERING - METADATA: Transcriptomics → all numerical")
            dummy_cat_features = []
            num_features = x.columns.tolist()

        if verbose:
            print(f"FEATURE ENGINEERING - IMPUTATION: Using {imputer.upper()}")

        try:
            from lightgbm.basic import LightGBMError  # type: ignore

            _mice_excs = (ValueError, IndexError, LightGBMError)
        except Exception:
            _mice_excs = (ValueError, IndexError)

        if imputer.lower() == "knn":
            # Keep existing full KNN pipeline (encode→scale→impute→decode)
            processed_df = DataProcessor.knn_imputation(
                data=x,
                dummy_cat_columns=dummy_cat_features,
                ordinal_cat_columns=ordinal_cat_columns,
                numerical_columns=num_features,
                scaler=scaler,
                add_indicators=add_indicators,
                verbose=verbose,
                **imputer_params,  # method-specific extras go here
            )

        elif imputer.lower() == "mice":
            try:
                # For MICE, assume you've already preprocessed externally OR want to impute raw mixed df.
                # If you want to enforce encoding rules first, you can plug a preprocessing step here.
                processed_df = DataProcessor.mice_imputation(
                    data=x,
                    add_indicators=add_indicators,
                    verbose=verbose,
                    **imputer_params,  # iterations, n_estimators, random_state, verbose, etc.
                )
            except _mice_excs as e:
                print(x.columns)
                verbose = False
                n_fallback = 5
                processed_df = DataProcessor.knn_imputation(
                    data=x,
                    dummy_cat_columns=dummy_cat_features,
                    ordinal_cat_columns=ordinal_cat_columns,
                    numerical_columns=num_features,
                    scaler=scaler,
                    add_indicators=add_indicators,
                    verbose=verbose,
                    n_neighbors=5,  # method-specific extras go here
                )
                print(
                    f"\nMICE failed ({e.__class__.__name__}: {e}). Falling back to KNN (n_neighbors={n_fallback})."
                )
        else:
            raise ValueError("Imputer must be 'mice' or 'knn'")

        if verbose:
            print("FEATURE ENGINEERING: Completed")
        return processed_df, num_features, dummy_cat_features

encode_dummy_features(data) staticmethod

Encode categorical features into dummy/one-hot columns.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with categorical features.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Dummy-encoded DataFrame.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

ValueError

If encoding fails.

Source code in src/synomicsbench/processing/preprocessing.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
@staticmethod
def encode_dummy_features(data: pd.DataFrame) -> pd.DataFrame:
    """
    Encode categorical features into dummy/one-hot columns.

    Args:
        data (pd.DataFrame): DataFrame with categorical features.

    Returns:
        pd.DataFrame: Dummy-encoded DataFrame.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        ValueError: If encoding fails.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    try:
        cat_data_encoded = pd.get_dummies(
            data, dtype="int64", dummy_na=True, columns=data.columns
        )
        return cat_data_encoded
    except Exception as e:
        raise ValueError(f"Error encoding categorical features: {e}")

encode_ordinal_features(data) staticmethod

Encode ordinal categorical features using OrdinalEncoder.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with ordinal categorical features.

required

Returns:

Type Description
Tuple[DataFrame, OrdinalEncoder]

Tuple[pd.DataFrame, OrdinalEncoder]: Encoded DataFrame and fitted encoder.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

ValueError

If encoding fails.

Source code in src/synomicsbench/processing/preprocessing.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
@staticmethod
def encode_ordinal_features(
    data: pd.DataFrame,
) -> Tuple[pd.DataFrame, OrdinalEncoder]:
    """
    Encode ordinal categorical features using OrdinalEncoder.

    Args:
        data (pd.DataFrame): DataFrame with ordinal categorical features.

    Returns:
        Tuple[pd.DataFrame, OrdinalEncoder]: Encoded DataFrame and fitted encoder.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        ValueError: If encoding fails.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    try:
        oe = OrdinalEncoder()
        encoded_values = oe.fit_transform(data)
        df_ordinal_encoded = pd.DataFrame(
            encoded_values, columns=data.columns, index=data.index
        )
        return df_ordinal_encoded, oe
    except Exception as e:
        raise ValueError(f"Error encoding ordinal categorical features: {e}")

extract_missingindicator_columns(data) staticmethod

Extract columns indicating missingness from imputed DataFrame.

Parameters:

Name Type Description Default
data DataFrame

DataFrame after imputation.

required

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with only missing indicator columns.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

ValueError

On extraction errors.

Source code in src/synomicsbench/processing/preprocessing.py
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
@staticmethod
def extract_missingindicator_columns(data: pd.DataFrame) -> pd.DataFrame:
    """
    Extract columns indicating missingness from imputed DataFrame.

    Args:
        data (pd.DataFrame): DataFrame after imputation.

    Returns:
        pd.DataFrame: DataFrame with only missing indicator columns.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        ValueError: On extraction errors.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    try:
        missing_indicator_cols = [
            col for col in data.columns if col.startswith("missingindicator_")
        ]
        return data[missing_indicator_cols]
    except Exception as e:
        raise ValueError(f"Error extracting missing indicator columns: {e}")

feature_engineering(data, data_type, overmissing_threshold=50, imputer='mice', ordinal_cat_columns=None, unique_threshold=10, scaler='minmax', imputer_params=None, add_indicators=True, verbose=True) staticmethod

Perform feature engineering and imputation using a chosen method with its specific params.

Parameters:

Name Type Description Default
data DataFrame

Input DataFrame.

required
data_type str

'transcriptomics' or 'clinical'.

required
overmissing_threshold float

Drop features with missingness > threshold.

50
imputer str

'mice' or 'knn'.

'mice'
ordinal_cat_columns List[str]

Known ordinal categorical columns.

None
unique_threshold int

Threshold for classifying dummy features (clinical).

10
scaler str

'minmax' | 'standard' | 'robust'.

'minmax'
imputer_params dict

Method-specific params. - For 'knn': pass keys as in knn_imputer (e.g., n_neighbors, dummy_cat_columns, ordinal_cat_columns) - For 'mice': pass keys as in mice_imputation (e.g., iterations, n_estimators, random_state, verbose, …)

None
add_indicators bool

Append missing indicators.

True
verbose bool

Print progress.

True

Returns:

Name Type Description
tuple tuple

(processed_df, numerical_features, dummy_categorical_features)

Source code in src/synomicsbench/processing/preprocessing.py
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
@staticmethod
def feature_engineering(
    data: pd.DataFrame,
    data_type: str,
    overmissing_threshold: float = 50,
    imputer: str = "mice",
    ordinal_cat_columns: Optional[List[str]] = None,
    unique_threshold: int = 10,
    scaler: str = "minmax",
    imputer_params: Optional[Dict[str, Any]] = None,
    add_indicators: bool = True,
    verbose: bool = True,
) -> tuple:
    """
    Perform feature engineering and imputation using a chosen method with its specific params.

    Args:
        data (pd.DataFrame): Input DataFrame.
        data_type (str): 'transcriptomics' or 'clinical'.
        overmissing_threshold (float): Drop features with missingness > threshold.
        imputer (str): 'mice' or 'knn'.
        ordinal_cat_columns (List[str], optional): Known ordinal categorical columns.
        unique_threshold (int): Threshold for classifying dummy features (clinical).
        scaler (str): 'minmax' | 'standard' | 'robust'.
        imputer_params (dict, optional): Method-specific params.
            - For 'knn': pass keys as in knn_imputer (e.g., n_neighbors, dummy_cat_columns, ordinal_cat_columns)
            - For 'mice': pass keys as in mice_imputation (e.g., iterations, n_estimators, random_state, verbose, …)
        add_indicators (bool): Append missing indicators.
        verbose (bool): Print progress.

    Returns:
        tuple: (processed_df, numerical_features, dummy_categorical_features)
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    if data_type not in ["transcriptomics", "clinical"]:
        raise ValueError("Data type must be 'transcriptomics' or 'clinical'")
    imputer_params = imputer_params or {}
    ordinal_cat_columns = ordinal_cat_columns or []

    if verbose:
        print(
            f"FEATURE ENGINEERING - OVERMISSING: Removing features with missing > {overmissing_threshold}%"
        )
    x = DataProcessor.remove_overmissing_features(
        data=data, threshold=overmissing_threshold
    )

    if data_type == "clinical":
        if verbose:
            print(
                "FEATURE ENGINEERING - METADATA: Classifying feature types for clinical data"
            )
        dummy_cat_features, num_features = MetaData.classify_features_types(
            data=x,
            threshold_unique_values=unique_threshold,
            ordinal_features=ordinal_cat_columns,
        )
        if verbose:
            print(
                f"  Dummy categorical: {len(dummy_cat_features)} | Numerical: {len(num_features)} | Ordinal: {len(ordinal_cat_columns)}"
            )
    else:
        if verbose:
            print("FEATURE ENGINEERING - METADATA: Transcriptomics → all numerical")
        dummy_cat_features = []
        num_features = x.columns.tolist()

    if verbose:
        print(f"FEATURE ENGINEERING - IMPUTATION: Using {imputer.upper()}")

    try:
        from lightgbm.basic import LightGBMError  # type: ignore

        _mice_excs = (ValueError, IndexError, LightGBMError)
    except Exception:
        _mice_excs = (ValueError, IndexError)

    if imputer.lower() == "knn":
        # Keep existing full KNN pipeline (encode→scale→impute→decode)
        processed_df = DataProcessor.knn_imputation(
            data=x,
            dummy_cat_columns=dummy_cat_features,
            ordinal_cat_columns=ordinal_cat_columns,
            numerical_columns=num_features,
            scaler=scaler,
            add_indicators=add_indicators,
            verbose=verbose,
            **imputer_params,  # method-specific extras go here
        )

    elif imputer.lower() == "mice":
        try:
            # For MICE, assume you've already preprocessed externally OR want to impute raw mixed df.
            # If you want to enforce encoding rules first, you can plug a preprocessing step here.
            processed_df = DataProcessor.mice_imputation(
                data=x,
                add_indicators=add_indicators,
                verbose=verbose,
                **imputer_params,  # iterations, n_estimators, random_state, verbose, etc.
            )
        except _mice_excs as e:
            print(x.columns)
            verbose = False
            n_fallback = 5
            processed_df = DataProcessor.knn_imputation(
                data=x,
                dummy_cat_columns=dummy_cat_features,
                ordinal_cat_columns=ordinal_cat_columns,
                numerical_columns=num_features,
                scaler=scaler,
                add_indicators=add_indicators,
                verbose=verbose,
                n_neighbors=5,  # method-specific extras go here
            )
            print(
                f"\nMICE failed ({e.__class__.__name__}: {e}). Falling back to KNN (n_neighbors={n_fallback})."
            )
    else:
        raise ValueError("Imputer must be 'mice' or 'knn'")

    if verbose:
        print("FEATURE ENGINEERING: Completed")
    return processed_df, num_features, dummy_cat_features

find_missing_percent(data) staticmethod

Calculate the percentage of missing values for each column.

Parameters:

Name Type Description Default
data DataFrame

Input DataFrame.

required

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with columns 'ColumnName' and 'PercentMissing'.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

RuntimeError

On error during calculation.

Source code in src/synomicsbench/processing/preprocessing.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@staticmethod
def find_missing_percent(data: pd.DataFrame) -> pd.DataFrame:
    """
    Calculate the percentage of missing values for each column.

    Args:
        data (pd.DataFrame): Input DataFrame.

    Returns:
        pd.DataFrame: DataFrame with columns 'ColumnName' and 'PercentMissing'.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        RuntimeError: On error during calculation.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Data input must be a pandas DataFrame")
    try:
        return (
            (data.isnull().sum() / len(data) * 100)
            .reset_index()
            .rename(columns={0: "PercentMissing", "index": "ColumnName"})
        )
    except Exception as e:
        raise RuntimeError(f"Error calculating missing percentages: {e}")

impute_router(data, method='mice', *, add_indicators=True, knn_params=None, mice_params=None) staticmethod

Dispatch imputation to KNN or MICE with method-specific parameters.

Parameters:

Name Type Description Default
data DataFrame

Input frame (already preprocessed if needed).

required
method str

'knn' or 'mice'.

'mice'
add_indicators bool

Whether to append missing indicators.

True
knn_params dict

Passed to knn_imputer (n_neighbors, etc.).

None
mice_params dict

Passed to mice_imputation (iterations, n_estimators, etc.).

None

Returns:

Type Description
DataFrame

pd.DataFrame: Imputed dataframe (with indicators if requested).

Source code in src/synomicsbench/processing/preprocessing.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
@staticmethod
def impute_router(
    data: pd.DataFrame,
    method: str = "mice",
    *,
    add_indicators: bool = True,
    knn_params: Optional[Dict[str, Any]] = None,
    mice_params: Optional[Dict[str, Any]] = None,
) -> pd.DataFrame:
    """
    Dispatch imputation to KNN or MICE with method-specific parameters.

    Args:
        data (pd.DataFrame): Input frame (already preprocessed if needed).
        method (str): 'knn' or 'mice'.
        add_indicators (bool): Whether to append missing indicators.
        knn_params (dict): Passed to knn_imputer (n_neighbors, etc.).
        mice_params (dict): Passed to mice_imputation (iterations, n_estimators, etc.).

    Returns:
        pd.DataFrame: Imputed dataframe (with indicators if requested).
    """
    method = method.lower()
    knn_params = knn_params or {}
    mice_params = mice_params or {}

    if method == "knn":
        # Expecting this low-level utility to receive an already-encoded/normalized frame.
        # If you want the full pipeline, keep using knn_imputation (the full pipeline) elsewhere.
        return DataProcessor.knn_imputer(
            data=data,
            dummy_cat_columns=knn_params.pop("dummy_cat_columns", []),
            ordinal_cat_columns=knn_params.pop("ordinal_cat_columns", []),
            n_neighbors=knn_params.pop("n_neighbors", 5),
            add_indicators=add_indicators,
            **knn_params,
        )
    elif method == "mice":
        return DataProcessor.mice_imputation(
            data=data, add_indicators=add_indicators, **mice_params
        )
    else:
        raise ValueError("method must be 'knn' or 'mice'")

inverse_dummy_features(data, dummy_cat_columns) staticmethod

Decode dummy-encoded categorical features back to original categories.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with dummy-encoded features.

required
dummy_cat_columns list

List of dummy categorical feature names.

required

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with decoded categorical columns.

Raises:

Type Description
TypeError

On invalid input types or empty list.

ValueError

On decoding errors.

Source code in src/synomicsbench/processing/preprocessing.py
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
@staticmethod
def inverse_dummy_features(
    data: pd.DataFrame, dummy_cat_columns: list
) -> pd.DataFrame:
    """
    Decode dummy-encoded categorical features back to original categories.

    Args:
        data (pd.DataFrame): DataFrame with dummy-encoded features.
        dummy_cat_columns (list): List of dummy categorical feature names.

    Returns:
        pd.DataFrame: DataFrame with decoded categorical columns.

    Raises:
        TypeError: On invalid input types or empty list.
        ValueError: On decoding errors.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input data must be a pandas DataFrame")
    if not isinstance(dummy_cat_columns, list) or not dummy_cat_columns:
        raise TypeError(
            "dummy_cat_columns must be a non-empty list of column names"
        )
    try:
        cat_features_dict = {}
        for cat_feature in dummy_cat_columns:
            for column in data.columns:
                if column.startswith(cat_feature):
                    cat_features_dict[cat_feature] = cat_features_dict.get(
                        cat_feature, []
                    ) + [column]
        for feature, cols in cat_features_dict.items():
            data[feature] = np.array(cols)[data[cols].to_numpy().argmax(axis=1)]
            data[feature] = data[feature].str.replace(f"{feature}_", "")
            data.drop(columns=cols, inplace=True)
        data = data[dummy_cat_columns].astype("category")
        return data
    except Exception as e:
        raise ValueError(f"Error in inverse encoding categorical features: {e}")

inverse_ordinal_features(data, encoder) staticmethod

Decode ordinal categorical features from encoded values back to original categories.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with encoded ordinal features.

required
encoder OrdinalEncoder

Fitted ordinal encoder.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Decoded ordinal categorical features.

Raises:

Type Description
AttributeError

If encoder is not initialized.

TypeError

If input is not a DataFrame.

ValueError

On decoding errors.

Source code in src/synomicsbench/processing/preprocessing.py
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
@staticmethod
def inverse_ordinal_features(
    data: pd.DataFrame, encoder: OrdinalEncoder
) -> pd.DataFrame:
    """
    Decode ordinal categorical features from encoded values back to original categories.

    Args:
        data (pd.DataFrame): DataFrame with encoded ordinal features.
        encoder (OrdinalEncoder): Fitted ordinal encoder.

    Returns:
        pd.DataFrame: Decoded ordinal categorical features.

    Raises:
        AttributeError: If encoder is not initialized.
        TypeError: If input is not a DataFrame.
        ValueError: On decoding errors.
    """
    if not isinstance(encoder, OrdinalEncoder):
        raise AttributeError(
            "Ordinal encoder not initialized. Run encode_ordinal_features first"
        )
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input data must be a pandas DataFrame")
    try:
        data_round = data.round()
        return pd.DataFrame(
            encoder.inverse_transform(data_round),
            columns=data.columns,
            index=data.index,
        )
    except Exception as e:
        raise ValueError(f"Error in inverse ordinal encoding: {e}")

inverse_standardization(data, scaler) staticmethod

Undo normalization/standardization on numerical features.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with normalized features.

required
scaler BaseEstimator

Fitted scaler.

required

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with original scale restored.

Raises:

Type Description
AttributeError

If scaler is not initialized.

TypeError

If input is not a DataFrame.

ValueError

On inverse normalization errors.

Source code in src/synomicsbench/processing/preprocessing.py
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
@staticmethod
def inverse_standardization(
    data: pd.DataFrame, scaler: BaseEstimator
) -> pd.DataFrame:
    """
    Undo normalization/standardization on numerical features.

    Args:
        data (pd.DataFrame): DataFrame with normalized features.
        scaler (BaseEstimator): Fitted scaler.

    Returns:
        pd.DataFrame: DataFrame with original scale restored.

    Raises:
        AttributeError: If scaler is not initialized.
        TypeError: If input is not a DataFrame.
        ValueError: On inverse normalization errors.
    """
    if not isinstance(scaler, BaseEstimator):
        raise AttributeError("Scaler not initialized. Run standardization first")
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    try:
        data_inverse_normalized = pd.DataFrame(
            scaler.inverse_transform(data),
            columns=data.columns,
            index=data.index,
        )
        return data_inverse_normalized.astype("float64")
    except Exception as e:
        raise ValueError(f"Error in inverse normalization: {e}")

knn_imputation(data, dummy_cat_columns=None, ordinal_cat_columns=None, numerical_columns=None, scaler='minmax', n_neighbors=5, add_indicators=True, verbose=True, **kwargs) staticmethod

Perform full preprocessing and KNN imputation, including encoding, normalization, and postprocessing.

Parameters:

Name Type Description Default
data DataFrame

Input DataFrame.

required
dummy_cat_columns List

Dummy categorical columns.

None
ordinal_cat_columns List

Ordinal categorical columns.

None
numerical_columns List

Numerical columns.

None
scaler str

Scaler type ('minmax', 'standard', 'robust').

'minmax'
n_neighbors int

KNN neighbors.

5
add_indicators bool

Add missing indicators.

True
verbose bool

Print progress.

True
**kwargs

Additional KNNImputer arguments.

{}

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame after imputation and decoding.

Raises:

Type Description
TypeError

On invalid input.

ValueError

On missing columns or errors.

Source code in src/synomicsbench/processing/preprocessing.py
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
@staticmethod
def knn_imputation(
    data: pd.DataFrame,
    dummy_cat_columns: Optional[List] = None,
    ordinal_cat_columns: Optional[List] = None,
    numerical_columns: Optional[List] = None,
    scaler: str = "minmax",
    n_neighbors: int = 5,
    add_indicators: bool = True,
    verbose: bool = True,
    **kwargs,
) -> pd.DataFrame:
    """
    Perform full preprocessing and KNN imputation, including encoding, normalization, and postprocessing.

    Args:
        data (pd.DataFrame): Input DataFrame.
        dummy_cat_columns (List, optional): Dummy categorical columns.
        ordinal_cat_columns (List, optional): Ordinal categorical columns.
        numerical_columns (List, optional): Numerical columns.
        scaler (str, optional): Scaler type ('minmax', 'standard', 'robust').
        n_neighbors (int, optional): KNN neighbors.
        add_indicators (bool, optional): Add missing indicators.
        verbose (bool, optional): Print progress.
        **kwargs: Additional KNNImputer arguments.

    Returns:
        pd.DataFrame: DataFrame after imputation and decoding.

    Raises:
        TypeError: On invalid input.
        ValueError: On missing columns or errors.
    """
    dummy_cat_columns = dummy_cat_columns or []
    ordinal_cat_columns = ordinal_cat_columns or []
    numerical_columns = numerical_columns or []

    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")

    all_cols = dummy_cat_columns + ordinal_cat_columns + numerical_columns
    if not all(col in data.columns for col in all_cols):
        missing_cols = [col for col in all_cols if col not in data.columns]
        raise ValueError(f"Some specified columns not in dataset: {missing_cols}")

    try:
        if verbose:
            print("IMPUTATION: Starting preprocessing")
        preprocessed_parts = []

        or_encoder: Optional[OrdinalEncoder] = None
        scaler_obj: Optional[BaseEstimator] = None

        if ordinal_cat_columns:
            if verbose:
                print(
                    "IMPUTATION-PREPROCESSING: Encoding ordinal categorical features"
                )
            ordinal_data = data[ordinal_cat_columns]
            ordinal_encoded, or_encoder = DataProcessor.encode_ordinal_features(
                ordinal_data
            )
            preprocessed_parts.append(ordinal_encoded)

        if dummy_cat_columns:
            if verbose:
                print(
                    "IMPUTATION-PREPROCESSING: Encoding dummy categorical features"
                )
            dummy_data = data[dummy_cat_columns]
            dummy_encoded = DataProcessor.encode_dummy_features(dummy_data)
            preprocessed_parts.append(dummy_encoded)

        if numerical_columns:
            if verbose:
                print("IMPUTATION-PREPROCESSING: Standardizing numerical features")
            numerical_data = data[numerical_columns]
            numerical_normalized, scaler_obj = DataProcessor.standardization(
                numerical_data, scaler=scaler
            )
            preprocessed_parts.append(numerical_normalized)

        if not preprocessed_parts:
            raise ValueError("No columns specified for processing")

        data_preprocessed = pd.concat(preprocessed_parts, axis=1)

        if verbose:
            print("IMPUTATION: Initializing KNN imputer")
        data_imputed = DataProcessor.knn_imputer(
            data_preprocessed,
            dummy_cat_columns=dummy_cat_columns,
            ordinal_cat_columns=ordinal_cat_columns,
            n_neighbors=n_neighbors,
            add_indicators=add_indicators,
            **kwargs,
        )

        missing_indicator_cols = [
            col
            for col in data_imputed.columns
            if col.startswith("missingindicator_")
        ]
        postprocessed_parts = []
        if verbose:
            print("IMPUTATION: Starting postprocessing")

        if ordinal_cat_columns:
            if verbose:
                print(
                    "IMPUTATION-POSTPROCESSING: Decoding ordinal categorical features"
                )
            ordinal_imputed = data_imputed[ordinal_cat_columns]
            if or_encoder is None:
                raise RuntimeError("Ordinal encoder is not initialized")
            ordinal_decoded = DataProcessor.inverse_ordinal_features(
                ordinal_imputed, or_encoder
            )
            postprocessed_parts.append(ordinal_decoded)

        if dummy_cat_columns:
            if verbose:
                print(
                    "IMPUTATION-POSTPROCESSING: Decoding dummy categorical features"
                )
            dummy_cols_after_encoding = [
                col
                for col in data_imputed.columns
                if any(col.startswith(cat) for cat in dummy_cat_columns)
            ]
            dummy_imputed = data_imputed[dummy_cols_after_encoding]
            dummy_decoded = DataProcessor.inverse_dummy_features(
                dummy_imputed, dummy_cat_columns
            )
            postprocessed_parts.append(dummy_decoded)

        if numerical_columns:
            if verbose:
                print(
                    "IMPUTATION-POSTPROCESSING: Inverse normalizing numerical features"
                )
            numerical_imputed = data_imputed[numerical_columns]
            if scaler_obj is None:
                raise RuntimeError("Scaler is not initialized")
            numerical_denormalized = DataProcessor.inverse_standardization(
                numerical_imputed, scaler_obj
            )
            postprocessed_parts.append(numerical_denormalized)
        # Extract missing indicator columns
        if verbose:
            print("IMPUTATION-POSTPROCESSING: Extracting missing indicator columns")
        missing_indicator_cols = DataProcessor.extract_missingindicator_columns(
            data_imputed
        )
        postprocessed_parts.append(missing_indicator_cols)

        data_final = pd.concat(postprocessed_parts, axis=1)
        if verbose:
            print("IMPUTATION: Data imputation completed")
        return data_final
    except Exception as e:
        raise ValueError(f"Error in data imputation: {e}")

knn_imputer(data, dummy_cat_columns, ordinal_cat_columns=None, n_neighbors=5, add_indicators=True, **kwargs) staticmethod

Impute missing values using KNNImputer, with special handling for dummy and ordinal features.

Parameters:

Name Type Description Default
data DataFrame

Preprocessed DataFrame.

required
dummy_cat_columns list

List of dummy categorical feature names.

required
ordinal_cat_columns List

List of ordinal categorical feature names.

None
n_neighbors int

Number of neighbors for KNN.

5
add_indicators bool

Whether to add missing indicators.

True
**kwargs

Additional KNNImputer arguments.

{}

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with imputed values and indicators.

Raises:

Type Description
TypeError

On invalid input types.

ValueError

On imputation errors.

Source code in src/synomicsbench/processing/preprocessing.py
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
@staticmethod
def knn_imputer(
    data: pd.DataFrame,
    dummy_cat_columns: list,
    ordinal_cat_columns: Optional[List] = None,
    n_neighbors: int = 5,
    add_indicators: bool = True,
    **kwargs,
) -> pd.DataFrame:
    """
    Impute missing values using KNNImputer, with special handling for dummy and ordinal features.

    Args:
        data (pd.DataFrame): Preprocessed DataFrame.
        dummy_cat_columns (list): List of dummy categorical feature names.
        ordinal_cat_columns (List, optional): List of ordinal categorical feature names.
        n_neighbors (int): Number of neighbors for KNN.
        add_indicators (bool, optional): Whether to add missing indicators.
        **kwargs: Additional KNNImputer arguments.

    Returns:
        pd.DataFrame: DataFrame with imputed values and indicators.

    Raises:
        TypeError: On invalid input types.
        ValueError: On imputation errors.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    if not isinstance(dummy_cat_columns, list):
        raise TypeError("dummy_cat_columns must be a list")
    ordinal_cat_columns = ordinal_cat_columns or []
    if not isinstance(ordinal_cat_columns, list):
        raise TypeError("ordinal_cat_columns must be a list")
    try:
        dummy_cat_features_dict = {}
        nan_dummy_features = []
        # Create a dictionary to hold dummy categorical features
        for dummy_cat_feature in dummy_cat_columns:
            for column in data.columns:
                if column.startswith(dummy_cat_feature):
                    dummy_cat_features_dict[dummy_cat_feature] = (
                        dummy_cat_features_dict.get(dummy_cat_feature, [])
                        + [column]
                    )
        # Check for NaN dummy columns and handle them
        NAN_SUFFIX = "_nan"
        for dummy_cat_feature, dummy_columns in dummy_cat_features_dict.items():
            for column in dummy_columns:
                if column.endswith(NAN_SUFFIX):
                    nan_dummy_features.append(column)
        # Refill np.nan to dummy categorical features
        index_dict = {}
        for key, value in dummy_cat_features_dict.items():
            num_type_cat_feature = len(value) - 1
            index_dict[key] = []
            for row_i in data.index:
                if np.array_equal(
                    data.loc[row_i, dummy_cat_features_dict[key]].values,
                    np.array([0] * num_type_cat_feature + [1]),
                ):
                    data.loc[row_i, dummy_cat_features_dict[key]] = [
                        np.nan
                    ] * num_type_cat_feature + [1]
                    index_dict[key].append(row_i)
        # Remove NaN dummy columns from the DataFrame
        if len(nan_dummy_features) > 0:
            data = data.drop(nan_dummy_features, axis=1)
        # Use KNN imputer to fill missing values
        imputer = KNNImputer(
            n_neighbors=n_neighbors,
            weights="distance",
            add_indicator=add_indicators,
            **kwargs,
        )
        data_imputed = pd.DataFrame(
            imputer.fit_transform(data),
            columns=imputer.get_feature_names_out(),
            index=data.index,
        )
        # Post-process the imputed data
        data_imputed[ordinal_cat_columns] = data_imputed[
            ordinal_cat_columns
        ].round()
        for key, value in index_dict.items():
            features = dummy_cat_features_dict[key][0:-1]
            argmax_index = np.argmax(
                data_imputed.loc[value, features].values, axis=1
            )
            for i in range(len(value)):
                data_imputed.loc[value[i], features] = [0] * len(features)
                data_imputed.loc[value[i], features[argmax_index[i]]] = 1

        # Combine duplicate missing indicators for dummy variables
        if add_indicators:
            indicator_cols = [
                col
                for col in data_imputed.columns
                if col.startswith("missingindicator_")
            ]
            # Map: dummy feature -> list of indicator columns
            indicator_map = {}
            for col in indicator_cols:
                for dummy_cat_feature in dummy_cat_columns:
                    if col.startswith(f"missingindicator_{dummy_cat_feature}"):
                        indicator_map.setdefault(dummy_cat_feature, []).append(col)
            # For each dummy feature, keep only one indicator column (since all are identical)
            for dummy_cat_feature, cols in indicator_map.items():
                if len(cols) > 1:
                    # Keep the first column, drop the rest, and rename to standard name
                    combined_col = f"missingindicator_{dummy_cat_feature}"
                    data_imputed[combined_col] = data_imputed[cols[0]]
                    data_imputed.drop(columns=cols, inplace=True)
                elif len(cols) == 1:
                    # Rename to standard name if needed
                    col = cols[0]
                    combined_col = f"missingindicator_{dummy_cat_feature}"
                    if col != combined_col:
                        data_imputed.rename(
                            columns={col: combined_col}, inplace=True
                        )
        return data_imputed
    except Exception as e:
        raise ValueError(f"Error in KNN imputation: {e}")

mice_imputation(data, *, random_state=42, iterations=20, n_estimators=300, add_indicators=True, verbose=True, **mice_kwargs) staticmethod

Impute with miceforest and optionally append missing indicators. Additional miceforest arguments can be passed via **mice_kwargs (e.g., variable_schema).

Source code in src/synomicsbench/processing/preprocessing.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
@staticmethod
def mice_imputation(
    data: pd.DataFrame,
    *,
    random_state: int = 42,
    iterations: int = 20,
    n_estimators: int = 300,
    add_indicators: bool = True,
    verbose: bool = True,
    **mice_kwargs: Any,
) -> pd.DataFrame:
    """
    Impute with miceforest and optionally append missing indicators.
    Additional miceforest arguments can be passed via **mice_kwargs (e.g., variable_schema).
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")

    if mf is None:
        raise ImportError(
            "miceforest is required for MICE imputation. Install it or use imputer='knn'."
        )

    out = data.copy()

    # Cast object to category so miceforest can treat them as categorical if present
    for c in out.select_dtypes(include=["object"]).columns:
        out[c] = out[c].astype("category")

    kernel = mf.ImputationKernel(
        data=out,
        num_datasets=1,  # miceforest uses `datasets`
        random_state=random_state,
        **mice_kwargs,
    )

    kernel.mice(iterations=iterations, n_estimators=n_estimators, verbose=verbose)

    df_imputed = kernel.complete_data(dataset=0)

    if add_indicators:
        df_ind = DataProcessor._add_indicators(
            data
        )  # build indicators from original
        df_imputed = pd.concat([df_imputed, df_ind], axis=1)

    return df_imputed

remove_duplications(data, axis) staticmethod

Remove duplicate rows or columns from the DataFrame.

Parameters:

Name Type Description Default
data DataFrame

Input DataFrame.

required
axis int

0 to remove duplicate rows, 1 to remove duplicate columns.

required

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with duplicates removed.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

ValueError

If axis is not 0 or 1.

RuntimeError

On error during duplicate removal.

Source code in src/synomicsbench/processing/preprocessing.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
@staticmethod
def remove_duplications(data: pd.DataFrame, axis: int) -> pd.DataFrame:
    """
    Remove duplicate rows or columns from the DataFrame.

    Args:
        data (pd.DataFrame): Input DataFrame.
        axis (int): 0 to remove duplicate rows, 1 to remove duplicate columns.

    Returns:
        pd.DataFrame: DataFrame with duplicates removed.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        ValueError: If axis is not 0 or 1.
        RuntimeError: On error during duplicate removal.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    if axis not in [0, 1]:
        raise ValueError("Axis must be 0 (rows) or 1 (columns)")
    try:
        if axis == 0:
            return data.drop_duplicates()
        elif axis == 1:
            return data.loc[:, ~data.columns.duplicated()]
    except Exception as e:
        raise RuntimeError(f"Error removing duplicates: {e}")

remove_low_expression_genes(data, gene_id_column='gene_id', variance_threshold=0.0005) staticmethod

Remove low-expressed genes from transcriptomics expression data.

This utility supports two common transcriptomics orientations: - Genes as rows and samples as columns (optionally with a gene_id column). - Samples as rows and genes as columns (pipeline convention in synomicsbench).

A gene is removed if either: - Total expression equals 0 across all samples, OR - Variance is <= variance_threshold across all samples.

Parameters:

Name Type Description Default
data DataFrame

Expression DataFrame.

required
gene_id_column str

Column containing gene IDs when genes are rows.

'gene_id'
variance_threshold float

Near-zero variance threshold.

0.0005

Returns:

Type Description
DataFrame

pd.DataFrame: Filtered DataFrame in the same orientation as the input.

Raises:

Type Description
TypeError

If data is not a pandas DataFrame.

ValueError

If variance_threshold is negative or data contains non-numeric columns.

Source code in src/synomicsbench/processing/preprocessing.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
@staticmethod
def remove_low_expression_genes(
    data: pd.DataFrame,
    gene_id_column: str = "gene_id",
    variance_threshold: float = 0.0005,
) -> pd.DataFrame:
    """
    Remove low-expressed genes from transcriptomics expression data.

    This utility supports two common transcriptomics orientations:
    - Genes as rows and samples as columns (optionally with a `gene_id` column).
    - Samples as rows and genes as columns (pipeline convention in synomicsbench).

    A gene is removed if either:
    - Total expression equals 0 across all samples, OR
    - Variance is <= `variance_threshold` across all samples.

    Args:
        data (pd.DataFrame): Expression DataFrame.
        gene_id_column (str): Column containing gene IDs when genes are rows.
        variance_threshold (float): Near-zero variance threshold.

    Returns:
        pd.DataFrame: Filtered DataFrame in the same orientation as the input.

    Raises:
        TypeError: If data is not a pandas DataFrame.
        ValueError: If variance_threshold is negative or data contains non-numeric columns.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    if variance_threshold < 0:
        raise ValueError("variance_threshold must be >= 0")

    # Orientation: genes are rows (optionally with a gene_id column).
    if gene_id_column in data.columns:
        expression_matrix = data.set_index(gene_id_column)
        numeric = expression_matrix.apply(pd.to_numeric, errors="coerce")
        non_numeric_cols = numeric.columns[
            numeric.notna().sum(axis=0) == 0
        ].tolist()
        if non_numeric_cols:
            raise ValueError(
                "Non-numeric sample columns found in expression matrix: "
                f"{non_numeric_cols}"
            )

        sum_expression = numeric.sum(axis=1)
        zero_genes = sum_expression[sum_expression == 0].index.tolist()

        variances_expression = numeric.var(axis=1)
        near_zero_var = variances_expression[
            variances_expression <= variance_threshold
        ].index.tolist()

        remove_genes = set(zero_genes) | set(near_zero_var)
        filtered = expression_matrix.drop(index=list(remove_genes), errors="ignore")
        return filtered.reset_index()

    # Orientation: samples are rows, genes are columns.
    numeric = data.apply(pd.to_numeric, errors="coerce")
    non_numeric_cols = numeric.columns[numeric.notna().sum(axis=0) == 0].tolist()
    if non_numeric_cols:
        raise ValueError(
            "Non-numeric gene columns found in expression matrix: "
            f"{non_numeric_cols}"
        )

    sum_expression = numeric.sum(axis=0)
    zero_genes = sum_expression[sum_expression == 0].index.tolist()

    variances_expression = numeric.var(axis=0)
    near_zero_var = variances_expression[
        variances_expression <= variance_threshold
    ].index.tolist()

    remove_genes = set(zero_genes) | set(near_zero_var)
    return data.drop(columns=list(remove_genes), errors="ignore")

remove_overmissing_entities(data, threshold) staticmethod

Remove rows (entities) with missing value percentage above a threshold.

Parameters:

Name Type Description Default
data DataFrame

Input DataFrame.

required
threshold float

Percentage threshold (0-100).

required

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with over-missing entities removed.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

ValueError

If threshold is not between 0 and 100.

Source code in src/synomicsbench/processing/preprocessing.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@staticmethod
def remove_overmissing_entities(
    data: pd.DataFrame, threshold: float
) -> pd.DataFrame:
    """
    Remove rows (entities) with missing value percentage above a threshold.

    Args:
        data (pd.DataFrame): Input DataFrame.
        threshold (float): Percentage threshold (0-100).

    Returns:
        pd.DataFrame: DataFrame with over-missing entities removed.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        ValueError: If threshold is not between 0 and 100.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    if not 0 <= threshold <= 100:
        raise ValueError("Threshold must be between 0 and 100")
    try:
        # Calculate the percentage of missing values per row
        missing_percentage = data.isnull().mean(axis=1) * 100
        # Keep rows where missing percentage is <= threshold
        keep_rows = missing_percentage <= threshold
        return data[keep_rows].copy()
    except Exception as e:
        raise ValueError(f"Error removing over-missing samples: {e}")

remove_overmissing_features(data, threshold) staticmethod

Remove columns (features) with missing value percentage above a threshold.

Parameters:

Name Type Description Default
data DataFrame

Input DataFrame.

required
threshold float

Percentage threshold (0-100).

required

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with over-missing features removed.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

ValueError

If threshold is not between 0 and 100.

Source code in src/synomicsbench/processing/preprocessing.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
@staticmethod
def remove_overmissing_features(
    data: pd.DataFrame, threshold: float
) -> pd.DataFrame:
    """
    Remove columns (features) with missing value percentage above a threshold.

    Args:
        data (pd.DataFrame): Input DataFrame.
        threshold (float): Percentage threshold (0-100).

    Returns:
        pd.DataFrame: DataFrame with over-missing features removed.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        ValueError: If threshold is not between 0 and 100.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    if not 0 <= threshold <= 100:
        raise ValueError("Threshold must be between 0 and 100")
    try:
        miss_df = (
            (data.isnull().sum() / len(data) * 100)
            .reset_index()
            .rename(columns={0: "PercentMissing", "index": "ColumnName"})
        )
        drop_cols = miss_df[miss_df["PercentMissing"] > threshold][
            "ColumnName"
        ].tolist()
        return data.drop(drop_cols, axis=1)
    except Exception as e:
        raise ValueError(f"Error removing over-missing columns: {e}")

remove_unknown_entities(data, id_column) staticmethod

Remove rows where the identifier column contains missing values.

Parameters:

Name Type Description Default
data DataFrame

Input DataFrame.

required
id_column str

Name of the identifier column.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Filtered DataFrame.

Raises:

Type Description
TypeError

If input is not a pandas DataFrame.

KeyError

If id_column is not in DataFrame.

Source code in src/synomicsbench/processing/preprocessing.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
@staticmethod
def remove_unknown_entities(data: pd.DataFrame, id_column: str) -> pd.DataFrame:
    """
    Remove rows where the identifier column contains missing values.

    Args:
        data (pd.DataFrame): Input DataFrame.
        id_column (str): Name of the identifier column.

    Returns:
        pd.DataFrame: Filtered DataFrame.

    Raises:
        TypeError: If input is not a pandas DataFrame.
        KeyError: If id_column is not in DataFrame.
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Data input must be a pandas DataFrame")
    if id_column not in data.columns:
        raise KeyError(f"ID column '{id_column}' not in DataFrame")
    return data.dropna(subset=[id_column])

standardization(data, scaler) staticmethod

Standardize numerical features using specified scaler.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with numerical features.

required
scaler str

Scaler type ('standard', 'minmax', or 'robust').

required

Returns:

Type Description
Tuple[DataFrame, BaseEstimator]

Tuple[pd.DataFrame, BaseEstimator]: Scaled DataFrame and scaler object.

Raises:

Type Description
ValueError

If scaler type is invalid.

TypeError

If input is not a pandas DataFrame.

RuntimeError

If scaling fails.

Source code in src/synomicsbench/processing/preprocessing.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
@staticmethod
def standardization(
    data: pd.DataFrame, scaler: str
) -> Tuple[pd.DataFrame, BaseEstimator]:
    """
    Standardize numerical features using specified scaler.

    Args:
        data (pd.DataFrame): DataFrame with numerical features.
        scaler (str): Scaler type ('standard', 'minmax', or 'robust').

    Returns:
        Tuple[pd.DataFrame, BaseEstimator]: Scaled DataFrame and scaler object.

    Raises:
        ValueError: If scaler type is invalid.
        TypeError: If input is not a pandas DataFrame.
        RuntimeError: If scaling fails.
    """
    scalers = {
        "standard": StandardScaler(),
        "minmax": MinMaxScaler(),
        "robust": RobustScaler(),
    }
    if scaler not in scalers:
        raise ValueError("Scaler must be 'standard', 'minmax', or 'robust'")
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    try:
        scaler_obj = scalers[scaler]
        scaled_data = pd.DataFrame(
            scaler_obj.fit_transform(data), columns=data.columns, index=data.index
        )
        return scaled_data, scaler_obj
    except Exception as e:
        raise RuntimeError(f"Error in standardization: {e}")

postprocessing

anonymize_ids(ids, synthetic_data, output_path, mapping_file='anonymized_ids.json')

Anonymize a list of IDs using random UUIDs and save the mapping.

Parameters:

Name Type Description Default
ids list or Series

List of original IDs to anonymize.

required
synthetic_data DataFrame

synthetic data

required
output_path str

Directory to save the ID mapping.

required
mapping_file str

Filename for the mapping JSON. Defaults to 'anonymized_ids.json'.

'anonymized_ids.json'

Returns:

Type Description
DataFrame

pd.DataFrame: synthetic data with anonymized IDs

Source code in src/synomicsbench/processing/postprocessing.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def anonymize_ids(ids: list, 
                  synthetic_data: pd.DataFrame, 
                  output_path: str, 
                  mapping_file="anonymized_ids.json") -> pd.DataFrame:
    """
    Anonymize a list of IDs using random UUIDs and save the mapping.

    Args:
        ids (list or pd.Series): List of original IDs to anonymize.
        synthetic_data (pd.DataFrame): synthetic data
        output_path (str): Directory to save the ID mapping.
        mapping_file (str): Filename for the mapping JSON. Defaults to 'anonymized_ids.json'.

    Returns:
        pd.DataFrame: synthetic data with anonymized IDs
    """
    mapping_path = os.path.join(output_path, mapping_file)
    id_mapping = {}

    # Load existing mapping if it exists
    if os.path.exists(mapping_path):
        with open(mapping_path, 'r') as f:
            id_mapping = json.load(f)

    anonymized_ids = []
    for id_value in ids:
        id_str = str(id_value)
        if id_str not in id_mapping:
            # Generate random UUID
            id_mapping[id_str] = str(uuid.uuid4())
        anonymized_ids.append(id_mapping[id_str])

    # Save updated mapping
    with open(mapping_path, 'w') as f:
        json.dump(id_mapping, f, indent=4)

    synthetic_data_anonymized = synthetic_data.copy()
    synthetic_data_anonymized.insert(0, 'Patient_ID', anonymized_ids)
    return synthetic_data_anonymized

apply_min_max(synthetic_data, numerical_columns, min_values, max_values)

Apply minimum and maximum constraints to numerical columns in synthetic data. Args: synthetic_data (pd.DataFrame): Synthetic data DataFrame. numerical_columns (list): List of numerical column names. min_values: Dictionary of min value of the corresponding column max_values: Dictionary of max value of the corresponding column

Source code in src/synomicsbench/processing/postprocessing.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def apply_min_max(synthetic_data: pd.DataFrame, numerical_columns: list, min_values: dict, max_values: dict) -> pd.DataFrame:
    """
    Apply minimum and maximum constraints to numerical columns in synthetic data.
    Args:
        synthetic_data (pd.DataFrame): Synthetic data DataFrame.
        numerical_columns (list): List of numerical column names.
        min_values: Dictionary of min value of the corresponding column
        max_values:  Dictionary of max value of the corresponding column
    """
    for column in numerical_columns:
        if column in synthetic_data.columns:
            synthetic_data[column] = synthetic_data[column].clip(lower=min_values[column], upper=max_values[column])
        else:
            continue
    return synthetic_data

apply_rounding(synthetic_data, numerical_columns, rounding_digits)

Round numerical columns to the sotred numer of decimal places Args: synthetic_data (pd.DataFrame): Synthetic data DataFrame. numerical_columns (list): List of numerical column names rounding_digits (dict): Dictionary containing the number of rounding digits for each numerical column.

Source code in src/synomicsbench/processing/postprocessing.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def apply_rounding(synthetic_data: pd.DataFrame, numerical_columns: list, rounding_digits : dict):
    """
    Round numerical columns to the sotred numer of decimal places
    Args:
        synthetic_data (pd.DataFrame): Synthetic data DataFrame.
        numerical_columns (list): List of numerical column names
        rounding_digits (dict): Dictionary containing the number of rounding digits for each numerical column.
    """
    for column in numerical_columns:
        if column in rounding_digits:
            synthetic_data[column] = synthetic_data[column].round(rounding_digits[column])
        else:
            continue
    return synthetic_data

load_metadata(metadata_path)

Load metadata from a JSON file.

Parameters:

Name Type Description Default
metadata_path str

Path to the metadata JSON file.

required

Returns:

Name Type Description
dict dict

Loaded metadata.

Raises:

Type Description
FileNotFoundError

If the metadata file does not exist.

JSONDecodeError

If the metadata file is not a valid JSON.

Source code in src/synomicsbench/processing/postprocessing.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def load_metadata(metadata_path: str) -> dict:
    """
    Load metadata from a JSON file.

    Args:
        metadata_path (str): Path to the metadata JSON file.

    Returns:
        dict: Loaded metadata.

    Raises:
        FileNotFoundError: If the metadata file does not exist.
        json.JSONDecodeError: If the metadata file is not a valid JSON.
    """
    try:    
        with open(metadata_path, 'r') as file:
            metadata = json.load(file)
        return metadata
    except FileNotFoundError as e:
        raise FileNotFoundError(f"Metadata file not found: {e}. Please run classify_feature_type function in Preprocess module")
    except json.JSONDecodeError as e:
        raise json.JSONDecodeError(f"Error decoding JSON from metadata file: {e}")

MetaData

Source code in src/synomicsbench/processing/metadata.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class MetaData:
    def __init__(self):
        pass

    @staticmethod
    def classify_features_types(
        data: pd.DataFrame,
        threshold_unique_values: int,
        ordinal_features: Optional[List] = None,
        binary_values: set = {0, 1}
    ) -> (List[str], List[str]):
        """
        Classify columns into dummy categorical and numerical features.

        Args:
            data (pd.DataFrame): Input dataframe.
            threshold_unique_values (int): Threshold for unique values to consider as categorical.
            ordinal_features (Optional[List]): List of columns to treat as ordinal (will be excluded from both outputs).
            binary_values (set): Set of values to consider as binary for dummy categorical.

        Returns:
            tuple: (dummy_categorical_columns, numerical_columns)
        """
        if not isinstance(data, pd.DataFrame):
            raise TypeError("Input must be a pandas DataFrame")
        try:
            if ordinal_features is None:
                ordinal_features = []
            # Identify dummy categorical features:
            dummy_cat_columns = []
            num_columns = []
            for col in data.columns:
                col_data = data[col].dropna()
                unique_vals = set(col_data.unique())
                if pd.api.types.is_object_dtype(data[col]) or isinstance(data[col].dtype, pd.CategoricalDtype):
                    dummy_cat_columns.append(col)
                elif unique_vals <= binary_values and len(unique_vals) > 0:
                    dummy_cat_columns.append(col)
                elif len(unique_vals) <= threshold_unique_values:
                    dummy_cat_columns.append(col)
                elif pd.api.types.is_numeric_dtype(data[col]):
                    # Prioritize dtype: all numeric columns are NUMERICAL, even with low unique values
                    num_columns.append(col)
            # Remove ordinal features from outputs if provided
            dummy_cat_columns = list(set(dummy_cat_columns) - set(ordinal_features))
            num_columns = list(set(num_columns) - set(ordinal_features))
            return dummy_cat_columns, num_columns
        except Exception as e:
            raise ValueError(f"Error classifying feature types: {e}")

    @staticmethod
    def get_metadata(
        data: pd.DataFrame,
        threshold_unique_values: int,
        id_columns: Optional[List] = None,
        ordinal_features: Optional[List] = None,
        transcriptomic_cols: Optional[List] = None,
        binary_values: set = {0, 1},
    ) -> dict[str, str]:
        """
        Extract column type metadata from a DataFrame and assign columns to corresponding types.

        Args:
            data (pd.DataFrame): DataFrame with features.
            threshold_unique_values (int): Threshold for unique values to consider as categorical.
            id_columns (Optional[List]): List of columns to ignore (e.g., sample IDs).
            ordinal_features (Optional[List]): List of columns to treat as ordinal.
            binary_values (set): Set of values to consider as binary.

        Returns:
            dict: Dictionary mapping each column to its feature type:
                'numerical', 'ordinal_categorical', 'dummy_categorical', 'missing_categorical', or 'unclassified'.

        Raises:
            ValueError: If classification fails.
        """
        if id_columns is not None:
            data = data.drop(id_columns, axis=1)
        features = data.columns.to_list()
        metadata_dict = {}
        dummy_cols, numerical_cols = MetaData.classify_features_types(
            data=data,
            threshold_unique_values=threshold_unique_values,
            ordinal_features=ordinal_features,
            binary_values=binary_values
        )
        if transcriptomic_cols is None:
            transcriptomic_cols = []

        for col in features:
            if col in transcriptomic_cols:
                metadata_dict[col] = "numerical"
                continue

            if ordinal_features and col in ordinal_features:
                metadata_dict[col] = "ordinal_categorical"
            elif col in dummy_cols:
                metadata_dict[col] = "dummy_categorical"
            elif col in numerical_cols:
                metadata_dict[col] = "numerical"
            elif col.startswith("missingindicator_"):
                metadata_dict[col] = "missing_categorical"
            else:
                metadata_dict[col] = "unclassified"
        if "unclassified" in metadata_dict.values():
            count_unclassified = list(metadata_dict.values()).count("unclassified")
            warnings.warn(
                f"There are {count_unclassified} unclassified features in the metadata. "
                "Consider reviewing your metadata or feature typing logic.",
                UserWarning
            )

        return metadata_dict

    @staticmethod
    def grouping_features_astype(data: pd.DataFrame, metadata: dict) -> dict:
        """
        Extract column type metadata from a JSON file and assign columns to corresponding types.

        Args:
            data (pd.DataFrame): DataFrame used to filter available columns.
            metadata (str): metadata dictionary.


        Returns:
            dict: Dictionary with keys 'numerical', 'ordinal_categorical', 'dummy_categorical', and
                'missing_categorical', mapping to lists of column names.

        Raises:
            FileNotFoundError: If the metadata file is not found.
            json.JSONDecodeError: If the metadata file is not a valid JSON.
        """

        # try:
        #     with open(metadata_path, "r") as f:
        #         metadata = json.load(f)
        # except FileNotFoundError as e:
        #     raise FileNotFoundError(f"Metadata file not found: {metadata_path}") from e
        # except json.JSONDecodeError as e:
        #     raise json.JSONDecodeError(
        #         f"Invalid JSON in metadata file: {metadata_path}", doc=e.doc, pos=e.pos
        #     )

        dummy_cat_cols = []
        ordinal_cat_cols = []
        num_cols = []
        missing_indicator_cols = []
        for k, v in metadata.items():
            if k in data.columns:
                if v == "numerical":
                    num_cols.append(k)
                elif v == "ordinal_categorical":
                    ordinal_cat_cols.append(k)
                elif v == "missing_categorical":
                    missing_indicator_cols.append(k)
                else:
                    dummy_cat_cols.append(k)
        return {
            "numerical": num_cols,
            "ordinal_categorical": ordinal_cat_cols,
            "dummy_categorical": dummy_cat_cols,
            "missing_categorical": missing_indicator_cols,
        }

    @staticmethod
    def metadata_as_SDV(data: pd.DataFrame, metadata: dict)->dict:
        """
        Extract column type metadata from a JSON file and assign columns to corresponding types follow SDV library format.

        Args:
            data (pd.DataFrame): DataFrame used to filter available columns.
            metadata (str): metadata dictionary.


        Returns:
            dict: Dictionary followed SDV format.

        Raises:
            FileNotFoundError: If the metadata file is not found.
            json.JSONDecodeError: If the metadata file is not a valid JSON.
        """
        metadata = MetaData.grouping_features_astype(data, metadata)
        # Build SDMetrics metadata dict
        metadata_sdmetrics = {'columns': {}}
        for col in metadata.get("numerical"):
            metadata_sdmetrics['columns'][col] = {'sdtype': 'numerical'}
        for col in metadata.get("ordinal_categorical"):
            metadata_sdmetrics['columns'][col] = {'sdtype': 'categorical'}  # or 'ordinal' if supported
        for col in  metadata.get("dummy_categorical"):
            metadata_sdmetrics['columns'][col] = {'sdtype': 'categorical'}
        for col in metadata.get("missing_categorical"):
            metadata_sdmetrics['columns'][col] = {'sdtype': 'categorical'}
        return metadata_sdmetrics

    @staticmethod
    def get_column_indices(data: pd.DataFrame, column_list: list) -> np.array:
        """
        Get indices of columns in a given list in a given dataframe.

        Args:
            data (pd.DataFrame): The DataFrame containing columns extracted indices.
            column_list (list): List of columns.

        Returns:
            np.array: Array containing the column indices.

        Raises:
            ValueError: If a specified column does not exist in the DataFrame.
        """
        column_df = data.columns.tolist()
        missing = [column for column in column_list if column not in column_df]
        if missing:
            raise ValueError(f"The following columns from column_list are not found in the DataFrame: {missing}")
        col_indices = []
        for col in column_list:
            col_indice = column_df.index(col)
            col_indices.append(col_indice)
        col_indices_array = np.array(col_indices, dtype=np.int64)
        return col_indices_array

    @staticmethod
    def save(metadata: dict, output_dir: str="", filename: str= "metadata"):
        json_dir = os.path.join(output_dir, f"{filename}.json")
        with open(json_dir, 'w') as file: 
            json.dump(metadata, file, indent=4)

    @staticmethod
    def load(metadata_dir: str) -> dict:
        """
        Load metadata from a JSON file.

        Args:
            metadata_path (str): Path to the metadata JSON file.

        Returns:
            dict: Loaded metadata.

        Raises:
            FileNotFoundError: If the metadata file does not exist.
            json.JSONDecodeError: If the metadata file is not a valid JSON.
        """
        try:    
            with open(metadata_dir, 'r') as file:
                metadata = json.load(file)
            return metadata
        except FileNotFoundError as e:
            raise FileNotFoundError(f"Metadata file not found: {e}. Please run get_metadata and save functions in MetaData module")
        except json.JSONDecodeError as e:
            raise json.JSONDecodeError(f"Error decoding JSON from metadata file: {e}")

classify_features_types(data, threshold_unique_values, ordinal_features=None, binary_values={0, 1}) staticmethod

Classify columns into dummy categorical and numerical features.

Parameters:

Name Type Description Default
data DataFrame

Input dataframe.

required
threshold_unique_values int

Threshold for unique values to consider as categorical.

required
ordinal_features Optional[List]

List of columns to treat as ordinal (will be excluded from both outputs).

None
binary_values set

Set of values to consider as binary for dummy categorical.

{0, 1}

Returns:

Name Type Description
tuple (List[str], List[str])

(dummy_categorical_columns, numerical_columns)

Source code in src/synomicsbench/processing/metadata.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@staticmethod
def classify_features_types(
    data: pd.DataFrame,
    threshold_unique_values: int,
    ordinal_features: Optional[List] = None,
    binary_values: set = {0, 1}
) -> (List[str], List[str]):
    """
    Classify columns into dummy categorical and numerical features.

    Args:
        data (pd.DataFrame): Input dataframe.
        threshold_unique_values (int): Threshold for unique values to consider as categorical.
        ordinal_features (Optional[List]): List of columns to treat as ordinal (will be excluded from both outputs).
        binary_values (set): Set of values to consider as binary for dummy categorical.

    Returns:
        tuple: (dummy_categorical_columns, numerical_columns)
    """
    if not isinstance(data, pd.DataFrame):
        raise TypeError("Input must be a pandas DataFrame")
    try:
        if ordinal_features is None:
            ordinal_features = []
        # Identify dummy categorical features:
        dummy_cat_columns = []
        num_columns = []
        for col in data.columns:
            col_data = data[col].dropna()
            unique_vals = set(col_data.unique())
            if pd.api.types.is_object_dtype(data[col]) or isinstance(data[col].dtype, pd.CategoricalDtype):
                dummy_cat_columns.append(col)
            elif unique_vals <= binary_values and len(unique_vals) > 0:
                dummy_cat_columns.append(col)
            elif len(unique_vals) <= threshold_unique_values:
                dummy_cat_columns.append(col)
            elif pd.api.types.is_numeric_dtype(data[col]):
                # Prioritize dtype: all numeric columns are NUMERICAL, even with low unique values
                num_columns.append(col)
        # Remove ordinal features from outputs if provided
        dummy_cat_columns = list(set(dummy_cat_columns) - set(ordinal_features))
        num_columns = list(set(num_columns) - set(ordinal_features))
        return dummy_cat_columns, num_columns
    except Exception as e:
        raise ValueError(f"Error classifying feature types: {e}")

get_column_indices(data, column_list) staticmethod

Get indices of columns in a given list in a given dataframe.

Parameters:

Name Type Description Default
data DataFrame

The DataFrame containing columns extracted indices.

required
column_list list

List of columns.

required

Returns:

Type Description
array

np.array: Array containing the column indices.

Raises:

Type Description
ValueError

If a specified column does not exist in the DataFrame.

Source code in src/synomicsbench/processing/metadata.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
@staticmethod
def get_column_indices(data: pd.DataFrame, column_list: list) -> np.array:
    """
    Get indices of columns in a given list in a given dataframe.

    Args:
        data (pd.DataFrame): The DataFrame containing columns extracted indices.
        column_list (list): List of columns.

    Returns:
        np.array: Array containing the column indices.

    Raises:
        ValueError: If a specified column does not exist in the DataFrame.
    """
    column_df = data.columns.tolist()
    missing = [column for column in column_list if column not in column_df]
    if missing:
        raise ValueError(f"The following columns from column_list are not found in the DataFrame: {missing}")
    col_indices = []
    for col in column_list:
        col_indice = column_df.index(col)
        col_indices.append(col_indice)
    col_indices_array = np.array(col_indices, dtype=np.int64)
    return col_indices_array

get_metadata(data, threshold_unique_values, id_columns=None, ordinal_features=None, transcriptomic_cols=None, binary_values={0, 1}) staticmethod

Extract column type metadata from a DataFrame and assign columns to corresponding types.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with features.

required
threshold_unique_values int

Threshold for unique values to consider as categorical.

required
id_columns Optional[List]

List of columns to ignore (e.g., sample IDs).

None
ordinal_features Optional[List]

List of columns to treat as ordinal.

None
binary_values set

Set of values to consider as binary.

{0, 1}

Returns:

Name Type Description
dict dict[str, str]

Dictionary mapping each column to its feature type: 'numerical', 'ordinal_categorical', 'dummy_categorical', 'missing_categorical', or 'unclassified'.

Raises:

Type Description
ValueError

If classification fails.

Source code in src/synomicsbench/processing/metadata.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@staticmethod
def get_metadata(
    data: pd.DataFrame,
    threshold_unique_values: int,
    id_columns: Optional[List] = None,
    ordinal_features: Optional[List] = None,
    transcriptomic_cols: Optional[List] = None,
    binary_values: set = {0, 1},
) -> dict[str, str]:
    """
    Extract column type metadata from a DataFrame and assign columns to corresponding types.

    Args:
        data (pd.DataFrame): DataFrame with features.
        threshold_unique_values (int): Threshold for unique values to consider as categorical.
        id_columns (Optional[List]): List of columns to ignore (e.g., sample IDs).
        ordinal_features (Optional[List]): List of columns to treat as ordinal.
        binary_values (set): Set of values to consider as binary.

    Returns:
        dict: Dictionary mapping each column to its feature type:
            'numerical', 'ordinal_categorical', 'dummy_categorical', 'missing_categorical', or 'unclassified'.

    Raises:
        ValueError: If classification fails.
    """
    if id_columns is not None:
        data = data.drop(id_columns, axis=1)
    features = data.columns.to_list()
    metadata_dict = {}
    dummy_cols, numerical_cols = MetaData.classify_features_types(
        data=data,
        threshold_unique_values=threshold_unique_values,
        ordinal_features=ordinal_features,
        binary_values=binary_values
    )
    if transcriptomic_cols is None:
        transcriptomic_cols = []

    for col in features:
        if col in transcriptomic_cols:
            metadata_dict[col] = "numerical"
            continue

        if ordinal_features and col in ordinal_features:
            metadata_dict[col] = "ordinal_categorical"
        elif col in dummy_cols:
            metadata_dict[col] = "dummy_categorical"
        elif col in numerical_cols:
            metadata_dict[col] = "numerical"
        elif col.startswith("missingindicator_"):
            metadata_dict[col] = "missing_categorical"
        else:
            metadata_dict[col] = "unclassified"
    if "unclassified" in metadata_dict.values():
        count_unclassified = list(metadata_dict.values()).count("unclassified")
        warnings.warn(
            f"There are {count_unclassified} unclassified features in the metadata. "
            "Consider reviewing your metadata or feature typing logic.",
            UserWarning
        )

    return metadata_dict

grouping_features_astype(data, metadata) staticmethod

Extract column type metadata from a JSON file and assign columns to corresponding types.

Parameters:

Name Type Description Default
data DataFrame

DataFrame used to filter available columns.

required
metadata str

metadata dictionary.

required

Returns:

Name Type Description
dict dict

Dictionary with keys 'numerical', 'ordinal_categorical', 'dummy_categorical', and 'missing_categorical', mapping to lists of column names.

Raises:

Type Description
FileNotFoundError

If the metadata file is not found.

JSONDecodeError

If the metadata file is not a valid JSON.

Source code in src/synomicsbench/processing/metadata.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
@staticmethod
def grouping_features_astype(data: pd.DataFrame, metadata: dict) -> dict:
    """
    Extract column type metadata from a JSON file and assign columns to corresponding types.

    Args:
        data (pd.DataFrame): DataFrame used to filter available columns.
        metadata (str): metadata dictionary.


    Returns:
        dict: Dictionary with keys 'numerical', 'ordinal_categorical', 'dummy_categorical', and
            'missing_categorical', mapping to lists of column names.

    Raises:
        FileNotFoundError: If the metadata file is not found.
        json.JSONDecodeError: If the metadata file is not a valid JSON.
    """

    # try:
    #     with open(metadata_path, "r") as f:
    #         metadata = json.load(f)
    # except FileNotFoundError as e:
    #     raise FileNotFoundError(f"Metadata file not found: {metadata_path}") from e
    # except json.JSONDecodeError as e:
    #     raise json.JSONDecodeError(
    #         f"Invalid JSON in metadata file: {metadata_path}", doc=e.doc, pos=e.pos
    #     )

    dummy_cat_cols = []
    ordinal_cat_cols = []
    num_cols = []
    missing_indicator_cols = []
    for k, v in metadata.items():
        if k in data.columns:
            if v == "numerical":
                num_cols.append(k)
            elif v == "ordinal_categorical":
                ordinal_cat_cols.append(k)
            elif v == "missing_categorical":
                missing_indicator_cols.append(k)
            else:
                dummy_cat_cols.append(k)
    return {
        "numerical": num_cols,
        "ordinal_categorical": ordinal_cat_cols,
        "dummy_categorical": dummy_cat_cols,
        "missing_categorical": missing_indicator_cols,
    }

load(metadata_dir) staticmethod

Load metadata from a JSON file.

Parameters:

Name Type Description Default
metadata_path str

Path to the metadata JSON file.

required

Returns:

Name Type Description
dict dict

Loaded metadata.

Raises:

Type Description
FileNotFoundError

If the metadata file does not exist.

JSONDecodeError

If the metadata file is not a valid JSON.

Source code in src/synomicsbench/processing/metadata.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
@staticmethod
def load(metadata_dir: str) -> dict:
    """
    Load metadata from a JSON file.

    Args:
        metadata_path (str): Path to the metadata JSON file.

    Returns:
        dict: Loaded metadata.

    Raises:
        FileNotFoundError: If the metadata file does not exist.
        json.JSONDecodeError: If the metadata file is not a valid JSON.
    """
    try:    
        with open(metadata_dir, 'r') as file:
            metadata = json.load(file)
        return metadata
    except FileNotFoundError as e:
        raise FileNotFoundError(f"Metadata file not found: {e}. Please run get_metadata and save functions in MetaData module")
    except json.JSONDecodeError as e:
        raise json.JSONDecodeError(f"Error decoding JSON from metadata file: {e}")

metadata_as_SDV(data, metadata) staticmethod

Extract column type metadata from a JSON file and assign columns to corresponding types follow SDV library format.

Parameters:

Name Type Description Default
data DataFrame

DataFrame used to filter available columns.

required
metadata str

metadata dictionary.

required

Returns:

Name Type Description
dict dict

Dictionary followed SDV format.

Raises:

Type Description
FileNotFoundError

If the metadata file is not found.

JSONDecodeError

If the metadata file is not a valid JSON.

Source code in src/synomicsbench/processing/metadata.py
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
@staticmethod
def metadata_as_SDV(data: pd.DataFrame, metadata: dict)->dict:
    """
    Extract column type metadata from a JSON file and assign columns to corresponding types follow SDV library format.

    Args:
        data (pd.DataFrame): DataFrame used to filter available columns.
        metadata (str): metadata dictionary.


    Returns:
        dict: Dictionary followed SDV format.

    Raises:
        FileNotFoundError: If the metadata file is not found.
        json.JSONDecodeError: If the metadata file is not a valid JSON.
    """
    metadata = MetaData.grouping_features_astype(data, metadata)
    # Build SDMetrics metadata dict
    metadata_sdmetrics = {'columns': {}}
    for col in metadata.get("numerical"):
        metadata_sdmetrics['columns'][col] = {'sdtype': 'numerical'}
    for col in metadata.get("ordinal_categorical"):
        metadata_sdmetrics['columns'][col] = {'sdtype': 'categorical'}  # or 'ordinal' if supported
    for col in  metadata.get("dummy_categorical"):
        metadata_sdmetrics['columns'][col] = {'sdtype': 'categorical'}
    for col in metadata.get("missing_categorical"):
        metadata_sdmetrics['columns'][col] = {'sdtype': 'categorical'}
    return metadata_sdmetrics

GeneQuery

A class to query gene information using MyGeneInfo and convert Ensembl IDs to HUGO symbols.

This class facilitates querying gene data, checking for duplicates, and mapping Ensembl gene IDs to HUGO symbols. Results are saved as JSON/CSV files, and operations are logged for debugging.

Attributes:

Name Type Description
fields List[str]

Fields to retrieve from MyGeneInfo (e.g., ["symbol"]).

scopes List[str]

Scopes for gene queries (e.g., ["ensembl.gene"]).

species List[str]

Species to query (e.g., ["human"]).

output_dir str

Directory to save results and logs.

logger Logger

Logger instance for debugging and error tracking.

Source code in src/synomicsbench/processing/gene_query.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
class GeneQuery:
    """A class to query gene information using MyGeneInfo and convert Ensembl IDs to HUGO symbols.

    This class facilitates querying gene data, checking for duplicates, and mapping Ensembl gene IDs
    to HUGO symbols. Results are saved as JSON/CSV files, and operations are logged for debugging.

    Attributes:
        fields (List[str]): Fields to retrieve from MyGeneInfo (e.g., ["symbol"]).
        scopes (List[str]): Scopes for gene queries (e.g., ["ensembl.gene"]).
        species (List[str]): Species to query (e.g., ["human"]).
        output_dir (str): Directory to save results and logs.
        logger (logging.Logger): Logger instance for debugging and error tracking.
    """

    def __init__(self, fields: List[str], scopes: List[str], species: List[str], output_dir: str):
        """Initialize GeneQuery with query parameters and logging setup.

        Args:
            fields: Fields to retrieve from MyGeneInfo queries.
            scopes: Scopes defining the gene identifier types.
            species: Species to include in the query.
            output_dir: Directory to store output files and logs.

        Raises:
            RuntimeError: If the log file cannot be created due to permissions or path issues.
        """
        self.fields = fields
        self.scopes = scopes
        self.species = species
        self.output_dir = output_dir
        try:
            os.makedirs(output_dir, exist_ok=True)
        except OSError as e:
            raise RuntimeError(f"Failed to create output directory {output_dir}: {e}")

        # Set up logger
        logger_name = f"{self.__class__.__name__}_{id(self)}"
        log_file_name = f"{self.__class__.__name__}_{id(self)}.log"
        self.logger = set_logger(logger_name, self.output_dir, log_file_name)

    def gene_query(self, genes_list: List[str], **kwargs: Any) -> List[Dict[str, Any]]:
        """Query gene information using MyGeneInfo.

        Args:
            genes_list: List of gene identifiers to query.
            **kwargs: Additional arguments for MyGeneInfo.querymany.

        Returns:
            List of dictionaries containing gene mapping information.

        Raises:
            ValueError: If the query fails due to network issues or invalid parameters.
        """
        try:
            mg = mygene.MyGeneInfo()
            return mg.querymany(
                genes_list,
                scopes=self.scopes,
                fields=self.fields, 
                species=self.species,
                **kwargs
            )
        except Exception as e:
            self.logger.error(f"Gene query failed: {e}", exc_info=True)
            raise ValueError(f"Gene query failed: {e}")

    def check_duplicates(self, data: pd.DataFrame, **kwargs: Any) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
        """Identify and group duplicate genes based on expression values.

        Args:
            data: DataFrame with samples as rows and genes as columns.
            **kwargs: Additional arguments for gene_query.

        Returns:
            Tuple of two dictionaries:
            - Mapping of unique gene IDs to lists of duplicate gene IDs.
            - Mapping of unique gene IDs to their HUGO symbols.

        Raises:
            ValueError: If duplicate checking fails due to invalid data or query errors.
        """
        try:
            self.logger.info("Grouping genes with identical expression values...")
            # Transpose to check duplicates across genes
            genes_duplicated = data.T[data.T.duplicated(keep=False)].T
            if genes_duplicated.empty:
                self.logger.info("No duplicate genes found")
                return {}, {}

            # Map each gene to its expression values
            gene_values = {col: vals.tolist() for col, vals in genes_duplicated.items()}
            duplicated_genes_dict = {col: [] for col in genes_duplicated.columns}
            mapped_genes_dict = duplicated_genes_dict.copy()

            # Group duplicates by expression values
            seen_values = set()
            for column in tqdm(genes_duplicated.columns, desc="Processing duplicates"):
                values = tuple(gene_values[column])  # Use tuple for hashability
                if values in seen_values:
                    continue
                seen_values.add(values)
                duplicates = [col for col, vals in gene_values.items() if vals == gene_values[column]]
                for dup in duplicates:
                    duplicated_genes_dict[dup] = duplicates

                # Query HUGO symbols for duplicates
                query_ids = [col.split(".")[0] for col in duplicates]
                mapping_info = self.gene_query(query_ids, **kwargs)
                mapping_df = pd.DataFrame(mapping_info)
                symbols = mapping_df["symbol"].tolist() if "symbol" in mapping_df else ["not_found"] * len(query_ids)
                for dup in duplicates:
                    mapped_genes_dict[dup] = symbols

            # Save results
            self._save_json(duplicated_genes_dict, "grouped_dup_genes.json")
            self._save_json(mapped_genes_dict, "grouped_mapped_dup_genes.json")
            self.logger.info(f"Duplicate gene information saved to {self.output_dir}")
            return duplicated_genes_dict, mapped_genes_dict
        except Exception as e:
            self.logger.error(f"Duplicate check failed: {e}", exc_info=True)
            raise ValueError(f"Duplicate check failed: {e}")

    def convert_genes(self, data: pd.DataFrame, 
                      **kwargs: Any) -> pd.DataFrame:
        """Convert Ensembl gene IDs to HUGO symbols.

        Args:
            data: DataFrame with Ensembl gene IDs as columns.
            **kwargs: Additional arguments for gene_query.

        Returns:
            DataFrame containing gene mapping information.

        Raises:
            ValueError: If conversion fails due to invalid input or query errors.
        """
        try:
            # Extract Ensembl IDs
            genes_list = data.columns.str.split(".").str[0]
            self.logger.info(f"Querying {len(genes_list)} genes...")

            # Query gene information
            gene_info = self.gene_query(genes_list, **kwargs)
            gene_info_df = pd.DataFrame(gene_info)
            if "symbol" not in gene_info_df:
                self.logger.error("No 'symbol' column in query results")
                raise ValueError("No 'symbol' column in query results")

            # Log unfound and unmapped genes
            if "notfound" in gene_info_df.columns:
                unfound_genes = gene_info_df[gene_info_df["notfound"]]["query"].tolist()
                unfound_names = ", ".join(unfound_genes) if unfound_genes else "None"
                self.logger.debug(f"Unfound genes: {len(unfound_genes)} ({unfound_names})")

                unmapped_genes = gene_info_df[gene_info_df[["symbol", "notfound"]].isna().all(axis=1)]["query"].tolist()
                unmapped_names = ", ".join(unmapped_genes) if unmapped_genes else "None"
                self.logger.debug(f"Unmapped genes: {len(unmapped_genes)} ({unmapped_names})")
            else:
                self.logger.debug(f"All {len(genes_list)} genes have found their HUGO ID")

            # Save mapping dictionary
            ens_id_symbol_map = gene_info_df.groupby("query")["symbol"].apply(list).to_dict()
            self._save_json(ens_id_symbol_map, "mapping_genes.json")
            self.logger.info(f"Gene mapping saved to {self.output_dir}")

            # # Log and save genes with multiple HUGO symbols
            # duplicates_df = gene_info_df[gene_info_df.duplicated(subset=["query"], keep=False)][["query", "symbol"]]
            # if not duplicates_df.empty:
            #     duplicates_path = os.path.join(self.output_dir, "ens_having_mul_hugo.csv")
            #     duplicates_df.to_csv(duplicates_path, index=False)
            #     self.logger.info(f"Genes with multiple HUGO symbols saved to {duplicates_path}")
            #     self.logger.debug(f"Genes with multiple HUGO symbols: {len(duplicates_df)}")
            gene_info_df_path = os.path.join(self.output_dir,"gene_mapping_df.csv")
            gene_info_df.to_csv(gene_info_df_path, index=False)
            return gene_info_df

        except ValueError as e:
            self.logger.error(f"Gene conversion failed: {e}", exc_info=True)
            raise 
        except Exception as e:
            self.logger.error(f"Gene conversion failed: {e}", exc_info=True)
            raise ValueError(f"Gene conversion failed: {e}")

    def _save_json(self, data: Dict, filename: str) -> None:
        """Save a dictionary to a JSON file.

        Args:
            data: Dictionary to save.
            filename: Name of the JSON file.
        """
        path = os.path.join(self.output_dir, filename)
        with open(path, "w") as f:
            json.dump(data, f, indent=4)

__init__(fields, scopes, species, output_dir)

Initialize GeneQuery with query parameters and logging setup.

Parameters:

Name Type Description Default
fields List[str]

Fields to retrieve from MyGeneInfo queries.

required
scopes List[str]

Scopes defining the gene identifier types.

required
species List[str]

Species to include in the query.

required
output_dir str

Directory to store output files and logs.

required

Raises:

Type Description
RuntimeError

If the log file cannot be created due to permissions or path issues.

Source code in src/synomicsbench/processing/gene_query.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(self, fields: List[str], scopes: List[str], species: List[str], output_dir: str):
    """Initialize GeneQuery with query parameters and logging setup.

    Args:
        fields: Fields to retrieve from MyGeneInfo queries.
        scopes: Scopes defining the gene identifier types.
        species: Species to include in the query.
        output_dir: Directory to store output files and logs.

    Raises:
        RuntimeError: If the log file cannot be created due to permissions or path issues.
    """
    self.fields = fields
    self.scopes = scopes
    self.species = species
    self.output_dir = output_dir
    try:
        os.makedirs(output_dir, exist_ok=True)
    except OSError as e:
        raise RuntimeError(f"Failed to create output directory {output_dir}: {e}")

    # Set up logger
    logger_name = f"{self.__class__.__name__}_{id(self)}"
    log_file_name = f"{self.__class__.__name__}_{id(self)}.log"
    self.logger = set_logger(logger_name, self.output_dir, log_file_name)

check_duplicates(data, **kwargs)

Identify and group duplicate genes based on expression values.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with samples as rows and genes as columns.

required
**kwargs Any

Additional arguments for gene_query.

{}

Returns:

Type Description
Dict[str, List[str]]

Tuple of two dictionaries:

Dict[str, List[str]]
  • Mapping of unique gene IDs to lists of duplicate gene IDs.
Tuple[Dict[str, List[str]], Dict[str, List[str]]]
  • Mapping of unique gene IDs to their HUGO symbols.

Raises:

Type Description
ValueError

If duplicate checking fails due to invalid data or query errors.

Source code in src/synomicsbench/processing/gene_query.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def check_duplicates(self, data: pd.DataFrame, **kwargs: Any) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
    """Identify and group duplicate genes based on expression values.

    Args:
        data: DataFrame with samples as rows and genes as columns.
        **kwargs: Additional arguments for gene_query.

    Returns:
        Tuple of two dictionaries:
        - Mapping of unique gene IDs to lists of duplicate gene IDs.
        - Mapping of unique gene IDs to their HUGO symbols.

    Raises:
        ValueError: If duplicate checking fails due to invalid data or query errors.
    """
    try:
        self.logger.info("Grouping genes with identical expression values...")
        # Transpose to check duplicates across genes
        genes_duplicated = data.T[data.T.duplicated(keep=False)].T
        if genes_duplicated.empty:
            self.logger.info("No duplicate genes found")
            return {}, {}

        # Map each gene to its expression values
        gene_values = {col: vals.tolist() for col, vals in genes_duplicated.items()}
        duplicated_genes_dict = {col: [] for col in genes_duplicated.columns}
        mapped_genes_dict = duplicated_genes_dict.copy()

        # Group duplicates by expression values
        seen_values = set()
        for column in tqdm(genes_duplicated.columns, desc="Processing duplicates"):
            values = tuple(gene_values[column])  # Use tuple for hashability
            if values in seen_values:
                continue
            seen_values.add(values)
            duplicates = [col for col, vals in gene_values.items() if vals == gene_values[column]]
            for dup in duplicates:
                duplicated_genes_dict[dup] = duplicates

            # Query HUGO symbols for duplicates
            query_ids = [col.split(".")[0] for col in duplicates]
            mapping_info = self.gene_query(query_ids, **kwargs)
            mapping_df = pd.DataFrame(mapping_info)
            symbols = mapping_df["symbol"].tolist() if "symbol" in mapping_df else ["not_found"] * len(query_ids)
            for dup in duplicates:
                mapped_genes_dict[dup] = symbols

        # Save results
        self._save_json(duplicated_genes_dict, "grouped_dup_genes.json")
        self._save_json(mapped_genes_dict, "grouped_mapped_dup_genes.json")
        self.logger.info(f"Duplicate gene information saved to {self.output_dir}")
        return duplicated_genes_dict, mapped_genes_dict
    except Exception as e:
        self.logger.error(f"Duplicate check failed: {e}", exc_info=True)
        raise ValueError(f"Duplicate check failed: {e}")

convert_genes(data, **kwargs)

Convert Ensembl gene IDs to HUGO symbols.

Parameters:

Name Type Description Default
data DataFrame

DataFrame with Ensembl gene IDs as columns.

required
**kwargs Any

Additional arguments for gene_query.

{}

Returns:

Type Description
DataFrame

DataFrame containing gene mapping information.

Raises:

Type Description
ValueError

If conversion fails due to invalid input or query errors.

Source code in src/synomicsbench/processing/gene_query.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def convert_genes(self, data: pd.DataFrame, 
                  **kwargs: Any) -> pd.DataFrame:
    """Convert Ensembl gene IDs to HUGO symbols.

    Args:
        data: DataFrame with Ensembl gene IDs as columns.
        **kwargs: Additional arguments for gene_query.

    Returns:
        DataFrame containing gene mapping information.

    Raises:
        ValueError: If conversion fails due to invalid input or query errors.
    """
    try:
        # Extract Ensembl IDs
        genes_list = data.columns.str.split(".").str[0]
        self.logger.info(f"Querying {len(genes_list)} genes...")

        # Query gene information
        gene_info = self.gene_query(genes_list, **kwargs)
        gene_info_df = pd.DataFrame(gene_info)
        if "symbol" not in gene_info_df:
            self.logger.error("No 'symbol' column in query results")
            raise ValueError("No 'symbol' column in query results")

        # Log unfound and unmapped genes
        if "notfound" in gene_info_df.columns:
            unfound_genes = gene_info_df[gene_info_df["notfound"]]["query"].tolist()
            unfound_names = ", ".join(unfound_genes) if unfound_genes else "None"
            self.logger.debug(f"Unfound genes: {len(unfound_genes)} ({unfound_names})")

            unmapped_genes = gene_info_df[gene_info_df[["symbol", "notfound"]].isna().all(axis=1)]["query"].tolist()
            unmapped_names = ", ".join(unmapped_genes) if unmapped_genes else "None"
            self.logger.debug(f"Unmapped genes: {len(unmapped_genes)} ({unmapped_names})")
        else:
            self.logger.debug(f"All {len(genes_list)} genes have found their HUGO ID")

        # Save mapping dictionary
        ens_id_symbol_map = gene_info_df.groupby("query")["symbol"].apply(list).to_dict()
        self._save_json(ens_id_symbol_map, "mapping_genes.json")
        self.logger.info(f"Gene mapping saved to {self.output_dir}")

        # # Log and save genes with multiple HUGO symbols
        # duplicates_df = gene_info_df[gene_info_df.duplicated(subset=["query"], keep=False)][["query", "symbol"]]
        # if not duplicates_df.empty:
        #     duplicates_path = os.path.join(self.output_dir, "ens_having_mul_hugo.csv")
        #     duplicates_df.to_csv(duplicates_path, index=False)
        #     self.logger.info(f"Genes with multiple HUGO symbols saved to {duplicates_path}")
        #     self.logger.debug(f"Genes with multiple HUGO symbols: {len(duplicates_df)}")
        gene_info_df_path = os.path.join(self.output_dir,"gene_mapping_df.csv")
        gene_info_df.to_csv(gene_info_df_path, index=False)
        return gene_info_df

    except ValueError as e:
        self.logger.error(f"Gene conversion failed: {e}", exc_info=True)
        raise 
    except Exception as e:
        self.logger.error(f"Gene conversion failed: {e}", exc_info=True)
        raise ValueError(f"Gene conversion failed: {e}")

gene_query(genes_list, **kwargs)

Query gene information using MyGeneInfo.

Parameters:

Name Type Description Default
genes_list List[str]

List of gene identifiers to query.

required
**kwargs Any

Additional arguments for MyGeneInfo.querymany.

{}

Returns:

Type Description
List[Dict[str, Any]]

List of dictionaries containing gene mapping information.

Raises:

Type Description
ValueError

If the query fails due to network issues or invalid parameters.

Source code in src/synomicsbench/processing/gene_query.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def gene_query(self, genes_list: List[str], **kwargs: Any) -> List[Dict[str, Any]]:
    """Query gene information using MyGeneInfo.

    Args:
        genes_list: List of gene identifiers to query.
        **kwargs: Additional arguments for MyGeneInfo.querymany.

    Returns:
        List of dictionaries containing gene mapping information.

    Raises:
        ValueError: If the query fails due to network issues or invalid parameters.
    """
    try:
        mg = mygene.MyGeneInfo()
        return mg.querymany(
            genes_list,
            scopes=self.scopes,
            fields=self.fields, 
            species=self.species,
            **kwargs
        )
    except Exception as e:
        self.logger.error(f"Gene query failed: {e}", exc_info=True)
        raise ValueError(f"Gene query failed: {e}")

Metrics: Fidelity

The metrics.fidelity module provides assessment tools for evaluating the distribution quality of synthetic data against real data.

UnivariateSimilarity

Compute and validate univariate similarity between original and synthetic data, including score computation, logging, result saving, and visualization.

Parameters:

Name Type Description Default
output_dir str

Directory to save outputs and logs.

required
logger_name str

Logger name

'UnivariateSimilarity'

Attributes: output_dir (str): Output directory path. column_shapes (ColumnShapes): SDMetrics ColumnShapes property. logger (logging.Logger): Logger for this class.

Source code in src/synomicsbench/metrics/fidelity/UnivariateSimilarity.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
class UnivariateSimilarity:
    """
    Compute and validate univariate similarity between original and synthetic data,
    including score computation, logging, result saving, and visualization.

    Args:
        output_dir (str): Directory to save outputs and logs.
        logger_name (str): Logger name
    Attributes:
        output_dir (str): Output directory path.
        column_shapes (ColumnShapes): SDMetrics ColumnShapes property.
        logger (logging.Logger): Logger for this class.
    """

    def __init__(self, output_dir: str, logger_name: str = "UnivariateSimilarity"):
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
        self.logger = set_logger(logger_name = logger_name, output_path = self.output_dir)
        self.logger_name = logger_name


    def get_univariate_score(self, original_data: pd.DataFrame,
                            synthetic_data: pd.DataFrame,
                            metadata: dict,
                            save: bool = True) -> float:
        """
        Compute univariate similarity score between original and synthetic data.

        Args:
            original_data (pd.DataFrame): Original data.
            synthetic_data (pd.DataFrame): Synthetic data.
            metadata_path (str): Path to SDMetrics metadata.
            save (bool): Save score dataframe and score distribution.

        Returns:
            float: Overall univariate similarity score.

        Raises:
            ValueError: If score computation fails.
        """
        try:
            self.logger.info("Starting univariate similarity computation.")
            self.column_shapes = ColumnShapes()
            metadata_sdmetrics = MetaData.metadata_as_SDV(data=original_data, metadata=metadata)
            with tqdm(total=len(metadata_sdmetrics['columns'])) as pbar:
                score = self.column_shapes.get_score(
                    original_data,
                    synthetic_data,
                    metadata_sdmetrics,
                    progress_bar=pbar
                )
            self.logger.info(f"Univariate similarity score: {score}")
            # Save details
            if save:
                details_df = self.get_detail_df()
                details_csv_path = os.path.join(self.output_dir, f"Detail_score_{self.logger_name}.csv")
                details_df.to_csv(details_csv_path, index = False)
                self.logger.info(f"Details DataFrame saved to {details_csv_path}")
                # Save visualization
                fig = self.plot_column_score_histogram(details_df, data_name = self.logger_name)
                fig_path = os.path.join(self.output_dir, f"{self.logger_name}.png")
                fig.savefig(fig_path, dpi=300)
                plt.close(fig)
                self.logger.info(f"Histogram figure saved to {fig_path}")
            return score
        except Exception as e:
            self.logger.error(f"Failed to validate the Univariate Similarity: {e}")
            raise ValueError(f"Failed to validate the Univariate Similarity: {e}")

    def get_detail_df(self):
        """
        Get the DataFrame with column-level univariate similarity details.

        Returns:
            pd.DataFrame: Details DataFrame.

        Raises:
            AttributeError: If column_shapes is not initialized.
        """
        if not hasattr(self, 'column_shapes'):
            raise AttributeError("column_shapes is not initialized. Run get_univariate_score first.")
        return self.column_shapes.details

    def summarize(self):
        """
        Summarize the scores by metric type.

        Returns:
            pd.DataFrame: Grouped summary statistics by metric.
        """
        detail_df = self.get_detail_df()
        return detail_df.groupby('Metric')['Score'].describe()

    def get_visualization(self, plotly: bool = False, data_name: str = "", bins: int = 50):
        """
        Get a visualization of the column shape scores.

        Args:
            plotly (bool): Whether to use Plotly for visualization.

        Returns:
            Figure: Matplotlib or Plotly figure.
        """
        details_df = self.get_detail_df()
        if plotly:
            return self.column_shapes.get_visualization()
        else:
            return self.plot_column_score_histogram(details_df, data_name, bins)

    @staticmethod
    def plot_column_score_histogram(details, data_name: str = "", bins: int = 50):
        """
        Plot a decorated histogram of column shape scores for a set of features.

        Args:
            details (pd.DataFrame): DataFrame containing at least a 'Score' column with numerical values.
            data_name (str): Optional label for the data (for title).
            bins (int): Number of bins for the histogram.

        Returns:
            matplotlib.figure.Figure: Figure object for saving or further manipulation.

        Raises:
            KeyError: If the 'Score' column is not present in the DataFrame.
            TypeError: If details is not a pandas DataFrame.
        """
        if not isinstance(details, pd.DataFrame):
            raise TypeError("details must be a pandas DataFrame.")
        if 'Score' not in details.columns:
            raise KeyError("The 'Score' column must be present in the details DataFrame.")

        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(figsize=(10, 6))

        # Plot histogram only (no KDE)
        n, bins_, patches = ax.hist(
            details['Score'].dropna(),
            bins=bins,
            color='#4682B4',
            alpha=0.75,
            edgecolor='white',
            linewidth=1.2,
            label='Features'
        )

        # Add vertical lines for key statistics
        mean_score = details['Score'].mean()
        median_score = details['Score'].median()
        ax.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.2f}')
        ax.axvline(median_score, color='purple', linestyle='-.', linewidth=2, label=f'Median: {median_score:.2f}')

        # Customize ticks and grid
        ax.set_xticks(np.linspace(0, 1, 11))
        ax.tick_params(axis='x', labelsize=12)
        ax.tick_params(axis='y', labelsize=12)
        ax.grid(axis='y', alpha=0.3)

        # Add labels and title with font size and weight
        ax.set_xlabel('Score', fontsize=14, fontweight='bold')
        ax.set_ylabel('Number of Features', fontsize=14, fontweight='bold')
        ax.set_title(f'Distribution of Column Shape Scores {data_name}', fontsize=16, fontweight='bold', pad=20)

        # Add legend
        ax.legend(fontsize=12, frameon=True)

        # Add annotation for total features
        ax.annotate(
            f"Total features: {details['Score'].dropna().shape[0]}",
            xy=(0.99, 0.95), xycoords='axes fraction',
            ha='right', va='top', fontsize=12, color='gray'
        )

        plt.tight_layout()
        return fig

get_detail_df()

Get the DataFrame with column-level univariate similarity details.

Returns:

Type Description

pd.DataFrame: Details DataFrame.

Raises:

Type Description
AttributeError

If column_shapes is not initialized.

Source code in src/synomicsbench/metrics/fidelity/UnivariateSimilarity.py
84
85
86
87
88
89
90
91
92
93
94
95
96
def get_detail_df(self):
    """
    Get the DataFrame with column-level univariate similarity details.

    Returns:
        pd.DataFrame: Details DataFrame.

    Raises:
        AttributeError: If column_shapes is not initialized.
    """
    if not hasattr(self, 'column_shapes'):
        raise AttributeError("column_shapes is not initialized. Run get_univariate_score first.")
    return self.column_shapes.details

get_univariate_score(original_data, synthetic_data, metadata, save=True)

Compute univariate similarity score between original and synthetic data.

Parameters:

Name Type Description Default
original_data DataFrame

Original data.

required
synthetic_data DataFrame

Synthetic data.

required
metadata_path str

Path to SDMetrics metadata.

required
save bool

Save score dataframe and score distribution.

True

Returns:

Name Type Description
float float

Overall univariate similarity score.

Raises:

Type Description
ValueError

If score computation fails.

Source code in src/synomicsbench/metrics/fidelity/UnivariateSimilarity.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def get_univariate_score(self, original_data: pd.DataFrame,
                        synthetic_data: pd.DataFrame,
                        metadata: dict,
                        save: bool = True) -> float:
    """
    Compute univariate similarity score between original and synthetic data.

    Args:
        original_data (pd.DataFrame): Original data.
        synthetic_data (pd.DataFrame): Synthetic data.
        metadata_path (str): Path to SDMetrics metadata.
        save (bool): Save score dataframe and score distribution.

    Returns:
        float: Overall univariate similarity score.

    Raises:
        ValueError: If score computation fails.
    """
    try:
        self.logger.info("Starting univariate similarity computation.")
        self.column_shapes = ColumnShapes()
        metadata_sdmetrics = MetaData.metadata_as_SDV(data=original_data, metadata=metadata)
        with tqdm(total=len(metadata_sdmetrics['columns'])) as pbar:
            score = self.column_shapes.get_score(
                original_data,
                synthetic_data,
                metadata_sdmetrics,
                progress_bar=pbar
            )
        self.logger.info(f"Univariate similarity score: {score}")
        # Save details
        if save:
            details_df = self.get_detail_df()
            details_csv_path = os.path.join(self.output_dir, f"Detail_score_{self.logger_name}.csv")
            details_df.to_csv(details_csv_path, index = False)
            self.logger.info(f"Details DataFrame saved to {details_csv_path}")
            # Save visualization
            fig = self.plot_column_score_histogram(details_df, data_name = self.logger_name)
            fig_path = os.path.join(self.output_dir, f"{self.logger_name}.png")
            fig.savefig(fig_path, dpi=300)
            plt.close(fig)
            self.logger.info(f"Histogram figure saved to {fig_path}")
        return score
    except Exception as e:
        self.logger.error(f"Failed to validate the Univariate Similarity: {e}")
        raise ValueError(f"Failed to validate the Univariate Similarity: {e}")

get_visualization(plotly=False, data_name='', bins=50)

Get a visualization of the column shape scores.

Parameters:

Name Type Description Default
plotly bool

Whether to use Plotly for visualization.

False

Returns:

Name Type Description
Figure

Matplotlib or Plotly figure.

Source code in src/synomicsbench/metrics/fidelity/UnivariateSimilarity.py
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def get_visualization(self, plotly: bool = False, data_name: str = "", bins: int = 50):
    """
    Get a visualization of the column shape scores.

    Args:
        plotly (bool): Whether to use Plotly for visualization.

    Returns:
        Figure: Matplotlib or Plotly figure.
    """
    details_df = self.get_detail_df()
    if plotly:
        return self.column_shapes.get_visualization()
    else:
        return self.plot_column_score_histogram(details_df, data_name, bins)

plot_column_score_histogram(details, data_name='', bins=50) staticmethod

Plot a decorated histogram of column shape scores for a set of features.

Parameters:

Name Type Description Default
details DataFrame

DataFrame containing at least a 'Score' column with numerical values.

required
data_name str

Optional label for the data (for title).

''
bins int

Number of bins for the histogram.

50

Returns:

Type Description

matplotlib.figure.Figure: Figure object for saving or further manipulation.

Raises:

Type Description
KeyError

If the 'Score' column is not present in the DataFrame.

TypeError

If details is not a pandas DataFrame.

Source code in src/synomicsbench/metrics/fidelity/UnivariateSimilarity.py
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
@staticmethod
def plot_column_score_histogram(details, data_name: str = "", bins: int = 50):
    """
    Plot a decorated histogram of column shape scores for a set of features.

    Args:
        details (pd.DataFrame): DataFrame containing at least a 'Score' column with numerical values.
        data_name (str): Optional label for the data (for title).
        bins (int): Number of bins for the histogram.

    Returns:
        matplotlib.figure.Figure: Figure object for saving or further manipulation.

    Raises:
        KeyError: If the 'Score' column is not present in the DataFrame.
        TypeError: If details is not a pandas DataFrame.
    """
    if not isinstance(details, pd.DataFrame):
        raise TypeError("details must be a pandas DataFrame.")
    if 'Score' not in details.columns:
        raise KeyError("The 'Score' column must be present in the details DataFrame.")

    sns.set_theme(style="whitegrid")
    fig, ax = plt.subplots(figsize=(10, 6))

    # Plot histogram only (no KDE)
    n, bins_, patches = ax.hist(
        details['Score'].dropna(),
        bins=bins,
        color='#4682B4',
        alpha=0.75,
        edgecolor='white',
        linewidth=1.2,
        label='Features'
    )

    # Add vertical lines for key statistics
    mean_score = details['Score'].mean()
    median_score = details['Score'].median()
    ax.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.2f}')
    ax.axvline(median_score, color='purple', linestyle='-.', linewidth=2, label=f'Median: {median_score:.2f}')

    # Customize ticks and grid
    ax.set_xticks(np.linspace(0, 1, 11))
    ax.tick_params(axis='x', labelsize=12)
    ax.tick_params(axis='y', labelsize=12)
    ax.grid(axis='y', alpha=0.3)

    # Add labels and title with font size and weight
    ax.set_xlabel('Score', fontsize=14, fontweight='bold')
    ax.set_ylabel('Number of Features', fontsize=14, fontweight='bold')
    ax.set_title(f'Distribution of Column Shape Scores {data_name}', fontsize=16, fontweight='bold', pad=20)

    # Add legend
    ax.legend(fontsize=12, frameon=True)

    # Add annotation for total features
    ax.annotate(
        f"Total features: {details['Score'].dropna().shape[0]}",
        xy=(0.99, 0.95), xycoords='axes fraction',
        ha='right', va='top', fontsize=12, color='gray'
    )

    plt.tight_layout()
    return fig

summarize()

Summarize the scores by metric type.

Returns:

Type Description

pd.DataFrame: Grouped summary statistics by metric.

Source code in src/synomicsbench/metrics/fidelity/UnivariateSimilarity.py
 98
 99
100
101
102
103
104
105
106
def summarize(self):
    """
    Summarize the scores by metric type.

    Returns:
        pd.DataFrame: Grouped summary statistics by metric.
    """
    detail_df = self.get_detail_df()
    return detail_df.groupby('Metric')['Score'].describe()

PairwiseSimilarity

Compute and analyze pairwise similarity metrics (Pearson and contingency) between columns of original and synthetic datasets.

Parameters:

Name Type Description Default
original_data DataFrame

Input original dataset.

required
synthetic_data DataFrame

Input synthetic dataset.

required
metadata dict

Column type metadata.

required
output_dir str

Directory for logs and outputs.

''

Returns:

Name Type Description
PairwiseSimilarity

An initialized instance for similarity analysis.

Raises:

Type Description
OSError

If the output directory cannot be created.

Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
class PairwiseSimilarity:
    """
    Compute and analyze pairwise similarity metrics (Pearson and contingency) between columns of original and synthetic datasets.

    Args:
        original_data (pd.DataFrame): Input original dataset.
        synthetic_data (pd.DataFrame): Input synthetic dataset.
        metadata (dict): Column type metadata.
        output_dir (str): Directory for logs and outputs.

    Returns:
        PairwiseSimilarity: An initialized instance for similarity analysis.

    Raises:
        OSError: If the output directory cannot be created.
    """

    def __init__(self, original_data: pd.DataFrame, synthetic_data: pd.DataFrame, metadata: dict, output_dir: str="", verbose: bool=False, save: bool=True, name: str=''):
        self.output_dir = output_dir
        os.makedirs(self.output_dir, exist_ok=True)
        self.all_columns = original_data.columns.to_list()
        self.original_data = original_data
        self.synthetic_data = synthetic_data
        self.metadata = metadata
        if set(metadata.keys()) != set(original_data.columns):
            raise RuntimeError("metadata does not match columns in the data.")
        self.grouping_metadata = MetaData.grouping_features_astype(self.original_data, self.metadata)
        self.numerical_features = self.grouping_metadata.get("numerical")
        self.ordinal_features = self.grouping_metadata.get("ordinal_categorical") 
        self.dummy_features = self.grouping_metadata.get("dummy_categorical") 
        self.missing_indicators = self.grouping_metadata.get("missing_categorical")

        self.categorical_features = self.dummy_features + self.ordinal_features + self.missing_indicators
        self.numerical_indices = MetaData.get_column_indices(self.original_data, self.numerical_features)
        self.categorical_indices = MetaData.get_column_indices(self.original_data, self.categorical_features)
        self.verbose = verbose
        self.square_form_or_corr = None
        self.square_form_syn_corr = None
        self.square_form_score_matrix = None
        self.save = save
        self.name = name

    @staticmethod
    def save(path, condensed):
        condensed = np.asarray(condensed)
        np.save(path, condensed)

    def _process(self, data: pd.DataFrame, metadata: dict):
        """
        Preprocess data by encoding categorical features and returning the processed DataFrame.

        Args:
            data (pd.DataFrame): Input data.
            metadata (dict): Metadata for grouping features.

        Returns:
            pd.DataFrame: Preprocessed DataFrame.

        Raises:
            ValueError: If no columns are specified for processing.
        """
        column_order = data.columns.to_list()
        self.grouping_metadata = MetaData.grouping_features_astype(data, metadata)
        num_features =  self.grouping_metadata.get("numerical")
        # print(num_features)
        ordinal_features = self.grouping_metadata.get("ordinal_categorical")
        # print(ordinal_features)
        dummy_features = self.grouping_metadata.get("dummy_categorical")
        # print(dummy_features)
        missing_indicators = self.grouping_metadata.get("missing_categorical")
        # print(missing_indicators)
        preprocessed_parts = []
        if ordinal_features: 
            ordinal_data = data[ordinal_features]
            ordinal_encoded, _ = DataProcessor.encode_ordinal_features(ordinal_data)
            preprocessed_parts.append(ordinal_encoded)

        if dummy_features:
            dummy_data = data[dummy_features]
            dummy_encoded,_ = DataProcessor.encode_ordinal_features(dummy_data)
            preprocessed_parts.append(dummy_encoded)

        if num_features:
            numerical_data = data[num_features]
            preprocessed_parts.append(numerical_data)

        if missing_indicators:
            missing_indicators_data = data[missing_indicators]
            preprocessed_parts.append(missing_indicators_data)

        if not preprocessed_parts:
            raise ValueError("No columns specified for processing")

        data_preprocessed = pd.concat(preprocessed_parts, axis=1)
        data_preprocessed = data_preprocessed[column_order]
        return data_preprocessed            

    def _calculate_corr_matrix(self, data: pd.DataFrame, method: str = "spearman", n_bins: int=10):
        """
        Calculate the mixed correlation matrix (Pearson, Spearman, Cramér's V) for the given data.

        Args:
            data (pd.DataFrame): Input data.
            method (str): Correlation method.
            n_bins (int): Number of bins for discretization.

        Returns:
            np.ndarray: Correlation matrix.
        """
        corr_matrix_computer = MixedCorrelation(self.categorical_indices, self.numerical_indices, method=method, n_bins=n_bins)
        corr_matrix = corr_matrix_computer.compute(data)
        return corr_matrix

    def _calculate_score_matrix(self, original_matrix: np.array, synthetic_matrix: np.array):
        """
        Calculate the similarity score matrix between original and synthetic data.
        Scoring rules (direction-aware):
            - numeric-numeric (Spearman/Pearson in [-1, 1]):
                score = 1 - 0.5 * |delta|
            - categorical-categorical OR categorical-(binned numeric) (Cramér's V in [0, 1]):
                score = 1 - |delta|
        Args:
            original_matrix (np.array): Correlation matrix for original data.
            synthetic_matrix (np.array): Correlation matrix for synthetic data.

        Returns:
            np.array: Similarity score matrix.

        Raises:
            ValueError: If calculation fails.
        """

        # try:
        #     score_matrix = 1 - 0.5 * np.abs(original_matrix - synthetic_matrix)
        #     return score_matrix
        # except Exception as e:
        #     raise ValueError(f"Error in calculating score matrix: {e}")

        if original_matrix.shape != synthetic_matrix.shape:
            raise ValueError("original_matrix and synthetic_matrix must have the same shape.")

        A = np.asarray(original_matrix, dtype=float)
        B = np.asarray(synthetic_matrix, dtype=float)
        n = A.shape[0]

        num_set = set(int(i) for i in self.numerical_indices)
        score = np.ones((n, n), dtype=float)

        for i in range(n):
            score[i, i] = 1.0
            for j in range(i + 1, n):
                if (i in num_set) and (j in num_set):
                    # Spearman/Pearson: [-1, 1] => max delta = 2
                    d = abs(A[i, j] - B[i, j])
                    s = 1.0 - 0.5 * d
                else:
                    # Cramér's V (including cat-num after discretization): [0, 1] => max delta = 1
                    d = abs(A[i, j] - B[i, j])
                    s = 1.0 - d

                score[i, j] = s
                score[j, i] = s

        np.clip(score, 0.0, 1.0, out=score)
        return score

    def print_condensed_matrix_summary(self, condensed_matrix: np.ndarray, name: str = "Condensed Matrix"):
        """
        Print statistical summary of a condensed matrix (1D array containing the upper triangle of a square matrix).

        Args:
            condensed_matrix (np.ndarray): 1D numpy array containing condensed matrix values.
            name (str): Name for display.

        Returns:
            None

        Raises:
            ValueError: If condensed_matrix is not a 1D numpy array.
        """
        if not isinstance(condensed_matrix, np.ndarray) or condensed_matrix.ndim != 1:
            raise ValueError("condensed_matrix must be a 1D numpy array.")
        print(f"Summary of {name}:")
        print(f"  Number of elements: {condensed_matrix.size}")
        print(f"  Min value: {np.min(condensed_matrix):.4f}")
        print(f"  Max value: {np.max(condensed_matrix):.4f}")
        print(f"  Mean: {np.mean(condensed_matrix):.4f}")
        print(f"  Median: {np.median(condensed_matrix):.4f}")
        print(f"  Standard deviation: {np.std(condensed_matrix):.4f}")

    @monitor_resources
    def get_pairwise_scores(self, method: str, n_bins: int = 10):
        """
        Compute the pairwise similarity matrices and print summary statistics.

        Args:
            method (str): Correlation method for MixedCorrelation.
            n_bins (int): Number of bins for discretization (default 10).

        Returns:
            dict: Dictionary containing pairwise score matrix, original condensed correlation, and synthetic condensed correlation.

        Raises:
            Exception: If any computation or extraction fails.
        """
        try:
            if self.verbose: 
                print("Processing data")
            processed_or_df = self._process(self.original_data, self.metadata)
            processed_syn_df = self._process(self.synthetic_data, self.metadata)

            if self.verbose:
                print(f"Calculating correlation matrix for both original data and synthetic data. This matrix is mixed between {method} correlation and Cramér's V.")

            or_corr_matrix = self._calculate_corr_matrix(processed_or_df, method, n_bins)
            syn_corr_matrix = self._calculate_corr_matrix(processed_syn_df, method, n_bins)

            self.square_form_or_corr = MixedCorrelation.get_squareform(or_corr_matrix)
            self.square_form_syn_corr = MixedCorrelation.get_squareform(syn_corr_matrix)

            # self.square_form_score_matrix = self._calculate_score_matrix(self.square_form_or_corr, self.square_form_syn_corr)
            # Compute score on square matrix, then condense
            score_square = self._calculate_score_matrix(or_corr_matrix, syn_corr_matrix)
            self.square_form_score_matrix = MixedCorrelation.get_squareform(score_square)

            # overall_score = np.mean(self.square_form_score_matrix)
            if self.verbose:
                print("\nSummary of Pairwise score:")
                self.print_condensed_matrix_summary(self.square_form_score_matrix, "Pairwise Score Matrix")
            if self.save:
                fig = PairwiseSimilarity.plot_column_score_histogram(self.square_form_score_matrix, 
                                                              data_name = self.name)
                fig_path = os.path.join(self.output_dir, f"{self.name}.png")
                fig.savefig(fig_path, dpi=300, bbox_inches='tight')
                plt.close(fig)

                # dict_path = os.path.join(self.output_dir, f"{self.name}_pairwisescore.npy")
                # PairwiseSimilarity.save(dict_path, self.square_form_score_matrix)

            dict_results = {
                "PairwiseScore": self.square_form_score_matrix,
                "OriginalCorrelation": self.square_form_or_corr,
                "SyntheticCorrelation": self.square_form_syn_corr
            }
            return dict_results
        except Exception as e:
            raise Exception(f"Error in get_scores: {e}")

    def get_single_associations(
        self,
        feature_1: Union[str, int],
        feature_2: Union[str, int],
        score_matrix: Optional[np.array] = None,
        original_matrix: Optional[np.array] = None,
        synthetic_matrix: Optional[np.array] = None
    ):
        """
        Retrieve detailed correlation and score between two features from the condensed correlation matrices.

        Args:
            feature_1 (Union[str, int]): First feature name or column index.
            feature_2 (Union[str, int]): Second feature name or column index.
            score_matrix (Optional[np.array]): Condensed score matrix.
            original_matrix (Optional[np.array]): Condensed correlation matrix for original data.
            synthetic_matrix (Optional[np.array]): Condensed correlation matrix for synthetic data.

        Returns:
            dict: Dictionary containing feature names, metric type, original correlation, synthetic correlation, and score.

        Raises:
            RuntimeError: If features are not in the data, or indices are out-of-bounds.
            ValueError: If correlation matrices are not provided.
        """
        # Convert feature names to indices if necessary
        if isinstance(feature_1, str):
            if feature_1 not in self.all_columns or feature_2 not in self.all_columns:
                raise RuntimeError(f"{feature_1} and {feature_2} are not in data")
            feature_1 = MetaData.get_column_indices(self.original_data, [feature_1])[0]
            feature_2 = MetaData.get_column_indices(self.original_data, [feature_2])[0]
        else:
            n_cols = self.original_data.shape[1]
            if feature_1 >= n_cols or feature_2 >= n_cols or feature_1 < 0 or feature_2 < 0:
                raise RuntimeError(f"Exceed the range. Data has {len(self.all_columns)} columns")

        # Use instance variables if parameters not provided
        if original_matrix is None:
            if self.square_form_or_corr is None:
                raise ValueError("Original correlation matrix is not provided.")
            original_matrix = self.square_form_or_corr
        if synthetic_matrix is None:
            if self.square_form_syn_corr is None:
                raise ValueError("Synthetic correlation matrix is not provided.")
            synthetic_matrix = self.square_form_syn_corr
        if score_matrix is None:
            if self.square_form_score_matrix is None:
                raise ValueError("Score matrix is not provided.")
            score_matrix = self.square_form_score_matrix

        n = self.original_data.shape[1]
        # Always make sure i < j for condensed index
        i, j = sorted([feature_1, feature_2])
        if i ==j:
            or_corr = 1.0
            syn_corr = 1.0
            score = 1.0
        else:
            k = n * i - i * (i + 1) // 2 + (j - i - 1)
            or_corr = original_matrix[k]
            syn_corr = synthetic_matrix[k]
            score = score_matrix[k]

        metrics = "CramersV_Correlation" if (feature_1 in self.categorical_indices or feature_2 in self.categorical_indices) else "Spearman_Correlation"

        results = {
            "Feature_1": self.all_columns[feature_1],
            "Feature_2": self.all_columns[feature_2],
            "Metrics": metrics,
            "Original_Correlation": or_corr,
            "Synthetic_Correlation": syn_corr,
            "Score": score
        }
        return results


    @monitor_resources
    def get_multiple_associations(
        self,
        feature_pairs: List[Tuple[Union[str, int], Union[str, int]]],
        score_matrix: Optional[np.ndarray] = None,
        original_matrix: Optional[np.ndarray] = None,
        synthetic_matrix: Optional[np.ndarray] = None,
        condensed: bool = True,
    ) -> pd.DataFrame:
        """
        Extract association metrics for multiple feature pairs using vectorized indexing.

        This implementation minimizes Python-loop overhead by computing all indices at once and
        gathering values from correlation/score matrices via NumPy advanced indexing.

        Args:
            feature_pairs (List[Tuple[Union[str, int], Union[str, int]]]):
                List of (feature_1, feature_2) pairs specified by name or index.
            score_matrix (Optional[np.ndarray]):
                Score matrix. If `condensed=True`, must be 1D of length n*(n-1)/2; if `condensed=False`,
                must be square (n x n). If None, uses `self.square_form_score_matrix`.
            original_matrix (Optional[np.ndarray]):
                Original correlation matrix (condensed or square depending on `condensed`).
                If None, uses `self.square_form_or_corr`.
            synthetic_matrix (Optional[np.ndarray]):
                Synthetic correlation matrix (condensed or square depending on `condensed`).
                If None, uses `self.square_form_syn_corr`.
            condensed (bool):
                If True, matrices are interpreted as condensed upper-triangular vectors (no diagonal).
                If False, matrices are interpreted as square (n x n) arrays.

        Returns:
            pd.DataFrame: DataFrame with columns:
                - 'Feature_1'
                - 'Feature_2'
                - 'Metrics' ("CramersV_Correlation" if either feature is categorical, else "Spearman_Correlation")
                - 'Original_Correlation'
                - 'Synthetic_Correlation'
                - 'Score'

        Raises:
            RuntimeError:
                - If any feature name is not in the data.
                - If any feature index is out-of-bounds.
                - If any pair repeats the same feature (i == j) when `condensed=True`.
            ValueError:
                - If required matrices are not provided.
                - If provided matrix shapes are inconsistent with `condensed`.
        """
        # Resolve matrices
        if original_matrix is None:
            if self.square_form_or_corr is None:
                raise ValueError("Original correlation matrix is not provided.")
            original_matrix = self.square_form_or_corr
        if synthetic_matrix is None:
            if self.square_form_syn_corr is None:
                raise ValueError("Synthetic correlation matrix is not provided.")
            synthetic_matrix = self.square_form_syn_corr
        if score_matrix is None:
            if self.square_form_score_matrix is None:
                raise ValueError("Score matrix is not provided.")
            score_matrix = self.square_form_score_matrix

        # Validate shapes against number of columns
        n_cols = int(self.original_data.shape[1])
        if condensed:
            expected_len = n_cols * (n_cols - 1) // 2
            for name, arr in (
                ("original_matrix", original_matrix),
                ("synthetic_matrix", synthetic_matrix),
                ("score_matrix", score_matrix),
            ):
                if arr.ndim != 1:
                    raise ValueError(f"{name} must be 1D (condensed), got shape {arr.shape}.")
                if arr.size != expected_len:
                    raise ValueError(
                        f"{name} length {arr.size} != expected {expected_len} for n={n_cols}."
                    )
        else:
            for name, arr in (
                ("original_matrix", original_matrix),
                ("synthetic_matrix", synthetic_matrix),
                ("score_matrix", score_matrix),
            ):
                if arr.ndim != 2 or arr.shape != (n_cols, n_cols):
                    raise ValueError(
                        f"{name} must be square with shape ({n_cols}, {n_cols}), got {arr.shape}."
                    )

        # Resolve pairs to indices (vector-friendly via list -> arrays)
        name_to_idx = {name: idx for idx, name in enumerate(self.all_columns)}

        def resolve_one(x: Union[str, int]) -> int:
            if isinstance(x, str):
                if x not in name_to_idx:
                    raise RuntimeError(f"Feature '{x}' is not in data.")
                return name_to_idx[x]
            if isinstance(x, int):
                return x
            raise RuntimeError(f"Feature identifier must be str or int, got {type(x)}")

        idx_pairs = np.array([(resolve_one(a), resolve_one(b)) for a, b in feature_pairs], dtype=int)
        i_arr = idx_pairs[:, 0]
        j_arr = idx_pairs[:, 1]

        # Bounds check
        if (i_arr.min() < 0) or (j_arr.min() < 0) or (i_arr.max() >= n_cols) or (j_arr.max() >= n_cols):
            raise RuntimeError(f"Index out of bounds. Data has {len(self.all_columns)} columns.")

        # Disallow i == j for condensed form (undefined index on diagonal)
        if condensed and np.any(i_arr == j_arr):
            raise RuntimeError("Pairs with identical features (i == j) are not allowed in condensed form.")

        # For condensed, enforce i<j to index the upper triangle vector
        if condensed:
            imin = np.minimum(i_arr, j_arr)
            jmax = np.maximum(i_arr, j_arr)
            # condensed index formula (upper triangle, no diag), matches squareform
            # k = n*i - i*(i+1)//2 + (j - i - 1)
            k_arr = n_cols * imin - imin * (imin + 1) // 2 + (jmax - imin - 1)
            original_vals = original_matrix[k_arr]
            synthetic_vals = synthetic_matrix[k_arr]
            score_vals = score_matrix[k_arr]
            # Keep display names aligned to original input order (use original i_arr, j_arr)
            f1_idx, f2_idx = i_arr, j_arr
        else:
            # Square matrix direct advanced indexing
            original_vals = original_matrix[i_arr, j_arr]
            synthetic_vals = synthetic_matrix[i_arr, j_arr]
            score_vals = score_matrix[i_arr, j_arr]
            f1_idx, f2_idx = i_arr, j_arr

        # Determine metric labels vectorized
        categorical_set = np.array(self.categorical_indices, dtype=int)
        # Fast path when no categorical indices
        if categorical_set.size == 0:
            metrics = np.full(i_arr.shape[0], "Spearman_Correlation", dtype=object)
        else:
            is_cat = np.isin(f1_idx, categorical_set) | np.isin(f2_idx, categorical_set)
            metrics = np.where(is_cat, "CramersV_Correlation", "Spearman_Correlation")

        # Build DataFrame
        col_names = np.array(self.all_columns, dtype=object)
        df = pd.DataFrame({
            "Feature_1": col_names[f1_idx],
            "Feature_2": col_names[f2_idx],
            "Metrics": metrics,
            "Original_Correlation": original_vals,
            "Synthetic_Correlation": synthetic_vals,
            "Score": score_vals,
        })

        return df

    @staticmethod
    def summarize(results: pd.DataFrame):
        """
        Summarize pairwise similarity scores by metric type.

        Args:
            results (pd.DataFrame): DataFrame containing pairwise association results.

        Returns:
            pd.DataFrame: Summary statistics (count, mean, std, min, max, etc.) grouped by metric type.

        Raises:
            Exception: If summary calculation fails.
        """
        try:
            print("\nSummary by Metric Type:")
            summary = results.groupby('Metrics')['Score'].describe()
            return summary
        except Exception as e:
            raise Exception(f"Error in summarize: {e}")

    @staticmethod
    def plot_column_score_histogram(results: np.ndarray, data_name: str = "", bins: int = 50):
        """
        Plot a histogram of column shape scores for a set of features.

        Args:
            results (np.array): Numpy array containing at least a 'Score' column with numerical values.
            data_name (str): Optional label for the data (for title).
            bins (int): Number of bins for the histogram.

        Returns:
            matplotlib.figure.Figure: Figure object for saving or further manipulation.

        Raises:
            KeyError: If the 'Score' column is not present in the DataFrame.
            TypeError: If results is not a pandas DataFrame.
        """
        if not isinstance(results,np.ndarray):
            raise TypeError("results must be a Numpy array.")
        # if 'Score' not in results.columns:
        #     raise KeyError("The 'Score' column must be present in the results DataFrame.")

        sns.set_theme(style="whitegrid")
        fig, ax = plt.subplots(figsize=(10, 6))
        results_no_nan = results[~np.isnan(results)]
        # Plot histogram only (no KDE)
        n, bins_, patches = ax.hist(
            results_no_nan,
            bins=bins,
            color='#4682B4',
            alpha=0.75,
            edgecolor='white',
            linewidth=1.2
        )

        # Add vertical lines for key statistics
        mean_score = np.mean(results_no_nan)
        median_score = np.median(results_no_nan)
        ax.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.2f}')
        ax.axvline(median_score, color='purple', linestyle='-.', linewidth=2, label=f'Median: {median_score:.2f}')

        # Customize ticks and grid
        ax.set_xticks(np.linspace(0, 1, 11))
        ax.tick_params(axis='x', labelsize=12)
        ax.tick_params(axis='y', labelsize=12)
        ax.grid(axis='y', alpha=0.3)

        # Add labels and title with font size and weight
        ax.set_xlabel('Score', fontsize=14, fontweight='bold')
        ax.set_ylabel('Number of Features', fontsize=14, fontweight='bold')
        ax.set_title(f'Distribution of PairWise Scores {data_name}', fontsize=16, fontweight='bold', pad=20)

        # Add legend
        ax.legend(fontsize=12, frameon=True)

        # Add annotation for total features
        ax.annotate(
            f"Total associations: {results_no_nan.shape[0]}",
            xy=(0.01, 0.05), xycoords='axes fraction',
            ha='left', va='bottom', fontsize=12, color='gray'
        )

        plt.tight_layout()
        return fig
    @monitor_resources
    def get_cross_group_associations(
        self,
        group_1_features: List[str],
        group_2_features: List[str],
        score_matrix: Optional[np.ndarray] = None,
        original_matrix: Optional[np.ndarray] = None,
        synthetic_matrix: Optional[np.ndarray] = None,
        condensed: bool = True,
    ) -> pd.DataFrame:
        """
        Calculate bivariate scores ONLY between two groups of features (e.g., clinical vs transcriptomics).

        Args:
            group_1_features (List[str]): List of feature names in first group (e.g., clinical).
            group_2_features (List[str]): List of feature names in second group (e.g., transcriptomics).
            score_matrix (Optional[np.ndarray]): Score matrix (condensed or square).
            original_matrix (Optional[np.ndarray]): Original correlation matrix.
            synthetic_matrix (Optional[np.ndarray]): Synthetic correlation matrix.
            condensed (bool): Whether matrices are in condensed form.

        Returns:
            pd.DataFrame: DataFrame with cross-group associations only.

        Raises:
            ValueError: If any feature name is not found in the data.
        """
        # Validate feature names
        missing_g1 = set(group_1_features) - set(self.all_columns)
        missing_g2 = set(group_2_features) - set(self.all_columns)

        if missing_g1:
            raise ValueError(f"Features not found in group 1: {missing_g1}")
        if missing_g2:
            raise ValueError(f"Features not found in group 2: {missing_g2}")

        # Check for overlap
        overlap = set(group_1_features) & set(group_2_features)
        if overlap:
            raise ValueError(f"Features cannot be in both groups: {overlap}")

        # Generate all cross-group pairs (Cartesian product)
        feature_pairs = [
            (f1, f2) 
            for f1 in group_1_features 
            for f2 in group_2_features
        ]

        if self.verbose:
            print(f"Computing {len(feature_pairs)} cross-group associations:")
            print(f"  Group 1 ({len(group_1_features)} features): {group_1_features[:3]}...")
            print(f"  Group 2 ({len(group_2_features)} features): {group_2_features[:3]}...")

        # Use existing method to compute associations
        results_df = self.get_multiple_associations(
            feature_pairs=feature_pairs,
            score_matrix=score_matrix,
            original_matrix=original_matrix,
            synthetic_matrix=synthetic_matrix,
            condensed=condensed
        )

        if self.verbose:
            print("\nCross-group association summary:")
            print(f"  Total pairs: {len(results_df)}")
            print(f"  Mean score: {results_df['Score'].mean():.4f}")
            print(f"  Median score: {results_df['Score'].median():.4f}")

        return results_df

get_cross_group_associations(group_1_features, group_2_features, score_matrix=None, original_matrix=None, synthetic_matrix=None, condensed=True)

Calculate bivariate scores ONLY between two groups of features (e.g., clinical vs transcriptomics).

Parameters:

Name Type Description Default
group_1_features List[str]

List of feature names in first group (e.g., clinical).

required
group_2_features List[str]

List of feature names in second group (e.g., transcriptomics).

required
score_matrix Optional[ndarray]

Score matrix (condensed or square).

None
original_matrix Optional[ndarray]

Original correlation matrix.

None
synthetic_matrix Optional[ndarray]

Synthetic correlation matrix.

None
condensed bool

Whether matrices are in condensed form.

True

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with cross-group associations only.

Raises:

Type Description
ValueError

If any feature name is not found in the data.

Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
@monitor_resources
def get_cross_group_associations(
    self,
    group_1_features: List[str],
    group_2_features: List[str],
    score_matrix: Optional[np.ndarray] = None,
    original_matrix: Optional[np.ndarray] = None,
    synthetic_matrix: Optional[np.ndarray] = None,
    condensed: bool = True,
) -> pd.DataFrame:
    """
    Calculate bivariate scores ONLY between two groups of features (e.g., clinical vs transcriptomics).

    Args:
        group_1_features (List[str]): List of feature names in first group (e.g., clinical).
        group_2_features (List[str]): List of feature names in second group (e.g., transcriptomics).
        score_matrix (Optional[np.ndarray]): Score matrix (condensed or square).
        original_matrix (Optional[np.ndarray]): Original correlation matrix.
        synthetic_matrix (Optional[np.ndarray]): Synthetic correlation matrix.
        condensed (bool): Whether matrices are in condensed form.

    Returns:
        pd.DataFrame: DataFrame with cross-group associations only.

    Raises:
        ValueError: If any feature name is not found in the data.
    """
    # Validate feature names
    missing_g1 = set(group_1_features) - set(self.all_columns)
    missing_g2 = set(group_2_features) - set(self.all_columns)

    if missing_g1:
        raise ValueError(f"Features not found in group 1: {missing_g1}")
    if missing_g2:
        raise ValueError(f"Features not found in group 2: {missing_g2}")

    # Check for overlap
    overlap = set(group_1_features) & set(group_2_features)
    if overlap:
        raise ValueError(f"Features cannot be in both groups: {overlap}")

    # Generate all cross-group pairs (Cartesian product)
    feature_pairs = [
        (f1, f2) 
        for f1 in group_1_features 
        for f2 in group_2_features
    ]

    if self.verbose:
        print(f"Computing {len(feature_pairs)} cross-group associations:")
        print(f"  Group 1 ({len(group_1_features)} features): {group_1_features[:3]}...")
        print(f"  Group 2 ({len(group_2_features)} features): {group_2_features[:3]}...")

    # Use existing method to compute associations
    results_df = self.get_multiple_associations(
        feature_pairs=feature_pairs,
        score_matrix=score_matrix,
        original_matrix=original_matrix,
        synthetic_matrix=synthetic_matrix,
        condensed=condensed
    )

    if self.verbose:
        print("\nCross-group association summary:")
        print(f"  Total pairs: {len(results_df)}")
        print(f"  Mean score: {results_df['Score'].mean():.4f}")
        print(f"  Median score: {results_df['Score'].median():.4f}")

    return results_df

get_multiple_associations(feature_pairs, score_matrix=None, original_matrix=None, synthetic_matrix=None, condensed=True)

Extract association metrics for multiple feature pairs using vectorized indexing.

This implementation minimizes Python-loop overhead by computing all indices at once and gathering values from correlation/score matrices via NumPy advanced indexing.

Parameters:

Name Type Description Default
feature_pairs List[Tuple[Union[str, int], Union[str, int]]]

List of (feature_1, feature_2) pairs specified by name or index.

required
score_matrix Optional[ndarray]

Score matrix. If condensed=True, must be 1D of length n*(n-1)/2; if condensed=False, must be square (n x n). If None, uses self.square_form_score_matrix.

None
original_matrix Optional[ndarray]

Original correlation matrix (condensed or square depending on condensed). If None, uses self.square_form_or_corr.

None
synthetic_matrix Optional[ndarray]

Synthetic correlation matrix (condensed or square depending on condensed). If None, uses self.square_form_syn_corr.

None
condensed bool

If True, matrices are interpreted as condensed upper-triangular vectors (no diagonal). If False, matrices are interpreted as square (n x n) arrays.

True

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with columns: - 'Feature_1' - 'Feature_2' - 'Metrics' ("CramersV_Correlation" if either feature is categorical, else "Spearman_Correlation") - 'Original_Correlation' - 'Synthetic_Correlation' - 'Score'

Raises:

Type Description
RuntimeError
  • If any feature name is not in the data.
  • If any feature index is out-of-bounds.
  • If any pair repeats the same feature (i == j) when condensed=True.
ValueError
  • If required matrices are not provided.
  • If provided matrix shapes are inconsistent with condensed.
Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
@monitor_resources
def get_multiple_associations(
    self,
    feature_pairs: List[Tuple[Union[str, int], Union[str, int]]],
    score_matrix: Optional[np.ndarray] = None,
    original_matrix: Optional[np.ndarray] = None,
    synthetic_matrix: Optional[np.ndarray] = None,
    condensed: bool = True,
) -> pd.DataFrame:
    """
    Extract association metrics for multiple feature pairs using vectorized indexing.

    This implementation minimizes Python-loop overhead by computing all indices at once and
    gathering values from correlation/score matrices via NumPy advanced indexing.

    Args:
        feature_pairs (List[Tuple[Union[str, int], Union[str, int]]]):
            List of (feature_1, feature_2) pairs specified by name or index.
        score_matrix (Optional[np.ndarray]):
            Score matrix. If `condensed=True`, must be 1D of length n*(n-1)/2; if `condensed=False`,
            must be square (n x n). If None, uses `self.square_form_score_matrix`.
        original_matrix (Optional[np.ndarray]):
            Original correlation matrix (condensed or square depending on `condensed`).
            If None, uses `self.square_form_or_corr`.
        synthetic_matrix (Optional[np.ndarray]):
            Synthetic correlation matrix (condensed or square depending on `condensed`).
            If None, uses `self.square_form_syn_corr`.
        condensed (bool):
            If True, matrices are interpreted as condensed upper-triangular vectors (no diagonal).
            If False, matrices are interpreted as square (n x n) arrays.

    Returns:
        pd.DataFrame: DataFrame with columns:
            - 'Feature_1'
            - 'Feature_2'
            - 'Metrics' ("CramersV_Correlation" if either feature is categorical, else "Spearman_Correlation")
            - 'Original_Correlation'
            - 'Synthetic_Correlation'
            - 'Score'

    Raises:
        RuntimeError:
            - If any feature name is not in the data.
            - If any feature index is out-of-bounds.
            - If any pair repeats the same feature (i == j) when `condensed=True`.
        ValueError:
            - If required matrices are not provided.
            - If provided matrix shapes are inconsistent with `condensed`.
    """
    # Resolve matrices
    if original_matrix is None:
        if self.square_form_or_corr is None:
            raise ValueError("Original correlation matrix is not provided.")
        original_matrix = self.square_form_or_corr
    if synthetic_matrix is None:
        if self.square_form_syn_corr is None:
            raise ValueError("Synthetic correlation matrix is not provided.")
        synthetic_matrix = self.square_form_syn_corr
    if score_matrix is None:
        if self.square_form_score_matrix is None:
            raise ValueError("Score matrix is not provided.")
        score_matrix = self.square_form_score_matrix

    # Validate shapes against number of columns
    n_cols = int(self.original_data.shape[1])
    if condensed:
        expected_len = n_cols * (n_cols - 1) // 2
        for name, arr in (
            ("original_matrix", original_matrix),
            ("synthetic_matrix", synthetic_matrix),
            ("score_matrix", score_matrix),
        ):
            if arr.ndim != 1:
                raise ValueError(f"{name} must be 1D (condensed), got shape {arr.shape}.")
            if arr.size != expected_len:
                raise ValueError(
                    f"{name} length {arr.size} != expected {expected_len} for n={n_cols}."
                )
    else:
        for name, arr in (
            ("original_matrix", original_matrix),
            ("synthetic_matrix", synthetic_matrix),
            ("score_matrix", score_matrix),
        ):
            if arr.ndim != 2 or arr.shape != (n_cols, n_cols):
                raise ValueError(
                    f"{name} must be square with shape ({n_cols}, {n_cols}), got {arr.shape}."
                )

    # Resolve pairs to indices (vector-friendly via list -> arrays)
    name_to_idx = {name: idx for idx, name in enumerate(self.all_columns)}

    def resolve_one(x: Union[str, int]) -> int:
        if isinstance(x, str):
            if x not in name_to_idx:
                raise RuntimeError(f"Feature '{x}' is not in data.")
            return name_to_idx[x]
        if isinstance(x, int):
            return x
        raise RuntimeError(f"Feature identifier must be str or int, got {type(x)}")

    idx_pairs = np.array([(resolve_one(a), resolve_one(b)) for a, b in feature_pairs], dtype=int)
    i_arr = idx_pairs[:, 0]
    j_arr = idx_pairs[:, 1]

    # Bounds check
    if (i_arr.min() < 0) or (j_arr.min() < 0) or (i_arr.max() >= n_cols) or (j_arr.max() >= n_cols):
        raise RuntimeError(f"Index out of bounds. Data has {len(self.all_columns)} columns.")

    # Disallow i == j for condensed form (undefined index on diagonal)
    if condensed and np.any(i_arr == j_arr):
        raise RuntimeError("Pairs with identical features (i == j) are not allowed in condensed form.")

    # For condensed, enforce i<j to index the upper triangle vector
    if condensed:
        imin = np.minimum(i_arr, j_arr)
        jmax = np.maximum(i_arr, j_arr)
        # condensed index formula (upper triangle, no diag), matches squareform
        # k = n*i - i*(i+1)//2 + (j - i - 1)
        k_arr = n_cols * imin - imin * (imin + 1) // 2 + (jmax - imin - 1)
        original_vals = original_matrix[k_arr]
        synthetic_vals = synthetic_matrix[k_arr]
        score_vals = score_matrix[k_arr]
        # Keep display names aligned to original input order (use original i_arr, j_arr)
        f1_idx, f2_idx = i_arr, j_arr
    else:
        # Square matrix direct advanced indexing
        original_vals = original_matrix[i_arr, j_arr]
        synthetic_vals = synthetic_matrix[i_arr, j_arr]
        score_vals = score_matrix[i_arr, j_arr]
        f1_idx, f2_idx = i_arr, j_arr

    # Determine metric labels vectorized
    categorical_set = np.array(self.categorical_indices, dtype=int)
    # Fast path when no categorical indices
    if categorical_set.size == 0:
        metrics = np.full(i_arr.shape[0], "Spearman_Correlation", dtype=object)
    else:
        is_cat = np.isin(f1_idx, categorical_set) | np.isin(f2_idx, categorical_set)
        metrics = np.where(is_cat, "CramersV_Correlation", "Spearman_Correlation")

    # Build DataFrame
    col_names = np.array(self.all_columns, dtype=object)
    df = pd.DataFrame({
        "Feature_1": col_names[f1_idx],
        "Feature_2": col_names[f2_idx],
        "Metrics": metrics,
        "Original_Correlation": original_vals,
        "Synthetic_Correlation": synthetic_vals,
        "Score": score_vals,
    })

    return df

get_pairwise_scores(method, n_bins=10)

Compute the pairwise similarity matrices and print summary statistics.

Parameters:

Name Type Description Default
method str

Correlation method for MixedCorrelation.

required
n_bins int

Number of bins for discretization (default 10).

10

Returns:

Name Type Description
dict

Dictionary containing pairwise score matrix, original condensed correlation, and synthetic condensed correlation.

Raises:

Type Description
Exception

If any computation or extraction fails.

Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
@monitor_resources
def get_pairwise_scores(self, method: str, n_bins: int = 10):
    """
    Compute the pairwise similarity matrices and print summary statistics.

    Args:
        method (str): Correlation method for MixedCorrelation.
        n_bins (int): Number of bins for discretization (default 10).

    Returns:
        dict: Dictionary containing pairwise score matrix, original condensed correlation, and synthetic condensed correlation.

    Raises:
        Exception: If any computation or extraction fails.
    """
    try:
        if self.verbose: 
            print("Processing data")
        processed_or_df = self._process(self.original_data, self.metadata)
        processed_syn_df = self._process(self.synthetic_data, self.metadata)

        if self.verbose:
            print(f"Calculating correlation matrix for both original data and synthetic data. This matrix is mixed between {method} correlation and Cramér's V.")

        or_corr_matrix = self._calculate_corr_matrix(processed_or_df, method, n_bins)
        syn_corr_matrix = self._calculate_corr_matrix(processed_syn_df, method, n_bins)

        self.square_form_or_corr = MixedCorrelation.get_squareform(or_corr_matrix)
        self.square_form_syn_corr = MixedCorrelation.get_squareform(syn_corr_matrix)

        # self.square_form_score_matrix = self._calculate_score_matrix(self.square_form_or_corr, self.square_form_syn_corr)
        # Compute score on square matrix, then condense
        score_square = self._calculate_score_matrix(or_corr_matrix, syn_corr_matrix)
        self.square_form_score_matrix = MixedCorrelation.get_squareform(score_square)

        # overall_score = np.mean(self.square_form_score_matrix)
        if self.verbose:
            print("\nSummary of Pairwise score:")
            self.print_condensed_matrix_summary(self.square_form_score_matrix, "Pairwise Score Matrix")
        if self.save:
            fig = PairwiseSimilarity.plot_column_score_histogram(self.square_form_score_matrix, 
                                                          data_name = self.name)
            fig_path = os.path.join(self.output_dir, f"{self.name}.png")
            fig.savefig(fig_path, dpi=300, bbox_inches='tight')
            plt.close(fig)

            # dict_path = os.path.join(self.output_dir, f"{self.name}_pairwisescore.npy")
            # PairwiseSimilarity.save(dict_path, self.square_form_score_matrix)

        dict_results = {
            "PairwiseScore": self.square_form_score_matrix,
            "OriginalCorrelation": self.square_form_or_corr,
            "SyntheticCorrelation": self.square_form_syn_corr
        }
        return dict_results
    except Exception as e:
        raise Exception(f"Error in get_scores: {e}")

get_single_associations(feature_1, feature_2, score_matrix=None, original_matrix=None, synthetic_matrix=None)

Retrieve detailed correlation and score between two features from the condensed correlation matrices.

Parameters:

Name Type Description Default
feature_1 Union[str, int]

First feature name or column index.

required
feature_2 Union[str, int]

Second feature name or column index.

required
score_matrix Optional[array]

Condensed score matrix.

None
original_matrix Optional[array]

Condensed correlation matrix for original data.

None
synthetic_matrix Optional[array]

Condensed correlation matrix for synthetic data.

None

Returns:

Name Type Description
dict

Dictionary containing feature names, metric type, original correlation, synthetic correlation, and score.

Raises:

Type Description
RuntimeError

If features are not in the data, or indices are out-of-bounds.

ValueError

If correlation matrices are not provided.

Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def get_single_associations(
    self,
    feature_1: Union[str, int],
    feature_2: Union[str, int],
    score_matrix: Optional[np.array] = None,
    original_matrix: Optional[np.array] = None,
    synthetic_matrix: Optional[np.array] = None
):
    """
    Retrieve detailed correlation and score between two features from the condensed correlation matrices.

    Args:
        feature_1 (Union[str, int]): First feature name or column index.
        feature_2 (Union[str, int]): Second feature name or column index.
        score_matrix (Optional[np.array]): Condensed score matrix.
        original_matrix (Optional[np.array]): Condensed correlation matrix for original data.
        synthetic_matrix (Optional[np.array]): Condensed correlation matrix for synthetic data.

    Returns:
        dict: Dictionary containing feature names, metric type, original correlation, synthetic correlation, and score.

    Raises:
        RuntimeError: If features are not in the data, or indices are out-of-bounds.
        ValueError: If correlation matrices are not provided.
    """
    # Convert feature names to indices if necessary
    if isinstance(feature_1, str):
        if feature_1 not in self.all_columns or feature_2 not in self.all_columns:
            raise RuntimeError(f"{feature_1} and {feature_2} are not in data")
        feature_1 = MetaData.get_column_indices(self.original_data, [feature_1])[0]
        feature_2 = MetaData.get_column_indices(self.original_data, [feature_2])[0]
    else:
        n_cols = self.original_data.shape[1]
        if feature_1 >= n_cols or feature_2 >= n_cols or feature_1 < 0 or feature_2 < 0:
            raise RuntimeError(f"Exceed the range. Data has {len(self.all_columns)} columns")

    # Use instance variables if parameters not provided
    if original_matrix is None:
        if self.square_form_or_corr is None:
            raise ValueError("Original correlation matrix is not provided.")
        original_matrix = self.square_form_or_corr
    if synthetic_matrix is None:
        if self.square_form_syn_corr is None:
            raise ValueError("Synthetic correlation matrix is not provided.")
        synthetic_matrix = self.square_form_syn_corr
    if score_matrix is None:
        if self.square_form_score_matrix is None:
            raise ValueError("Score matrix is not provided.")
        score_matrix = self.square_form_score_matrix

    n = self.original_data.shape[1]
    # Always make sure i < j for condensed index
    i, j = sorted([feature_1, feature_2])
    if i ==j:
        or_corr = 1.0
        syn_corr = 1.0
        score = 1.0
    else:
        k = n * i - i * (i + 1) // 2 + (j - i - 1)
        or_corr = original_matrix[k]
        syn_corr = synthetic_matrix[k]
        score = score_matrix[k]

    metrics = "CramersV_Correlation" if (feature_1 in self.categorical_indices or feature_2 in self.categorical_indices) else "Spearman_Correlation"

    results = {
        "Feature_1": self.all_columns[feature_1],
        "Feature_2": self.all_columns[feature_2],
        "Metrics": metrics,
        "Original_Correlation": or_corr,
        "Synthetic_Correlation": syn_corr,
        "Score": score
    }
    return results

plot_column_score_histogram(results, data_name='', bins=50) staticmethod

Plot a histogram of column shape scores for a set of features.

Parameters:

Name Type Description Default
results array

Numpy array containing at least a 'Score' column with numerical values.

required
data_name str

Optional label for the data (for title).

''
bins int

Number of bins for the histogram.

50

Returns:

Type Description

matplotlib.figure.Figure: Figure object for saving or further manipulation.

Raises:

Type Description
KeyError

If the 'Score' column is not present in the DataFrame.

TypeError

If results is not a pandas DataFrame.

Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
@staticmethod
def plot_column_score_histogram(results: np.ndarray, data_name: str = "", bins: int = 50):
    """
    Plot a histogram of column shape scores for a set of features.

    Args:
        results (np.array): Numpy array containing at least a 'Score' column with numerical values.
        data_name (str): Optional label for the data (for title).
        bins (int): Number of bins for the histogram.

    Returns:
        matplotlib.figure.Figure: Figure object for saving or further manipulation.

    Raises:
        KeyError: If the 'Score' column is not present in the DataFrame.
        TypeError: If results is not a pandas DataFrame.
    """
    if not isinstance(results,np.ndarray):
        raise TypeError("results must be a Numpy array.")
    # if 'Score' not in results.columns:
    #     raise KeyError("The 'Score' column must be present in the results DataFrame.")

    sns.set_theme(style="whitegrid")
    fig, ax = plt.subplots(figsize=(10, 6))
    results_no_nan = results[~np.isnan(results)]
    # Plot histogram only (no KDE)
    n, bins_, patches = ax.hist(
        results_no_nan,
        bins=bins,
        color='#4682B4',
        alpha=0.75,
        edgecolor='white',
        linewidth=1.2
    )

    # Add vertical lines for key statistics
    mean_score = np.mean(results_no_nan)
    median_score = np.median(results_no_nan)
    ax.axvline(mean_score, color='green', linestyle='--', linewidth=2, label=f'Mean: {mean_score:.2f}')
    ax.axvline(median_score, color='purple', linestyle='-.', linewidth=2, label=f'Median: {median_score:.2f}')

    # Customize ticks and grid
    ax.set_xticks(np.linspace(0, 1, 11))
    ax.tick_params(axis='x', labelsize=12)
    ax.tick_params(axis='y', labelsize=12)
    ax.grid(axis='y', alpha=0.3)

    # Add labels and title with font size and weight
    ax.set_xlabel('Score', fontsize=14, fontweight='bold')
    ax.set_ylabel('Number of Features', fontsize=14, fontweight='bold')
    ax.set_title(f'Distribution of PairWise Scores {data_name}', fontsize=16, fontweight='bold', pad=20)

    # Add legend
    ax.legend(fontsize=12, frameon=True)

    # Add annotation for total features
    ax.annotate(
        f"Total associations: {results_no_nan.shape[0]}",
        xy=(0.01, 0.05), xycoords='axes fraction',
        ha='left', va='bottom', fontsize=12, color='gray'
    )

    plt.tight_layout()
    return fig

print_condensed_matrix_summary(condensed_matrix, name='Condensed Matrix')

Print statistical summary of a condensed matrix (1D array containing the upper triangle of a square matrix).

Parameters:

Name Type Description Default
condensed_matrix ndarray

1D numpy array containing condensed matrix values.

required
name str

Name for display.

'Condensed Matrix'

Returns:

Type Description

None

Raises:

Type Description
ValueError

If condensed_matrix is not a 1D numpy array.

Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def print_condensed_matrix_summary(self, condensed_matrix: np.ndarray, name: str = "Condensed Matrix"):
    """
    Print statistical summary of a condensed matrix (1D array containing the upper triangle of a square matrix).

    Args:
        condensed_matrix (np.ndarray): 1D numpy array containing condensed matrix values.
        name (str): Name for display.

    Returns:
        None

    Raises:
        ValueError: If condensed_matrix is not a 1D numpy array.
    """
    if not isinstance(condensed_matrix, np.ndarray) or condensed_matrix.ndim != 1:
        raise ValueError("condensed_matrix must be a 1D numpy array.")
    print(f"Summary of {name}:")
    print(f"  Number of elements: {condensed_matrix.size}")
    print(f"  Min value: {np.min(condensed_matrix):.4f}")
    print(f"  Max value: {np.max(condensed_matrix):.4f}")
    print(f"  Mean: {np.mean(condensed_matrix):.4f}")
    print(f"  Median: {np.median(condensed_matrix):.4f}")
    print(f"  Standard deviation: {np.std(condensed_matrix):.4f}")

summarize(results) staticmethod

Summarize pairwise similarity scores by metric type.

Parameters:

Name Type Description Default
results DataFrame

DataFrame containing pairwise association results.

required

Returns:

Type Description

pd.DataFrame: Summary statistics (count, mean, std, min, max, etc.) grouped by metric type.

Raises:

Type Description
Exception

If summary calculation fails.

Source code in src/synomicsbench/metrics/fidelity/PairwiseSimilarity.py
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
@staticmethod
def summarize(results: pd.DataFrame):
    """
    Summarize pairwise similarity scores by metric type.

    Args:
        results (pd.DataFrame): DataFrame containing pairwise association results.

    Returns:
        pd.DataFrame: Summary statistics (count, mean, std, min, max, etc.) grouped by metric type.

    Raises:
        Exception: If summary calculation fails.
    """
    try:
        print("\nSummary by Metric Type:")
        summary = results.groupby('Metrics')['Score'].describe()
        return summary
    except Exception as e:
        raise Exception(f"Error in summarize: {e}")

MissingValueSimilarity

MissingValue_Similarity(origin_data, synthetic_data, missing_indicators)

Compute missing value similarity scores for each target column.

Parameters:

Name Type Description Default
origin_data DataFrame

Original data.

required
synthetic_data DataFrame

Synthetic data.

required
missing_indicators list

List of column names with missing indicators.

required

Returns:

Name Type Description
dict

Mapping from column name to similarity score.

Raises:

Type Description
Exception

If computation fails for any column.

Source code in src/synomicsbench/metrics/fidelity/MissingValueSimilarity.py
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def MissingValue_Similarity(
    origin_data: pd.DataFrame, synthetic_data: pd.DataFrame, missing_indicators: list
):
    """
    Compute missing value similarity scores for each target column.

    Args:
        origin_data (pd.DataFrame): Original data.
        synthetic_data (pd.DataFrame): Synthetic data.
        missing_indicators (list): List of column names with missing indicators.

    Returns:
        dict: Mapping from column name to similarity score.

    Raises:
        Exception: If computation fails for any column.
    """
    MissingValue_dict = {}
    target_cols = [col.replace("missingindicator_", "") for col in missing_indicators]
    for target_col in target_cols:
        try:
            simi_score = MissingValueSimilarity.compute(
                real_data=origin_data[target_col], synthetic_data=synthetic_data[target_col]
            )
            MissingValue_dict[target_col] = simi_score
        except Exception:
            # Continue but note error in dict
            MissingValue_dict[target_col] = None
    return MissingValue_dict

BayesianComparison

BaycompStyle dataclass

Store manuscript-style visualization settings for baycomp heatmaps.

Parameters:

Name Type Description Default
fontsize int

Base font size used in matplotlib rcParams.

11
nature_font Dict[str, Sequence[str]]

Font family configuration.

(lambda: dict(_FONT))()
plot_colors Dict[str, str]

Common background/grid colors.

(lambda: dict(STABILITY_PLOT_COLORS))()
pbetter_cmap LinearSegmentedColormap

Colormap for P(Better) heatmaps.

(lambda: PBETTER_FOCUS_CMAP)()
cancer_colors Dict[str, str]

Color strip mapping for each cancer panel.

(lambda: dict(CANCER_COLORS))()

Raises:

Type Description
ValueError

If fontsize is not positive.

Source code in src/synomicsbench/metrics/fidelity/BayesianComparison.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass(frozen=True)
class BaycompStyle:
    """
    Store manuscript-style visualization settings for baycomp heatmaps.

    Args:
        fontsize (int): Base font size used in matplotlib rcParams.
        nature_font (Dict[str, Sequence[str]]): Font family configuration.
        plot_colors (Dict[str, str]): Common background/grid colors.
        pbetter_cmap (LinearSegmentedColormap): Colormap for P(Better) heatmaps.
        cancer_colors (Dict[str, str]): Color strip mapping for each cancer panel.

    Raises:
        ValueError: If fontsize is not positive.
    """
    fontsize: int = 11
    nature_font: Dict[str, Sequence[str]] = field(default_factory=lambda: dict(_FONT))
    plot_colors: Dict[str, str] = field(default_factory=lambda: dict(STABILITY_PLOT_COLORS))
    pbetter_cmap: LinearSegmentedColormap = field(default_factory=lambda: PBETTER_FOCUS_CMAP)
    cancer_colors: Dict[str, str] = field(default_factory=lambda: dict(CANCER_COLORS))

    def __post_init__(self) -> None:
        if self.fontsize <= 0:
            raise ValueError("fontsize must be a positive integer.")

BayesianComparison dataclass

Compute Bayesian pairwise comparisons using the external baycomp package.

Parameters:

Name Type Description Default
rope float

ROPE threshold for practical equivalence.

0.01
seed int

Random seed for reproducibility in baycomp.

0
style BaycompStyle

Plot styling settings for consistent manuscript figures.

BaycompStyle()

Raises:

Type Description
ValueError

If rope is not positive.

Source code in src/synomicsbench/metrics/fidelity/BayesianComparison.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
@dataclass
class BayesianComparison:
    """
    Compute Bayesian pairwise comparisons using the external baycomp package.

    Args:
        rope (float): ROPE threshold for practical equivalence.
        seed (int): Random seed for reproducibility in baycomp.
        style (BaycompStyle): Plot styling settings for consistent manuscript figures.

    Raises:
        ValueError: If rope is not positive.
    """
    rope: float = 0.01
    seed: int = 0
    style: BaycompStyle = field(default_factory=BaycompStyle)

    def __post_init__(self) -> None:
        if self.rope <= 0:
            raise ValueError("rope must be > 0.")

    @staticmethod
    def _validate_scores(scores: Sequence[float], method: str) -> np.ndarray:
        """
        Validate and coerce a score sequence to a finite numpy array.
        """
        x = np.asarray(scores, dtype=float)
        x = x[np.isfinite(x)]
        if x.size < 2:
            raise ValueError(f"Method '{method}' must have at least 2 finite scores for baycomp.")
        return x

    @staticmethod
    def _resolve_methods_order(
        method_to_scores: Mapping[str, Sequence[float]],
        methods_order: Optional[Sequence[str]],
    ) -> List[str]:
        """
        Resolve and validate the method ordering.
        """
        if not method_to_scores or len(method_to_scores) < 2:
            raise ValueError("method_to_scores must contain at least 2 methods.")

        if methods_order is None:
            return list(method_to_scores.keys())

        resolved = list(methods_order)
        missing = [m for m in resolved if m not in method_to_scores]
        if missing:
            raise ValueError(f"methods_order contains methods missing in method_to_scores: {missing}")
        return resolved

    def compare_methods(
        self,
        method_to_scores: Mapping[str, Sequence[float]],
        methods_order: Optional[Sequence[str]] = None,
    ) -> pd.DataFrame:
        """
        Compute ordered-pair Bayesian comparison probabilities for a set of methods.

        Args:
            method_to_scores (Mapping[str, Sequence[float]]): Mapping method -> scores across seeds/folds.
            methods_order (Optional[Sequence[str]]): Optional ordering for methods.

        Returns:
            pd.DataFrame: Table with columns:
                ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

        Raises:
            ValueError: If fewer than 2 methods are provided.
            ValueError: If any method has fewer than 2 finite scores.
        """
        methods = self._resolve_methods_order(method_to_scores, methods_order)
        scores_np = {m: self._validate_scores(method_to_scores[m], method=m) for m in methods}

        rows: List[dict] = []
        for i, m1 in enumerate(methods):
            for j, m2 in enumerate(methods):
                if i == j:
                    continue

                probs = baycomp.two_on_single(scores_np[m1], scores_np[m2], rope=self.rope)
                p_better = float(probs[0])
                p_equiv = float(probs[1])
                p_worse = float(probs[2])

                rows.append(
                    {
                        "Method 1": m1,
                        "Method 2": m2,
                        "Better Prob": p_better,
                        "Worse Prob": p_worse,
                        "Equivalent Prob": p_equiv,
                    }
                )

        return pd.DataFrame(rows)

    @staticmethod
    def comparison_to_matrix(
        comparison_df: pd.DataFrame,
        methods_order: Sequence[str],
        value_col: str = "Better Prob",
        nan_diagonal: bool = True,
    ) -> pd.DataFrame:
        """
        Convert a comparison table into a square matrix suitable for heatmap plotting.

        Args:
            comparison_df (pd.DataFrame): Output of compare_methods.
            methods_order (Sequence[str]): Method ordering for rows and columns.
            value_col (str): Which column to visualize.
            nan_diagonal (bool): If True, set diagonal values to NaN.

        Returns:
            pd.DataFrame: Square matrix with index=Method 1 and columns=Method 2.

        Raises:
            ValueError: If value_col is not present in comparison_df.
        """
        if value_col not in comparison_df.columns:
            raise ValueError(f"value_col must be a column in comparison_df, got '{value_col}'.")

        mat = comparison_df.pivot(index="Method 1", columns="Method 2", values=value_col)
        mat = mat.reindex(index=list(methods_order), columns=list(methods_order))

        if nan_diagonal:
            for m in methods_order:
                if m in mat.index and m in mat.columns:
                    mat.loc[m, m] = np.nan

        return mat

    def plot_pbetter_heatmap_grid(
        self,
        cancer_to_method_scores: Dict[str, Dict[str, Sequence[float]]],
        cancers_order: Sequence[str] = ("ccRCC", "Melanoma", "NSCLC"),
        methods_order: Optional[Sequence[str]] = None,
        value_col: str = "Better Prob",
        figsize: Tuple[int, int] = (18, 5.5),
        annot: bool = True,
        fmt: str = ".2f",
        show: bool = True,
    ) -> Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
        """
        Plot a 1xN grid of Bayesian comparison probability heatmaps (one per cancer cohort).

        Args:
            cancer_to_method_scores (Dict[str, Dict[str, Sequence[float]]]): cancer -> {method -> scores across seeds/folds}.
            cancers_order (Sequence[str]): Order of cohorts in the grid.
            methods_order (Optional[Sequence[str]]): Global method ordering. If None, uses union over cancers.
            value_col (str): One of ["Better Prob", "Worse Prob", "Equivalent Prob"] to visualize.
            figsize (Tuple[int, int]): Figure size in inches.
            annot (bool): If True, annotate each cell with numeric values.
            fmt (str): Annotation format passed to seaborn.
            show (bool): If True, calls plt.show().

        Returns:
            Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
                fig, axes, matrices_by_cancer, comparisons_by_cancer

        Raises:
            ValueError: If value_col is invalid.
            ValueError: If cancers_order contains a cohort not present in cancer_to_method_scores.
        """
        allowed = {"Better Prob", "Worse Prob", "Equivalent Prob"}
        if value_col not in allowed:
            raise ValueError(f"value_col must be one of {sorted(allowed)}.")

        cancers = list(cancers_order)
        for c in cancers:
            if c not in cancer_to_method_scores:
                raise ValueError(f"Cancer '{c}' missing from cancer_to_method_scores.")

        if methods_order is None:
            union_methods: List[str] = []
            for c in cancers:
                for m in cancer_to_method_scores[c].keys():
                    if m not in union_methods:
                        union_methods.append(m)
            methods = union_methods
        else:
            methods = list(methods_order)

        _set_nature_rcparams(fontsize=self.style.fontsize)
        sns.set(style="white", rc={"axes.facecolor": self.style.plot_colors["heatmap_bg"]})

        fig, axes = plt.subplots(1, len(cancers), figsize=figsize)
        if len(cancers) == 1:
            axes = np.array([axes])

        norm = TwoSlopeNorm(vmin=0.0, vcenter=0.5, vmax=1.0)

        mats_by_cancer: Dict[str, pd.DataFrame] = {}
        comps_by_cancer: Dict[str, pd.DataFrame] = {}

        for ax, cancer in zip(axes, cancers):
            comp_df = self.compare_methods(
                method_to_scores=cancer_to_method_scores[cancer],
                methods_order=methods,
            )
            comps_by_cancer[cancer] = comp_df

            mat = self.comparison_to_matrix(
                comparison_df=comp_df,
                methods_order=methods,
                value_col=value_col,
                nan_diagonal=True,
            )
            mats_by_cancer[cancer] = mat

            if value_col == "Better Prob":
                cmap = self.style.pbetter_cmap.copy()
                used_norm = norm
            else:
                cmap = plt.cm.viridis.copy()
                used_norm = None
            cmap.set_bad(color="#D3D3D3")

            sns.heatmap(
                mat,
                ax=ax,
                cmap=cmap,
                norm=used_norm,
                annot=annot,
                fmt=fmt,
                linewidths=0.5,
                linecolor="lightgray",
                cbar=(ax is axes[-1]),
                cbar_kws={"label": value_col} if (ax is axes[-1]) else None,
                square=True,
            )

            ax.set_title(cancer, fontsize=self.style.fontsize + 1, fontweight="bold", pad=10)

            cancer_color = self.style.cancer_colors.get(cancer, "#333333")
            ax.plot(
                [0.02, 0.98],
                [1.02, 1.02],
                transform=ax.transAxes,
                color=cancer_color,
                lw=4,
                clip_on=False,
            )

            ax.tick_params(axis="x", rotation=45)
            ax.tick_params(axis="y", rotation=0)

        plt.tight_layout()
        if show:
            plt.show()

        return fig, axes, mats_by_cancer, comps_by_cancer

compare_methods(method_to_scores, methods_order=None)

Compute ordered-pair Bayesian comparison probabilities for a set of methods.

Parameters:

Name Type Description Default
method_to_scores Mapping[str, Sequence[float]]

Mapping method -> scores across seeds/folds.

required
methods_order Optional[Sequence[str]]

Optional ordering for methods.

None

Returns:

Type Description
DataFrame

pd.DataFrame: Table with columns: ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

Raises:

Type Description
ValueError

If fewer than 2 methods are provided.

ValueError

If any method has fewer than 2 finite scores.

Source code in src/synomicsbench/metrics/fidelity/BayesianComparison.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def compare_methods(
    self,
    method_to_scores: Mapping[str, Sequence[float]],
    methods_order: Optional[Sequence[str]] = None,
) -> pd.DataFrame:
    """
    Compute ordered-pair Bayesian comparison probabilities for a set of methods.

    Args:
        method_to_scores (Mapping[str, Sequence[float]]): Mapping method -> scores across seeds/folds.
        methods_order (Optional[Sequence[str]]): Optional ordering for methods.

    Returns:
        pd.DataFrame: Table with columns:
            ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

    Raises:
        ValueError: If fewer than 2 methods are provided.
        ValueError: If any method has fewer than 2 finite scores.
    """
    methods = self._resolve_methods_order(method_to_scores, methods_order)
    scores_np = {m: self._validate_scores(method_to_scores[m], method=m) for m in methods}

    rows: List[dict] = []
    for i, m1 in enumerate(methods):
        for j, m2 in enumerate(methods):
            if i == j:
                continue

            probs = baycomp.two_on_single(scores_np[m1], scores_np[m2], rope=self.rope)
            p_better = float(probs[0])
            p_equiv = float(probs[1])
            p_worse = float(probs[2])

            rows.append(
                {
                    "Method 1": m1,
                    "Method 2": m2,
                    "Better Prob": p_better,
                    "Worse Prob": p_worse,
                    "Equivalent Prob": p_equiv,
                }
            )

    return pd.DataFrame(rows)

comparison_to_matrix(comparison_df, methods_order, value_col='Better Prob', nan_diagonal=True) staticmethod

Convert a comparison table into a square matrix suitable for heatmap plotting.

Parameters:

Name Type Description Default
comparison_df DataFrame

Output of compare_methods.

required
methods_order Sequence[str]

Method ordering for rows and columns.

required
value_col str

Which column to visualize.

'Better Prob'
nan_diagonal bool

If True, set diagonal values to NaN.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Square matrix with index=Method 1 and columns=Method 2.

Raises:

Type Description
ValueError

If value_col is not present in comparison_df.

Source code in src/synomicsbench/metrics/fidelity/BayesianComparison.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
@staticmethod
def comparison_to_matrix(
    comparison_df: pd.DataFrame,
    methods_order: Sequence[str],
    value_col: str = "Better Prob",
    nan_diagonal: bool = True,
) -> pd.DataFrame:
    """
    Convert a comparison table into a square matrix suitable for heatmap plotting.

    Args:
        comparison_df (pd.DataFrame): Output of compare_methods.
        methods_order (Sequence[str]): Method ordering for rows and columns.
        value_col (str): Which column to visualize.
        nan_diagonal (bool): If True, set diagonal values to NaN.

    Returns:
        pd.DataFrame: Square matrix with index=Method 1 and columns=Method 2.

    Raises:
        ValueError: If value_col is not present in comparison_df.
    """
    if value_col not in comparison_df.columns:
        raise ValueError(f"value_col must be a column in comparison_df, got '{value_col}'.")

    mat = comparison_df.pivot(index="Method 1", columns="Method 2", values=value_col)
    mat = mat.reindex(index=list(methods_order), columns=list(methods_order))

    if nan_diagonal:
        for m in methods_order:
            if m in mat.index and m in mat.columns:
                mat.loc[m, m] = np.nan

    return mat

plot_pbetter_heatmap_grid(cancer_to_method_scores, cancers_order=('ccRCC', 'Melanoma', 'NSCLC'), methods_order=None, value_col='Better Prob', figsize=(18, 5.5), annot=True, fmt='.2f', show=True)

Plot a 1xN grid of Bayesian comparison probability heatmaps (one per cancer cohort).

Parameters:

Name Type Description Default
cancer_to_method_scores Dict[str, Dict[str, Sequence[float]]]

cancer -> {method -> scores across seeds/folds}.

required
cancers_order Sequence[str]

Order of cohorts in the grid.

('ccRCC', 'Melanoma', 'NSCLC')
methods_order Optional[Sequence[str]]

Global method ordering. If None, uses union over cancers.

None
value_col str

One of ["Better Prob", "Worse Prob", "Equivalent Prob"] to visualize.

'Better Prob'
figsize Tuple[int, int]

Figure size in inches.

(18, 5.5)
annot bool

If True, annotate each cell with numeric values.

True
fmt str

Annotation format passed to seaborn.

'.2f'
show bool

If True, calls plt.show().

True

Returns:

Type Description
Tuple[Figure, ndarray, Dict[str, DataFrame], Dict[str, DataFrame]]

Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: fig, axes, matrices_by_cancer, comparisons_by_cancer

Raises:

Type Description
ValueError

If value_col is invalid.

ValueError

If cancers_order contains a cohort not present in cancer_to_method_scores.

Source code in src/synomicsbench/metrics/fidelity/BayesianComparison.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def plot_pbetter_heatmap_grid(
    self,
    cancer_to_method_scores: Dict[str, Dict[str, Sequence[float]]],
    cancers_order: Sequence[str] = ("ccRCC", "Melanoma", "NSCLC"),
    methods_order: Optional[Sequence[str]] = None,
    value_col: str = "Better Prob",
    figsize: Tuple[int, int] = (18, 5.5),
    annot: bool = True,
    fmt: str = ".2f",
    show: bool = True,
) -> Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    """
    Plot a 1xN grid of Bayesian comparison probability heatmaps (one per cancer cohort).

    Args:
        cancer_to_method_scores (Dict[str, Dict[str, Sequence[float]]]): cancer -> {method -> scores across seeds/folds}.
        cancers_order (Sequence[str]): Order of cohorts in the grid.
        methods_order (Optional[Sequence[str]]): Global method ordering. If None, uses union over cancers.
        value_col (str): One of ["Better Prob", "Worse Prob", "Equivalent Prob"] to visualize.
        figsize (Tuple[int, int]): Figure size in inches.
        annot (bool): If True, annotate each cell with numeric values.
        fmt (str): Annotation format passed to seaborn.
        show (bool): If True, calls plt.show().

    Returns:
        Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
            fig, axes, matrices_by_cancer, comparisons_by_cancer

    Raises:
        ValueError: If value_col is invalid.
        ValueError: If cancers_order contains a cohort not present in cancer_to_method_scores.
    """
    allowed = {"Better Prob", "Worse Prob", "Equivalent Prob"}
    if value_col not in allowed:
        raise ValueError(f"value_col must be one of {sorted(allowed)}.")

    cancers = list(cancers_order)
    for c in cancers:
        if c not in cancer_to_method_scores:
            raise ValueError(f"Cancer '{c}' missing from cancer_to_method_scores.")

    if methods_order is None:
        union_methods: List[str] = []
        for c in cancers:
            for m in cancer_to_method_scores[c].keys():
                if m not in union_methods:
                    union_methods.append(m)
        methods = union_methods
    else:
        methods = list(methods_order)

    _set_nature_rcparams(fontsize=self.style.fontsize)
    sns.set(style="white", rc={"axes.facecolor": self.style.plot_colors["heatmap_bg"]})

    fig, axes = plt.subplots(1, len(cancers), figsize=figsize)
    if len(cancers) == 1:
        axes = np.array([axes])

    norm = TwoSlopeNorm(vmin=0.0, vcenter=0.5, vmax=1.0)

    mats_by_cancer: Dict[str, pd.DataFrame] = {}
    comps_by_cancer: Dict[str, pd.DataFrame] = {}

    for ax, cancer in zip(axes, cancers):
        comp_df = self.compare_methods(
            method_to_scores=cancer_to_method_scores[cancer],
            methods_order=methods,
        )
        comps_by_cancer[cancer] = comp_df

        mat = self.comparison_to_matrix(
            comparison_df=comp_df,
            methods_order=methods,
            value_col=value_col,
            nan_diagonal=True,
        )
        mats_by_cancer[cancer] = mat

        if value_col == "Better Prob":
            cmap = self.style.pbetter_cmap.copy()
            used_norm = norm
        else:
            cmap = plt.cm.viridis.copy()
            used_norm = None
        cmap.set_bad(color="#D3D3D3")

        sns.heatmap(
            mat,
            ax=ax,
            cmap=cmap,
            norm=used_norm,
            annot=annot,
            fmt=fmt,
            linewidths=0.5,
            linecolor="lightgray",
            cbar=(ax is axes[-1]),
            cbar_kws={"label": value_col} if (ax is axes[-1]) else None,
            square=True,
        )

        ax.set_title(cancer, fontsize=self.style.fontsize + 1, fontweight="bold", pad=10)

        cancer_color = self.style.cancer_colors.get(cancer, "#333333")
        ax.plot(
            [0.02, 0.98],
            [1.02, 1.02],
            transform=ax.transAxes,
            color=cancer_color,
            lw=4,
            clip_on=False,
        )

        ax.tick_params(axis="x", rotation=45)
        ax.tick_params(axis="y", rotation=0)

    plt.tight_layout()
    if show:
        plt.show()

    return fig, axes, mats_by_cancer, comps_by_cancer

visualization

plot_violin_grid_by_cancer(cancer_to_method_scores, value_name='Score', methods_order=None, palette=DATASET_COLORS, figsize=(18, 5), fontsize=11, annotate_mean=True, mean_fmt='{:.3f}', mean_marker='D', mean_marker_size=70.0, mean_text_offset_frac=0.03, sharey=True, show=True)

Plot a 1xN grid of violin plots, one per cancer type, using a shared method color palette. The mean of each method is shown as a white diamond and optionally annotated as text.

Parameters:

Name Type Description Default
cancer_to_method_scores Dict[str, Dict[str, List[float]]]

Mapping cancer -> {method -> list of scores}.

required
value_name str

Y-axis label for the score metric.

'Score'
methods_order Optional[Sequence[str]]

Global ordering of methods across all panels. If None, uses the union of methods in insertion order.

None
palette Optional[Dict[str, str]]

Mapping method -> color. If None, uses DATASET_COLORS.

DATASET_COLORS
figsize Tuple[int, int]

Figure size.

(18, 5)
fontsize int

Base font size.

11
annotate_mean bool

If True, write the mean value above each violin.

True
mean_fmt str

Format string for the mean annotation (e.g., "{:.3f}").

'{:.3f}'
mean_marker str

Marker for mean point.

'D'
mean_marker_size float

Marker size for mean point.

70.0
mean_text_offset_frac float

Vertical offset for mean text as a fraction of y-span.

0.03
sharey bool

If True, share y-axis across panels.

True
show bool

If True, calls plt.show().

True

Returns:

Type Description
Tuple[Figure, ndarray, Dict[str, Series]]

Tuple[plt.Figure, np.ndarray, Dict[str, pd.Series]]: (figure, axes, mean_by_cancer).

Raises:

Type Description
ValueError

If cancer_to_method_scores is empty.

ValueError

If any cancer has no methods or no numeric scores after cleaning.

ValueError

If methods_order contains methods missing from all cancers.

Source code in src/synomicsbench/metrics/fidelity/visualization.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
def plot_violin_grid_by_cancer(
    cancer_to_method_scores: Dict[str, Dict[str, List[float]]],
    value_name: str = "Score",
    methods_order: Optional[Sequence[str]] = None,
    palette: Optional[Dict[str, str]] = DATASET_COLORS,
    figsize: Tuple[int, int] = (18, 5),
    fontsize: int = 11,
    annotate_mean: bool = True,
    mean_fmt: str = "{:.3f}",
    mean_marker: str = "D",
    mean_marker_size: float = 70.0,
    mean_text_offset_frac: float = 0.03,
    sharey: bool = True,
    show: bool = True,
) -> Tuple[plt.Figure, np.ndarray, Dict[str, pd.Series]]:
    """
    Plot a 1xN grid of violin plots, one per cancer type, using a shared method color palette.
    The mean of each method is shown as a white diamond and optionally annotated as text.

    Args:
        cancer_to_method_scores (Dict[str, Dict[str, List[float]]]): Mapping cancer -> {method -> list of scores}.
        value_name (str): Y-axis label for the score metric.
        methods_order (Optional[Sequence[str]]): Global ordering of methods across all panels.
            If None, uses the union of methods in insertion order.
        palette (Optional[Dict[str, str]]): Mapping method -> color. If None, uses DATASET_COLORS.
        figsize (Tuple[int, int]): Figure size.
        fontsize (int): Base font size.
        annotate_mean (bool): If True, write the mean value above each violin.
        mean_fmt (str): Format string for the mean annotation (e.g., "{:.3f}").
        mean_marker (str): Marker for mean point.
        mean_marker_size (float): Marker size for mean point.
        mean_text_offset_frac (float): Vertical offset for mean text as a fraction of y-span.
        sharey (bool): If True, share y-axis across panels.
        show (bool): If True, calls plt.show().

    Returns:
        Tuple[plt.Figure, np.ndarray, Dict[str, pd.Series]]: (figure, axes, mean_by_cancer).

    Raises:
        ValueError: If cancer_to_method_scores is empty.
        ValueError: If any cancer has no methods or no numeric scores after cleaning.
        ValueError: If methods_order contains methods missing from all cancers.
    """
    if not cancer_to_method_scores:
        raise ValueError("cancer_to_method_scores must be a non-empty dictionary.")

    cancers = list(cancer_to_method_scores.keys())

    # Determine global method order
    if methods_order is None:
        union_methods: List[str] = []
        for cancer in cancers:
            for m in list(cancer_to_method_scores.get(cancer, {}).keys()):
                if m not in union_methods:
                    union_methods.append(m)
        methods_order = union_methods
    else:
        methods_order = list(methods_order)

    if len(methods_order) == 0:
        raise ValueError("methods_order resolved to an empty list.")

    # Palette: default to DATASET_COLORS (requested) and ensure all needed methods exist
    if palette is None:
        palette = dict(DATASET_COLORS)

    pal_map = {m: palette.get(m, "#333333") for m in methods_order}

    _set_nature_rcparams(fontsize=fontsize)
    sns.set_style("whitegrid", rc={"grid.color": STABILITY_PLOT_COLORS["grid"]})

    n = len(cancers)
    fig, axes = plt.subplots(1, n, figsize=figsize, sharey=sharey)
    if n == 1:
        axes = np.array([axes])

    mean_by_cancer: Dict[str, pd.Series] = {}

    for ax, cancer in zip(axes, cancers):
        method_scores = cancer_to_method_scores.get(cancer, {})
        if method_scores is None or len(method_scores) == 0:
            raise ValueError(f"Cancer '{cancer}' has no method scores.")

        # Build long df
        rows = []
        for m in methods_order:
            vals = method_scores.get(m, None)
            if vals is None:
                continue
            arr = pd.to_numeric(pd.Series(vals), errors="coerce").dropna().to_numpy(dtype=float)
            for v in arr:
                rows.append({"Method": m, value_name: float(v)})

        df_long = pd.DataFrame(rows)
        if df_long.empty:
            raise ValueError(f"Cancer '{cancer}' has no valid numeric scores after cleaning.")

        df_long["Method"] = pd.Categorical(df_long["Method"], categories=list(methods_order), ordered=True)

        sns.violinplot(
            data=df_long,
            x="Method",
            y=value_name,
            order=list(methods_order),
            palette=pal_map,
            cut=0,
            inner="quartile",
            linewidth=1.2,
            ax=ax,
        )

        # Means
        mean_by_method = df_long.groupby("Method")[value_name].mean().reindex(list(methods_order))
        mean_by_cancer[cancer] = mean_by_method

        y_min, y_max = ax.get_ylim()
        y_span = (y_max - y_min) if (y_max > y_min) else 1.0

        for i, m in enumerate(list(methods_order)):
            mu = mean_by_method.loc[m]
            if pd.isna(mu):
                continue

            ax.scatter(
                i,
                float(mu),
                marker=mean_marker,
                s=mean_marker_size,
                facecolor=STABILITY_PLOT_COLORS["mean_face"],
                edgecolor=STABILITY_PLOT_COLORS["mean_edge"],
                linewidth=1.3,
                zorder=10,
            )

            if annotate_mean:
                ax.text(
                    i,
                    float(mu) + float(mean_text_offset_frac) * y_span,
                    mean_fmt.format(float(mu)),
                    ha="center",
                    va="bottom",
                    fontsize=fontsize - 2,
                    color="black",
                )

        ax.set_xlabel("")
        if ax is axes[0]:
            ax.set_ylabel(value_name, fontsize=fontsize)
        else:
            ax.set_ylabel("")

        ax.tick_params(axis="x", rotation=45)
        # ax.tick_params(axis="x", rotation=0)

        for lbl in ax.get_xticklabels():
            lbl.set_fontweight("bold")   # hoặc 700


        cancer_color = CANCER_COLORS.get(cancer, "#333333")
        ax.set_title(cancer, fontsize=fontsize + 1, fontweight="bold", color="black", pad=10)
        ax.plot([0.02, 0.98], [1.02, 1.02], transform=ax.transAxes, color=cancer_color, lw=4, clip_on=False)

    plt.tight_layout()
    if show:
        plt.show()

    return fig, axes, mean_by_cancer

Metrics: Biological Utility

The metrics.narrow_utility module evaluates the capability of synthetic data in downstream bioinformatics tasks.

DGE

GCSAnalyzer

Refactor of the original GCS computation code into a class, preserving identical logic.

Parameters:

Name Type Description Default
term_col str

Column name for pathway / gene-set names.

'Gene'
nes_col str

Column name for effect size (e.g., Log2FC).

'Log2FC'
q_col str

Column name for Q-value.

'Q_value'
seed int

Seed used for np.random.seed to generate jitter.

42
q_thr float

Q-value threshold used to define significance boundary.

0.05
w float

Weight for non-significant concordance in GCS.

0.5

Raises:

Type Description
ValueError

If q_thr is not in (0, 1].

ValueError

If w is negative.

Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
class GCSAnalyzer:
    """
    Refactor of the original GCS computation code into a class, preserving identical logic.

    Args:
        term_col (str): Column name for pathway / gene-set names.
        nes_col (str): Column name for effect size (e.g., Log2FC).
        q_col (str): Column name for Q-value.
        seed (int): Seed used for np.random.seed to generate jitter.
        q_thr (float): Q-value threshold used to define significance boundary.
        w (float): Weight for non-significant concordance in GCS.

    Raises:
        ValueError: If q_thr is not in (0, 1].
        ValueError: If w is negative.
    """

    def __init__(
        self,
        term_col: str = "Gene",
        nes_col: str = "Log2FC",
        q_col: str = "Q_value",
        seed: int = 42,
        q_thr: float = 0.05,
        w: float = 0.5,
    ) -> None:
        if not (0.0 < q_thr <= 1.0):
            raise ValueError("q_thr must be in (0, 1].")
        if w < 0:
            raise ValueError("w must be >= 0.")

        self.term_col = term_col
        self.nes_col = nes_col
        self.q_col = q_col
        self.seed = seed
        self.q_thr = q_thr
        self.w = w

    @staticmethod
    def _load_df(data: Union[str, "os.PathLike[str]", pd.DataFrame]) -> pd.DataFrame:
        """Helper to load a CSV if a path is provided, otherwise return the DataFrame."""
        if isinstance(data, pd.DataFrame):
            return data
        return pd.read_csv(data)

    def compute_rank_score(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Compute rank score = sign(Log2FC) * -log10(Q-value) with tiny jitter.

        Args:
            df (pd.DataFrame): DGE results table.

        Returns:
            pd.DataFrame: Ranked table indexed by term_col with columns rank_score and qval.
        """
        np.random.seed(self.seed)

        q = pd.to_numeric(df[self.q_col], errors="coerce").clip(lower=0.001)
        nes = pd.to_numeric(df[self.nes_col], errors="coerce")

        rank_score = np.sign(nes) * (-np.log10(q))

        jitter = np.random.uniform(-1e-10, 1e-10, size=len(rank_score))
        rank_score = rank_score + jitter

        rnk = (
            df[[self.term_col]]
            .assign(rank_score=rank_score, qval=q)
            .dropna()
            .groupby(self.term_col)
            .mean()
            .sort_values("rank_score", ascending=False)
        )

        return rnk

    @staticmethod
    def align_rank_scores(rnk_ori: pd.DataFrame, rnk_syn: pd.DataFrame) -> pd.DataFrame:
        """
        Inner join on pathway / gene names.

        Args:
            rnk_ori (pd.DataFrame): Original rank table indexed by term.
            rnk_syn (pd.DataFrame): Synthetic rank table indexed by term.

        Returns:
            pd.DataFrame: Aligned rank table.
        """
        aligned = (
            rnk_ori.rename(columns={"rank_score": "rank_ori", "qval": "q_ori"}).join(
                rnk_syn.rename(columns={"rank_score": "rank_syn"}),
                how="inner",
            )
        )
        return aligned

    def gene_set_concordance_score(
        self,
        df_rank: pd.DataFrame,
        ori_rank_size: int,
    ) -> Tuple[float, int, int, float]:
        """
        Compute GCS using original pathway count to avoid bias due to pathway loss after alignment.

        Args:
            df_rank (pd.DataFrame): Aligned rank table with rank_ori, rank_syn, q_ori.
            ori_rank_size (int): Number of pathways in original ranking.

        Returns:
            tuple[float, int, int, float]: (GCS, N_sign, N_non_sign, M).
        """
        x = df_rank["rank_ori"].values
        y = df_rank["rank_syn"].values

        s_thr = -np.log10(self.q_thr)

        zone1 = (x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)
        zone2 = (x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)

        zone3 = (x <= -s_thr) & (y <= -s_thr)
        zone4 = (x >= s_thr) & (y >= s_thr)

        n_sign = int(zone3.sum() + zone4.sum())
        n_non_sign = int(zone1.sum() + zone2.sum())

        num_sign = int((df_rank["q_ori"] < self.q_thr).sum())
        num_non_sign = int(ori_rank_size - num_sign)

        m = float(num_sign + self.w * num_non_sign)
        gcs = float((n_sign + self.w * n_non_sign) / m) if m > 0 else np.nan

        return gcs, n_sign, n_non_sign, m

    def process_single_dge_result(
        self,
        dge_ori: Union[str, "os.PathLike[str]", pd.DataFrame],
        dge_syn: Union[str, "os.PathLike[str]", pd.DataFrame],
    ) -> Tuple[np.ndarray, np.ndarray, float, int, int, int, int, float, int, int, int]:
        """
        Process original and synthetic DGE results (paths or DataFrames).

        Args:
            dge_ori (str | PathLike | pd.DataFrame): Original DGE table/path.
            dge_syn (str | PathLike | pd.DataFrame): Synthetic DGE table/path.

        Returns:
            tuple: (x, y, GCS, n_zone1, n_zone2, n_zone3, n_zone4, M, ori_rank_size, aligned_size, seed_used).
        """
        df_ori = self._load_df(dge_ori)
        df_syn = self._load_df(dge_syn)

        rnk_ori = self.compute_rank_score(df_ori)
        rnk_syn = self.compute_rank_score(df_syn)

        ori_rank_size = int(len(rnk_ori))

        aligned = self.align_rank_scores(rnk_ori, rnk_syn)

        x = aligned["rank_ori"].values
        y = aligned["rank_syn"].values

        s_thr = -np.log10(self.q_thr)

        n_zone1 = int(((x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)).sum())
        n_zone2 = int(((x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)).sum())
        n_zone3 = int(((x <= -s_thr) & (y <= -s_thr)).sum())
        n_zone4 = int(((x >= s_thr) & (y >= s_thr)).sum())

        gcs, _, _, m = self.gene_set_concordance_score(
            aligned,
            ori_rank_size=ori_rank_size,
        )

        return (
            x,
            y,
            gcs,
            n_zone1,
            n_zone2,
            n_zone3,
            n_zone4,
            m,
            ori_rank_size,
            int(len(aligned)),
            int(self.seed),
        )

    def plot_single_gcs_panel(
        self,
        ax: plt.Axes,
        x: np.ndarray,
        y: np.ndarray,
        gcs: float,
        n_zone1: int,
        n_zone2: int,
        n_zone3: int,
        n_zone4: int,
        tool_name: str,
        m: float,
    ) -> None:
        """
        Plot single GCS comparison panel (identical logic and styling to original code).

        Args:
            ax (matplotlib.axes.Axes): Target axis.
            x (np.ndarray): Rank scores (original).
            y (np.ndarray): Rank scores (synthetic).
            gcs (float): GCS value.
            n_zone1 (int): Count in zone1.
            n_zone2 (int): Count in zone2.
            n_zone3 (int): Count in zone3.
            n_zone4 (int): Count in zone4.
            tool_name (str): Tool/dataset name for title.
            m (float): M normalization used in GCS.

        Returns:
            None
        """
        s_thr = -np.log10(self.q_thr)
        lo, hi = -2, 2

        zone1 = (x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)
        zone2 = (x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)
        zone3 = (x <= -s_thr) & (y <= -s_thr)
        zone4 = (x >= s_thr) & (y >= s_thr)
        other = ~(zone1 | zone2 | zone3 | zone4)

        ax.add_patch(
            plt.Rectangle((-s_thr, -s_thr), s_thr, s_thr, color=GCS_ZONE_BG_COLORS["zone1"], alpha=0.4)
        )
        ax.add_patch(
            plt.Rectangle((0, 0), s_thr, s_thr, color=GCS_ZONE_BG_COLORS["zone2"], alpha=0.4)
        )
        ax.add_patch(
            plt.Rectangle((lo, lo), -s_thr - lo, -s_thr - lo, color=GCS_ZONE_BG_COLORS["zone3"], alpha=0.2)
        )
        ax.add_patch(
            plt.Rectangle((s_thr, s_thr), hi - s_thr, hi - s_thr, color=GCS_ZONE_BG_COLORS["zone4"], alpha=0.2)
        )

        ax.scatter(x[other], y[other], c=GCS_ZONE_COLORS["other"], alpha=0.3, s=20)
        ax.scatter(x[zone1], y[zone1], c=GCS_ZONE_COLORS["zone1"], s=60, edgecolor="k")
        ax.scatter(x[zone2], y[zone2], c=GCS_ZONE_COLORS["zone2"], s=60, edgecolor="k")
        ax.scatter(x[zone3], y[zone3], c=GCS_ZONE_COLORS["zone3"], s=70, edgecolor="k")
        ax.scatter(x[zone4], y[zone4], c=GCS_ZONE_COLORS["zone4"], s=70, edgecolor="k")

        ax.plot([lo, hi], [lo, hi], "--", c="black", lw=1)
        for v in [-s_thr, 0, s_thr]:
            ax.axhline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)
            ax.axvline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)

        ax.set_xlim(lo, hi)
        ax.set_ylim(lo, hi)
        ax.set_xlabel("Rank score (Original)")
        ax.set_ylabel("Rank score (Synthetic)")
        ax.set_title(f"{tool_name}\nGCS={gcs:.3f} (w={self.w})")

        ax.text(
            -1.9,
            1.8,
            "Zones\n"
            f"NonSig LL: {n_zone1}\n"
            f"NonSig UR: {n_zone2}\n"
            f"Sig LL: {n_zone3}\n"
            f"Sig UR: {n_zone4}\n"
            f"M={m}",
            fontsize=10,
            va="top",
            bbox=dict(fc="white", alpha=0.7, ec="none"),
        )

        ax.grid(True)

    def plot_gcs_datasets(
        self,
        ori_data: Union[str, "os.PathLike[str]", pd.DataFrame],
        dataset_dict: Mapping[str, Union[str, "os.PathLike[str]", pd.DataFrame]],
        figsize: Tuple[int, int] = (18, 10),
    ) -> Tuple[plt.Figure, Dict[str, float]]:
        """
        Plot all datasets in a grid.

        Args:
            ori_data (str | PathLike | pd.DataFrame): Path to original CSV or DataFrame.
            dataset_dict (Mapping[str, str | PathLike | pd.DataFrame]): Tool name -> path or DataFrame.
            figsize (tuple[int, int]): Figure size.

        Returns:
            tuple[matplotlib.figure.Figure, dict[str, float]]: (figure, gcs_dict).
        """
        df_ori = self._load_df(ori_data)

        fig, axes = plt.subplots(2, 3, figsize=figsize)
        axes = axes.flatten()

        gcs_dict: Dict[str, float] = {}

        for idx, (tool, data) in enumerate(dataset_dict.items()):
            try:
                (
                    x,
                    y,
                    gcs,
                    n1,
                    n2,
                    n3,
                    n4,
                    m,
                    _ori_rank_size,
                    _aligned_size,
                    _seed_used,
                ) = self.process_single_dge_result(df_ori, data)

                self.plot_single_gcs_panel(
                    axes[idx],
                    x,
                    y,
                    gcs,
                    n1,
                    n2,
                    n3,
                    n4,
                    tool,
                    m,
                )

                gcs_dict[tool] = gcs

            except Exception as e:
                print(f"Skipping {tool}: {e}")

        for j in range(len(dataset_dict), len(axes)):
            fig.delaxes(axes[j])

        plt.tight_layout()
        plt.show()

        return fig, gcs_dict

align_rank_scores(rnk_ori, rnk_syn) staticmethod

Inner join on pathway / gene names.

Parameters:

Name Type Description Default
rnk_ori DataFrame

Original rank table indexed by term.

required
rnk_syn DataFrame

Synthetic rank table indexed by term.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Aligned rank table.

Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@staticmethod
def align_rank_scores(rnk_ori: pd.DataFrame, rnk_syn: pd.DataFrame) -> pd.DataFrame:
    """
    Inner join on pathway / gene names.

    Args:
        rnk_ori (pd.DataFrame): Original rank table indexed by term.
        rnk_syn (pd.DataFrame): Synthetic rank table indexed by term.

    Returns:
        pd.DataFrame: Aligned rank table.
    """
    aligned = (
        rnk_ori.rename(columns={"rank_score": "rank_ori", "qval": "q_ori"}).join(
            rnk_syn.rename(columns={"rank_score": "rank_syn"}),
            how="inner",
        )
    )
    return aligned

compute_rank_score(df)

Compute rank score = sign(Log2FC) * -log10(Q-value) with tiny jitter.

Parameters:

Name Type Description Default
df DataFrame

DGE results table.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Ranked table indexed by term_col with columns rank_score and qval.

Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def compute_rank_score(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute rank score = sign(Log2FC) * -log10(Q-value) with tiny jitter.

    Args:
        df (pd.DataFrame): DGE results table.

    Returns:
        pd.DataFrame: Ranked table indexed by term_col with columns rank_score and qval.
    """
    np.random.seed(self.seed)

    q = pd.to_numeric(df[self.q_col], errors="coerce").clip(lower=0.001)
    nes = pd.to_numeric(df[self.nes_col], errors="coerce")

    rank_score = np.sign(nes) * (-np.log10(q))

    jitter = np.random.uniform(-1e-10, 1e-10, size=len(rank_score))
    rank_score = rank_score + jitter

    rnk = (
        df[[self.term_col]]
        .assign(rank_score=rank_score, qval=q)
        .dropna()
        .groupby(self.term_col)
        .mean()
        .sort_values("rank_score", ascending=False)
    )

    return rnk

gene_set_concordance_score(df_rank, ori_rank_size)

Compute GCS using original pathway count to avoid bias due to pathway loss after alignment.

Parameters:

Name Type Description Default
df_rank DataFrame

Aligned rank table with rank_ori, rank_syn, q_ori.

required
ori_rank_size int

Number of pathways in original ranking.

required

Returns:

Type Description
Tuple[float, int, int, float]

tuple[float, int, int, float]: (GCS, N_sign, N_non_sign, M).

Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def gene_set_concordance_score(
    self,
    df_rank: pd.DataFrame,
    ori_rank_size: int,
) -> Tuple[float, int, int, float]:
    """
    Compute GCS using original pathway count to avoid bias due to pathway loss after alignment.

    Args:
        df_rank (pd.DataFrame): Aligned rank table with rank_ori, rank_syn, q_ori.
        ori_rank_size (int): Number of pathways in original ranking.

    Returns:
        tuple[float, int, int, float]: (GCS, N_sign, N_non_sign, M).
    """
    x = df_rank["rank_ori"].values
    y = df_rank["rank_syn"].values

    s_thr = -np.log10(self.q_thr)

    zone1 = (x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)
    zone2 = (x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)

    zone3 = (x <= -s_thr) & (y <= -s_thr)
    zone4 = (x >= s_thr) & (y >= s_thr)

    n_sign = int(zone3.sum() + zone4.sum())
    n_non_sign = int(zone1.sum() + zone2.sum())

    num_sign = int((df_rank["q_ori"] < self.q_thr).sum())
    num_non_sign = int(ori_rank_size - num_sign)

    m = float(num_sign + self.w * num_non_sign)
    gcs = float((n_sign + self.w * n_non_sign) / m) if m > 0 else np.nan

    return gcs, n_sign, n_non_sign, m

plot_gcs_datasets(ori_data, dataset_dict, figsize=(18, 10))

Plot all datasets in a grid.

Parameters:

Name Type Description Default
ori_data str | PathLike | DataFrame

Path to original CSV or DataFrame.

required
dataset_dict Mapping[str, str | PathLike | DataFrame]

Tool name -> path or DataFrame.

required
figsize tuple[int, int]

Figure size.

(18, 10)

Returns:

Type Description
Tuple[Figure, Dict[str, float]]

tuple[matplotlib.figure.Figure, dict[str, float]]: (figure, gcs_dict).

Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def plot_gcs_datasets(
    self,
    ori_data: Union[str, "os.PathLike[str]", pd.DataFrame],
    dataset_dict: Mapping[str, Union[str, "os.PathLike[str]", pd.DataFrame]],
    figsize: Tuple[int, int] = (18, 10),
) -> Tuple[plt.Figure, Dict[str, float]]:
    """
    Plot all datasets in a grid.

    Args:
        ori_data (str | PathLike | pd.DataFrame): Path to original CSV or DataFrame.
        dataset_dict (Mapping[str, str | PathLike | pd.DataFrame]): Tool name -> path or DataFrame.
        figsize (tuple[int, int]): Figure size.

    Returns:
        tuple[matplotlib.figure.Figure, dict[str, float]]: (figure, gcs_dict).
    """
    df_ori = self._load_df(ori_data)

    fig, axes = plt.subplots(2, 3, figsize=figsize)
    axes = axes.flatten()

    gcs_dict: Dict[str, float] = {}

    for idx, (tool, data) in enumerate(dataset_dict.items()):
        try:
            (
                x,
                y,
                gcs,
                n1,
                n2,
                n3,
                n4,
                m,
                _ori_rank_size,
                _aligned_size,
                _seed_used,
            ) = self.process_single_dge_result(df_ori, data)

            self.plot_single_gcs_panel(
                axes[idx],
                x,
                y,
                gcs,
                n1,
                n2,
                n3,
                n4,
                tool,
                m,
            )

            gcs_dict[tool] = gcs

        except Exception as e:
            print(f"Skipping {tool}: {e}")

    for j in range(len(dataset_dict), len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()

    return fig, gcs_dict

plot_single_gcs_panel(ax, x, y, gcs, n_zone1, n_zone2, n_zone3, n_zone4, tool_name, m)

Plot single GCS comparison panel (identical logic and styling to original code).

Parameters:

Name Type Description Default
ax Axes

Target axis.

required
x ndarray

Rank scores (original).

required
y ndarray

Rank scores (synthetic).

required
gcs float

GCS value.

required
n_zone1 int

Count in zone1.

required
n_zone2 int

Count in zone2.

required
n_zone3 int

Count in zone3.

required
n_zone4 int

Count in zone4.

required
tool_name str

Tool/dataset name for title.

required
m float

M normalization used in GCS.

required

Returns:

Type Description
None

None

Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def plot_single_gcs_panel(
    self,
    ax: plt.Axes,
    x: np.ndarray,
    y: np.ndarray,
    gcs: float,
    n_zone1: int,
    n_zone2: int,
    n_zone3: int,
    n_zone4: int,
    tool_name: str,
    m: float,
) -> None:
    """
    Plot single GCS comparison panel (identical logic and styling to original code).

    Args:
        ax (matplotlib.axes.Axes): Target axis.
        x (np.ndarray): Rank scores (original).
        y (np.ndarray): Rank scores (synthetic).
        gcs (float): GCS value.
        n_zone1 (int): Count in zone1.
        n_zone2 (int): Count in zone2.
        n_zone3 (int): Count in zone3.
        n_zone4 (int): Count in zone4.
        tool_name (str): Tool/dataset name for title.
        m (float): M normalization used in GCS.

    Returns:
        None
    """
    s_thr = -np.log10(self.q_thr)
    lo, hi = -2, 2

    zone1 = (x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)
    zone2 = (x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)
    zone3 = (x <= -s_thr) & (y <= -s_thr)
    zone4 = (x >= s_thr) & (y >= s_thr)
    other = ~(zone1 | zone2 | zone3 | zone4)

    ax.add_patch(
        plt.Rectangle((-s_thr, -s_thr), s_thr, s_thr, color=GCS_ZONE_BG_COLORS["zone1"], alpha=0.4)
    )
    ax.add_patch(
        plt.Rectangle((0, 0), s_thr, s_thr, color=GCS_ZONE_BG_COLORS["zone2"], alpha=0.4)
    )
    ax.add_patch(
        plt.Rectangle((lo, lo), -s_thr - lo, -s_thr - lo, color=GCS_ZONE_BG_COLORS["zone3"], alpha=0.2)
    )
    ax.add_patch(
        plt.Rectangle((s_thr, s_thr), hi - s_thr, hi - s_thr, color=GCS_ZONE_BG_COLORS["zone4"], alpha=0.2)
    )

    ax.scatter(x[other], y[other], c=GCS_ZONE_COLORS["other"], alpha=0.3, s=20)
    ax.scatter(x[zone1], y[zone1], c=GCS_ZONE_COLORS["zone1"], s=60, edgecolor="k")
    ax.scatter(x[zone2], y[zone2], c=GCS_ZONE_COLORS["zone2"], s=60, edgecolor="k")
    ax.scatter(x[zone3], y[zone3], c=GCS_ZONE_COLORS["zone3"], s=70, edgecolor="k")
    ax.scatter(x[zone4], y[zone4], c=GCS_ZONE_COLORS["zone4"], s=70, edgecolor="k")

    ax.plot([lo, hi], [lo, hi], "--", c="black", lw=1)
    for v in [-s_thr, 0, s_thr]:
        ax.axhline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)
        ax.axvline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)

    ax.set_xlim(lo, hi)
    ax.set_ylim(lo, hi)
    ax.set_xlabel("Rank score (Original)")
    ax.set_ylabel("Rank score (Synthetic)")
    ax.set_title(f"{tool_name}\nGCS={gcs:.3f} (w={self.w})")

    ax.text(
        -1.9,
        1.8,
        "Zones\n"
        f"NonSig LL: {n_zone1}\n"
        f"NonSig UR: {n_zone2}\n"
        f"Sig LL: {n_zone3}\n"
        f"Sig UR: {n_zone4}\n"
        f"M={m}",
        fontsize=10,
        va="top",
        bbox=dict(fc="white", alpha=0.7, ec="none"),
    )

    ax.grid(True)

process_single_dge_result(dge_ori, dge_syn)

Process original and synthetic DGE results (paths or DataFrames).

Parameters:

Name Type Description Default
dge_ori str | PathLike | DataFrame

Original DGE table/path.

required
dge_syn str | PathLike | DataFrame

Synthetic DGE table/path.

required

Returns:

Name Type Description
tuple Tuple[ndarray, ndarray, float, int, int, int, int, float, int, int, int]

(x, y, GCS, n_zone1, n_zone2, n_zone3, n_zone4, M, ori_rank_size, aligned_size, seed_used).

Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def process_single_dge_result(
    self,
    dge_ori: Union[str, "os.PathLike[str]", pd.DataFrame],
    dge_syn: Union[str, "os.PathLike[str]", pd.DataFrame],
) -> Tuple[np.ndarray, np.ndarray, float, int, int, int, int, float, int, int, int]:
    """
    Process original and synthetic DGE results (paths or DataFrames).

    Args:
        dge_ori (str | PathLike | pd.DataFrame): Original DGE table/path.
        dge_syn (str | PathLike | pd.DataFrame): Synthetic DGE table/path.

    Returns:
        tuple: (x, y, GCS, n_zone1, n_zone2, n_zone3, n_zone4, M, ori_rank_size, aligned_size, seed_used).
    """
    df_ori = self._load_df(dge_ori)
    df_syn = self._load_df(dge_syn)

    rnk_ori = self.compute_rank_score(df_ori)
    rnk_syn = self.compute_rank_score(df_syn)

    ori_rank_size = int(len(rnk_ori))

    aligned = self.align_rank_scores(rnk_ori, rnk_syn)

    x = aligned["rank_ori"].values
    y = aligned["rank_syn"].values

    s_thr = -np.log10(self.q_thr)

    n_zone1 = int(((x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)).sum())
    n_zone2 = int(((x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)).sum())
    n_zone3 = int(((x <= -s_thr) & (y <= -s_thr)).sum())
    n_zone4 = int(((x >= s_thr) & (y >= s_thr)).sum())

    gcs, _, _, m = self.gene_set_concordance_score(
        aligned,
        ori_rank_size=ori_rank_size,
    )

    return (
        x,
        y,
        gcs,
        n_zone1,
        n_zone2,
        n_zone3,
        n_zone4,
        m,
        ori_rank_size,
        int(len(aligned)),
        int(self.seed),
    )

GCSResult dataclass

Container for GCS computation outputs.

Parameters:

Name Type Description Default
gcs float

Gene-set Concordance Score (same formula as PCS, renamed).

required
n_sign int

Number of concordant significant pathways (zones 3 + 4).

required
n_non_sign int

Number of concordant non-significant pathways (zones 1 + 2).

required
m float

Normalization constant M used in the GCS formula.

required
n_zone1 int

Count in non-significant lower-left zone.

required
n_zone2 int

Count in non-significant upper-right zone.

required
n_zone3 int

Count in significant lower-left zone.

required
n_zone4 int

Count in significant upper-right zone.

required
ori_rank_size int

Number of ranked pathways in the original ranking.

required
aligned_size int

Number of pathways after alignment (inner join).

required
Source code in src/synomicsbench/metrics/narrow_utility/DGE.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@dataclass(frozen=True)
class GCSResult:
    """
    Container for GCS computation outputs.

    Args:
        gcs (float): Gene-set Concordance Score (same formula as PCS, renamed).
        n_sign (int): Number of concordant significant pathways (zones 3 + 4).
        n_non_sign (int): Number of concordant non-significant pathways (zones 1 + 2).
        m (float): Normalization constant M used in the GCS formula.
        n_zone1 (int): Count in non-significant lower-left zone.
        n_zone2 (int): Count in non-significant upper-right zone.
        n_zone3 (int): Count in significant lower-left zone.
        n_zone4 (int): Count in significant upper-right zone.
        ori_rank_size (int): Number of ranked pathways in the original ranking.
        aligned_size (int): Number of pathways after alignment (inner join).
    """
    gcs: float
    n_sign: int
    n_non_sign: int
    m: float
    n_zone1: int
    n_zone2: int
    n_zone3: int
    n_zone4: int
    ori_rank_size: int
    aligned_size: int

GSEA

PCSAnalyzer

Compute PCS from GSEA pathway enrichment tables and generate manuscript-style panels.

This class preserves the original logic provided in gsea_pathway_analysis.py, including jitter behavior and normalization by aligned pathways (len(x)).

Parameters:

Name Type Description Default
term_col str

Pathway term column name.

'Term'
nes_col str

Normalized enrichment score column name.

'NES'
q_col str

FDR q-value column name.

'FDR q-val'
seed int

Seed passed to np.random.seed for jitter reproducibility.

42
q_thr float

Q-value threshold for significance zones.

0.05
w float

Weight for non-significant concordance.

0.5

Raises:

Type Description
ValueError

If q_thr is not in (0, 1].

ValueError

If w is negative.

Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
class PCSAnalyzer:
    """
    Compute PCS from GSEA pathway enrichment tables and generate manuscript-style panels.

    This class preserves the original logic provided in `gsea_pathway_analysis.py`,
    including jitter behavior and normalization by aligned pathways (len(x)).

    Args:
        term_col (str): Pathway term column name.
        nes_col (str): Normalized enrichment score column name.
        q_col (str): FDR q-value column name.
        seed (int): Seed passed to np.random.seed for jitter reproducibility.
        q_thr (float): Q-value threshold for significance zones.
        w (float): Weight for non-significant concordance.

    Raises:
        ValueError: If q_thr is not in (0, 1].
        ValueError: If w is negative.
    """

    def __init__(
        self,
        term_col: str = "Term",
        nes_col: str = "NES",
        q_col: str = "FDR q-val",
        seed: int = 42,
        q_thr: float = 0.05,
        w: float = 0.5,
    ) -> None:
        if not (0.0 < q_thr <= 1.0):
            raise ValueError("q_thr must be in (0, 1].")
        if w < 0:
            raise ValueError("w must be >= 0.")

        self.term_col = term_col
        self.nes_col = nes_col
        self.q_col = q_col
        self.seed = seed
        self.q_thr = q_thr
        self.w = w

    @staticmethod
    def _load_df(data: Union[str, "os.PathLike[str]", pd.DataFrame]) -> pd.DataFrame:
        """Helper to load a CSV if a path is provided, otherwise return the DataFrame."""
        if isinstance(data, pd.DataFrame):
            return data
        return pd.read_csv(data)

    def gsea_rank_score(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Compute pathway rank score = sign(NES) * -log10(FDR-q) with tiny jitter.

        Args:
            df (pd.DataFrame): GSEA results table.

        Returns:
            pd.DataFrame: Ranked table indexed by term_col with columns rank_score and qval.

        Raises:
            KeyError: If required columns are missing.
        """
        np.random.seed(self.seed)

        q = pd.to_numeric(df[self.q_col], errors="coerce").clip(lower=0.001)
        nes = pd.to_numeric(df[self.nes_col], errors="coerce")

        rank_score = np.sign(nes) * (-np.log10(q))

        jitter = np.random.uniform(-1e-10, 1e-10, size=len(rank_score))
        rank_score += jitter

        rnk = (
            df[[self.term_col]]
            .assign(rank_score=rank_score, qval=q)
            .dropna()
            .groupby(self.term_col)
            .mean()
            .sort_values("rank_score", ascending=False)
        )
        return rnk

    @staticmethod
    def align_rank_scores(rnk_ori: pd.DataFrame, rnk_syn: pd.DataFrame) -> pd.DataFrame:
        """
        Inner join on pathway names.

        Args:
            rnk_ori (pd.DataFrame): Original rank table.
            rnk_syn (pd.DataFrame): Synthetic rank table.

        Returns:
            pd.DataFrame: Aligned rank table.
        """
        aligned = (
            rnk_ori.rename(columns={"rank_score": "rank_ori", "qval": "q_ori"}).join(
                rnk_syn.rename(columns={"rank_score": "rank_syn"}),
                how="inner",
            )
        )
        return aligned

    def pathway_concordance_score(self, df_rank: pd.DataFrame) -> Tuple[float, int, int, float]:
        """
        Compute PCS for aligned rank scores.

        Args:
            df_rank (pd.DataFrame): Aligned rank table with rank_ori, rank_syn, q_ori.

        Returns:
            tuple[float, int, int, float]: (PCS, N_sign, N_non_sign, M).
        """
        x = np.asarray(df_rank["rank_ori"].values)
        y = np.asarray(df_rank["rank_syn"].values)

        sign_threshold = -np.log10(self.q_thr)

        zone1_mask = (x >= -sign_threshold) & (x <= 0) & (y >= -sign_threshold) & (y <= 0)
        zone2_mask = (x >= 0) & (x <= sign_threshold) & (y >= 0) & (y <= sign_threshold)

        zone3_mask = (x <= -sign_threshold) & (y <= -sign_threshold)
        zone4_mask = (x >= sign_threshold) & (y >= sign_threshold)

        n_sign = int(np.sum(zone3_mask) + np.sum(zone4_mask))
        n_non_sign = int(np.sum(zone1_mask) + np.sum(zone2_mask))

        num_sign = int(np.sum(df_rank["q_ori"] < self.q_thr))
        num_non_sign = int(len(x) - num_sign)

        m = float(num_sign + self.w * num_non_sign)
        pcs = float((n_sign + self.w * n_non_sign) / m) if m > 0 else np.nan

        return pcs, n_sign, n_non_sign, m

    def process_single_gsea_result(
        self,
        gsea_ori: Union[str, "os.PathLike[str]", pd.DataFrame],
        gsea_syn: Union[str, "os.PathLike[str]", pd.DataFrame],
    ) -> Tuple[np.ndarray, np.ndarray, PCSResult]:
        """
        Process original and synthetic GSEA results (paths or DataFrames).

        Args:
            gsea_ori (str | PathLike | pd.DataFrame): Original GSEA table.
            gsea_syn (str | PathLike | pd.DataFrame): Synthetic GSEA table.

        Returns:
            tuple[np.ndarray, np.ndarray, PCSResult]: (x, y, result).
        """
        df_ori = self._load_df(gsea_ori)
        df_syn = self._load_df(gsea_syn)

        rnk_ori = self.gsea_rank_score(df_ori)
        rnk_syn = self.gsea_rank_score(df_syn)

        aligned = self.align_rank_scores(rnk_ori, rnk_syn)

        x = aligned["rank_ori"].values
        y = aligned["rank_syn"].values

        s_thr = -np.log10(self.q_thr)

        n_zone1 = int(((x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)).sum())
        n_zone2 = int(((x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)).sum())
        n_zone3 = int(((x <= -s_thr) & (y <= -s_thr)).sum())
        n_zone4 = int(((x >= s_thr) & (y >= s_thr)).sum())

        pcs, n_sign, n_non_sign, m = self.pathway_concordance_score(aligned)

        result = PCSResult(
            pcs=pcs,
            n_sign=n_sign,
            n_non_sign=n_non_sign,
            m=m,
            n_zone1=n_zone1,
            n_zone2=n_zone2,
            n_zone3=n_zone3,
            n_zone4=n_zone4,
            aligned_size=int(len(aligned)),
        )
        return x, y, result

    def plot_single_gsea_panel(
        self,
        ax: plt.Axes,
        x: np.ndarray,
        y: np.ndarray,
        result: PCSResult,
        tool_name: str,
        show_xlabel: bool = False,
    ) -> None:
        """
        Plot one PCS scatter panel (same geometry/logic as original script).

        Args:
            ax (matplotlib.axes.Axes): Matplotlib axis.
            x (np.ndarray): Original rank scores.
            y (np.ndarray): Synthetic rank scores.
            result (PCSResult): PCS result container.
            tool_name (str): Tool name for title.
            show_xlabel (bool): Whether to show x-axis label/ticks.

        Returns:
            None
        """
        s_thr = -np.log10(self.q_thr)
        lo, hi = -3, 3

        zone1 = (x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)
        zone2 = (x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)
        zone3 = (x <= -s_thr) & (y <= -s_thr)
        zone4 = (x >= s_thr) & (y >= s_thr)
        other = ~(zone1 | zone2 | zone3 | zone4)

        ax.add_patch(plt.Rectangle((-s_thr, -s_thr), s_thr, s_thr, color="#d0d0f4", alpha=0.4))
        ax.add_patch(plt.Rectangle((0, 0), s_thr, s_thr, color="#d0f4d0", alpha=0.4))
        ax.add_patch(plt.Rectangle((lo, lo), -s_thr - lo, -s_thr - lo, color="#89a5e6", alpha=0.2))
        ax.add_patch(plt.Rectangle((s_thr, s_thr), hi - s_thr, hi - s_thr, color="#66e699", alpha=0.2))

        ax.scatter(x[other], y[other], c="gray", alpha=0.3, s=20)
        ax.scatter(x[zone1], y[zone1], c="#38499b", s=60, edgecolor="k")
        ax.scatter(x[zone2], y[zone2], c="#239933", s=60, edgecolor="k")
        ax.scatter(x[zone3], y[zone3], c="#222e5c", s=70, edgecolor="k")
        ax.scatter(x[zone4], y[zone4], c="#13753a", s=70, edgecolor="k")

        ax.plot([lo, hi], [lo, hi], "--", c="black", lw=1)
        for v in [-s_thr, 0, s_thr]:
            ax.axhline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)
            ax.axvline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)

        ax.set_xlim(lo, hi)
        ax.set_ylim(lo, hi)

        if show_xlabel:
            ax.set_xlabel("Rank score (Original)")
        else:
            ax.set_xlabel("")
            ax.tick_params(axis="x", labelbottom=False)

        ax.set_ylabel("Rank score\n(Synthetic)")

        title_color = DATASET_COLORS.get(tool_name, "black")
        ax.set_title(
            f"{tool_name}\nPCS={result.pcs:.3f} (w={self.w})",
            fontsize=12,
            color=title_color,
            fontweight="bold",
        )

        ax.text(
            -2.9,
            2.8,
            f"NonSig LL: {result.n_zone1}\n"
            f"NonSig UR: {result.n_zone2}\n"
            f"Sig LL: {result.n_zone3}\n"
            f"Sig UR: {result.n_zone4}\n",
            fontsize=10,
            va="top",
            bbox=dict(fc="white", alpha=0.7, ec="none"),
        )

        ax.grid(True)

    def plot_gsea_datasets(
        self,
        ori_data: Union[str, "os.PathLike[str]", pd.DataFrame],
        dataset_dict: Mapping[str, Union[str, "os.PathLike[str]", pd.DataFrame]],
        figsize: Tuple[float, float] = (9, 10),
    ) -> Tuple[plt.Figure, Dict[str, float]]:
        """
        Plot multiple synthetic datasets against one original GSEA file.

        Args:
            ori_data (str | PathLike | pd.DataFrame): Original GSEA data or path.
            dataset_dict (Mapping[str, str | PathLike | pd.DataFrame]): Tool name -> synthetic data or path.
            figsize (tuple[float, float]): Figure size.

        Returns:
            tuple[matplotlib.figure.Figure, dict[str, float]]: (fig, pcs_dict).
        """
        df_ori = self._load_df(ori_data)

        fig, axes = plt.subplots(3, 2, figsize=figsize, constrained_layout=True)
        axes = axes.flatten()

        pcs_dict: Dict[str, float] = {}

        for idx, (tool, data) in enumerate(dataset_dict.items()):
            try:
                x, y, result = self.process_single_gsea_result(df_ori, data)

                row = idx // 2
                is_bottom = row == 2

                self.plot_single_gsea_panel(
                    axes[idx],
                    x,
                    y,
                    result,
                    tool,
                    show_xlabel=is_bottom,
                )

                pcs_dict[tool] = result.pcs
            except Exception as e:
                print(f"Skipping {tool}: {e}")

        for j in range(len(dataset_dict), len(axes)):
            fig.delaxes(axes[j])

        plt.tight_layout()
        plt.show()

        return fig, pcs_dict

align_rank_scores(rnk_ori, rnk_syn) staticmethod

Inner join on pathway names.

Parameters:

Name Type Description Default
rnk_ori DataFrame

Original rank table.

required
rnk_syn DataFrame

Synthetic rank table.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Aligned rank table.

Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
@staticmethod
def align_rank_scores(rnk_ori: pd.DataFrame, rnk_syn: pd.DataFrame) -> pd.DataFrame:
    """
    Inner join on pathway names.

    Args:
        rnk_ori (pd.DataFrame): Original rank table.
        rnk_syn (pd.DataFrame): Synthetic rank table.

    Returns:
        pd.DataFrame: Aligned rank table.
    """
    aligned = (
        rnk_ori.rename(columns={"rank_score": "rank_ori", "qval": "q_ori"}).join(
            rnk_syn.rename(columns={"rank_score": "rank_syn"}),
            how="inner",
        )
    )
    return aligned

gsea_rank_score(df)

Compute pathway rank score = sign(NES) * -log10(FDR-q) with tiny jitter.

Parameters:

Name Type Description Default
df DataFrame

GSEA results table.

required

Returns:

Type Description
DataFrame

pd.DataFrame: Ranked table indexed by term_col with columns rank_score and qval.

Raises:

Type Description
KeyError

If required columns are missing.

Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def gsea_rank_score(self, df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute pathway rank score = sign(NES) * -log10(FDR-q) with tiny jitter.

    Args:
        df (pd.DataFrame): GSEA results table.

    Returns:
        pd.DataFrame: Ranked table indexed by term_col with columns rank_score and qval.

    Raises:
        KeyError: If required columns are missing.
    """
    np.random.seed(self.seed)

    q = pd.to_numeric(df[self.q_col], errors="coerce").clip(lower=0.001)
    nes = pd.to_numeric(df[self.nes_col], errors="coerce")

    rank_score = np.sign(nes) * (-np.log10(q))

    jitter = np.random.uniform(-1e-10, 1e-10, size=len(rank_score))
    rank_score += jitter

    rnk = (
        df[[self.term_col]]
        .assign(rank_score=rank_score, qval=q)
        .dropna()
        .groupby(self.term_col)
        .mean()
        .sort_values("rank_score", ascending=False)
    )
    return rnk

pathway_concordance_score(df_rank)

Compute PCS for aligned rank scores.

Parameters:

Name Type Description Default
df_rank DataFrame

Aligned rank table with rank_ori, rank_syn, q_ori.

required

Returns:

Type Description
Tuple[float, int, int, float]

tuple[float, int, int, float]: (PCS, N_sign, N_non_sign, M).

Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def pathway_concordance_score(self, df_rank: pd.DataFrame) -> Tuple[float, int, int, float]:
    """
    Compute PCS for aligned rank scores.

    Args:
        df_rank (pd.DataFrame): Aligned rank table with rank_ori, rank_syn, q_ori.

    Returns:
        tuple[float, int, int, float]: (PCS, N_sign, N_non_sign, M).
    """
    x = np.asarray(df_rank["rank_ori"].values)
    y = np.asarray(df_rank["rank_syn"].values)

    sign_threshold = -np.log10(self.q_thr)

    zone1_mask = (x >= -sign_threshold) & (x <= 0) & (y >= -sign_threshold) & (y <= 0)
    zone2_mask = (x >= 0) & (x <= sign_threshold) & (y >= 0) & (y <= sign_threshold)

    zone3_mask = (x <= -sign_threshold) & (y <= -sign_threshold)
    zone4_mask = (x >= sign_threshold) & (y >= sign_threshold)

    n_sign = int(np.sum(zone3_mask) + np.sum(zone4_mask))
    n_non_sign = int(np.sum(zone1_mask) + np.sum(zone2_mask))

    num_sign = int(np.sum(df_rank["q_ori"] < self.q_thr))
    num_non_sign = int(len(x) - num_sign)

    m = float(num_sign + self.w * num_non_sign)
    pcs = float((n_sign + self.w * n_non_sign) / m) if m > 0 else np.nan

    return pcs, n_sign, n_non_sign, m

plot_gsea_datasets(ori_data, dataset_dict, figsize=(9, 10))

Plot multiple synthetic datasets against one original GSEA file.

Parameters:

Name Type Description Default
ori_data str | PathLike | DataFrame

Original GSEA data or path.

required
dataset_dict Mapping[str, str | PathLike | DataFrame]

Tool name -> synthetic data or path.

required
figsize tuple[float, float]

Figure size.

(9, 10)

Returns:

Type Description
Tuple[Figure, Dict[str, float]]

tuple[matplotlib.figure.Figure, dict[str, float]]: (fig, pcs_dict).

Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
def plot_gsea_datasets(
    self,
    ori_data: Union[str, "os.PathLike[str]", pd.DataFrame],
    dataset_dict: Mapping[str, Union[str, "os.PathLike[str]", pd.DataFrame]],
    figsize: Tuple[float, float] = (9, 10),
) -> Tuple[plt.Figure, Dict[str, float]]:
    """
    Plot multiple synthetic datasets against one original GSEA file.

    Args:
        ori_data (str | PathLike | pd.DataFrame): Original GSEA data or path.
        dataset_dict (Mapping[str, str | PathLike | pd.DataFrame]): Tool name -> synthetic data or path.
        figsize (tuple[float, float]): Figure size.

    Returns:
        tuple[matplotlib.figure.Figure, dict[str, float]]: (fig, pcs_dict).
    """
    df_ori = self._load_df(ori_data)

    fig, axes = plt.subplots(3, 2, figsize=figsize, constrained_layout=True)
    axes = axes.flatten()

    pcs_dict: Dict[str, float] = {}

    for idx, (tool, data) in enumerate(dataset_dict.items()):
        try:
            x, y, result = self.process_single_gsea_result(df_ori, data)

            row = idx // 2
            is_bottom = row == 2

            self.plot_single_gsea_panel(
                axes[idx],
                x,
                y,
                result,
                tool,
                show_xlabel=is_bottom,
            )

            pcs_dict[tool] = result.pcs
        except Exception as e:
            print(f"Skipping {tool}: {e}")

    for j in range(len(dataset_dict), len(axes)):
        fig.delaxes(axes[j])

    plt.tight_layout()
    plt.show()

    return fig, pcs_dict

plot_single_gsea_panel(ax, x, y, result, tool_name, show_xlabel=False)

Plot one PCS scatter panel (same geometry/logic as original script).

Parameters:

Name Type Description Default
ax Axes

Matplotlib axis.

required
x ndarray

Original rank scores.

required
y ndarray

Synthetic rank scores.

required
result PCSResult

PCS result container.

required
tool_name str

Tool name for title.

required
show_xlabel bool

Whether to show x-axis label/ticks.

False

Returns:

Type Description
None

None

Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
def plot_single_gsea_panel(
    self,
    ax: plt.Axes,
    x: np.ndarray,
    y: np.ndarray,
    result: PCSResult,
    tool_name: str,
    show_xlabel: bool = False,
) -> None:
    """
    Plot one PCS scatter panel (same geometry/logic as original script).

    Args:
        ax (matplotlib.axes.Axes): Matplotlib axis.
        x (np.ndarray): Original rank scores.
        y (np.ndarray): Synthetic rank scores.
        result (PCSResult): PCS result container.
        tool_name (str): Tool name for title.
        show_xlabel (bool): Whether to show x-axis label/ticks.

    Returns:
        None
    """
    s_thr = -np.log10(self.q_thr)
    lo, hi = -3, 3

    zone1 = (x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)
    zone2 = (x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)
    zone3 = (x <= -s_thr) & (y <= -s_thr)
    zone4 = (x >= s_thr) & (y >= s_thr)
    other = ~(zone1 | zone2 | zone3 | zone4)

    ax.add_patch(plt.Rectangle((-s_thr, -s_thr), s_thr, s_thr, color="#d0d0f4", alpha=0.4))
    ax.add_patch(plt.Rectangle((0, 0), s_thr, s_thr, color="#d0f4d0", alpha=0.4))
    ax.add_patch(plt.Rectangle((lo, lo), -s_thr - lo, -s_thr - lo, color="#89a5e6", alpha=0.2))
    ax.add_patch(plt.Rectangle((s_thr, s_thr), hi - s_thr, hi - s_thr, color="#66e699", alpha=0.2))

    ax.scatter(x[other], y[other], c="gray", alpha=0.3, s=20)
    ax.scatter(x[zone1], y[zone1], c="#38499b", s=60, edgecolor="k")
    ax.scatter(x[zone2], y[zone2], c="#239933", s=60, edgecolor="k")
    ax.scatter(x[zone3], y[zone3], c="#222e5c", s=70, edgecolor="k")
    ax.scatter(x[zone4], y[zone4], c="#13753a", s=70, edgecolor="k")

    ax.plot([lo, hi], [lo, hi], "--", c="black", lw=1)
    for v in [-s_thr, 0, s_thr]:
        ax.axhline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)
        ax.axvline(v, ls=":", c="black", lw=1 if v != 0 else 0.5)

    ax.set_xlim(lo, hi)
    ax.set_ylim(lo, hi)

    if show_xlabel:
        ax.set_xlabel("Rank score (Original)")
    else:
        ax.set_xlabel("")
        ax.tick_params(axis="x", labelbottom=False)

    ax.set_ylabel("Rank score\n(Synthetic)")

    title_color = DATASET_COLORS.get(tool_name, "black")
    ax.set_title(
        f"{tool_name}\nPCS={result.pcs:.3f} (w={self.w})",
        fontsize=12,
        color=title_color,
        fontweight="bold",
    )

    ax.text(
        -2.9,
        2.8,
        f"NonSig LL: {result.n_zone1}\n"
        f"NonSig UR: {result.n_zone2}\n"
        f"Sig LL: {result.n_zone3}\n"
        f"Sig UR: {result.n_zone4}\n",
        fontsize=10,
        va="top",
        bbox=dict(fc="white", alpha=0.7, ec="none"),
    )

    ax.grid(True)

process_single_gsea_result(gsea_ori, gsea_syn)

Process original and synthetic GSEA results (paths or DataFrames).

Parameters:

Name Type Description Default
gsea_ori str | PathLike | DataFrame

Original GSEA table.

required
gsea_syn str | PathLike | DataFrame

Synthetic GSEA table.

required

Returns:

Type Description
Tuple[ndarray, ndarray, PCSResult]

tuple[np.ndarray, np.ndarray, PCSResult]: (x, y, result).

Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def process_single_gsea_result(
    self,
    gsea_ori: Union[str, "os.PathLike[str]", pd.DataFrame],
    gsea_syn: Union[str, "os.PathLike[str]", pd.DataFrame],
) -> Tuple[np.ndarray, np.ndarray, PCSResult]:
    """
    Process original and synthetic GSEA results (paths or DataFrames).

    Args:
        gsea_ori (str | PathLike | pd.DataFrame): Original GSEA table.
        gsea_syn (str | PathLike | pd.DataFrame): Synthetic GSEA table.

    Returns:
        tuple[np.ndarray, np.ndarray, PCSResult]: (x, y, result).
    """
    df_ori = self._load_df(gsea_ori)
    df_syn = self._load_df(gsea_syn)

    rnk_ori = self.gsea_rank_score(df_ori)
    rnk_syn = self.gsea_rank_score(df_syn)

    aligned = self.align_rank_scores(rnk_ori, rnk_syn)

    x = aligned["rank_ori"].values
    y = aligned["rank_syn"].values

    s_thr = -np.log10(self.q_thr)

    n_zone1 = int(((x >= -s_thr) & (x <= 0) & (y >= -s_thr) & (y <= 0)).sum())
    n_zone2 = int(((x >= 0) & (x <= s_thr) & (y >= 0) & (y <= s_thr)).sum())
    n_zone3 = int(((x <= -s_thr) & (y <= -s_thr)).sum())
    n_zone4 = int(((x >= s_thr) & (y >= s_thr)).sum())

    pcs, n_sign, n_non_sign, m = self.pathway_concordance_score(aligned)

    result = PCSResult(
        pcs=pcs,
        n_sign=n_sign,
        n_non_sign=n_non_sign,
        m=m,
        n_zone1=n_zone1,
        n_zone2=n_zone2,
        n_zone3=n_zone3,
        n_zone4=n_zone4,
        aligned_size=int(len(aligned)),
    )
    return x, y, result

PCSResult dataclass

Container for PCS computation outputs.

Parameters:

Name Type Description Default
pcs float

Pathway Concordance Score.

required
n_sign int

Number of concordant significant pathways (zones 3 + 4).

required
n_non_sign int

Number of concordant non-significant pathways (zones 1 + 2).

required
m float

Normalization constant M used in the PCS formula.

required
n_zone1 int

Count in non-significant lower-left zone.

required
n_zone2 int

Count in non-significant upper-right zone.

required
n_zone3 int

Count in significant lower-left zone.

required
n_zone4 int

Count in significant upper-right zone.

required
aligned_size int

Number of pathways after alignment (inner join).

required
Source code in src/synomicsbench/metrics/narrow_utility/GSEA.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@dataclass(frozen=True)
class PCSResult:
    """
    Container for PCS computation outputs.

    Args:
        pcs (float): Pathway Concordance Score.
        n_sign (int): Number of concordant significant pathways (zones 3 + 4).
        n_non_sign (int): Number of concordant non-significant pathways (zones 1 + 2).
        m (float): Normalization constant M used in the PCS formula.
        n_zone1 (int): Count in non-significant lower-left zone.
        n_zone2 (int): Count in non-significant upper-right zone.
        n_zone3 (int): Count in significant lower-left zone.
        n_zone4 (int): Count in significant upper-right zone.
        aligned_size (int): Number of pathways after alignment (inner join).
    """
    pcs: float
    n_sign: int
    n_non_sign: int
    m: float
    n_zone1: int
    n_zone2: int
    n_zone3: int
    n_zone4: int
    aligned_size: int

cell_deconvolution

Aitchison distance for compositional cell-type deconvolution data.

Provides functions to compute the Aitchison distance between immune cell composition profiles estimated by CIBERSORTx (or similar tools) on original and synthetic datasets.

The Aitchison distance is computed in Centered Log-Ratio (CLR) space after multiplicative replacement of zeros and geometric mean centering.

aitchison_distance(df_orig, df_syn, cell_types)

Compute the Aitchison distance between two cell-type composition datasets.

The distance is defined as the Euclidean distance between the CLR- transformed compositional centers of the original and synthetic datasets.

Parameters:

Name Type Description Default
df_orig DataFrame

Original CIBERSORTx results (rows = samples, columns include the cell-type columns).

required
df_syn DataFrame

Synthetic CIBERSORTx results with the same cell-type columns.

required
cell_types List[str]

List of column names for the cell types to compare.

required

Returns:

Type Description
float

Aitchison distance (non-negative float).

Source code in src/synomicsbench/metrics/narrow_utility/cell_deconvolution.py
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def aitchison_distance(
    df_orig: pd.DataFrame,
    df_syn: pd.DataFrame,
    cell_types: List[str],
) -> float:
    """Compute the Aitchison distance between two cell-type composition datasets.

    The distance is defined as the Euclidean distance between the CLR-
    transformed compositional centers of the original and synthetic datasets.

    Args:
        df_orig: Original CIBERSORTx results (rows = samples, columns include
            the cell-type columns).
        df_syn: Synthetic CIBERSORTx results with the same cell-type columns.
        cell_types: List of column names for the cell types to compare.

    Returns:
        Aitchison distance (non-negative float).
    """
    data_orig = df_orig[cell_types].values.astype(float)
    data_syn = df_syn[cell_types].values.astype(float)

    # Replace zeros via multiplicative strategy
    data_orig_filled = multi_replace(data_orig)
    data_syn_filled = multi_replace(data_syn)

    # Compute compositional centers
    center_orig = geometric_center(data_orig_filled)
    center_syn = geometric_center(data_syn_filled)

    # CLR transform the centers
    clr_orig = clr(center_orig.reshape(1, -1)).flatten()
    clr_syn = clr(center_syn.reshape(1, -1)).flatten()

    # Euclidean distance in CLR space
    return euclidean(clr_orig, clr_syn)

aitchison_score(distance)

Convert Aitchison distance to a similarity score via exp(-d).

Higher values indicate better agreement. A distance of 0 yields a score of 1.0.

Parameters:

Name Type Description Default
distance float

Non-negative Aitchison distance.

required

Returns:

Type Description
float

Score in (0, 1].

Source code in src/synomicsbench/metrics/narrow_utility/cell_deconvolution.py
74
75
76
77
78
79
80
81
82
83
84
85
86
def aitchison_score(distance: float) -> float:
    """Convert Aitchison distance to a similarity score via exp(-d).

    Higher values indicate better agreement.  A distance of 0 yields a
    score of 1.0.

    Args:
        distance: Non-negative Aitchison distance.

    Returns:
        Score in (0, 1].
    """
    return float(np.exp(-distance))

geometric_center(X)

Compute the compositional center (geometric mean) for each component.

Parameters:

Name Type Description Default
X ndarray

2-D array of shape (n_samples, n_parts).

required

Returns:

Type Description
ndarray

1-D array of length n_parts representing the geometric mean per part.

Source code in src/synomicsbench/metrics/narrow_utility/cell_deconvolution.py
20
21
22
23
24
25
26
27
28
29
def geometric_center(X: np.ndarray) -> np.ndarray:
    """Compute the compositional center (geometric mean) for each component.

    Args:
        X: 2-D array of shape (n_samples, n_parts).

    Returns:
        1-D array of length n_parts representing the geometric mean per part.
    """
    return np.exp(np.mean(np.log(X), axis=0))

survival_analysis

SurvivalEvaluator

Perform survival analysis and visualization across multiple datasets comparing two phenotype groups.

Parameters:

Name Type Description Default
datasets_dict Dict[str, DataFrame]

Mapping from dataset names to DataFrames.

required
phenotype Dict[str, List[Any]]

{column_name: [value_A, value_B]}, specifying the phenotype column and comparison values.

required
time_target str

Name of the survival time column.

'OS'
event_target str

Name of the event indicator column.

'OS_CNSR'
dataset_order Optional[List[str]]

Custom plotting order for datasets. If None, use input order.

None
dataset_colors Optional[Dict[str, str]]

Colors for dataset annotation strips.

None
group_colors Optional[Dict[str, str]]

Colors for phenotype groups (A, B).

None
font_scale float

Global scaling for plot text sizes.

1.0

Returns:

Type Description

None

Raises:

Type Description
ValueError

If datasets_dict is empty or phenotype is not a {col: [A,B]} dict.

Source code in src/synomicsbench/metrics/narrow_utility/survival_analysis.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
class SurvivalEvaluator:
    """
    Perform survival analysis and visualization across multiple datasets
    comparing two phenotype groups.

    Args:
        datasets_dict (Dict[str, pd.DataFrame]): Mapping from dataset names to DataFrames.
        phenotype (Dict[str, List[Any]]): {column_name: [value_A, value_B]}, specifying the phenotype column and comparison values.
        time_target (str): Name of the survival time column.
        event_target (str): Name of the event indicator column.
        dataset_order (Optional[List[str]]): Custom plotting order for datasets. If None, use input order.
        dataset_colors (Optional[Dict[str, str]]): Colors for dataset annotation strips.
        group_colors (Optional[Dict[str, str]]): Colors for phenotype groups (A, B).
        font_scale (float): Global scaling for plot text sizes.

    Returns:
        None

    Raises:
        ValueError: If datasets_dict is empty or phenotype is not a {col: [A,B]} dict.
    """

    def __init__(
        self,
        datasets_dict: Dict[str, pd.DataFrame],
        phenotype: Dict[str, List[Any]],
        time_target: str = "OS",
        event_target: str = "OS_CNSR",
        dataset_order: Optional[List[str]] = None,
        dataset_colors: Optional[Dict[str, str]] = None,
        group_colors: Optional[Dict[str, str]] = None,
        font_scale: float = 1.0,
        is_pdf: bool = False,
        original_name: str = "Origin"
    ) -> None:
        """
        Initialize SurvivalGridEvaluator for grid-based survival comparison.

        Args:
            datasets_dict (Dict[str, pd.DataFrame]): Input datasets.
            phenotype (Dict[str, List[Any]]): {col: [value_A, value_B]}.
            time_target (str): Survival duration column.
            event_target (str): Event indicator column.
            dataset_order (List[str], optional): Custom dataset plotting order.
            dataset_colors (Dict[str, str], optional): Strip colors per dataset.
            group_colors (Dict[str, str], optional): Colors for phenotype groups.
            font_scale (float): Plot font scaling.
            is_pdf (boolean): save as pdf or png.
            original_name (str): Name of the reference dataset (default: "Origin").

        Returns:
            None

        Raises:
            ValueError: On empty datasets or invalid phenotype specification.
        """
        if not datasets_dict:
            raise ValueError("datasets_dict must not be empty.")
        if not isinstance(phenotype, dict) or len(phenotype) != 1:
            raise ValueError("phenotype must be a dict with one mapping {column_name: [valA, valB]}.")
        ph_column, ph_values = next(iter(phenotype.items()))
        if not isinstance(ph_values, (list, tuple)) or len(ph_values) != 2:
            raise ValueError("phenotype value must be a list/tuple of exactly two values: [value_A, value_B].")
        self.datasets_dict = datasets_dict
        self.phenotype_column = ph_column
        self.value_A, self.value_B = ph_values
        self.time_target = time_target
        self.event_target = event_target
        self.dataset_order = dataset_order
        self.dataset_colors = dataset_colors or DATASET_COLORS.copy()
        self.group_colors = group_colors or GROUP_COLORS.copy()
        self.font_scale = font_scale
        self.is_pdf = is_pdf
        self.original_name = original_name
        self.dataset_names = self._get_dataset_order()
        self.summary_df = None

    def _get_dataset_order(self) -> List[str]:
        """
        Determine the dataset order for visualization.

        Returns:
            List[str]: Dataset names in desired plotting order.

        Raises:
            ValueError: If dataset_order contains a missing dataset.
        """
        if self.dataset_order is None:
            order = list(self.datasets_dict.keys())
        else:
            missing = [d for d in self.dataset_order if d not in self.datasets_dict]
            if missing:
                raise ValueError(f"dataset_order contains datasets absent from datasets_dict: {missing}")
            order = self.dataset_order
        order = [d for d in [self.original_name] if d in order] + [d for d in order if d != self.original_name]
        return order

    def compute_survival_metrics(self) -> pd.DataFrame:
        """
        Compute log-rank test p-values and C-index for each dataset grid panel.

        Args:
            None

        Returns:
            pd.DataFrame: DataFrame with columns ['Dataset', 'pvalue', 'C-index', 'n_A', 'n_B'].
        """
        summary_rows = []
        for ds_name in self.dataset_names:
            df_raw = self.datasets_dict[ds_name].copy()
            df = df_raw.copy()
            df[self.event_target] = pd.to_numeric(df.get(self.event_target, 0), errors="coerce").fillna(0).astype(int)
            df[self.time_target] = pd.to_numeric(df.get(self.time_target, np.nan), errors="coerce")
            df_A = df[df[self.phenotype_column] == self.value_A]
            df_B = df[df[self.phenotype_column] == self.value_B]

            n_A, n_B = len(df_A), len(df_B)
            pvalue = np.nan
            try:
                if n_A > 0 and n_B > 0:
                    gA = df_A.dropna(subset=[self.time_target, self.event_target])
                    gB = df_B.dropna(subset=[self.time_target, self.event_target])
                    if len(gA) > 0 and len(gB) > 0:
                        res = logrank_test(
                            gA[self.time_target], gB[self.time_target],
                            event_observed_A=gA[self.event_target], event_observed_B=gB[self.event_target]
                        )
                        pvalue = float(res.p_value)
            except Exception:
                pvalue = np.nan

            cindex_text = "NA"
            try:
                df_cox = df[[self.time_target, self.event_target]].copy()
                df_cox["phenotype_binary"] = pd.NA
                df_cox.loc[df[self.phenotype_column] == self.value_A, "phenotype_binary"] = 1
                df_cox.loc[df[self.phenotype_column] == self.value_B, "phenotype_binary"] = 0
                df_cox_fit = df_cox.dropna(subset=[self.time_target, "phenotype_binary"])
                df_cox_fit[self.time_target] = pd.to_numeric(df_cox_fit[self.time_target], errors="coerce")
                df_cox_fit = df_cox_fit.dropna(subset=[self.time_target])
                df_cox_fit["phenotype_binary"] = df_cox_fit["phenotype_binary"].astype(int)
                if (
                    len(df_cox_fit) > 0
                    and df_cox_fit["phenotype_binary"].nunique() > 1
                    and df_cox_fit[self.event_target].sum() > 0
                ):
                    cph = CoxPHFitter()
                    cph.fit(df_cox_fit, duration_col=self.time_target, event_col=self.event_target, show_progress=False)
                    partial_h = cph.predict_partial_hazard(df_cox_fit)
                    cindex = concordance_index(df_cox_fit[self.time_target], -partial_h, df_cox_fit[self.event_target])
                    cindex_text = f"{cindex:.3f}"
                else:
                    cindex_text = "insufficient events"
            except Exception:
                cindex_text = "fit error"

            summary_rows.append(
                {
                    "Dataset": ds_name,
                    "pvalue": (np.nan if (pvalue is None or np.isnan(pvalue)) else float(pvalue)),
                    "C-index": cindex_text,
                    "n_A": int(n_A),
                    "n_B": int(n_B),
                }
            )
        summary_df = pd.DataFrame(summary_rows, columns=["Dataset", "pvalue", "C-index", "n_A", "n_B"])
        summary_df["Dataset"] = pd.Categorical(summary_df["Dataset"], categories=self.dataset_names, ordered=True)
        summary_df = summary_df.sort_values("Dataset").reset_index(drop=True)
        self.summary_df = summary_df
        return summary_df

    def compute_cindex_scores(self, original_name: Optional[str] = None) -> pd.DataFrame:
        """
        Compute C-index similarity scores between the reference and synthetic datasets.

        Args:
            original_name (Optional[str]): Reference dataset for score calculation. 
                If None, uses self.original_name.

        Returns:
            pd.DataFrame: DataFrame with additional column 'C-index_score'.

        Raises:
            KeyError: If required columns or reference row are missing.
            RuntimeError: If compute_survival_metrics() was not called prior.
        """
        if self.summary_df is None:
            raise RuntimeError("Must call compute_survival_metrics() before scoring.")
        df = self.summary_df.copy()
        ref_name = original_name or self.original_name
        if "Dataset" not in df.columns:
            raise KeyError("'Dataset' column not found.")
        if ref_name not in df["Dataset"].values:
            raise KeyError(f"Reference dataset '{ref_name}' not found in summary dataframe.")
        parsed_cindex = []
        for val in df["C-index"]:
            try:
                f = float(str(val).split()[0])
                if not math.isfinite(f):
                    f = np.nan
            except Exception:
                f = np.nan
            parsed_cindex.append(f)
        df["_cindex_parsed"] = parsed_cindex
        c_orig = df.loc[df["Dataset"] == ref_name, "_cindex_parsed"].iloc[0]
        scores = []
        for c_syn in df["_cindex_parsed"]:
            if pd.isna(c_orig) or pd.isna(c_syn):
                score = np.nan
            else:
                diff = abs(c_orig - c_syn)
                score = 1.0 - diff
                score = max(0.0, min(1.0, score))
            scores.append(score)
        df["C-index_score"] = scores
        df = df.drop(columns=["_cindex_parsed"])
        return df

    def plot_grid(
        self,
        figsize: Optional[Tuple[float, float]] = (18, 6),
        show_censors: bool = True,
        ci_show: bool = False,
        title_prefix: Optional[str] = "Survival",
        dataset_strip_height: float = 0.02,
        dataset_strip_y: float = 1.01,
        save_dir: Optional[str] = None
    ) -> Tuple[plt.Figure, pd.DataFrame]:
        """
        Plot a manuscript-style grid of Kaplan–Meier survival curves for all datasets.

        Args:
            figsize (Tuple[float, float], optional): Figure dimensions (W, H).
            show_censors (bool): Whether to display censor marks.
            ci_show (bool): Whether to render CI for KM curves.
            title_prefix (str, optional): Prefix for subplot titles.
            dataset_strip_height (float): Height of colored dataset strip.
            dataset_strip_y (float): Y-position of top dataset strip.
            save_dir (str, optional): If set, save individual dataset KM curves to this folder.

        Returns:
            Tuple[plt.Figure, pd.DataFrame]: The matplotlib Figure and the summary dataframe.

        Raises:
            None
        """
        plt.style.use(["science", "nature", "notebook"])
        plt.rcParams.update({
            "font.size": 10 * self.font_scale,
            "axes.titlesize": 11 * self.font_scale,
            "axes.labelsize": 10 * self.font_scale,
            "legend.fontsize": 9 * self.font_scale,
            "xtick.labelsize": 9 * self.font_scale,
            "ytick.labelsize": 9 * self.font_scale,
            "axes.linewidth": 1.0,
            "text.usetex": False,
            "axes.edgecolor": "#333333",
            "xtick.minor.visible": False,
            "ytick.minor.visible": False,
            "xtick.top": False,
            "ytick.right": False,
        })
        fig = plt.figure(figsize=figsize)
        gs = gridspec.GridSpec(2, 4, figure=fig, wspace=0.1, hspace=0.3,
                               width_ratios=[2, 1.25, 1.25, 1.25], height_ratios=[1, 1])
        ax_origin = fig.add_subplot(gs[:, 0])
        ax_origin.set_box_aspect(1)
        axes_small = []
        for r in range(2):
            for c in range(1, 4):
                ax = fig.add_subplot(gs[r, c])
                axes_small.append(ax)
        max_small_axes = len(axes_small)
        synthetic_names = [d for d in self.dataset_names if d != self.original_name]
        color_A = self.group_colors.get("GroupA", "#0072B2")
        color_B = self.group_colors.get("GroupB", "#D55E00")

        def _plot_single(ds_name: str, ax: plt.Axes, is_large: bool = False, save: bool = False) -> None:
            """
            Plot KM curve for a single dataset.

            Args:
                ds_name (str): Dataset name.
                ax (plt.Axes): Matplotlib axes.
                is_large (bool): If True, use larger annotation and legend.
                save (bool): If True, save individual plot.

            Returns:
                None
            """
            df_raw = self.datasets_dict[ds_name].copy()
            df = df_raw.copy()
            df[self.event_target] = pd.to_numeric(df.get(self.event_target, 0), errors="coerce").fillna(0).astype(int)
            df[self.time_target] = pd.to_numeric(df.get(self.time_target, np.nan), errors="coerce")
            df_A = df[df[self.phenotype_column] == self.value_A].copy()
            df_B = df[df[self.phenotype_column] == self.value_B].copy()
            n_A, n_B = len(df_A), len(df_B)
            ds_color = self.dataset_colors.get(ds_name, "#cccccc")
            ax.add_patch(
                plt.Rectangle(
                    (0, dataset_strip_y),
                    1,
                    dataset_strip_height,
                    transform=ax.transAxes,
                    facecolor=ds_color,
                    edgecolor="none",
                    clip_on=False,
                    zorder=10,
                )
            )
            kmf = KaplanMeierFitter()
            plotted_any = False
            for group_df, label, group_color in [
                (df_B, f"{self.value_B} (n={n_B})", color_B),
                (df_A, f"{self.value_A} (n={n_A})", color_A),
            ]:
                if len(group_df) == 0 or not group_df[self.time_target].notna().any():
                    continue
                g = group_df.dropna(subset=[self.time_target, self.event_target])
                if len(g) == 0:
                    continue
                kmf.fit(g[self.time_target], event_observed=g[self.event_target], label=label)
                kmf.plot_survival_function(
                    ax=ax,
                    ci_show=ci_show,
                    show_censors=show_censors,
                    color=group_color,
                    linewidth=2.0 if is_large else 1.8,
                    censor_styles={"ms": 4, "marker": "|", "mew": 1.2} if show_censors else None,
                )
                plotted_any = True
            # Annotations
            pvalue = self.summary_df.loc[self.summary_df["Dataset"] == ds_name, "pvalue"].values[0] \
                if self.summary_df is not None else np.nan
            cindex_text = self.summary_df.loc[self.summary_df["Dataset"] == ds_name, "C-index"].values[0] \
                if self.summary_df is not None else "NA"
            ptext = "p = NA" if (pvalue is None or np.isnan(pvalue)) else f"p = {pvalue:.4g}"
            ax.text(
                0.98, 0.96, ptext, transform=ax.transAxes,
                fontsize=10 * self.font_scale,
                horizontalalignment="right", verticalalignment="top", zorder=10,
            )
            ax.text(
                0.98, 0.88, f"C-index: {cindex_text}", transform=ax.transAxes,
                fontsize=10 * self.font_scale,
                horizontalalignment="right", verticalalignment="top", zorder=10,
            )
            title_t = f"{ds_name}" if title_prefix is None else f"{title_prefix} — {ds_name}"
            ax.set_title(title_t, fontweight="bold", pad=12)
            ax.set_xlabel(f"Time ({self.time_target})")
            ax.set_ylabel("Survival probability" if is_large else "")
            ax.set_ylim(0, 1.02)
            ax.grid(True, axis="y", linestyle="--", linewidth=0.6, alpha=0.5)
            if plotted_any:
                ax.legend(
                    frameon=True, framealpha=1.0, loc="lower left",
                    bbox_to_anchor=(0, 0.02), borderaxespad=0.2,
                )
            else:
                ax.text(
                    0.5, 0.5, "No valid survival data",
                    ha="center", va="center", transform=ax.transAxes,
                    fontsize=11 * self.font_scale,
                )
                leg = ax.get_legend()
                if leg is not None:
                    leg.remove()
            if save and save_dir:
                fig_indiv = plt.figure(figsize=(6, 5))
                ax_indiv = fig_indiv.add_subplot(111)
                _plot_single(ds_name, ax_indiv, is_large=True, save=False)
                fig_indiv.tight_layout()
                if self.is_pdf:
                    outpath = os.path.join(save_dir, f"{ds_name}_KMcurve.pdf")
                else: 
                    outpath = os.path.join(save_dir, f"{ds_name}_KMcurve.png")
                fig_indiv.savefig(outpath, dpi=300)
                plt.close(fig_indiv)

        _plot_single(self.original_name, ax_origin, is_large=True, save=bool(save_dir))
        for i, ds_name in enumerate(synthetic_names):
            if i >= max_small_axes:
                break
            _plot_single(ds_name, axes_small[i], is_large=False, save=bool(save_dir))
        for j in range(len(synthetic_names), max_small_axes):
            axes_small[j].axis("off")
        plt.tight_layout()
        return fig, self.summary_df

__init__(datasets_dict, phenotype, time_target='OS', event_target='OS_CNSR', dataset_order=None, dataset_colors=None, group_colors=None, font_scale=1.0, is_pdf=False, original_name='Origin')

Initialize SurvivalGridEvaluator for grid-based survival comparison.

Parameters:

Name Type Description Default
datasets_dict Dict[str, DataFrame]

Input datasets.

required
phenotype Dict[str, List[Any]]

{col: [value_A, value_B]}.

required
time_target str

Survival duration column.

'OS'
event_target str

Event indicator column.

'OS_CNSR'
dataset_order List[str]

Custom dataset plotting order.

None
dataset_colors Dict[str, str]

Strip colors per dataset.

None
group_colors Dict[str, str]

Colors for phenotype groups.

None
font_scale float

Plot font scaling.

1.0
is_pdf boolean

save as pdf or png.

False
original_name str

Name of the reference dataset (default: "Origin").

'Origin'

Returns:

Type Description
None

None

Raises:

Type Description
ValueError

On empty datasets or invalid phenotype specification.

Source code in src/synomicsbench/metrics/narrow_utility/survival_analysis.py
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def __init__(
    self,
    datasets_dict: Dict[str, pd.DataFrame],
    phenotype: Dict[str, List[Any]],
    time_target: str = "OS",
    event_target: str = "OS_CNSR",
    dataset_order: Optional[List[str]] = None,
    dataset_colors: Optional[Dict[str, str]] = None,
    group_colors: Optional[Dict[str, str]] = None,
    font_scale: float = 1.0,
    is_pdf: bool = False,
    original_name: str = "Origin"
) -> None:
    """
    Initialize SurvivalGridEvaluator for grid-based survival comparison.

    Args:
        datasets_dict (Dict[str, pd.DataFrame]): Input datasets.
        phenotype (Dict[str, List[Any]]): {col: [value_A, value_B]}.
        time_target (str): Survival duration column.
        event_target (str): Event indicator column.
        dataset_order (List[str], optional): Custom dataset plotting order.
        dataset_colors (Dict[str, str], optional): Strip colors per dataset.
        group_colors (Dict[str, str], optional): Colors for phenotype groups.
        font_scale (float): Plot font scaling.
        is_pdf (boolean): save as pdf or png.
        original_name (str): Name of the reference dataset (default: "Origin").

    Returns:
        None

    Raises:
        ValueError: On empty datasets or invalid phenotype specification.
    """
    if not datasets_dict:
        raise ValueError("datasets_dict must not be empty.")
    if not isinstance(phenotype, dict) or len(phenotype) != 1:
        raise ValueError("phenotype must be a dict with one mapping {column_name: [valA, valB]}.")
    ph_column, ph_values = next(iter(phenotype.items()))
    if not isinstance(ph_values, (list, tuple)) or len(ph_values) != 2:
        raise ValueError("phenotype value must be a list/tuple of exactly two values: [value_A, value_B].")
    self.datasets_dict = datasets_dict
    self.phenotype_column = ph_column
    self.value_A, self.value_B = ph_values
    self.time_target = time_target
    self.event_target = event_target
    self.dataset_order = dataset_order
    self.dataset_colors = dataset_colors or DATASET_COLORS.copy()
    self.group_colors = group_colors or GROUP_COLORS.copy()
    self.font_scale = font_scale
    self.is_pdf = is_pdf
    self.original_name = original_name
    self.dataset_names = self._get_dataset_order()
    self.summary_df = None

compute_cindex_scores(original_name=None)

Compute C-index similarity scores between the reference and synthetic datasets.

Parameters:

Name Type Description Default
original_name Optional[str]

Reference dataset for score calculation. If None, uses self.original_name.

None

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with additional column 'C-index_score'.

Raises:

Type Description
KeyError

If required columns or reference row are missing.

RuntimeError

If compute_survival_metrics() was not called prior.

Source code in src/synomicsbench/metrics/narrow_utility/survival_analysis.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def compute_cindex_scores(self, original_name: Optional[str] = None) -> pd.DataFrame:
    """
    Compute C-index similarity scores between the reference and synthetic datasets.

    Args:
        original_name (Optional[str]): Reference dataset for score calculation. 
            If None, uses self.original_name.

    Returns:
        pd.DataFrame: DataFrame with additional column 'C-index_score'.

    Raises:
        KeyError: If required columns or reference row are missing.
        RuntimeError: If compute_survival_metrics() was not called prior.
    """
    if self.summary_df is None:
        raise RuntimeError("Must call compute_survival_metrics() before scoring.")
    df = self.summary_df.copy()
    ref_name = original_name or self.original_name
    if "Dataset" not in df.columns:
        raise KeyError("'Dataset' column not found.")
    if ref_name not in df["Dataset"].values:
        raise KeyError(f"Reference dataset '{ref_name}' not found in summary dataframe.")
    parsed_cindex = []
    for val in df["C-index"]:
        try:
            f = float(str(val).split()[0])
            if not math.isfinite(f):
                f = np.nan
        except Exception:
            f = np.nan
        parsed_cindex.append(f)
    df["_cindex_parsed"] = parsed_cindex
    c_orig = df.loc[df["Dataset"] == ref_name, "_cindex_parsed"].iloc[0]
    scores = []
    for c_syn in df["_cindex_parsed"]:
        if pd.isna(c_orig) or pd.isna(c_syn):
            score = np.nan
        else:
            diff = abs(c_orig - c_syn)
            score = 1.0 - diff
            score = max(0.0, min(1.0, score))
        scores.append(score)
    df["C-index_score"] = scores
    df = df.drop(columns=["_cindex_parsed"])
    return df

compute_survival_metrics()

Compute log-rank test p-values and C-index for each dataset grid panel.

Returns:

Type Description
DataFrame

pd.DataFrame: DataFrame with columns ['Dataset', 'pvalue', 'C-index', 'n_A', 'n_B'].

Source code in src/synomicsbench/metrics/narrow_utility/survival_analysis.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def compute_survival_metrics(self) -> pd.DataFrame:
    """
    Compute log-rank test p-values and C-index for each dataset grid panel.

    Args:
        None

    Returns:
        pd.DataFrame: DataFrame with columns ['Dataset', 'pvalue', 'C-index', 'n_A', 'n_B'].
    """
    summary_rows = []
    for ds_name in self.dataset_names:
        df_raw = self.datasets_dict[ds_name].copy()
        df = df_raw.copy()
        df[self.event_target] = pd.to_numeric(df.get(self.event_target, 0), errors="coerce").fillna(0).astype(int)
        df[self.time_target] = pd.to_numeric(df.get(self.time_target, np.nan), errors="coerce")
        df_A = df[df[self.phenotype_column] == self.value_A]
        df_B = df[df[self.phenotype_column] == self.value_B]

        n_A, n_B = len(df_A), len(df_B)
        pvalue = np.nan
        try:
            if n_A > 0 and n_B > 0:
                gA = df_A.dropna(subset=[self.time_target, self.event_target])
                gB = df_B.dropna(subset=[self.time_target, self.event_target])
                if len(gA) > 0 and len(gB) > 0:
                    res = logrank_test(
                        gA[self.time_target], gB[self.time_target],
                        event_observed_A=gA[self.event_target], event_observed_B=gB[self.event_target]
                    )
                    pvalue = float(res.p_value)
        except Exception:
            pvalue = np.nan

        cindex_text = "NA"
        try:
            df_cox = df[[self.time_target, self.event_target]].copy()
            df_cox["phenotype_binary"] = pd.NA
            df_cox.loc[df[self.phenotype_column] == self.value_A, "phenotype_binary"] = 1
            df_cox.loc[df[self.phenotype_column] == self.value_B, "phenotype_binary"] = 0
            df_cox_fit = df_cox.dropna(subset=[self.time_target, "phenotype_binary"])
            df_cox_fit[self.time_target] = pd.to_numeric(df_cox_fit[self.time_target], errors="coerce")
            df_cox_fit = df_cox_fit.dropna(subset=[self.time_target])
            df_cox_fit["phenotype_binary"] = df_cox_fit["phenotype_binary"].astype(int)
            if (
                len(df_cox_fit) > 0
                and df_cox_fit["phenotype_binary"].nunique() > 1
                and df_cox_fit[self.event_target].sum() > 0
            ):
                cph = CoxPHFitter()
                cph.fit(df_cox_fit, duration_col=self.time_target, event_col=self.event_target, show_progress=False)
                partial_h = cph.predict_partial_hazard(df_cox_fit)
                cindex = concordance_index(df_cox_fit[self.time_target], -partial_h, df_cox_fit[self.event_target])
                cindex_text = f"{cindex:.3f}"
            else:
                cindex_text = "insufficient events"
        except Exception:
            cindex_text = "fit error"

        summary_rows.append(
            {
                "Dataset": ds_name,
                "pvalue": (np.nan if (pvalue is None or np.isnan(pvalue)) else float(pvalue)),
                "C-index": cindex_text,
                "n_A": int(n_A),
                "n_B": int(n_B),
            }
        )
    summary_df = pd.DataFrame(summary_rows, columns=["Dataset", "pvalue", "C-index", "n_A", "n_B"])
    summary_df["Dataset"] = pd.Categorical(summary_df["Dataset"], categories=self.dataset_names, ordered=True)
    summary_df = summary_df.sort_values("Dataset").reset_index(drop=True)
    self.summary_df = summary_df
    return summary_df

plot_grid(figsize=(18, 6), show_censors=True, ci_show=False, title_prefix='Survival', dataset_strip_height=0.02, dataset_strip_y=1.01, save_dir=None)

Plot a manuscript-style grid of Kaplan–Meier survival curves for all datasets.

Parameters:

Name Type Description Default
figsize Tuple[float, float]

Figure dimensions (W, H).

(18, 6)
show_censors bool

Whether to display censor marks.

True
ci_show bool

Whether to render CI for KM curves.

False
title_prefix str

Prefix for subplot titles.

'Survival'
dataset_strip_height float

Height of colored dataset strip.

0.02
dataset_strip_y float

Y-position of top dataset strip.

1.01
save_dir str

If set, save individual dataset KM curves to this folder.

None

Returns:

Type Description
Tuple[Figure, DataFrame]

Tuple[plt.Figure, pd.DataFrame]: The matplotlib Figure and the summary dataframe.

Source code in src/synomicsbench/metrics/narrow_utility/survival_analysis.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
def plot_grid(
    self,
    figsize: Optional[Tuple[float, float]] = (18, 6),
    show_censors: bool = True,
    ci_show: bool = False,
    title_prefix: Optional[str] = "Survival",
    dataset_strip_height: float = 0.02,
    dataset_strip_y: float = 1.01,
    save_dir: Optional[str] = None
) -> Tuple[plt.Figure, pd.DataFrame]:
    """
    Plot a manuscript-style grid of Kaplan–Meier survival curves for all datasets.

    Args:
        figsize (Tuple[float, float], optional): Figure dimensions (W, H).
        show_censors (bool): Whether to display censor marks.
        ci_show (bool): Whether to render CI for KM curves.
        title_prefix (str, optional): Prefix for subplot titles.
        dataset_strip_height (float): Height of colored dataset strip.
        dataset_strip_y (float): Y-position of top dataset strip.
        save_dir (str, optional): If set, save individual dataset KM curves to this folder.

    Returns:
        Tuple[plt.Figure, pd.DataFrame]: The matplotlib Figure and the summary dataframe.

    Raises:
        None
    """
    plt.style.use(["science", "nature", "notebook"])
    plt.rcParams.update({
        "font.size": 10 * self.font_scale,
        "axes.titlesize": 11 * self.font_scale,
        "axes.labelsize": 10 * self.font_scale,
        "legend.fontsize": 9 * self.font_scale,
        "xtick.labelsize": 9 * self.font_scale,
        "ytick.labelsize": 9 * self.font_scale,
        "axes.linewidth": 1.0,
        "text.usetex": False,
        "axes.edgecolor": "#333333",
        "xtick.minor.visible": False,
        "ytick.minor.visible": False,
        "xtick.top": False,
        "ytick.right": False,
    })
    fig = plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(2, 4, figure=fig, wspace=0.1, hspace=0.3,
                           width_ratios=[2, 1.25, 1.25, 1.25], height_ratios=[1, 1])
    ax_origin = fig.add_subplot(gs[:, 0])
    ax_origin.set_box_aspect(1)
    axes_small = []
    for r in range(2):
        for c in range(1, 4):
            ax = fig.add_subplot(gs[r, c])
            axes_small.append(ax)
    max_small_axes = len(axes_small)
    synthetic_names = [d for d in self.dataset_names if d != self.original_name]
    color_A = self.group_colors.get("GroupA", "#0072B2")
    color_B = self.group_colors.get("GroupB", "#D55E00")

    def _plot_single(ds_name: str, ax: plt.Axes, is_large: bool = False, save: bool = False) -> None:
        """
        Plot KM curve for a single dataset.

        Args:
            ds_name (str): Dataset name.
            ax (plt.Axes): Matplotlib axes.
            is_large (bool): If True, use larger annotation and legend.
            save (bool): If True, save individual plot.

        Returns:
            None
        """
        df_raw = self.datasets_dict[ds_name].copy()
        df = df_raw.copy()
        df[self.event_target] = pd.to_numeric(df.get(self.event_target, 0), errors="coerce").fillna(0).astype(int)
        df[self.time_target] = pd.to_numeric(df.get(self.time_target, np.nan), errors="coerce")
        df_A = df[df[self.phenotype_column] == self.value_A].copy()
        df_B = df[df[self.phenotype_column] == self.value_B].copy()
        n_A, n_B = len(df_A), len(df_B)
        ds_color = self.dataset_colors.get(ds_name, "#cccccc")
        ax.add_patch(
            plt.Rectangle(
                (0, dataset_strip_y),
                1,
                dataset_strip_height,
                transform=ax.transAxes,
                facecolor=ds_color,
                edgecolor="none",
                clip_on=False,
                zorder=10,
            )
        )
        kmf = KaplanMeierFitter()
        plotted_any = False
        for group_df, label, group_color in [
            (df_B, f"{self.value_B} (n={n_B})", color_B),
            (df_A, f"{self.value_A} (n={n_A})", color_A),
        ]:
            if len(group_df) == 0 or not group_df[self.time_target].notna().any():
                continue
            g = group_df.dropna(subset=[self.time_target, self.event_target])
            if len(g) == 0:
                continue
            kmf.fit(g[self.time_target], event_observed=g[self.event_target], label=label)
            kmf.plot_survival_function(
                ax=ax,
                ci_show=ci_show,
                show_censors=show_censors,
                color=group_color,
                linewidth=2.0 if is_large else 1.8,
                censor_styles={"ms": 4, "marker": "|", "mew": 1.2} if show_censors else None,
            )
            plotted_any = True
        # Annotations
        pvalue = self.summary_df.loc[self.summary_df["Dataset"] == ds_name, "pvalue"].values[0] \
            if self.summary_df is not None else np.nan
        cindex_text = self.summary_df.loc[self.summary_df["Dataset"] == ds_name, "C-index"].values[0] \
            if self.summary_df is not None else "NA"
        ptext = "p = NA" if (pvalue is None or np.isnan(pvalue)) else f"p = {pvalue:.4g}"
        ax.text(
            0.98, 0.96, ptext, transform=ax.transAxes,
            fontsize=10 * self.font_scale,
            horizontalalignment="right", verticalalignment="top", zorder=10,
        )
        ax.text(
            0.98, 0.88, f"C-index: {cindex_text}", transform=ax.transAxes,
            fontsize=10 * self.font_scale,
            horizontalalignment="right", verticalalignment="top", zorder=10,
        )
        title_t = f"{ds_name}" if title_prefix is None else f"{title_prefix} — {ds_name}"
        ax.set_title(title_t, fontweight="bold", pad=12)
        ax.set_xlabel(f"Time ({self.time_target})")
        ax.set_ylabel("Survival probability" if is_large else "")
        ax.set_ylim(0, 1.02)
        ax.grid(True, axis="y", linestyle="--", linewidth=0.6, alpha=0.5)
        if plotted_any:
            ax.legend(
                frameon=True, framealpha=1.0, loc="lower left",
                bbox_to_anchor=(0, 0.02), borderaxespad=0.2,
            )
        else:
            ax.text(
                0.5, 0.5, "No valid survival data",
                ha="center", va="center", transform=ax.transAxes,
                fontsize=11 * self.font_scale,
            )
            leg = ax.get_legend()
            if leg is not None:
                leg.remove()
        if save and save_dir:
            fig_indiv = plt.figure(figsize=(6, 5))
            ax_indiv = fig_indiv.add_subplot(111)
            _plot_single(ds_name, ax_indiv, is_large=True, save=False)
            fig_indiv.tight_layout()
            if self.is_pdf:
                outpath = os.path.join(save_dir, f"{ds_name}_KMcurve.pdf")
            else: 
                outpath = os.path.join(save_dir, f"{ds_name}_KMcurve.png")
            fig_indiv.savefig(outpath, dpi=300)
            plt.close(fig_indiv)

    _plot_single(self.original_name, ax_origin, is_large=True, save=bool(save_dir))
    for i, ds_name in enumerate(synthetic_names):
        if i >= max_small_axes:
            break
        _plot_single(ds_name, axes_small[i], is_large=False, save=bool(save_dir))
    for j in range(len(synthetic_names), max_small_axes):
        axes_small[j].axis("off")
    plt.tight_layout()
    return fig, self.summary_df

predictive_model_comp

BayesianComparison

BaycompStyle dataclass

Store manuscript-style visualization settings for baycomp heatmaps.

Parameters:

Name Type Description Default
fontsize int

Base font size used in matplotlib rcParams.

11
nature_font Dict[str, Sequence[str]]

Font family configuration.

None
plot_colors Dict[str, str]

Common background/grid colors.

None
pbetter_cmap LinearSegmentedColormap

Colormap for P(Better) heatmaps.

(lambda: PBETTER_FOCUS_CMAP)()
cancer_colors Dict[str, str]

Color strip mapping for each cancer panel.

None

Raises:

Type Description
ValueError

If fontsize is not positive.

Source code in src/synomicsbench/metrics/narrow_utility/BayesianComparison.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
@dataclass(frozen=True)
class BaycompStyle:
    """
    Store manuscript-style visualization settings for baycomp heatmaps.

    Args:
        fontsize (int): Base font size used in matplotlib rcParams.
        nature_font (Dict[str, Sequence[str]]): Font family configuration.
        plot_colors (Dict[str, str]): Common background/grid colors.
        pbetter_cmap (LinearSegmentedColormap): Colormap for P(Better) heatmaps.
        cancer_colors (Dict[str, str]): Color strip mapping for each cancer panel.

    Raises:
        ValueError: If fontsize is not positive.
    """
    fontsize: int = 11
    nature_font: Dict[str, Sequence[str]] = None  # type: ignore[assignment]
    plot_colors: Dict[str, str] = None  # type: ignore[assignment]
    # pbetter_cmap: LinearSegmentedColormap = PBETTER_FOCUS_CMAP
    pbetter_cmap: LinearSegmentedColormap = field(
        default_factory=lambda: PBETTER_FOCUS_CMAP
    ) 
    cancer_colors: Dict[str, str] = None  # type: ignore[assignment]

    def __post_init__(self) -> None:
        if self.fontsize <= 0:
            raise ValueError("fontsize must be a positive integer.")
        object.__setattr__(self, "nature_font", dict(NATURE_FONT))
        object.__setattr__(self, "plot_colors", dict(STABILITY_PLOT_COLORS))
        object.__setattr__(self, "cancer_colors", dict(CANCER_COLORS))

BayesianComparison dataclass

Compute Bayesian pairwise comparisons using the external baycomp package and plot heatmaps.

Parameters:

Name Type Description Default
rope float

ROPE threshold for practical equivalence.

0.01
seed int

Random seed for reproducibility in baycomp.

0
style BaycompStyle

Plot styling settings for consistent manuscript figures.

BaycompStyle()

Raises:

Type Description
ValueError

If rope is not positive.

Source code in src/synomicsbench/metrics/narrow_utility/BayesianComparison.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
@dataclass
class BayesianComparison:
    """
    Compute Bayesian pairwise comparisons using the external baycomp package and plot heatmaps.

    Args:
        rope (float): ROPE threshold for practical equivalence.
        seed (int): Random seed for reproducibility in baycomp.
        style (BaycompStyle): Plot styling settings for consistent manuscript figures.

    Raises:
        ValueError: If rope is not positive.
    """
    rope: float = 0.01
    seed: int = 0
    style: BaycompStyle = BaycompStyle()

    def __post_init__(self) -> None:
        if self.rope <= 0:
            raise ValueError("rope must be > 0.")

    @staticmethod
    def _validate_scores(scores: Sequence[float], method: str) -> np.ndarray:
        """
        Validate and coerce a score sequence to a finite numpy array.

        Args:
            scores (Sequence[float]): Metric values across repeated runs (seeds/folds).
            method (str): Method name for error messages.

        Returns:
            np.ndarray: Finite float numpy array of scores.

        Raises:
            ValueError: If fewer than 2 finite scores are present.
        """
        x = np.asarray(scores, dtype=float)
        x = x[np.isfinite(x)]
        if x.size < 2:
            raise ValueError(f"Method '{method}' must have at least 2 finite scores for baycomp.")
        return x

    @staticmethod
    def _resolve_methods_order(
        method_to_scores: Mapping[str, Sequence[float]],
        methods_order: Optional[Sequence[str]],
    ) -> List[str]:
        """
        Resolve and validate the method ordering.

        Args:
            method_to_scores (Mapping[str, Sequence[float]]): Mapping method -> scores.
            methods_order (Optional[Sequence[str]]): Optional explicit ordering.

        Returns:
            List[str]: Ordered list of methods.

        Raises:
            ValueError: If fewer than 2 methods are provided.
            ValueError: If methods_order contains methods not present in method_to_scores.
        """
        if not method_to_scores or len(method_to_scores) < 2:
            raise ValueError("method_to_scores must contain at least 2 methods.")

        if methods_order is None:
            return list(method_to_scores.keys())

        resolved = list(methods_order)
        missing = [m for m in resolved if m not in method_to_scores]
        if missing:
            raise ValueError(f"methods_order contains methods missing in method_to_scores: {missing}")
        return resolved

    def compare_methods(
        self,
        method_to_scores: Mapping[str, Sequence[float]],
        methods_order: Optional[Sequence[str]] = None,
    ) -> pd.DataFrame:
        """
        Compute ordered-pair Bayesian comparison probabilities for a set of methods.

        Args:
            method_to_scores (Mapping[str, Sequence[float]]): Mapping method -> scores across seeds/folds.
            methods_order (Optional[Sequence[str]]): Optional ordering for methods.

        Returns:
            pd.DataFrame: Table with columns:
                ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

        Raises:
            ValueError: If fewer than 2 methods are provided.
            ValueError: If any method has fewer than 2 finite scores.
        """
        methods = self._resolve_methods_order(method_to_scores, methods_order)
        scores_np = {m: self._validate_scores(method_to_scores[m], method=m) for m in methods}

        rows: List[dict] = []
        for i, m1 in enumerate(methods):
            for j, m2 in enumerate(methods):
                if i == j:
                    continue

                p_left, p_rope, p_right = baycomp.two_on_single(
                    scores_np[m1],
                    scores_np[m2],
                    rope=self.rope,
                    plot=False,
                )

                rows.append(
                    {
                        "Method 1": m1,
                        "Method 2": m2,
                        "Better Prob": float(p_left),
                        "Worse Prob": float(p_right),
                        "Equivalent Prob": float(p_rope),
                    }
                )

        return pd.DataFrame(rows)

    @staticmethod
    def comparison_to_matrix(
        comparison_df: pd.DataFrame,
        methods_order: Sequence[str],
        value_col: str = "Better Prob",
        nan_diagonal: bool = True,
    ) -> pd.DataFrame:
        """
        Convert a comparison table into a square matrix suitable for heatmap plotting.

        Args:
            comparison_df (pd.DataFrame): Output of compare_methods.
            methods_order (Sequence[str]): Method ordering for rows and columns.
            value_col (str): Which column to visualize.
            nan_diagonal (bool): If True, set diagonal values to NaN.

        Returns:
            pd.DataFrame: Square matrix with index=Method 1 and columns=Method 2.

        Raises:
            ValueError: If value_col is not present in comparison_df.
        """
        if value_col not in comparison_df.columns:
            raise ValueError(f"value_col must be a column in comparison_df, got '{value_col}'.")

        mat = comparison_df.pivot(index="Method 1", columns="Method 2", values=value_col)
        mat = mat.reindex(index=list(methods_order), columns=list(methods_order))

        if nan_diagonal:
            for m in methods_order:
                if m in mat.index and m in mat.columns:
                    mat.loc[m, m] = np.nan

        return mat

    def plot_pbetter_heatmap_grid(
        self,
        cancer_to_method_scores: Mapping[str, Mapping[str, Sequence[float]]],
        cancers_order: Sequence[str] = ("ccRCC", "Melanoma", "NSCLC"),
        methods_order: Optional[Sequence[str]] = None,
        value_col: str = "Better Prob",
        figsize: Tuple[float, float] = (18, 5),
        annot: bool = True,
        fmt: str = ".2f",
        missing_hatch: str = "///",
        show: bool = True,
    ) -> Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
        """
        Plot a 1xN grid of baycomp probability heatmaps (one per cancer).

        Args:
            cancer_to_method_scores (Mapping[str, Mapping[str, Sequence[float]]]): Cancer -> method/tool -> score list.
            cancers_order (Sequence[str]): Order of cancers in the grid.
            methods_order (Optional[Sequence[str]]): Global method ordering. If None, uses union across cancers.
            value_col (str): Which probability to visualize:
                "Better Prob", "Worse Prob", or "Equivalent Prob".
            figsize (Tuple[float, float]): Figure size.
            annot (bool): If True, annotate cells (NaN cells are blank).
            fmt (str): Annotation format.
            missing_hatch (str): Hatch pattern for missing (NaN) cells.
            show (bool): If True, calls plt.show().

        Returns:
            Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: Tuple containing:
                - fig: Matplotlib Figure object.
                - axes: Array of Axes objects.
                - matrices_by_cancer: Probability matrices for each cancer.
                - comparison_dfs_by_cancer: Comparison DataFrames for each cancer.

        Raises:
            ValueError: If value_col is invalid.
            ValueError: If any requested cancer is missing from cancer_to_method_scores.
            ValueError: If resolved methods_order is empty.
        """
        allowed = {"Better Prob", "Worse Prob", "Equivalent Prob"}
        if value_col not in allowed:
            raise ValueError(f"value_col must be one of {sorted(allowed)}.")

        cancers = list(cancers_order)
        for c in cancers:
            if c not in cancer_to_method_scores:
                raise ValueError(f"Cancer '{c}' missing from cancer_to_method_scores.")

        if methods_order is None:
            union_methods: List[str] = []
            for c in cancers:
                for m in cancer_to_method_scores[c].keys():
                    if m not in union_methods:
                        union_methods.append(m)
            methods_order_resolved = union_methods
        else:
            methods_order_resolved = list(methods_order)

        if len(methods_order_resolved) == 0:
            raise ValueError("methods_order resolved to an empty list.")

        _set_nature_rcparams(fontsize=self.style.fontsize)
        sns.set(style="white", rc={"axes.facecolor": self.style.plot_colors["heatmap_bg"]})

        fig, axes = plt.subplots(1, len(cancers), figsize=figsize)
        if len(cancers) == 1:
            axes = np.array([axes])

        norm = TwoSlopeNorm(vmin=0.0, vcenter=0.5, vmax=1.0)
        cmap = self.style.pbetter_cmap if value_col == "Better Prob" else plt.get_cmap("viridis")

        matrices_by_cancer: Dict[str, pd.DataFrame] = {}
        comparison_by_cancer: Dict[str, pd.DataFrame] = {}

        for ax, cancer in zip(axes, cancers):
            present_methods = set(cancer_to_method_scores[cancer].keys())

            comp_df = build_baycomp_comparison_df(
                method_to_scores=cancer_to_method_scores[cancer],
                methods_order=methods_order_resolved,
                rope=self.rope,
            )
            comparison_by_cancer[cancer] = comp_df

            mat = pd.DataFrame(index=methods_order_resolved, columns=methods_order_resolved, dtype=float)
            if not comp_df.empty:
                mat_present = comp_df.pivot(index="Method 1", columns="Method 2", values=value_col)
                mat.loc[mat_present.index, mat_present.columns] = mat_present

            for m in methods_order_resolved:
                if m in present_methods:
                    mat.loc[m, m] = 0.5

            matrices_by_cancer[cancer] = mat

            mask = mat.isna()
            sns.heatmap(
                mat,
                ax=ax,
                cmap=cmap,
                norm=norm if value_col == "Better Prob" else None,
                mask=mask,
                annot=annot,
                fmt=fmt,
                linewidths=0.5,
                linecolor="lightgray",
                cbar=(ax is axes[-1]),
                cbar_kws={"label": value_col} if (ax is axes[-1]) else None,
                square=True,
            )

            _overlay_hatched_missing_cells(
                ax=ax,
                mat=mat,
                hatch=missing_hatch,
                edgecolor=self.style.plot_colors.get("missing_edge", STABILITY_PLOT_COLORS["missing_edge"]),
                facecolor=self.style.plot_colors.get("missing_face", STABILITY_PLOT_COLORS["missing_face"]),
                linewidth=0.6,
            )

            for i, m in enumerate(methods_order_resolved):
                if m in present_methods:
                    rect = mpatches.Rectangle(
                        (i, i),
                        1.0,
                        1.0,
                        facecolor="#D3D3D3",
                        edgecolor="lightgray",
                        linewidth=0.5,
                        zorder=11,
                    )
                    ax.add_patch(rect)

            ax.set_title(cancer, fontsize=self.style.fontsize + 1, fontweight="bold", pad=10)

            cancer_color = self.style.cancer_colors.get(cancer, "#333333")
            ax.plot(
                [0.02, 0.98],
                [1.02, 1.02],
                transform=ax.transAxes,
                color=cancer_color,
                lw=4,
                clip_on=False,
            )

            ax.set_xlabel("Method 2", fontsize=self.style.fontsize)
            ax.set_ylabel("Method 1", fontsize=self.style.fontsize)
            plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
            ax.tick_params(axis="y", rotation=0)

        plt.tight_layout()
        if show:
            plt.show()

        return fig, axes, matrices_by_cancer, comparison_by_cancer

compare_methods(method_to_scores, methods_order=None)

Compute ordered-pair Bayesian comparison probabilities for a set of methods.

Parameters:

Name Type Description Default
method_to_scores Mapping[str, Sequence[float]]

Mapping method -> scores across seeds/folds.

required
methods_order Optional[Sequence[str]]

Optional ordering for methods.

None

Returns:

Type Description
DataFrame

pd.DataFrame: Table with columns: ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

Raises:

Type Description
ValueError

If fewer than 2 methods are provided.

ValueError

If any method has fewer than 2 finite scores.

Source code in src/synomicsbench/metrics/narrow_utility/BayesianComparison.py
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def compare_methods(
    self,
    method_to_scores: Mapping[str, Sequence[float]],
    methods_order: Optional[Sequence[str]] = None,
) -> pd.DataFrame:
    """
    Compute ordered-pair Bayesian comparison probabilities for a set of methods.

    Args:
        method_to_scores (Mapping[str, Sequence[float]]): Mapping method -> scores across seeds/folds.
        methods_order (Optional[Sequence[str]]): Optional ordering for methods.

    Returns:
        pd.DataFrame: Table with columns:
            ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

    Raises:
        ValueError: If fewer than 2 methods are provided.
        ValueError: If any method has fewer than 2 finite scores.
    """
    methods = self._resolve_methods_order(method_to_scores, methods_order)
    scores_np = {m: self._validate_scores(method_to_scores[m], method=m) for m in methods}

    rows: List[dict] = []
    for i, m1 in enumerate(methods):
        for j, m2 in enumerate(methods):
            if i == j:
                continue

            p_left, p_rope, p_right = baycomp.two_on_single(
                scores_np[m1],
                scores_np[m2],
                rope=self.rope,
                plot=False,
            )

            rows.append(
                {
                    "Method 1": m1,
                    "Method 2": m2,
                    "Better Prob": float(p_left),
                    "Worse Prob": float(p_right),
                    "Equivalent Prob": float(p_rope),
                }
            )

    return pd.DataFrame(rows)

comparison_to_matrix(comparison_df, methods_order, value_col='Better Prob', nan_diagonal=True) staticmethod

Convert a comparison table into a square matrix suitable for heatmap plotting.

Parameters:

Name Type Description Default
comparison_df DataFrame

Output of compare_methods.

required
methods_order Sequence[str]

Method ordering for rows and columns.

required
value_col str

Which column to visualize.

'Better Prob'
nan_diagonal bool

If True, set diagonal values to NaN.

True

Returns:

Type Description
DataFrame

pd.DataFrame: Square matrix with index=Method 1 and columns=Method 2.

Raises:

Type Description
ValueError

If value_col is not present in comparison_df.

Source code in src/synomicsbench/metrics/narrow_utility/BayesianComparison.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
@staticmethod
def comparison_to_matrix(
    comparison_df: pd.DataFrame,
    methods_order: Sequence[str],
    value_col: str = "Better Prob",
    nan_diagonal: bool = True,
) -> pd.DataFrame:
    """
    Convert a comparison table into a square matrix suitable for heatmap plotting.

    Args:
        comparison_df (pd.DataFrame): Output of compare_methods.
        methods_order (Sequence[str]): Method ordering for rows and columns.
        value_col (str): Which column to visualize.
        nan_diagonal (bool): If True, set diagonal values to NaN.

    Returns:
        pd.DataFrame: Square matrix with index=Method 1 and columns=Method 2.

    Raises:
        ValueError: If value_col is not present in comparison_df.
    """
    if value_col not in comparison_df.columns:
        raise ValueError(f"value_col must be a column in comparison_df, got '{value_col}'.")

    mat = comparison_df.pivot(index="Method 1", columns="Method 2", values=value_col)
    mat = mat.reindex(index=list(methods_order), columns=list(methods_order))

    if nan_diagonal:
        for m in methods_order:
            if m in mat.index and m in mat.columns:
                mat.loc[m, m] = np.nan

    return mat

plot_pbetter_heatmap_grid(cancer_to_method_scores, cancers_order=('ccRCC', 'Melanoma', 'NSCLC'), methods_order=None, value_col='Better Prob', figsize=(18, 5), annot=True, fmt='.2f', missing_hatch='///', show=True)

Plot a 1xN grid of baycomp probability heatmaps (one per cancer).

Parameters:

Name Type Description Default
cancer_to_method_scores Mapping[str, Mapping[str, Sequence[float]]]

Cancer -> method/tool -> score list.

required
cancers_order Sequence[str]

Order of cancers in the grid.

('ccRCC', 'Melanoma', 'NSCLC')
methods_order Optional[Sequence[str]]

Global method ordering. If None, uses union across cancers.

None
value_col str

Which probability to visualize: "Better Prob", "Worse Prob", or "Equivalent Prob".

'Better Prob'
figsize Tuple[float, float]

Figure size.

(18, 5)
annot bool

If True, annotate cells (NaN cells are blank).

True
fmt str

Annotation format.

'.2f'
missing_hatch str

Hatch pattern for missing (NaN) cells.

'///'
show bool

If True, calls plt.show().

True

Returns:

Type Description
Tuple[Figure, ndarray, Dict[str, DataFrame], Dict[str, DataFrame]]

Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: Tuple containing: - fig: Matplotlib Figure object. - axes: Array of Axes objects. - matrices_by_cancer: Probability matrices for each cancer. - comparison_dfs_by_cancer: Comparison DataFrames for each cancer.

Raises:

Type Description
ValueError

If value_col is invalid.

ValueError

If any requested cancer is missing from cancer_to_method_scores.

ValueError

If resolved methods_order is empty.

Source code in src/synomicsbench/metrics/narrow_utility/BayesianComparison.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def plot_pbetter_heatmap_grid(
    self,
    cancer_to_method_scores: Mapping[str, Mapping[str, Sequence[float]]],
    cancers_order: Sequence[str] = ("ccRCC", "Melanoma", "NSCLC"),
    methods_order: Optional[Sequence[str]] = None,
    value_col: str = "Better Prob",
    figsize: Tuple[float, float] = (18, 5),
    annot: bool = True,
    fmt: str = ".2f",
    missing_hatch: str = "///",
    show: bool = True,
) -> Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    """
    Plot a 1xN grid of baycomp probability heatmaps (one per cancer).

    Args:
        cancer_to_method_scores (Mapping[str, Mapping[str, Sequence[float]]]): Cancer -> method/tool -> score list.
        cancers_order (Sequence[str]): Order of cancers in the grid.
        methods_order (Optional[Sequence[str]]): Global method ordering. If None, uses union across cancers.
        value_col (str): Which probability to visualize:
            "Better Prob", "Worse Prob", or "Equivalent Prob".
        figsize (Tuple[float, float]): Figure size.
        annot (bool): If True, annotate cells (NaN cells are blank).
        fmt (str): Annotation format.
        missing_hatch (str): Hatch pattern for missing (NaN) cells.
        show (bool): If True, calls plt.show().

    Returns:
        Tuple[plt.Figure, np.ndarray, Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]: Tuple containing:
            - fig: Matplotlib Figure object.
            - axes: Array of Axes objects.
            - matrices_by_cancer: Probability matrices for each cancer.
            - comparison_dfs_by_cancer: Comparison DataFrames for each cancer.

    Raises:
        ValueError: If value_col is invalid.
        ValueError: If any requested cancer is missing from cancer_to_method_scores.
        ValueError: If resolved methods_order is empty.
    """
    allowed = {"Better Prob", "Worse Prob", "Equivalent Prob"}
    if value_col not in allowed:
        raise ValueError(f"value_col must be one of {sorted(allowed)}.")

    cancers = list(cancers_order)
    for c in cancers:
        if c not in cancer_to_method_scores:
            raise ValueError(f"Cancer '{c}' missing from cancer_to_method_scores.")

    if methods_order is None:
        union_methods: List[str] = []
        for c in cancers:
            for m in cancer_to_method_scores[c].keys():
                if m not in union_methods:
                    union_methods.append(m)
        methods_order_resolved = union_methods
    else:
        methods_order_resolved = list(methods_order)

    if len(methods_order_resolved) == 0:
        raise ValueError("methods_order resolved to an empty list.")

    _set_nature_rcparams(fontsize=self.style.fontsize)
    sns.set(style="white", rc={"axes.facecolor": self.style.plot_colors["heatmap_bg"]})

    fig, axes = plt.subplots(1, len(cancers), figsize=figsize)
    if len(cancers) == 1:
        axes = np.array([axes])

    norm = TwoSlopeNorm(vmin=0.0, vcenter=0.5, vmax=1.0)
    cmap = self.style.pbetter_cmap if value_col == "Better Prob" else plt.get_cmap("viridis")

    matrices_by_cancer: Dict[str, pd.DataFrame] = {}
    comparison_by_cancer: Dict[str, pd.DataFrame] = {}

    for ax, cancer in zip(axes, cancers):
        present_methods = set(cancer_to_method_scores[cancer].keys())

        comp_df = build_baycomp_comparison_df(
            method_to_scores=cancer_to_method_scores[cancer],
            methods_order=methods_order_resolved,
            rope=self.rope,
        )
        comparison_by_cancer[cancer] = comp_df

        mat = pd.DataFrame(index=methods_order_resolved, columns=methods_order_resolved, dtype=float)
        if not comp_df.empty:
            mat_present = comp_df.pivot(index="Method 1", columns="Method 2", values=value_col)
            mat.loc[mat_present.index, mat_present.columns] = mat_present

        for m in methods_order_resolved:
            if m in present_methods:
                mat.loc[m, m] = 0.5

        matrices_by_cancer[cancer] = mat

        mask = mat.isna()
        sns.heatmap(
            mat,
            ax=ax,
            cmap=cmap,
            norm=norm if value_col == "Better Prob" else None,
            mask=mask,
            annot=annot,
            fmt=fmt,
            linewidths=0.5,
            linecolor="lightgray",
            cbar=(ax is axes[-1]),
            cbar_kws={"label": value_col} if (ax is axes[-1]) else None,
            square=True,
        )

        _overlay_hatched_missing_cells(
            ax=ax,
            mat=mat,
            hatch=missing_hatch,
            edgecolor=self.style.plot_colors.get("missing_edge", STABILITY_PLOT_COLORS["missing_edge"]),
            facecolor=self.style.plot_colors.get("missing_face", STABILITY_PLOT_COLORS["missing_face"]),
            linewidth=0.6,
        )

        for i, m in enumerate(methods_order_resolved):
            if m in present_methods:
                rect = mpatches.Rectangle(
                    (i, i),
                    1.0,
                    1.0,
                    facecolor="#D3D3D3",
                    edgecolor="lightgray",
                    linewidth=0.5,
                    zorder=11,
                )
                ax.add_patch(rect)

        ax.set_title(cancer, fontsize=self.style.fontsize + 1, fontweight="bold", pad=10)

        cancer_color = self.style.cancer_colors.get(cancer, "#333333")
        ax.plot(
            [0.02, 0.98],
            [1.02, 1.02],
            transform=ax.transAxes,
            color=cancer_color,
            lw=4,
            clip_on=False,
        )

        ax.set_xlabel("Method 2", fontsize=self.style.fontsize)
        ax.set_ylabel("Method 1", fontsize=self.style.fontsize)
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
        ax.tick_params(axis="y", rotation=0)

    plt.tight_layout()
    if show:
        plt.show()

    return fig, axes, matrices_by_cancer, comparison_by_cancer

build_baycomp_comparison_df(method_to_scores, methods_order=None, rope=0.01)

Compute pairwise Bayesian comparison probabilities (better/worse/equivalent) using baycomp.two_on_single for all ordered pairs of methods.

Parameters:

Name Type Description Default
method_to_scores Mapping[str, Sequence[float]]

Mapping tool/method -> list/array of scores.

required
methods_order Optional[Sequence[str]]

Optional method ordering. If None, uses dict keys.

None
rope float

ROPE threshold for practical equivalence.

0.01

Returns:

Type Description
DataFrame

pd.DataFrame: Comparison table with columns: ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

Raises:

Type Description
ValueError

If rope is not positive.

ValueError

If fewer than 2 methods are provided.

ValueError

If any method has fewer than 2 finite scores.

Source code in src/synomicsbench/metrics/narrow_utility/BayesianComparison.py
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def build_baycomp_comparison_df(
    method_to_scores: Mapping[str, Sequence[float]],
    methods_order: Optional[Sequence[str]] = None,
    rope: float = 0.01,
) -> pd.DataFrame:
    """
    Compute pairwise Bayesian comparison probabilities (better/worse/equivalent)
    using baycomp.two_on_single for all ordered pairs of methods.

    Args:
        method_to_scores (Mapping[str, Sequence[float]]): Mapping tool/method -> list/array of scores.
        methods_order (Optional[Sequence[str]]): Optional method ordering. If None, uses dict keys.
        rope (float): ROPE threshold for practical equivalence.

    Returns:
        pd.DataFrame: Comparison table with columns:
            ["Method 1", "Method 2", "Better Prob", "Worse Prob", "Equivalent Prob"].

    Raises:
        ValueError: If rope is not positive.
        ValueError: If fewer than 2 methods are provided.
        ValueError: If any method has fewer than 2 finite scores.
    """
    if rope <= 0:
        raise ValueError("rope must be > 0.")

    if not method_to_scores or len(method_to_scores) < 2:
        raise ValueError("method_to_scores must contain at least 2 methods.")

    methods = list(method_to_scores.keys()) if methods_order is None else list(methods_order)

    scores_np: Dict[str, np.ndarray] = {}
    for m in methods:
        if m not in method_to_scores:
            continue
        x = np.asarray(method_to_scores[m], dtype=float)
        x = x[np.isfinite(x)]
        if x.size < 2:
            raise ValueError(f"Method '{m}' must have at least 2 finite scores for baycomp.")
        scores_np[m] = x

    present_methods = [m for m in methods if m in scores_np]
    if len(present_methods) < 2:
        raise ValueError("Fewer than 2 methods remain after filtering invalid scores / missing methods.")

    rows: List[dict] = []
    for m1 in present_methods:
        for m2 in present_methods:
            if m1 == m2:
                continue

            p_left, p_rope, p_right = baycomp.two_on_single(
                scores_np[m1],
                scores_np[m2],
                rope=rope,
                plot=False,
            )

            rows.append(
                {
                    "Method 1": m1,
                    "Method 2": m2,
                    "Better Prob": float(p_left),
                    "Worse Prob": float(p_right),
                    "Equivalent Prob": float(p_rope),
                }
            )

    return pd.DataFrame(rows)

Metrics: Privacy

The metrics.privacy module provides tools to assess the likelihood of privacy attacks using synthetic data.

singling_out

Singling-out risk evaluation for synthetic data.

Provides helper functions that wrap the Anonymeter SinglingOutEvaluator to run univariate and multivariate singling-out attacks across multiple feature-sampling proportions.

eval_singling_out_multivariate(ori, syns, n_cols_list=(2, 3, 5, 7, 10, 20, 50), n_attacks=10000, max_attempts=1000000, seed=42)

Run multivariate singling-out attacks at varying attribute-combination sizes.

Parameters:

Name Type Description Default
ori DataFrame

Original dataset.

required
syns Dict[str, DataFrame]

Mapping from synthetic dataset name to its DataFrame.

required
n_cols_list Sequence[int]

Number of columns the attacker uses per predicate.

(2, 3, 5, 7, 10, 20, 50)
n_attacks int

Number of attack predicates per evaluation.

10000
max_attempts int

Maximum predicate generation attempts.

1000000
seed int

Random seed.

42

Returns:

Type Description
Dict[str, List[SinglingOutEvaluator]]

Mapping from synthetic dataset name to a list of evaluated

Dict[str, List[SinglingOutEvaluator]]

SinglingOutEvaluator instances (one per n_cols value).

Source code in src/synomicsbench/metrics/privacy/singling_out.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def eval_singling_out_multivariate(
    ori: pd.DataFrame,
    syns: Dict[str, pd.DataFrame],
    n_cols_list: Sequence[int] = (2, 3, 5, 7, 10, 20, 50),
    n_attacks: int = 10_000,
    max_attempts: int = 1_000_000,
    seed: int = 42,
) -> Dict[str, List[SinglingOutEvaluator]]:
    """Run multivariate singling-out attacks at varying attribute-combination sizes.

    Args:
        ori: Original dataset.
        syns: Mapping from synthetic dataset name to its DataFrame.
        n_cols_list: Number of columns the attacker uses per predicate.
        n_attacks: Number of attack predicates per evaluation.
        max_attempts: Maximum predicate generation attempts.
        seed: Random seed.

    Returns:
        Mapping from synthetic dataset name to a list of evaluated
        ``SinglingOutEvaluator`` instances (one per ``n_cols`` value).
    """
    results: Dict[str, List[SinglingOutEvaluator]] = {}

    for syn_name, syn_df in syns.items():
        risks: List[SinglingOutEvaluator] = []
        for n_col in n_cols_list:
            evaluator = SinglingOutEvaluator(
                ori=ori,
                syn=syn_df,
                n_cols=n_col,
                n_attacks=n_attacks,
                max_attempts=max_attempts,
            )
            evaluator.evaluate(mode="multivariate")
            risks.append(evaluator)
        results[syn_name] = risks

    return results

eval_singling_out_univariate(ori, syns, n_attacks=10000, max_attempts=1000000, proportions=(0.25, 0.5, 0.75, 1.0), seed=42)

Run univariate singling-out attacks at varying feature proportions.

For each synthetic dataset and each proportion p, a random subset of p × n_features columns is selected and the SinglingOutEvaluator is run in univariate mode.

Parameters:

Name Type Description Default
ori DataFrame

Original (real) dataset after post-processing.

required
syns Dict[str, DataFrame]

Mapping from synthetic dataset name to its DataFrame.

required
n_attacks int

Number of attack predicates to generate per evaluation.

10000
max_attempts int

Maximum predicate generation attempts.

1000000
proportions Sequence[float]

Fractions of columns to sample (e.g. 25 %, 50 %, …).

(0.25, 0.5, 0.75, 1.0)
seed int

Random seed for column sampling and the evaluator.

42

Returns:

Type Description
Dict[str, List[SinglingOutEvaluator]]

Mapping from synthetic dataset name to a list of evaluated

Dict[str, List[SinglingOutEvaluator]]

SinglingOutEvaluator instances (one per proportion).

Source code in src/synomicsbench/metrics/privacy/singling_out.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def eval_singling_out_univariate(
    ori: pd.DataFrame,
    syns: Dict[str, pd.DataFrame],
    n_attacks: int = 10_000,
    max_attempts: int = 1_000_000,
    proportions: Sequence[float] = (0.25, 0.50, 0.75, 1.0),
    seed: int = 42,
) -> Dict[str, List[SinglingOutEvaluator]]:
    """Run univariate singling-out attacks at varying feature proportions.

    For each synthetic dataset and each proportion *p*, a random subset of
    *p × n_features* columns is selected and the ``SinglingOutEvaluator``
    is run in ``univariate`` mode.

    Args:
        ori: Original (real) dataset after post-processing.
        syns: Mapping from synthetic dataset name to its DataFrame.
        n_attacks: Number of attack predicates to generate per evaluation.
        max_attempts: Maximum predicate generation attempts.
        proportions: Fractions of columns to sample (e.g. 25 %, 50 %, …).
        seed: Random seed for column sampling and the evaluator.

    Returns:
        Mapping from synthetic dataset name to a list of evaluated
        ``SinglingOutEvaluator`` instances (one per proportion).
    """
    results: Dict[str, List[SinglingOutEvaluator]] = {}

    for syn_name, syn_df in syns.items():
        random.seed(seed)
        cols = ori.columns.tolist()
        risks: List[SinglingOutEvaluator] = []

        for p in proportions:
            n = int(len(cols) * p)
            sampled_cols = random.sample(cols, n)
            evaluator = SinglingOutEvaluator(
                ori=ori[sampled_cols],
                syn=syn_df[sampled_cols],
                n_attacks=n_attacks,
                max_attempts=max_attempts,
            )
            evaluator.evaluate(mode="univariate")
            risks.append(evaluator)

        results[syn_name] = risks

    return results

linkability

Linkability risk evaluation for synthetic data.

Evaluates whether an attacker can use synthetic data to link molecular (gene expression) profiles to clinical attributes of the same individual.

eval_linkability_genes_clinical(ori, syns, clinical_cols, transcriptomic_cols=None, n_neighbors=1, proportions=(0.25, 0.5, 0.75, 1.0), seed=42)

Evaluate linkability risk between gene expression and clinical features.

The attack attempts to link two disjoint attribute sets of the same patient — clinical variables and a randomly sampled subset of gene expression features — using the synthetic dataset as a bridge.

Parameters:

Name Type Description Default
ori DataFrame

Original dataset.

required
syns Dict[str, DataFrame]

Mapping from synthetic dataset name to its DataFrame.

required
clinical_cols Union[List[int], List[str]]

List of column names or integer indices for clinical features. If indices are provided, they refer to ori column positions.

required
transcriptomic_cols Union[List[int], List[str]]

List of column names or integer indices for transcriptomic (gene) features. If None, uses all remaining columns not in clinical_cols.

None
n_neighbors int

Number of nearest neighbors for the linkability attack.

1
proportions Sequence[float]

Fractions of gene columns to sample for the attack.

(0.25, 0.5, 0.75, 1.0)
seed int

Random seed for gene column sampling.

42

Returns:

Type Description
Dict[str, List[LinkabilityEvaluator]]

Mapping from dataset name to a list of evaluated

Dict[str, List[LinkabilityEvaluator]]

LinkabilityEvaluator instances (one per proportion).

Source code in src/synomicsbench/metrics/privacy/linkability.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def eval_linkability_genes_clinical(
    ori: pd.DataFrame,
    syns: Dict[str, pd.DataFrame],
    clinical_cols: Union[List[int], List[str]],
    transcriptomic_cols: Union[List[int], List[str]] = None,
    n_neighbors: int = 1,
    proportions: Sequence[float] = (0.25, 0.50, 0.75, 1.0),
    seed: int = 42,
) -> Dict[str, List[LinkabilityEvaluator]]:
    """Evaluate linkability risk between gene expression and clinical features.

    The attack attempts to link two disjoint attribute sets of the same
    patient — clinical variables and a randomly sampled subset of gene
    expression features — using the synthetic dataset as a bridge.

    Args:
        ori: Original dataset.
        syns: Mapping from synthetic dataset name to its DataFrame.
        clinical_cols: List of column names or integer indices for clinical
            features. If indices are provided, they refer to ``ori`` column
            positions.
        transcriptomic_cols: List of column names or integer indices for
            transcriptomic (gene) features. If ``None``, uses all remaining
            columns not in ``clinical_cols``.
        n_neighbors: Number of nearest neighbors for the linkability attack.
        proportions: Fractions of gene columns to sample for the attack.
        seed: Random seed for gene column sampling.

    Returns:
        Mapping from dataset name to a list of evaluated
        ``LinkabilityEvaluator`` instances (one per proportion).
    """
    results: Dict[str, List[LinkabilityEvaluator]] = {}

    all_columns = ori.columns.tolist()

    # Resolve clinical columns (convert indices to names if needed)
    if clinical_cols and isinstance(clinical_cols[0], int):
        clinical_col_names = [all_columns[i] for i in clinical_cols]
    else:
        clinical_col_names = list(clinical_cols)

    # Resolve transcriptomic columns
    if transcriptomic_cols is not None:
        if isinstance(transcriptomic_cols[0], int):
            genes_cols = [all_columns[i] for i in transcriptomic_cols]
        else:
            genes_cols = list(transcriptomic_cols)
    else:
        genes_cols = [col for col in all_columns if col not in clinical_col_names]

    for syn_name, syn_df in syns.items():
        random.seed(seed)
        risks: List[LinkabilityEvaluator] = []

        for p in proportions:
            n = int(len(genes_cols) * p)
            sampled_genes = random.sample(genes_cols, n)
            aux_cols = (clinical_col_names, sampled_genes)

            evaluator = LinkabilityEvaluator(
                ori=ori,
                syn=syn_df,
                n_attacks=ori.shape[0],
                aux_cols=aux_cols,
                n_neighbors=n_neighbors,
            )
            evaluator.evaluate(n_jobs=-2)
            risks.append(evaluator)

        results[syn_name] = risks

    return results

inference

Attribute inference risk evaluation for synthetic data. Evaluates whether an adversary can infer sensitive clinical attributes (secrets) from auxiliary gene expression features using the synthetic dataset.

eval_inference_genes_clinical(ori, syns, clinical_cols, transcriptomic_cols=None, save_path=None)

Evaluate attribute inference risk for each clinical variable.

For every synthetic dataset and each clinical feature (treated as a secret), the InferenceEvaluator from Anonymeter uses all gene expression columns as auxiliary information to predict the secret attribute.

Parameters:

Name Type Description Default
ori DataFrame

Original dataset.

required
syns Dict[str, DataFrame]

Mapping from synthetic dataset name to its DataFrame.

required
clinical_cols Union[List[int], List[str]]

List of column names or integer indices for clinical features (secrets). If indices are provided, they refer to ori column positions.

required
transcriptomic_cols Union[List[int], List[str]]

List of column names or integer indices for transcriptomic (gene) features (auxiliary). If None, uses all remaining columns not in clinical_cols.

None
save_path Optional[str]

Optional path to incrementally save results as pickle.

None

Returns:

Type Description
Dict[str, List[Tuple[str, object]]]

Mapping from dataset name to a list of (secret_name, results)

Dict[str, List[Tuple[str, object]]]

tuples, where results is an EvaluationResults object or

Dict[str, List[Tuple[str, object]]]

an error dict if the evaluation failed.

Source code in src/synomicsbench/metrics/privacy/inference.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def eval_inference_genes_clinical(
    ori: pd.DataFrame,
    syns: Dict[str, pd.DataFrame],
    clinical_cols: Union[List[int], List[str]],
    transcriptomic_cols: Union[List[int], List[str]] = None,
    save_path: Optional[str] = None,
) -> Dict[str, List[Tuple[str, object]]]:
    """Evaluate attribute inference risk for each clinical variable.

    For every synthetic dataset and each clinical feature (treated as a
    *secret*), the ``InferenceEvaluator`` from Anonymeter uses all gene
    expression columns as auxiliary information to predict the secret
    attribute.

    Args:
        ori: Original dataset.
        syns: Mapping from synthetic dataset name to its DataFrame.
        clinical_cols: List of column names or integer indices for clinical
            features (secrets). If indices are provided, they refer to
            ``ori`` column positions.
        transcriptomic_cols: List of column names or integer indices for
            transcriptomic (gene) features (auxiliary). If ``None``, uses
            all remaining columns not in ``clinical_cols``.
        save_path: Optional path to incrementally save results as pickle.

    Returns:
        Mapping from dataset name to a list of ``(secret_name, results)``
        tuples, where ``results`` is an ``EvaluationResults`` object or
        an error dict if the evaluation failed.
    """
    all_columns = ori.columns.tolist()

    # Resolve clinical columns (convert indices to names if needed)
    if clinical_cols and isinstance(clinical_cols[0], int):
        clinical_col_names = [all_columns[i] for i in clinical_cols]
    else:
        clinical_col_names = list(clinical_cols)

    # Resolve transcriptomic columns
    if transcriptomic_cols is not None:
        if isinstance(transcriptomic_cols[0], int):
            genes_cols = [all_columns[i] for i in transcriptomic_cols]
        else:
            genes_cols = list(transcriptomic_cols)
    else:
        genes_cols = [col for col in all_columns if col not in clinical_col_names]

    results_all: Dict[str, List[Tuple[str, object]]] = {}

    for tool, syn_data in syns.items():
        results: List[Tuple[str, object]] = []
        logger.info("Starting inference evaluation for %s", tool)

        for secret in clinical_col_names:
            try:
                evaluator = InferenceEvaluator(
                    ori=ori,
                    syn=syn_data,
                    aux_cols=genes_cols,
                    secret=secret,
                    n_attacks=ori.shape[0],
                )
                evaluator.evaluate(n_jobs=-2)
                results.append((secret, evaluator.results()))
                logger.info(
                    "Tool %s: secret %s risk=%s",
                    tool,
                    secret,
                    evaluator.results().risk().value,
                )
            except Exception as ex:
                logger.exception(
                    "Error evaluating tool=%s secret=%s: %s", tool, secret, ex
                )
                results.append((secret, {"error": str(ex)}))

        results_all[tool] = results

    return results_all

Utilities

Utility modules provide monitoring capabilities, evaluation utilities, and correlation analysis tools used throughout the framework.

monitoring

monitor_resources(func)

Decorator to monitor CPU, RAM, and GPU (NVIDIA via nvidia-ml-py) resources during function execution.

Parameters:

Name Type Description Default
func callable

The function to be monitored.

required

Returns:

Name Type Description
callable

The wrapped function with resource monitoring.

Raises:

Type Description
Exception

Re-raises any exception from the wrapped function after reporting.

Source code in src/synomicsbench/utils/monitoring.py
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
def monitor_resources(func):
    """Decorator to monitor CPU, RAM, and GPU (NVIDIA via nvidia-ml-py) resources during function execution.

    Args:
        func (callable): The function to be monitored.

    Returns:
        callable: The wrapped function with resource monitoring.

    Raises:
        Exception: Re-raises any exception from the wrapped function after reporting.
    """
    import time

    def _decode_bytes(s):
        # nvidia-ml-py returns str, so this is a no-op but kept for compatibility
        return s

    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        print("--- System & Process Info ---")
        current_time = datetime.datetime.now(datetime.timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
        print(f"Current Date and Time (UTC): {current_time}")

        # Get user info
        try:
            print(f"Current User's Login: {getpass.getuser()}")
        except Exception:
            print("Could not determine user login")

        cpu_model = platform.processor() or getattr(platform.uname(), "processor", "") or platform.machine()
        print(f"CPU Model: {cpu_model}")
        print(f"Physical Cores: {psutil.cpu_count(logical=False)}")
        print(f"Logical Processors: {psutil.cpu_count(logical=True)}")

        process = psutil.Process()
        MB = 1024 * 1024
        memory_before = process.memory_info().rss / MB
        print(f"Process RAM before execution: {memory_before:.2f} MB")

        print("\n--- Function Execution ---")

        psutil.cpu_percent(percpu=True)
        start_time = time.perf_counter()

        exc = None
        result = None
        try:
            result = func(*args, **kwargs)
            return result
        except Exception as e:
            exc = e
        finally:
            execution_time = time.perf_counter() - start_time

            print("\n--- Resource Usage Summary ---")
            print(f"Execution time: {execution_time:.6f} seconds")

            memory_after = process.memory_info().rss / MB
            mem_delta = memory_after - memory_before
            print(f"Process RAM after execution: {memory_after:.2f} MB")
            print(f"Process RAM used by function: {mem_delta:.2f} MB")

            cpu_per_core_during = psutil.cpu_percent(percpu=True)
            utilized_cores = sum(1 for v in cpu_per_core_during if v > 10.0)
            print(f"Average per-core CPU during execution: {[round(v, 1) for v in cpu_per_core_during]}")
            print(f"CPU Cores utilized (>10%): {utilized_cores} of {psutil.cpu_count(logical=True)}")
            print("\n" + "=" * 40 + "\n")

            if exc is not None:
                raise exc

    return wrapper

set_logger(logger_name, output_path, log_file_name='activity.log')

Configures a logger with both file and console handlers and returns it.

Parameters:

Name Type Description Default
logger_name str

The name for the logger (e.g., name or 'myscript').

required
output_path str

The directory where the log file should be created.

required
log_file_name str

The name of the log file.

'activity.log'

Returns:

Type Description

logging.Logger: The configured logger instance.

Source code in src/synomicsbench/utils/monitoring.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def set_logger(logger_name: str, output_path: str, log_file_name: str = "activity.log"):
    """
    Configures a logger with both file and console handlers and returns it.

    Args:
        logger_name (str): The name for the logger (e.g., __name__ or 'myscript').
        output_path (str): The directory where the log file should be created.
        log_file_name (str): The name of the log file.

    Returns:
        logging.Logger: The configured logger instance.
    """
    logger = logging.getLogger(logger_name)
    logger.setLevel(logging.DEBUG)

    # Prevent re-adding handlers
    if logger.handlers:
        logger.handlers.clear()

    # Define a consistent formatter
    formatter = logging.Formatter(
        "%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
    )

    # File Handler Setup (same as before)
    log_path = os.path.join(output_path, log_file_name)
    try:
        os.makedirs(output_path, exist_ok=True)
        file_handler = logging.FileHandler(log_path)
        file_handler.setLevel(logging.DEBUG)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
    except OSError as e:
        print(f"WARNING: Could not create log file at {log_path}. Error: {e}", file=sys.stderr)

    # Console Handler Setup (same as before)
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.DEBUG)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    logger.debug("Standalone logger initialized successfully.")

    return logger

correlations

MixedCorrelation

Compute mixed-type correlation matrices combining BLAS/vectorized numeric-numeric and Numba-accelerated categorical interactions.

Missing handling
  • Numerical: pairwise deletion (drop rows with NaN in either column) for Pearson/Spearman.
  • Categorical (including discretized numeric for Cramér's V): missing is its own category.
Parameters

categorical_indices : array-like of int Column indices treated as categorical. numerical_indices : array-like of int Column indices treated as numerical. method : {'pearson','spearman'}, default 'pearson' Correlation for numeric-numeric pairs. n_bins : int, default 10 Number of bins for discretizing numerical columns when paired with categorical. engine : {'auto','numba','blas'}, default 'auto' - 'blas': use block GEMM path for numeric-numeric (fast for many features). - 'numba': compute everything in Numba loops. - 'auto': choose 'blas' when many numeric features, else 'numba'. block_cols : int, default 512 Block width for BLAS engine.

Source code in src/synomicsbench/utils/correlations.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
class MixedCorrelation:
    """
    Compute mixed-type correlation matrices combining BLAS/vectorized numeric-numeric and
    Numba-accelerated categorical interactions.

    Missing handling:
      - Numerical: pairwise deletion (drop rows with NaN in either column) for Pearson/Spearman.
      - Categorical (including discretized numeric for Cramér's V): missing is its own category.

    Parameters
    ----------
    categorical_indices : array-like of int
        Column indices treated as categorical.
    numerical_indices : array-like of int
        Column indices treated as numerical.
    method : {'pearson','spearman'}, default 'pearson'
        Correlation for numeric-numeric pairs.
    n_bins : int, default 10
        Number of bins for discretizing numerical columns when paired with categorical.
    engine : {'auto','numba','blas'}, default 'auto'
        - 'blas': use block GEMM path for numeric-numeric (fast for many features).
        - 'numba': compute everything in Numba loops.
        - 'auto': choose 'blas' when many numeric features, else 'numba'.
    block_cols : int, default 512
        Block width for BLAS engine.
    """

    def __init__(
        self,
        categorical_indices,
        numerical_indices,
        method: str = "pearson",
        n_bins: int = 10,
        engine: str = "auto",
        block_cols: int = 512
    ):
        self.categorical_indices = np.array(categorical_indices, dtype=np.int64)
        self.numerical_indices = np.array(numerical_indices, dtype=np.int64)
        if self.categorical_indices.size > 0:
            self.categorical_indices = np.sort(self.categorical_indices)
        if self.numerical_indices.size > 0:
            self.numerical_indices = np.sort(self.numerical_indices)

        self.method = method.lower()
        self.n_bins = int(n_bins)
        self.engine = engine.lower()
        self.block_cols = int(block_cols)

        if self.method not in ("pearson", "spearman"):
            raise ValueError("method must be one of {'pearson', 'spearman'}")
        if self.engine not in ("auto", "numba", "blas"):
            raise ValueError("engine must be one of {'auto','numba','blas'}")

    def _validate_indices(self, n_features: int):
        if np.any(self.categorical_indices < 0) or np.any(self.categorical_indices >= n_features):
            raise ValueError("categorical index out of bounds.")
        if np.any(self.numerical_indices < 0) or np.any(self.numerical_indices >= n_features):
            raise ValueError("numerical index out of bounds.")
        if self.categorical_indices.size and self.numerical_indices.size:
            if np.intersect1d(self.categorical_indices, self.numerical_indices).size > 0:
                raise ValueError("Overlap between categorical and numerical indices.")

    def compute(self, data):
        """
        Compute the mixed-type correlation matrix on the provided data.

        Returns
        -------
        corr : np.ndarray of shape (n_features, n_features)
        """
        # Convert to contiguous NumPy array
        if hasattr(data, "values"):
            X = np.ascontiguousarray(data.values, dtype=np.float64)
        else:
            X = np.ascontiguousarray(data, dtype=np.float64)
        if X.ndim != 2:
            raise ValueError("data must be 2D")

        n_samples, n_features = X.shape
        self._validate_indices(n_features)

        # Choose engine
        engine = self.engine
        if engine == "auto":
            Fn = int(self.numerical_indices.size)
            # Aggressive switch to BLAS for many numerical features
            engine = "blas" if Fn >= 256 else "numba"

        if engine == "numba":
            # Full Numba path (still benefits from bincount-based Cramér's V)
            # Build masks and precompute dense labels once
            is_cat, is_num, num_pos, cat_pos = _build_masks_and_maps(n_features, self.categorical_indices, self.numerical_indices)
            disc_num = _discretize_all_numericals(X, self.numerical_indices, self.n_bins)
            cat_lab = _convert_all_categoricals(X, self.categorical_indices)
            dense_cat_lab, cat_sizes, dense_disc_num, num_sizes = _precompute_dense_mappings(cat_lab, disc_num)

            corr = np.zeros((n_features, n_features), dtype=np.float64)

            # numeric-numeric via pairwise numba loops
            for i in range(n_features):
                corr[i, i] = 1.0

            # numeric-numeric
            for a in range(self.numerical_indices.shape[0]):
                ia = int(self.numerical_indices[a])
                for b in range(a, self.numerical_indices.shape[0]):
                    ib = int(self.numerical_indices[b])
                    if ia == ib:
                        c = 1.0
                    else:
                        if self.method == "pearson":
                            c = _pearson_ignore_nan(X[:, ia], X[:, ib])
                        else:
                            c = _spearman_ignore_nan(X[:, ia], X[:, ib])
                    if c > 1.0:
                        c = 1.0
                    elif c < -1.0:
                        c = -1.0
                    corr[ia, ib] = c
                    corr[ib, ia] = c

            # categorical-involving pairs
            _fill_categorical_pairs_only(
                corr_matrix=corr,
                is_cat=is_cat,
                is_num=is_num,
                cat_pos=cat_pos,
                num_pos=num_pos,
                dense_cat_lab=dense_cat_lab,
                cat_sizes=cat_sizes,
                dense_disc_num=dense_disc_num,
                num_sizes=num_sizes,
            )
            return corr

        # BLAS engine: numeric-numeric via GEMM; categorical via Numba
        corr = np.zeros((n_features, n_features), dtype=np.float64)

        # 1) numeric-numeric submatrix
        corr_nn = _compute_numeric_numeric_blas(X, self.numerical_indices, self.method, self.block_cols)
        if corr_nn is not None:
            idx = self.numerical_indices
            corr[np.ix_(idx, idx)] = corr_nn

        # 2) categorical interactions
        if self.categorical_indices.size > 0:
            disc_num = _discretize_all_numericals(X, self.numerical_indices, self.n_bins)
            cat_lab = _convert_all_categoricals(X, self.categorical_indices)
            dense_cat_lab, cat_sizes, dense_disc_num, num_sizes = _precompute_dense_mappings(cat_lab, disc_num)
            is_cat, is_num, num_pos, cat_pos = _build_masks_and_maps(n_features, self.categorical_indices, self.numerical_indices)
            _fill_categorical_pairs_only(
                corr_matrix=corr,
                is_cat=is_cat,
                is_num=is_num,
                cat_pos=cat_pos,
                num_pos=num_pos,
                dense_cat_lab=dense_cat_lab,
                cat_sizes=cat_sizes,
                dense_disc_num=dense_disc_num,
                num_sizes=num_sizes,
            )

        # 3) Diagonal
        diag = np.arange(n_features)
        corr[diag, diag] = 1.0

        np.clip(corr, -1.0, 1.0, out=corr)
        return corr

    @staticmethod
    def get_squareform(distance_matrix: np.array):
        return _memory_efficient_squareform(distance_matrix)

    def set_method(self, method: str):
        m = method.lower()
        if m not in ("pearson", "spearman"):
            raise ValueError("method must be one of {'pearson', 'spearman'}")
        self.method = m

compute(data)

Compute the mixed-type correlation matrix on the provided data.

Returns

corr : np.ndarray of shape (n_features, n_features)

Source code in src/synomicsbench/utils/correlations.py
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
def compute(self, data):
    """
    Compute the mixed-type correlation matrix on the provided data.

    Returns
    -------
    corr : np.ndarray of shape (n_features, n_features)
    """
    # Convert to contiguous NumPy array
    if hasattr(data, "values"):
        X = np.ascontiguousarray(data.values, dtype=np.float64)
    else:
        X = np.ascontiguousarray(data, dtype=np.float64)
    if X.ndim != 2:
        raise ValueError("data must be 2D")

    n_samples, n_features = X.shape
    self._validate_indices(n_features)

    # Choose engine
    engine = self.engine
    if engine == "auto":
        Fn = int(self.numerical_indices.size)
        # Aggressive switch to BLAS for many numerical features
        engine = "blas" if Fn >= 256 else "numba"

    if engine == "numba":
        # Full Numba path (still benefits from bincount-based Cramér's V)
        # Build masks and precompute dense labels once
        is_cat, is_num, num_pos, cat_pos = _build_masks_and_maps(n_features, self.categorical_indices, self.numerical_indices)
        disc_num = _discretize_all_numericals(X, self.numerical_indices, self.n_bins)
        cat_lab = _convert_all_categoricals(X, self.categorical_indices)
        dense_cat_lab, cat_sizes, dense_disc_num, num_sizes = _precompute_dense_mappings(cat_lab, disc_num)

        corr = np.zeros((n_features, n_features), dtype=np.float64)

        # numeric-numeric via pairwise numba loops
        for i in range(n_features):
            corr[i, i] = 1.0

        # numeric-numeric
        for a in range(self.numerical_indices.shape[0]):
            ia = int(self.numerical_indices[a])
            for b in range(a, self.numerical_indices.shape[0]):
                ib = int(self.numerical_indices[b])
                if ia == ib:
                    c = 1.0
                else:
                    if self.method == "pearson":
                        c = _pearson_ignore_nan(X[:, ia], X[:, ib])
                    else:
                        c = _spearman_ignore_nan(X[:, ia], X[:, ib])
                if c > 1.0:
                    c = 1.0
                elif c < -1.0:
                    c = -1.0
                corr[ia, ib] = c
                corr[ib, ia] = c

        # categorical-involving pairs
        _fill_categorical_pairs_only(
            corr_matrix=corr,
            is_cat=is_cat,
            is_num=is_num,
            cat_pos=cat_pos,
            num_pos=num_pos,
            dense_cat_lab=dense_cat_lab,
            cat_sizes=cat_sizes,
            dense_disc_num=dense_disc_num,
            num_sizes=num_sizes,
        )
        return corr

    # BLAS engine: numeric-numeric via GEMM; categorical via Numba
    corr = np.zeros((n_features, n_features), dtype=np.float64)

    # 1) numeric-numeric submatrix
    corr_nn = _compute_numeric_numeric_blas(X, self.numerical_indices, self.method, self.block_cols)
    if corr_nn is not None:
        idx = self.numerical_indices
        corr[np.ix_(idx, idx)] = corr_nn

    # 2) categorical interactions
    if self.categorical_indices.size > 0:
        disc_num = _discretize_all_numericals(X, self.numerical_indices, self.n_bins)
        cat_lab = _convert_all_categoricals(X, self.categorical_indices)
        dense_cat_lab, cat_sizes, dense_disc_num, num_sizes = _precompute_dense_mappings(cat_lab, disc_num)
        is_cat, is_num, num_pos, cat_pos = _build_masks_and_maps(n_features, self.categorical_indices, self.numerical_indices)
        _fill_categorical_pairs_only(
            corr_matrix=corr,
            is_cat=is_cat,
            is_num=is_num,
            cat_pos=cat_pos,
            num_pos=num_pos,
            dense_cat_lab=dense_cat_lab,
            cat_sizes=cat_sizes,
            dense_disc_num=dense_disc_num,
            num_sizes=num_sizes,
        )

    # 3) Diagonal
    diag = np.arange(n_features)
    corr[diag, diag] = 1.0

    np.clip(corr, -1.0, 1.0, out=corr)
    return corr

cramers_v_bincount_numba(x, y, nx, ny)

Fast Cramér's V using bincount on combined indices. x in [0..nx-1], y in [0..ny-1].

Source code in src/synomicsbench/utils/correlations.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
@njit
def cramers_v_bincount_numba(x: np.ndarray, y: np.ndarray, nx: int, ny: int) -> float:
    """
    Fast Cramér's V using bincount on combined indices. x in [0..nx-1], y in [0..ny-1].
    """
    n = x.shape[0]
    if nx <= 1 or ny <= 1:
        return 0.0

    # combined index: idx = x + nx * y
    combined = np.empty(n, dtype=np.int64)
    for i in range(n):
        combined[i] = int(x[i]) + int(nx) * int(y[i])

    counts = np.bincount(combined, minlength=int(nx) * int(ny)).astype(np.float64)

    # row and column sums
    row_sums = np.zeros(nx, dtype=np.float64)
    col_sums = np.zeros(ny, dtype=np.float64)
    total = 0.0
    # counts are laid out with x varying fastest
    idx = 0
    for j in range(ny):
        col_sum = 0.0
        for i in range(nx):
            c = counts[idx]
            row_sums[i] += c
            col_sum += c
            total += c
            idx += 1
        col_sums[j] = col_sum

    if total <= 0.0:
        return 0.0

    # chi^2
    chi2 = 0.0
    idx = 0
    for j in range(ny):
        for i in range(nx):
            e = (row_sums[i] * col_sums[j]) / total
            if e > 0.0:
                diff = counts[idx] - e
                chi2 += (diff * diff) / e
            idx += 1

    phi2 = chi2 / total
    denom = nx - 1
    if ny - 1 < denom:
        denom = ny - 1
    if denom <= 0:
        return 0.0

    v = np.sqrt(phi2 / denom)
    if v > 1.0:
        v = 1.0
    elif v < 0.0:
        v = 0.0
    return v

pearson_correlation_matrix_numba(X)

Pearson correlation matrix (dropping NaN pairs), Numba-parallel.

Source code in src/synomicsbench/utils/correlations.py
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
@njit(parallel=True)
def pearson_correlation_matrix_numba(X):
    """
    Pearson correlation matrix (dropping NaN pairs), Numba-parallel.
    """
    n_samples, n_features = X.shape
    corr_matrix = np.empty((n_features, n_features), dtype=np.float32)

    for i in prange(n_features):
        for j in range(i, n_features):
            if i == j:
                corr = 1.0
            else:
                corr = _pearson_ignore_nan(X[:, i], X[:, j])
            if corr > 1.0:
                corr = 1.0
            elif corr < -1.0:
                corr = -1.0

            corr_matrix[i, j] = corr
            corr_matrix[j, i] = corr  # symmetric

    return corr_matrix

spearman_correlation_matrix_numba(X)

Spearman correlation matrix (dropping NaN pairs), Numba-parallel.

Source code in src/synomicsbench/utils/correlations.py
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
@njit(parallel=True)
def spearman_correlation_matrix_numba(X):
    """
    Spearman correlation matrix (dropping NaN pairs), Numba-parallel.
    """
    n_samples, n_features = X.shape
    corr_matrix = np.empty((n_features, n_features), dtype=np.float32)

    for i in prange(n_features):
        for j in range(i, n_features):
            if i == j:
                corr = 1.0
            else:
                corr = _spearman_ignore_nan(X[:, i], X[:, j])
            if corr > 1.0:
                corr = 1.0
            elif corr < -1.0:
                corr = -1.0
            corr_matrix[i, j] = corr
            corr_matrix[j, i] = corr  # symmetric

    return corr_matrix

utils

DataProcessValidation

Bases: DataProcessor

Preprocessing pipeline for preparing tabular data for validation, including encoding, scaling, and imputation.

Parameters:

Name Type Description Default
data DataFrame

Input data to preprocess.

required
target_col str

Name of the target column.

required
output_dir str

Directory for outputs. Defaults to ".".

required
ordinal_cat_columns Optional[List[str]]

List of ordinal categorical column names.

required
dummy_cat_columns Optional[List[str]]

List of dummy categorical column names.

required
numerical_columns Optional[List[str]]

List of numerical column names.

required
scaler str

Which scaler to use for numerical columns. Defaults to "minmax".

'minmax'
n_neighbors int

Number of neighbors for KNN imputation. Defaults to 5.

5

Raises:

Type Description
ValueError

If the target column is not found in the input data.

Source code in src/synomicsbench/metrics/fidelity/utils.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
class DataProcessValidation(DataProcessor):
    """
    Preprocessing pipeline for preparing tabular data for validation, including encoding, scaling, and imputation.

    Args:
        data (pd.DataFrame): Input data to preprocess.
        target_col (str): Name of the target column.
        output_dir (str): Directory for outputs. Defaults to ".".
        ordinal_cat_columns (Optional[List[str]]): List of ordinal categorical column names.
        dummy_cat_columns (Optional[List[str]]): List of dummy categorical column names.
        numerical_columns (Optional[List[str]]): List of numerical column names.
        scaler (str): Which scaler to use for numerical columns. Defaults to "minmax".
        n_neighbors (int): Number of neighbors for KNN imputation. Defaults to 5.

    Raises:
        ValueError: If the target column is not found in the input data.
    """

    def __init__(
        self,
        data: pd.DataFrame,
        metadata: dict = None, 
        scaler: str = "minmax",
        n_neighbors: int = 5,
        **kwargs,
    ):
        """
        Initialize the PreprocessingForValidation class.

        Args:
            data (pd.DataFrame): Input data to preprocess.
            metadata (dict) : Metadata
            scaler (str): Which scaler to use for numerical columns. Defaults to "minmax".
            n_neighbors (int): Number of neighbors for KNN imputation. Defaults to 5.

        Raises:
            ValueError: If the target column is not found in the input data.
        """
        self.data = data
        self.ordinal_cat_columns = None 
        self.dummy_cat_columns = None
        self.numerical_columns = None
        self.scaler = scaler
        self.n_neighbors = n_neighbors


    def fit(self) -> pd.DataFrame:
        """
        Preprocess the data by encoding categorical features, normalizing numerical features,
        imputing missing values, and encoding the target variable.

        Returns:
            pd.DataFrame: The preprocessed data with transformed features and target column.

        Raises:
            ValueError: If required columns are missing or preprocessing fails.
            RuntimeError: If an unexpected error occurs during preprocessing.
        """
        try:
            preprocessed_parts = []

            for col_list in [
                self.ordinal_cat_columns,
                self.dummy_cat_columns,
                self.numerical_columns,
            ]:
                if col_list and self.target_col in col_list:
                    col_list.remove(self.target_col)

            y = pd.DataFrame(self.data[self.target_col])
            le = LabelEncoder()
            transformed_y = pd.DataFrame(
                le.fit_transform(y), index=y.index, columns=[self.target_col]
            )

            if self.ordinal_cat_columns:
                ordinal_data = self.data[self.ordinal_cat_columns]
                ordinal_encoded = super().encode_ordinal_cat_features(ordinal_data)
                preprocessed_parts.append(ordinal_encoded)

            if self.dummy_cat_columns:
                dummy_data = self.data[self.dummy_cat_columns]
                dummy_encoded = super().encode_dummy_cat_features(dummy_data)
                preprocessed_parts.append(dummy_encoded)
            # Normalize numerical columns
            if self.numerical_columns:
                numerical_data = self.data[self.numerical_columns]
                numerical_normalized = super().standardization(
                    numerical_data, scaler=self.scaler
                )
                preprocessed_parts.append(numerical_normalized)

            if not preprocessed_parts:
                raise ValueError("No columns specified for processing")

            data_preprocessed = pd.concat(preprocessed_parts, axis=1)

            # Imputation
            data_imputed = super().knn_imputer(
                data_preprocessed,
                dummy_cat_columns=self.dummy_cat_columns,
                ordinal_cat_columns=self.ordinal_cat_columns,
                n_neighbors=self.n_neighbors,
                add_indicators = False
            )
            data_preprocessed = pd.concat([transformed_y, data_imputed], axis=1)
            return data_preprocessed

        except KeyError as ke:
            raise ValueError(
                f"KeyError in class '{self.__class__.__name__}', method 'fit': {ke}"
            )
        except ValueError as ve:
            raise ValueError(
                f"ValueError in class '{self.__class__.__name__}', method 'fit': {ve}"
            )
        except Exception as e:
            raise RuntimeError(
                f"Error in class '{self.__class__.__name__}', method 'fit': {str(e)}"
            )

__init__(data, metadata=None, scaler='minmax', n_neighbors=5, **kwargs)

Initialize the PreprocessingForValidation class.

Parameters:

Name Type Description Default
data DataFrame

Input data to preprocess.

required
metadata (dict)

Metadata

required
scaler str

Which scaler to use for numerical columns. Defaults to "minmax".

'minmax'
n_neighbors int

Number of neighbors for KNN imputation. Defaults to 5.

5

Raises:

Type Description
ValueError

If the target column is not found in the input data.

Source code in src/synomicsbench/metrics/fidelity/utils.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def __init__(
    self,
    data: pd.DataFrame,
    metadata: dict = None, 
    scaler: str = "minmax",
    n_neighbors: int = 5,
    **kwargs,
):
    """
    Initialize the PreprocessingForValidation class.

    Args:
        data (pd.DataFrame): Input data to preprocess.
        metadata (dict) : Metadata
        scaler (str): Which scaler to use for numerical columns. Defaults to "minmax".
        n_neighbors (int): Number of neighbors for KNN imputation. Defaults to 5.

    Raises:
        ValueError: If the target column is not found in the input data.
    """
    self.data = data
    self.ordinal_cat_columns = None 
    self.dummy_cat_columns = None
    self.numerical_columns = None
    self.scaler = scaler
    self.n_neighbors = n_neighbors

fit()

Preprocess the data by encoding categorical features, normalizing numerical features, imputing missing values, and encoding the target variable.

Returns:

Type Description
DataFrame

pd.DataFrame: The preprocessed data with transformed features and target column.

Raises:

Type Description
ValueError

If required columns are missing or preprocessing fails.

RuntimeError

If an unexpected error occurs during preprocessing.

Source code in src/synomicsbench/metrics/fidelity/utils.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def fit(self) -> pd.DataFrame:
    """
    Preprocess the data by encoding categorical features, normalizing numerical features,
    imputing missing values, and encoding the target variable.

    Returns:
        pd.DataFrame: The preprocessed data with transformed features and target column.

    Raises:
        ValueError: If required columns are missing or preprocessing fails.
        RuntimeError: If an unexpected error occurs during preprocessing.
    """
    try:
        preprocessed_parts = []

        for col_list in [
            self.ordinal_cat_columns,
            self.dummy_cat_columns,
            self.numerical_columns,
        ]:
            if col_list and self.target_col in col_list:
                col_list.remove(self.target_col)

        y = pd.DataFrame(self.data[self.target_col])
        le = LabelEncoder()
        transformed_y = pd.DataFrame(
            le.fit_transform(y), index=y.index, columns=[self.target_col]
        )

        if self.ordinal_cat_columns:
            ordinal_data = self.data[self.ordinal_cat_columns]
            ordinal_encoded = super().encode_ordinal_cat_features(ordinal_data)
            preprocessed_parts.append(ordinal_encoded)

        if self.dummy_cat_columns:
            dummy_data = self.data[self.dummy_cat_columns]
            dummy_encoded = super().encode_dummy_cat_features(dummy_data)
            preprocessed_parts.append(dummy_encoded)
        # Normalize numerical columns
        if self.numerical_columns:
            numerical_data = self.data[self.numerical_columns]
            numerical_normalized = super().standardization(
                numerical_data, scaler=self.scaler
            )
            preprocessed_parts.append(numerical_normalized)

        if not preprocessed_parts:
            raise ValueError("No columns specified for processing")

        data_preprocessed = pd.concat(preprocessed_parts, axis=1)

        # Imputation
        data_imputed = super().knn_imputer(
            data_preprocessed,
            dummy_cat_columns=self.dummy_cat_columns,
            ordinal_cat_columns=self.ordinal_cat_columns,
            n_neighbors=self.n_neighbors,
            add_indicators = False
        )
        data_preprocessed = pd.concat([transformed_y, data_imputed], axis=1)
        return data_preprocessed

    except KeyError as ke:
        raise ValueError(
            f"KeyError in class '{self.__class__.__name__}', method 'fit': {ke}"
        )
    except ValueError as ve:
        raise ValueError(
            f"ValueError in class '{self.__class__.__name__}', method 'fit': {ve}"
        )
    except Exception as e:
        raise RuntimeError(
            f"Error in class '{self.__class__.__name__}', method 'fit': {str(e)}"
        )

check_column_consistency(origin_data, synthetic_data)

Check if columns and their data types match between the original and synthetic DataFrames.

Parameters:

Name Type Description Default
origin_data DataFrame

The original DataFrame.

required
synthetic_data DataFrame

The synthetic DataFrame to compare.

required

Returns:

Name Type Description
bool bool

True if both column names and data types match, False otherwise.

Source code in src/synomicsbench/metrics/fidelity/utils.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def check_column_consistency(
    origin_data: pd.DataFrame, synthetic_data: pd.DataFrame
) -> bool:
    """
    Check if columns and their data types match between the original and synthetic DataFrames.

    Args:
        origin_data (pd.DataFrame): The original DataFrame.
        synthetic_data (pd.DataFrame): The synthetic DataFrame to compare.

    Returns:
        bool: True if both column names and data types match, False otherwise.

    Raises:
        None
    """
    # Check columns
    check_columns = False
    orig_cols = set(origin_data.columns)
    syn_cols = set(synthetic_data.columns)
    if orig_cols != syn_cols:
        check_columns = False
    else:
        # print("Column names match.")
        check_columns = True

    # Check column data types (for matching columns)
    check_type = False
    common_cols = orig_cols & syn_cols
    mismatches = []
    for col in common_cols:
        if origin_data[col].dtype != synthetic_data[col].dtype:
            mismatches.append((col, origin_data[col].dtype, synthetic_data[col].dtype))
    if mismatches:
        check_type = False
    else:
        check_type = True
    return all([check_columns, check_type])

np_encoder(obj)

Convert NumPy data types to native Python types for JSON serialization.

Parameters:

Name Type Description Default
obj

The object to be encoded, potentially a NumPy scalar or array.

required

Returns:

Type Description

int, float, bool, list: The object converted to a native Python type suitable for JSON serialization.

Raises:

Type Description
TypeError

If the object type is not supported for conversion.

Source code in src/synomicsbench/metrics/fidelity/utils.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def np_encoder(obj):
    """
    Convert NumPy data types to native Python types for JSON serialization.

    Args:
        obj: The object to be encoded, potentially a NumPy scalar or array.

    Returns:
        int, float, bool, list: The object converted to a native Python type suitable for JSON serialization.

    Raises:
        TypeError: If the object type is not supported for conversion.
    """

    if isinstance(obj, (np.integer,)):
        return int(obj)
    if isinstance(obj, (np.floating,)):
        return float(obj)
    if isinstance(obj, (np.bool_, np.bool8)):
        return bool(obj)
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    raise TypeError(f"Type {type(obj)} not serializable")