import copy
import io
import json
import shutil
import warnings
import zipfile
from abc import abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import numpy as np
import scipy
import torch
from pytorch_tabnet import tab_network
from pytorch_tabnet.callbacks import (
CallbackContainer,
EarlyStopping,
History,
LRSchedulerCallback,
)
from pytorch_tabnet.metrics import MetricContainer, check_metrics
from pytorch_tabnet.utils import (
ComplexEncoder,
PredictDataset,
SparsePredictDataset,
check_embedding_parameters,
check_input,
check_warm_start,
create_dataloaders,
create_explain_matrix,
create_group_matrix,
define_device,
validate_eval_set,
)
from scipy.sparse import csc_matrix
from sklearn.base import BaseEstimator
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
[docs]
@dataclass
class TabModel(BaseEstimator):
"""Class for TabNet model."""
n_d: int = 8
n_a: int = 8
n_steps: int = 3
gamma: float = 1.3
cat_idxs: list[int] = field(default_factory=list)
cat_dims: list[int] = field(default_factory=list)
cat_emb_dim: int = 1
n_independent: int = 2
n_shared: int = 2
epsilon: float = 1e-15
momentum: float = 0.02
lambda_sparse: float = 1e-3
seed: int = 0
clip_value: int = 1
verbose: int = 1
optimizer_fn: Any = torch.optim.Adam
optimizer_params: dict = field(default_factory=lambda: dict(lr=2e-2))
scheduler_fn: Any = None
scheduler_params: dict = field(default_factory=dict)
mask_type: str = "sparsemax"
input_dim: int = None
output_dim: int = None
device_name: str = "auto"
n_shared_decoder: int = 1
n_indep_decoder: int = 1
grouped_features: list[list[int]] = field(default_factory=list)
def __post_init__(self):
# These are default values needed for saving model
self.batch_size = 1024
self.virtual_batch_size = 128
torch.manual_seed(self.seed)
# Defining device
self.device = torch.device(define_device(self.device_name))
if self.verbose != 0:
warnings.warn(f"Device used : {self.device}")
# create deep copies of mutable parameters
self.optimizer_fn = copy.deepcopy(self.optimizer_fn)
self.scheduler_fn = copy.deepcopy(self.scheduler_fn)
updated_params = check_embedding_parameters(self.cat_dims, self.cat_idxs, self.cat_emb_dim)
self.cat_dims, self.cat_idxs, self.cat_emb_dim = updated_params
def __update__(self, **kwargs):
"""
Updates parameters.
If does not already exists, creates it.
Otherwise overwrite with warnings.
"""
update_list = [
"cat_dims",
"cat_emb_dim",
"cat_idxs",
"input_dim",
"mask_type",
"n_a",
"n_d",
"n_independent",
"n_shared",
"n_steps",
"grouped_features",
]
for var_name, value in kwargs.items():
if var_name in update_list:
try:
exec(f"global previous_val; previous_val = self.{var_name}")
if previous_val != value: # noqa
wrn_msg = f"Pretraining: {var_name} changed from {previous_val} to {value}" # noqa
warnings.warn(wrn_msg)
exec(f"self.{var_name} = value")
except AttributeError:
exec(f"self.{var_name} = value")
[docs]
def fit(
self,
X_train,
y_train,
eval_set=None,
eval_name=None,
eval_metric=None,
loss_fn=None,
weights=0,
max_epochs=100,
patience=10,
batch_size=1024,
virtual_batch_size=128,
num_workers=0,
drop_last=True,
callbacks=None,
pin_memory=True,
from_unsupervised=None,
warm_start=False,
augmentations=None,
compute_importance=True,
):
"""Train a neural network stored in self.network
Using train_dataloader for training data and
valid_dataloader for validation.
Parameters
----------
X_train : np.ndarray
Train set
y_train : np.array
Train targets
eval_set : list of tuple
List of eval tuple set (X, y).
The last one is used for early stopping
eval_name : list of str
List of eval set names.
eval_metric : list of str
List of evaluation metrics.
The last metric is used for early stopping.
loss_fn : callable or None
a PyTorch loss function
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
dict for custom weights per class
max_epochs : int
Maximum number of epochs during training
patience : int
Number of consecutive non improving epoch before early stopping
batch_size : int
Training batch size
virtual_batch_size : int
Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size)
num_workers : int
Number of workers used in torch.utils.data.DataLoader
drop_last : bool
Whether to drop last batch during training
callbacks : list of callback function
List of custom callbacks
pin_memory: bool
Whether to set pin_memory to True or False during training
from_unsupervised: unsupervised trained model
Use a previously self supervised model as starting weights
warm_start: bool
If True, current model parameters are used to start training
compute_importance : bool
Whether to compute feature importance
"""
# update model name
self.max_epochs = max_epochs
self.patience = patience
self.batch_size = batch_size
self.virtual_batch_size = virtual_batch_size
self.num_workers = num_workers
self.drop_last = drop_last
self.input_dim = X_train.shape[1]
self._stop_training = False
self.pin_memory = pin_memory and (self.device.type != "cpu")
self.augmentations = augmentations
self.compute_importance = compute_importance
if self.augmentations is not None:
# This ensure reproducibility
self.augmentations._set_seed()
eval_set = eval_set if eval_set else []
if loss_fn is None:
self.loss_fn = self._default_loss
else:
self.loss_fn = loss_fn
check_input(X_train)
check_warm_start(warm_start, from_unsupervised)
self.update_fit_params(
X_train,
y_train,
eval_set,
weights,
)
# Validate and reformat eval set depending on training data
eval_names, eval_set = validate_eval_set(eval_set, eval_name, X_train, y_train)
train_dataloader, valid_dataloaders = self._construct_loaders(X_train, y_train, eval_set)
if from_unsupervised is not None:
# Update parameters to match self pretraining
self.__update__(**from_unsupervised.get_params())
if not hasattr(self, "network") or not warm_start:
# model has never been fitted before of warm_start is False
self._set_network()
self._update_network_params()
self._set_metrics(eval_metric, eval_names)
self._set_optimizer()
self._set_callbacks(callbacks)
if from_unsupervised is not None:
self.load_weights_from_unsupervised(from_unsupervised)
warnings.warn("Loading weights from unsupervised pretraining")
# Call method on_train_begin for all callbacks
self._callback_container.on_train_begin()
# Training loop over epochs
for epoch_idx in range(self.max_epochs):
# Call method on_epoch_begin for all callbacks
self._callback_container.on_epoch_begin(epoch_idx)
self._train_epoch(train_dataloader)
# Apply predict epoch to all eval sets
for eval_name, valid_dataloader in zip(eval_names, valid_dataloaders):
self._predict_epoch(eval_name, valid_dataloader)
# Call method on_epoch_end for all callbacks
self._callback_container.on_epoch_end(epoch_idx, logs=self.history.epoch_metrics)
if self._stop_training:
break
# Call method on_train_end for all callbacks
self._callback_container.on_train_end()
self.network.eval()
if self.compute_importance:
# compute feature importance once the best model is defined
self.feature_importances_ = self._compute_feature_importances(X_train)
[docs]
def predict(self, X):
"""
Make predictions on a batch (valid)
Parameters
----------
X : a :tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
Returns
-------
predictions : np.array
Predictions of the regression problem
"""
self.network.eval()
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
results = []
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()
output, M_loss = self.network(data)
predictions = output.cpu().detach().numpy()
results.append(predictions)
res = np.vstack(results)
return self.predict_func(res)
[docs]
def explain(self, X, normalize=False):
"""
Return local explanation
Parameters
----------
X : tensor: `torch.Tensor` or matrix: `scipy.sparse.csr_matrix`
Input data
normalize : bool (default False)
Wheter to normalize so that sum of features are equal to 1
Returns
-------
M_explain : matrix
Importance per sample, per columns.
masks : matrix
Sparse matrix showing attention masks used by network.
"""
self.network.eval()
if scipy.sparse.issparse(X):
dataloader = DataLoader(
SparsePredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
else:
dataloader = DataLoader(
PredictDataset(X),
batch_size=self.batch_size,
shuffle=False,
)
res_explain = []
for batch_nb, data in enumerate(dataloader):
data = data.to(self.device).float()
M_explain, masks = self.network.forward_masks(data)
for key, value in masks.items():
masks[key] = csc_matrix.dot(value.cpu().detach().numpy(), self.reducing_matrix)
original_feat_explain = csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix)
res_explain.append(original_feat_explain)
if batch_nb == 0:
res_masks = masks
else:
for key, value in masks.items():
res_masks[key] = np.vstack([res_masks[key], value])
res_explain = np.vstack(res_explain)
if normalize:
res_explain /= np.sum(res_explain, axis=1)[:, None]
return res_explain, res_masks
[docs]
def load_weights_from_unsupervised(self, unsupervised_model):
update_state_dict = copy.deepcopy(self.network.state_dict())
for param, weights in unsupervised_model.network.state_dict().items():
if param.startswith("encoder"):
# Convert encoder's layers name to match
new_param = "tabnet." + param
else:
new_param = param
if self.network.state_dict().get(new_param) is not None:
# update only common layers
update_state_dict[new_param] = weights
self.network.load_state_dict(update_state_dict)
[docs]
def load_class_attrs(self, class_attrs):
for attr_name, attr_value in class_attrs.items():
setattr(self, attr_name, attr_value)
[docs]
def save_model(self, path):
"""Saving TabNet model in two distinct files.
Parameters
----------
path : str
Path of the model.
Returns
-------
str
input filepath with ".zip" appended
"""
saved_params = {}
init_params = {}
for key, val in self.get_params().items():
if isinstance(val, type):
# Don't save torch specific params
continue
else:
init_params[key] = val
saved_params["init_params"] = init_params
class_attrs = {"preds_mapper": self.preds_mapper}
saved_params["class_attrs"] = class_attrs
# Create folder
Path(path).mkdir(parents=True, exist_ok=True)
# Save models params
with open(Path(path).joinpath("model_params.json"), "w", encoding="utf8") as f:
json.dump(saved_params, f, cls=ComplexEncoder)
# Save state_dict
torch.save(self.network.state_dict(), Path(path).joinpath("network.pt"))
shutil.make_archive(path, "zip", path)
shutil.rmtree(path)
print(f"Successfully saved model at {path}.zip")
return f"{path}.zip"
[docs]
def load_model(self, filepath):
"""Load TabNet model.
Parameters
----------
filepath : str
Path of the model.
"""
try:
with zipfile.ZipFile(filepath) as z:
with z.open("model_params.json") as f:
loaded_params = json.load(f)
loaded_params["init_params"]["device_name"] = self.device_name
with z.open("network.pt") as f:
try:
saved_state_dict = torch.load(f, map_location=self.device)
except io.UnsupportedOperation:
# In Python <3.7, the returned file object is not seekable (which at least
# some versions of PyTorch require) - so we'll try buffering it in to a
# BytesIO instead:
saved_state_dict = torch.load(
io.BytesIO(f.read()),
map_location=self.device,
)
except KeyError:
raise KeyError("Your zip file is missing at least one component")
self.__init__(**loaded_params["init_params"])
self._set_network()
self.network.load_state_dict(saved_state_dict)
self.network.eval()
self.load_class_attrs(loaded_params["class_attrs"])
return
def _train_epoch(self, train_loader):
"""
Trains one epoch of the network in self.network
Parameters
----------
train_loader : a :class: `torch.utils.data.Dataloader`
DataLoader with train set
"""
self.network.train()
for batch_idx, (X, y) in enumerate(train_loader):
self._callback_container.on_batch_begin(batch_idx)
batch_logs = self._train_batch(X, y)
self._callback_container.on_batch_end(batch_idx, batch_logs)
epoch_logs = {"lr": self._optimizer.param_groups[-1]["lr"]}
self.history.epoch_metrics.update(epoch_logs)
return
def _train_batch(self, X, y):
"""
Trains one batch of data
Parameters
----------
X : torch.Tensor
Train matrix
y : torch.Tensor
Target matrix
Returns
-------
batch_outs : dict
Dictionnary with "y": target and "score": prediction scores.
batch_logs : dict
Dictionnary with "batch_size" and "loss".
"""
batch_logs = {"batch_size": X.shape[0]}
X = X.to(self.device).float()
y = y.to(self.device).float()
if self.augmentations is not None:
X, y = self.augmentations(X, y)
for param in self.network.parameters():
param.grad = None
output, M_loss = self.network(X)
loss = self.compute_loss(output, y)
# Add the overall sparsity loss
loss = loss - self.lambda_sparse * M_loss
# Perform backward pass and optimization
loss.backward()
if self.clip_value:
clip_grad_norm_(self.network.parameters(), self.clip_value)
self._optimizer.step()
batch_logs["loss"] = loss.cpu().detach().numpy().item()
return batch_logs
def _predict_epoch(self, name, loader):
"""
Predict an epoch and update metrics.
Parameters
----------
name : str
Name of the validation set
loader : torch.utils.data.Dataloader
DataLoader with validation set
"""
# Setting network on evaluation mode
self.network.eval()
list_y_true = []
list_y_score = []
# Main loop
for batch_idx, (X, y) in enumerate(loader):
scores = self._predict_batch(X)
list_y_true.append(y)
list_y_score.append(scores)
y_true, scores = self.stack_batches(list_y_true, list_y_score)
metrics_logs = self._metric_container_dict[name](y_true, scores)
self.network.train()
self.history.epoch_metrics.update(metrics_logs)
return
def _predict_batch(self, X):
"""
Predict one batch of data.
Parameters
----------
X : torch.Tensor
Owned products
Returns
-------
np.array
model scores
"""
X = X.to(self.device).float()
# compute model output
scores, _ = self.network(X)
if isinstance(scores, list):
scores = [x.cpu().detach().numpy() for x in scores]
else:
scores = scores.cpu().detach().numpy()
return scores
def _set_network(self):
"""Setup the network and explain matrix."""
torch.manual_seed(self.seed)
self.group_matrix = create_group_matrix(self.grouped_features, self.input_dim)
self.network = tab_network.TabNet(
self.input_dim,
self.output_dim,
n_d=self.n_d,
n_a=self.n_a,
n_steps=self.n_steps,
gamma=self.gamma,
cat_idxs=self.cat_idxs,
cat_dims=self.cat_dims,
cat_emb_dim=self.cat_emb_dim,
n_independent=self.n_independent,
n_shared=self.n_shared,
epsilon=self.epsilon,
virtual_batch_size=self.virtual_batch_size,
momentum=self.momentum,
mask_type=self.mask_type,
group_attention_matrix=self.group_matrix.to(self.device),
).to(self.device)
self.reducing_matrix = create_explain_matrix(
self.network.input_dim,
self.network.cat_emb_dim,
self.network.cat_idxs,
self.network.post_embed_dim,
)
def _set_metrics(self, metrics, eval_names):
"""Set attributes relative to the metrics.
Parameters
----------
metrics : list of str
List of eval metric names.
eval_names : list of str
List of eval set names.
"""
metrics = metrics or [self._default_metric]
metrics = check_metrics(metrics)
# Set metric container for each sets
self._metric_container_dict = {}
for name in eval_names:
self._metric_container_dict.update({name: MetricContainer(metrics, prefix=f"{name}_")})
self._metrics = []
self._metrics_names = []
for _, metric_container in self._metric_container_dict.items():
self._metrics.extend(metric_container.metrics)
self._metrics_names.extend(metric_container.names)
# Early stopping metric is the last eval metric
self.early_stopping_metric = self._metrics_names[-1] if len(self._metrics_names) > 0 else None
def _set_callbacks(self, custom_callbacks):
"""Setup the callbacks functions.
Parameters
----------
custom_callbacks : list of func
List of callback functions.
"""
# Setup default callbacks history, early stopping and scheduler
callbacks = []
self.history = History(self, verbose=self.verbose)
callbacks.append(self.history)
if (self.early_stopping_metric is not None) and (self.patience > 0):
early_stopping = EarlyStopping(
early_stopping_metric=self.early_stopping_metric,
is_maximize=(self._metrics[-1]._maximize if len(self._metrics) > 0 else None),
patience=self.patience,
)
callbacks.append(early_stopping)
else:
wrn_msg = "No early stopping will be performed, last training weights will be used."
warnings.warn(wrn_msg)
if self.scheduler_fn is not None:
# Add LR Scheduler call_back
is_batch_level = self.scheduler_params.pop("is_batch_level", False)
scheduler = LRSchedulerCallback(
scheduler_fn=self.scheduler_fn,
scheduler_params=self.scheduler_params,
optimizer=self._optimizer,
early_stopping_metric=self.early_stopping_metric,
is_batch_level=is_batch_level,
)
callbacks.append(scheduler)
if custom_callbacks:
callbacks.extend(custom_callbacks)
self._callback_container = CallbackContainer(callbacks)
self._callback_container.set_trainer(self)
def _set_optimizer(self):
"""Setup optimizer."""
self._optimizer = self.optimizer_fn(self.network.parameters(), **self.optimizer_params)
def _construct_loaders(self, X_train, y_train, eval_set):
"""Generate dataloaders for train and eval set.
Parameters
----------
X_train : np.array
Train set.
y_train : np.array
Train targets.
eval_set : list of tuple
List of eval tuple set (X, y).
Returns
-------
train_dataloader : `torch.utils.data.Dataloader`
Training dataloader.
valid_dataloaders : list of `torch.utils.data.Dataloader`
List of validation dataloaders.
"""
# all weights are not allowed for this type of model
y_train_mapped = self.prepare_target(y_train)
for i, (X, y) in enumerate(eval_set):
y_mapped = self.prepare_target(y)
eval_set[i] = (X, y_mapped)
train_dataloader, valid_dataloaders = create_dataloaders(
X_train,
y_train_mapped,
eval_set,
self.updated_weights,
self.batch_size,
self.num_workers,
self.drop_last,
self.pin_memory,
)
return train_dataloader, valid_dataloaders
def _compute_feature_importances(self, X):
"""Compute global feature importance.
Parameters
----------
loader : `torch.utils.data.Dataloader`
Pytorch dataloader.
"""
M_explain, _ = self.explain(X, normalize=False)
sum_explain = M_explain.sum(axis=0)
feature_importances_ = sum_explain / np.sum(sum_explain)
return feature_importances_
def _update_network_params(self):
self.network.virtual_batch_size = self.virtual_batch_size
[docs]
@abstractmethod
def update_fit_params(self, X_train, y_train, eval_set, weights):
"""
Set attributes relative to fit function.
Parameters
----------
X_train : np.ndarray
Train set
y_train : np.array
Train targets
eval_set : list of tuple
List of eval tuple set (X, y).
weights : bool or dictionnary
0 for no balancing
1 for automated balancing
"""
raise NotImplementedError("users must define update_fit_params to use this base class")
[docs]
@abstractmethod
def compute_loss(self, y_score, y_true):
"""
Compute the loss.
Parameters
----------
y_score : a :tensor: `torch.Tensor`
Score matrix
y_true : a :tensor: `torch.Tensor`
Target matrix
Returns
-------
float
Loss value
"""
raise NotImplementedError("users must define compute_loss to use this base class")
[docs]
@abstractmethod
def prepare_target(self, y):
"""
Prepare target before training.
Parameters
----------
y : a :tensor: `torch.Tensor`
Target matrix.
Returns
-------
`torch.Tensor`
Converted target matrix.
"""
raise NotImplementedError("users must define prepare_target to use this base class")