alpbench.util.pytorch_tabnet.pretrainingΒΆ

Classes

TabNetPretrainer([n_d, n_a, n_steps, gamma, ...])

class alpbench.util.pytorch_tabnet.pretraining.TabNetPretrainer(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=<factory>, cat_dims=<factory>, cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=0, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params=<factory>, scheduler_fn=None, scheduler_params=<factory>, mask_type='sparsemax', input_dim=None, output_dim=None, device_name='auto', n_shared_decoder=1, n_indep_decoder=1, grouped_features=<factory>)[source]ΒΆ

Bases: TabModel

compute_loss(output, embedded_x, obf_vars)[source]ΒΆ
fit(X_train, eval_set=None, eval_name=None, loss_fn=None, pretraining_ratio=0.5, weights=0, max_epochs=100, patience=10, batch_size=1024, virtual_batch_size=128, num_workers=0, drop_last=True, callbacks=None, pin_memory=True, warm_start=False)[source]ΒΆ

Train a neural network stored in self.network Using train_dataloader for training data and valid_dataloader for validation.

Parameters:
  • X_train (np.ndarray) – Train set to reconstruct in self supervision

  • eval_set (list of np.array) – List of evaluation set The last one is used for early stopping

  • eval_name (list of str) – List of eval set names.

  • eval_metric (list of str) – List of evaluation metrics. The last metric is used for early stopping.

  • loss_fn (callable or None) – a PyTorch loss function should be left to None for self supervised and non experts

  • pretraining_ratio (float) – Between 0 and 1, percentage of feature to mask for reconstruction

  • weights (np.array) – Sampling weights for each example.

  • max_epochs (int) – Maximum number of epochs during training

  • patience (int) – Number of consecutive non improving epoch before early stopping

  • batch_size (int) – Training batch size

  • virtual_batch_size (int) – Batch size for Ghost Batch Normalization (virtual_batch_size < batch_size)

  • num_workers (int) – Number of workers used in torch.utils.data.DataLoader

  • drop_last (bool) – Whether to drop last batch during training

  • callbacks (list of callback function) – List of custom callbacks

  • pin_memory (bool) – Whether to set pin_memory to True or False during training

predict(X)[source]ΒΆ

Make predictions on a batch (valid)

Parameters:

X (a :tensor: torch.Tensor or matrix: scipy.sparse.csr_matrix) – Input data

Returns:

predictions – Predictions of the regression problem

Return type:

np.array

prepare_target(y)[source]ΒΆ
stack_batches(list_output, list_embedded_x, list_obfuscation)[source]ΒΆ
update_fit_params(weights)[source]ΒΆ