dai.train.Rd
If not specified, the optional parameters are determined by the Driverless AI
platform, see also dai.suggest_model_params
.
dai.train(training_frame, target_col, is_classification, is_timeseries, testing_frame = NULL, validation_frame = NULL, resumed_model = NULL, weight_col = NULL, fold_col = NULL, time_col = NULL, scorer = NULL, cols_to_drop = NULL, accuracy = NULL, time = NULL, interpretability = NULL, time_groups_columns = NULL, time_period_in_seconds = NULL, num_prediction_periods = NULL, num_gap_periods = NULL, enable_gpus = NULL, config_overrides = NULL, seed = NULL, progress = getOption("dai.progress", TRUE))
training_frame | DAIFrame to use to build the model. |
---|---|
target_col | The name of the target variable. |
is_classification | Whether the predicted variable is categorical (TRUE) or numerical (FALSE). |
is_timeseries | Whether the target variable is a time-series or not. |
testing_frame | DAIFrame to evaluate the model on at the end. It is not used for the model training (optional). |
validation_frame | DAIFrame to use for the model validation during the model training (optional). |
resumed_model | Model or model key used for retraining/re-ensembling/starting from checkpoint (optional). |
weight_col | Weights column name (optional). |
fold_col | Fold column name (optional). |
time_col | Time column name, containing time ordering for timeseries problems (optional). |
scorer | Name of one of the available scorers (optional). |
cols_to_drop | A character vector of column names to be dropped from the data (optional). |
accuracy | Accuracy setting [1-10] (optional). |
time | Time setting [1-10] (optional). |
interpretability | Interpretability setting [1-10] (optional). |
time_groups_columns | List of column names, contributing to time ordering (optional). |
time_period_in_seconds | Size of Lag features in seconds, used in timeseries problems (optional). |
num_prediction_periods | Timeseries forecast horizont in time period units (optional). |
num_gap_periods | Number of time periods after which forecast starts (optional). |
enable_gpus | Whether to use GPUs (optional). |
config_overrides | DriverlessAI config overrides for separate experiment in TOML string format (optional). |
seed | The random number generator's seed (optional). |
progress | Whether to display a progress bar (optional). |
DAIModel
dai.suggest_model_params
# NOT RUN { dai.connect(uri = 'http://127.0.0.1:12345', username = 'h2oai', password = 'h2oai') iris_dai <- as.DAIFrame(iris, progress = FALSE) # Simple model with minimal parameters simple_model <- dai.train(training_frame = iris_dai, target_col = 'Species', is_classification = TRUE, is_timeseries = FALSE, time = 1, accuracy = 1, interpretability = 10, progress = FALSE) print(simple_model) # }# NOT RUN { # More complex model that may take more time to train model <- dai.train(training_frame = iris_dai, target_col = 'Species', is_classification = TRUE, is_timeseries = FALSE, time = 10, accuracy = 10, interpretability = 5, progress = FALSE) print(model) # Custom config to enable compliant recipe (see config.toml for more details) compliant_model <- dai.train(training_frame = iris_dai, target_col = 'Species', is_classification = TRUE, is_timeseries = FALSE, time = 10, accuracy = 10, config_overrides = "recipe = 'compliant'", progress = FALSE) print(compliant_model) # }