hyperimpute.plugins.prediction.classifiers.plugin_neural_nets module

class BasicNet(n_unit_in: int, categories_cnt: int, n_layers_hidden: int = 1, n_units_hidden: int = 100, nonlin: str = 'relu', lr: float = 0.001, weight_decay: float = 0.001, n_iter: int = 300, batch_size: int = 1024, n_iter_print: int = 10, random_state: int = 0, patience: int = 10, n_iter_min: int = 100, dropout: float = 0.1, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True)

Bases: Module

Basic neural net.

Parameters:
  • n_unit_in (int) – Number of features

  • categories (int) –

  • n_layers_hidden (int) – Number of hypothesis layers (n_layers_hidden x n_units_hidden + 1 x Linear layer)

  • n_units_hidden (int) – Number of hidden units in each hypothesis layer

  • nonlin (string, default 'elu') – Nonlinearity to use in NN. Can be ‘elu’, ‘relu’, ‘selu’ or ‘leaky_relu’.

  • lr (float) – learning rate for optimizer. step_size equivalent in the JAX version.

  • weight_decay (float) – l2 (ridge) penalty for the weights.

  • n_iter (int) – Maximum number of iterations.

  • batch_size (int) – Batch size

  • n_iter_print (int) – Number of iterations after which to print updates and check the validation loss.

  • random_state (int) – random_state used

  • val_split_prop (float) – Proportion of samples used for validation split (can be 0)

  • patience (int) – Number of iterations to wait before early stopping after decrease in validation loss

  • n_iter_min (int) – Minimum number of iterations to go through before starting early stopping

  • clipping_value (int, default 1) – Gradients clipping value

_backward_hooks: Dict[int, Callable]
_buffers: Dict[str, Optional[Tensor]]
_check_tensor(X: Tensor) Tensor
_forward_hooks: Dict[int, Callable]
_forward_pre_hooks: Dict[int, Callable]
_is_full_backward_hook: Optional[bool]
_load_state_dict_post_hooks: Dict[int, Callable]
_load_state_dict_pre_hooks: Dict[int, Callable]
_modules: Dict[str, Optional[Module]]
_non_persistent_buffers_set: Set[str]
_parameters: Dict[str, Optional[Parameter]]
_state_dict_hooks: Dict[int, Callable]
forward(X: Tensor) Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

train(X: Tensor, y: Tensor) BasicNet

Sets the module in training mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

Parameters:

mode (bool) – whether to set training mode (True) or evaluation mode (False). Default: True.

Returns:

self

Return type:

Module

training: bool
class NeuralNetsPlugin(n_layers_hidden: int = 1, n_units_hidden: int = 100, nonlin: str = 'relu', lr: float = 0.001, weight_decay: float = 0.001, n_iter: int = 1000, batch_size: int = 128, n_iter_print: int = 10, random_state: int = 0, patience: int = 10, n_iter_min: int = 100, dropout: float = 0.1, clipping_value: int = 1, batch_norm: bool = True, early_stopping: bool = True, hyperparam_search_iterations: Optional[int] = None, **kwargs: Any)

Bases: ClassifierPlugin

Classification plugin based on Neural networks.

Example

>>> from hyperimpute.plugins.prediction import Predictions
>>> plugin = Predictions(category="classifiers").get("neural_nets")
>>> from sklearn.datasets import load_iris
>>> X, y = load_iris(return_X_y=True)
>>> plugin.fit_predict(X, y) # returns the probabilities for each class
_abc_impl = <_abc_data object>
_fit(X: DataFrame, *args: Any, **kwargs: Any) NeuralNetsPlugin
_predict(X: DataFrame, *args: Any, **kwargs: Any) DataFrame
_predict_proba(X: DataFrame, *args: Any, **kwargs: Any) DataFrame
static hyperparameter_space(*args: Any, **kwargs: Any) List[Params]
module_relative_path: Optional[Path]
static name() str
plugin

alias of NeuralNetsPlugin