src.train¶
Training module. Runs Optuna hyperparameter optimization across 6 model families, saves the best model, generates metrics and charts, and optionally logs Optuna trials to Weights & Biases and MLflow.
Usage¶
python src/train.py
Constants¶
MODEL_NAMES¶
MODEL_NAMES: list[str] = [
"random_forest",
"extra_trees",
"gradient_boosting",
"hist_gradient_boosting",
"xgboost",
"catboost",
]
The list of model families that Optuna searches across.
Functions¶
prepare_catboost_frame¶
def prepare_catboost_frame(
frame: pd.DataFrame,
categorical_features: list[str]
) -> pd.DataFrame
Prepares a DataFrame for CatBoost by filling NaN values in categorical columns with "missing" and casting them to string type.
Parameters:
| Parameter | Type | Description |
|---|---|---|
frame |
pd.DataFrame |
Input DataFrame |
categorical_features |
list[str] |
Names of categorical columns to prepare |
Returns: A copy of the DataFrame with categorical columns cleaned.
sample_model_params¶
def sample_model_params(trial: optuna.Trial) -> dict
Samples hyperparameters for a given Optuna trial. First selects a model family via trial.suggest_categorical, then samples family-specific hyperparameters.
Parameters:
| Parameter | Type | Description |
|---|---|---|
trial |
optuna.Trial |
The current Optuna trial object |
Returns:
| Type | Description |
|---|---|
dict |
Dictionary containing "model_name" and all sampled hyperparameters |
Hyperparameter Ranges by Model:
| Model | Key Parameters |
|---|---|
| Random Forest | n_estimators (200–700), max_depth (3–16), min_samples_split, min_samples_leaf, max_features |
| Extra Trees | Same ranges as Random Forest |
| Gradient Boosting | n_estimators (100–500), learning_rate (0.01–0.2), max_depth (2–6), subsample (0.6–1.0) |
| Hist Gradient Boosting | max_iter (150–600), learning_rate (0.01–0.2), max_leaf_nodes (15–63), l2_regularization |
| XGBoost | n_estimators (150–700), learning_rate (0.01–0.2), max_depth (3–10), reg_alpha, reg_lambda |
| CatBoost | iterations (200–800), learning_rate (0.01–0.2), depth (4–10), l2_leaf_reg, bagging_temperature |
build_model_from_params¶
def build_model_from_params(
params: dict,
onehot_preprocessor: ColumnTransformer,
hist_preprocessor: ColumnTransformer,
scale_pos_weight: float,
hist_categorical_feature_idx: list[int],
categorical_features: list[str],
X_train: pd.DataFrame,
X_train_catboost: pd.DataFrame,
) -> tuple
Constructs a full estimator (with preprocessing pipeline) from sampled parameters.
Parameters:
| Parameter | Type | Description |
|---|---|---|
params |
dict |
Hyperparameters from sample_model_params() |
onehot_preprocessor |
ColumnTransformer |
Fitted OneHotEncoder-based preprocessor |
hist_preprocessor |
ColumnTransformer |
Fitted OrdinalEncoder-based preprocessor |
scale_pos_weight |
float |
Class imbalance ratio for XGBoost |
hist_categorical_feature_idx |
list[int] |
Categorical column indices for HistGBT |
categorical_features |
list[str] |
Categorical column names for CatBoost |
X_train |
pd.DataFrame |
Training features (standard) |
X_train_catboost |
pd.DataFrame |
Training features (CatBoost-prepared) |
Returns:
| Type | Description |
|---|---|
tuple[estimator, DataFrame, dict] |
(model_or_pipeline, X_to_use, fit_kwargs) |
Model-to-Preprocessor Mapping:
| Model | Preprocessor | Wrapping |
|---|---|---|
| Random Forest | onehot_preprocessor |
sklearn.Pipeline |
| Extra Trees | onehot_preprocessor |
sklearn.Pipeline |
| Gradient Boosting | onehot_preprocessor |
sklearn.Pipeline |
| Hist Gradient Boosting | hist_preprocessor |
sklearn.Pipeline |
| XGBoost | onehot_preprocessor |
sklearn.Pipeline |
| CatBoost | None (native) | Standalone CatBoostClassifier |
_init_wandb_if_available¶
def _init_wandb_if_available() -> wandb | None
Checks if WANDB_API_KEY is set in the environment. If so, calls wandb.login() and returns the wandb module. If not, logs a warning and returns None.
Returns:
| Type | Description |
|---|---|
wandb module or None |
The wandb module if available, else None |
_init_mlflow_if_available¶
def _init_mlflow_if_available(study_name: str) -> MLflowCallback | None
Checks if MLFLOW_TRACKING_URI is set in the environment. If so, configures MLflow tracking and returns an Optuna MLflowCallback. If not, logs a warning and returns None.
Returns:
| Type | Description |
|---|---|
MLflowCallback or None |
The Optuna MLflow callback if enabled, else None |
run_training¶
def run_training() -> None
Main training entry point. Orchestrates the full Optuna optimization workflow.
Steps:
- Load processed data (
X_train,y_train) fromdata/processed/ - Load preprocessing pipelines from pickle
- Initialize wandb if API key is available
- Initialize MLflow if
MLFLOW_TRACKING_URIis available - Define the Optuna objective function (5-fold StratifiedKFold, ROC-AUC metric)
- Run Optuna study with
n_trialsfrom config - Retrain the best model on the full training set
- Save model bundle to pickle
- Save metrics to
reports/train_metrics.json - Generate top-10 trials bar chart
Config Keys Used:
| Key | Description |
|---|---|
training.n_trials |
Number of Optuna trials |
training.cv_folds |
Number of cross-validation folds |
training.random_state |
Random seed |
training.model_path |
Where to save the best model |
training.study_name |
Optuna study name |
Output Files:
| File | Content |
|---|---|
models/best_model.pkl |
Serialized best model bundle |
reports/train_metrics.json |
Training metrics (accuracy, ROC-AUC, best params) |
reports/figures/optuna_top10_accuracy.png |
Top 10 trials bar chart |
Model Bundle Contents:
{
"model": estimator, # Trained model (Pipeline or CatBoostClassifier)
"model_name": str, # e.g. "xgboost"
"best_roc_auc_cv": float, # Best cross-validation ROC-AUC
"best_params": dict, # Best hyperparameters
}