alpbench.util.pytorch_tabnet.tab_model¶

Classes

TabNetClassifier([n_d, n_a, n_steps, gamma, ...])

TabNetRegressor([n_d, n_a, n_steps, gamma, ...])

class alpbench.util.pytorch_tabnet.tab_model.TabNetClassifier(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=<factory>, cat_dims=<factory>, cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=0, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params=<factory>, scheduler_fn=None, scheduler_params=<factory>, mask_type='sparsemax', input_dim=None, output_dim=None, device_name='auto', n_shared_decoder=1, n_indep_decoder=1, grouped_features=<factory>)[source]¶

Bases: TabModel

compute_loss(y_pred, y_true)[source]¶
predict_func(outputs)[source]¶
predict_proba(X, get_embeds=False)[source]¶

Make predictions for classification on a batch (valid)

Parameters:

X (a :tensor: torch.Tensor or matrix: scipy.sparse.csr_matrix) – Input data

Returns:

res

Return type:

np.ndarray

prepare_target(y)[source]¶
stack_batches(list_y_true, list_y_score)[source]¶
update_fit_params(X_train, y_train, eval_set, weights)[source]¶
weight_updater(weights)[source]¶

Updates weights dictionary according to target_mapper.

Parameters:

weights (bool or dict) – Given weights for balancing training.

Returns:

Same bool if weights are bool, updated dict otherwise.

Return type:

bool or dict

class alpbench.util.pytorch_tabnet.tab_model.TabNetRegressor(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=<factory>, cat_dims=<factory>, cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=0, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params=<factory>, scheduler_fn=None, scheduler_params=<factory>, mask_type='sparsemax', input_dim=None, output_dim=None, device_name='auto', n_shared_decoder=1, n_indep_decoder=1, grouped_features=<factory>)[source]¶

Bases: TabModel

compute_loss(y_pred, y_true)[source]¶
predict_func(outputs)[source]¶
prepare_target(y)[source]¶
stack_batches(list_y_true, list_y_score)[source]¶
update_fit_params(X_train, y_train, eval_set, weights)[source]¶