alpbench.util.pytorch_tabnet.pretrainingΒΆ
Classes
|
- class alpbench.util.pytorch_tabnet.pretraining.TabNetPretrainer(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=<factory>, cat_dims=<factory>, cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=0, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params=<factory>, scheduler_fn=None, scheduler_params=<factory>, mask_type='sparsemax', input_dim=None, output_dim=None, device_name='auto', n_shared_decoder=1, n_indep_decoder=1, grouped_features=<factory>)[source]ΒΆ
Bases:
TabModel- fit(X_train, eval_set=None, eval_name=None, loss_fn=None, pretraining_ratio=0.5, weights=0, max_epochs=100, patience=10, batch_size=1024, virtual_batch_size=128, num_workers=0, drop_last=True, callbacks=None, pin_memory=True, warm_start=False)[source]ΒΆ
Train a neural network stored in self.network Using train_dataloader for training data and valid_dataloader for validation.
- Parameters:
X_train (np.ndarray) β Train set to reconstruct in self supervision
eval_set (list of np.array) β List of evaluation set The last one is used for early stopping
eval_metric (list of str) β List of evaluation metrics. The last metric is used for early stopping.
loss_fn (callable or None) β a PyTorch loss function should be left to None for self supervised and non experts
pretraining_ratio (float) β Between 0 and 1, percentage of feature to mask for reconstruction
weights (np.array) β Sampling weights for each example.
max_epochs (int) β Maximum number of epochs during training
patience (int) β Number of consecutive non improving epoch before early stopping
batch_size (int) β Training batch size
virtual_batch_size (int) β Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size)
num_workers (int) β Number of workers used in torch.utils.data.DataLoader
drop_last (bool) β Whether to drop last batch during training
callbacks (list of callback function) β List of custom callbacks
pin_memory (bool) β Whether to set pin_memory to True or False during training