MultiModalPredictor.fit#

MultiModalPredictor.fit(train_data: Union[DataFrame, str], presets: Optional[str] = None, config: Optional[dict] = None, tuning_data: Optional[Union[DataFrame, str]] = None, max_num_tuning_data: Optional[int] = None, id_mappings: Optional[Union[Dict[str, Dict], Dict[str, Series]]] = None, time_limit: Optional[int] = None, save_path: Optional[str] = None, hyperparameters: Optional[Union[str, Dict, List[str]]] = None, column_types: Optional[dict] = None, holdout_frac: Optional[float] = None, teacher_predictor: Optional[Union[str, MultiModalPredictor]] = None, seed: Optional[int] = 0, standalone: Optional[bool] = True, hyperparameter_tune_kwargs: Optional[dict] = None, clean_ckpts: Optional[bool] = True)[source]#

Fit MultiModalPredictor predict label column of a dataframe based on the other columns, which may contain image path, text, numeric, or categorical features.

Parameters
  • train_data – A dataframe containing training data.

  • presets – Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.

  • config

    A dictionary with four keys “model”, “data”, “optimization”, and “environment”. Each key’s value can be a string, yaml file path, or OmegaConf’s DictConfig. Strings should be the file names (DO NOT include the postfix “.yaml”) in automm/configs/model, automm/configs/data, automm/configs/optimization, and automm/configs/environment. For example, you can configure a late-fusion model for the image, text, and tabular data as follows: config = {

    ”model”: “fusion_mlp_image_text_tabular”, “data”: “default”, “optimization”: “adamw”, “environment”: “default”,

    }

    or config = {

    ”model”: “/path/to/model/config.yaml”, “data”: “/path/to/data/config.yaml”, “optimization”: “/path/to/optimization/config.yaml”, “environment”: “/path/to/environment/config.yaml”,

    }

    or config = {

    ”model”: OmegaConf.load(“/path/to/model/config.yaml”), “data”: OmegaConf.load(“/path/to/data/config.yaml”), “optimization”: OmegaConf.load(“/path/to/optimization/config.yaml”), “environment”: OmegaConf.load(“/path/to/environment/config.yaml”),

    }

  • tuning_data – A dataframe containing validation data, which should have the same columns as the train_data. If tuning_data = None, fit() will automatically hold out some random validation examples from train_data.

  • id_mappings – Id-to-content mappings. The contents can be text, image, etc. This is used when the dataframe contains the query/response identifiers instead of their contents.

  • time_limit – How long fit() should run for (wall clock time in seconds). If not specified, fit() will run until the model has completed training.

  • save_path – Path to directory where models and intermediate outputs should be saved.

  • hyperparameters

    This is to override some default configurations. For example, changing the text and image backbones can be done by formatting:

    a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”

    or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]

    or a dictionary hyperparameters = {

    ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,

    }

  • column_types

    A dictionary that maps column names to their data types. For example: column_types = {“item_name”: “text”, “image”: “image_path”, “product_description”: “text”, “height”: “numerical”} may be used for a table with columns: “item_name”, “brand”, “product_description”, and “height”. If None, column_types will be automatically inferred from the data. The current supported types are:

    • ”image_path”: each row in this column is one image path.

    • ”text”: each row in this column contains text (sentence, paragraph, etc.).

    • ”numerical”: each row in this column contains a number.

    • ”categorical”: each row in this column belongs to one of K categories.

  • holdout_frac – Fraction of train_data to holdout as tuning_data for optimizing hyper-parameters or early stopping (ignored unless tuning_data = None). Default value (if None) is selected based on the number of rows in the training data and whether hyper-parameter-tuning is utilized.

  • teacher_predictor – The pre-trained teacher predictor or its saved path. If provided, fit() can distill its knowledge to a student predictor, i.e., the current predictor.

  • seed – The random seed to use for this training run. Defaults to 0

  • standalone – Whether to save the enire model for offline deployment or only trained parameters of parameter-efficient fine-tuning strategy.

  • hyperparameter_tune_kwargs

    Hyperparameter tuning strategy and kwargs (for example, how many HPO trials to run). If None, then hyperparameter tuning will not be performed.

    num_trials: int

    How many HPO trials to run. Either num_trials or time_limit to fit needs to be specified.

    scheduler: Union[str, ray.tune.schedulers.TrialScheduler]

    If str is passed, AutoGluon will create the scheduler for you with some default parameters. If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.

    scheduler_init_args: Optional[dict] = None

    If provided str to scheduler, you can optionally provide custom init_args to the scheduler

    searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]

    If str is passed, AutoGluon will create the searcher for you with some default parameters. If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object. You don’t need to worry about metric and mode of the searcher object. AutoGluon will figure it out by itself.

    scheduler_init_args: Optional[dict] = None

    If provided str to searcher, you can optionally provide custom init_args to the searcher You don’t need to worry about metric and mode. AutoGluon will figure it out by itself.

  • clean_ckpts – Whether to clean the checkpoints of each validation step after training.

Return type

An “MultiModalPredictor” object (itself).