autogluon.multimodal.MultiModalPredictor

class autogluon.multimodal.MultiModalPredictor(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]

AutoMM is designed to simplify the fine-tuning of foundation models for downstream applications with just three lines of code. AutoMM seamlessly integrates with popular model zoos such as HuggingFace Transformers, TIMM, and MMDetection, accommodating a diverse range of data modalities, including image, text, tabular, and document data, whether used individually or in combination. It offers support for an array of tasks, encompassing classification, regression, object detection, named entity recognition, semantic matching, and image segmentation.

__init__(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]
Parameters:
  • label – Name of one pd.DataFrame column that contains the target variable to predict.

  • problem_type

    Type of problem. We support standard problems like

    • ’binary’: Binary classification

    • ’multiclass’: Multi-class classification

    • ’regression’: Regression

    • ’classification’: Classification problems include ‘binary’ and ‘multiclass’ classification.

    In addition, we support advanced problems such as

    • ’object_detection’: Object detection

    • ’ner’ or ‘named_entity_recognition’: Named entity extraction

    • ’text_similarity’: Text-text semantic matching

    • ’image_similarity’: Image-image semantic matching

    • ’image_text_similarity’: Text-image semantic matching

    • ’feature_extraction’: Extracting feature (only support inference)

    • ’zero_shot_image_classification’: Zero-shot image classification (only support inference)

    • ’few_shot_classification’: Few-shot classification for image or text data.

    • ’semantic_segmentation’: Semantic segmentation with Segment Anything Model.

    For certain problem types, the default behavior is to load a pretrained model based on the presets / hyperparameters and the predictor can do zero-shot inference (running inference without .fit()). Those include the following problem types:

    • ’object_detection’

    • ’text_similarity’

    • ’image_similarity’

    • ’image_text_similarity’

    • ’feature_extraction’

    • ’zero_shot_image_classification’

  • query – Name of one pd.DataFrame column that has the query data in semantic matching tasks.

  • response – Name of one pd.DataFrame column that contains the response data in semantic matching tasks. If no label column is provided, the query and response pairs in one pd.DataFrame row are assumed to be positive pairs.

  • match_label – The label class that indicates the <query, response> pair is counted as a “match”. This is used when the task belongs to semantic matching, and the labels are binary. For example, the label column can contain [“duplicate”, “not duplicate”] in a duplicate detection task. The match_label should be “duplicate” since it means that two items match.

  • presets – Presets regarding model quality, e.g., ‘best_quality’, ‘high_quality’ (default), and ‘medium_quality’. Each quality has its corresponding HPO presets: ‘best_quality_hpo’, ‘high_quality_hpo’, and ‘medium_quality_hpo’.

  • eval_metric – Evaluation metric name. If eval_metric = None, it is automatically chosen based on problem_type. Defaults to ‘accuracy’ for multiclass classification, roc_auc for binary classification, and ‘root_mean_squared_error’ for regression.

  • hyperparameters

    This is to override some default configurations. For example, changing the text and image backbones can be done by formatting:

    a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”

    or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]

    or a dictionary hyperparameters = {

    ”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,

    }

  • path – Path to directory where models and related artifacts should be saved. If unspecified, a time-stamped folder called “AutogluonAutoMM/ag-[TIMESTAMP]” will be created in the working directory. Note: To call fit() twice and save all results of each fit, you must specify different path locations or don’t specify path at all.

  • verbosity – Verbosity levels range from 0 to 4, controlling how much logging information is printed. Higher levels correspond to more detailed print statements. You can set verbosity = 0 to suppress warnings.

  • num_classes – Number of classes (used for object detection). If this is specified and is different from the pretrained model’s output shape, the model’s head will be changed to have <num_classes> output.

  • classes – All the classes (used for object detection).

  • warn_if_exist – Whether to raise warning if the specified path already exists (Default True).

  • enable_progress_bar – Whether to show progress bar (default True). It would be disabled if the environment variable os.environ[“AUTOMM_DISABLE_PROGRESS_BAR”] is set.

  • pretrained – Whether to initialize the model with pretrained weights (default True). If False, it creates a model with random initialization.

  • validation_metric – Validation metric for selecting the best model and early-stopping during training. If not provided, it would be automatically chosen based on the problem type.

  • sample_data_path – The path to sample data from which we can infer num_classes or classes used for object detection.

  • use_ensemble – Whether to use ensembling when fitting the predictor (Default False). Currently, it works only on multimodal data (image+text, image+tabular, text+tabular, image+text+tabular) with classification or regression tasks.

  • ensemble_size – A multiple of number of models in the ensembling pool (Default 2). The actual ensemble size = ensemble_size * the model number

  • ensemble_mode – The mode of conducting ensembling: - one_shot: the classic ensemble selection - sequential: iteratively calling the classic ensemble selection with each time growing the model zoo by the best next model.

Methods

dump_model

Save model weights and config to a local directory.

evaluate

Evaluate the model on a given dataset.

export_onnx

Export this predictor's model to an ONNX file.

extract_embedding

Extract features for each sample, i.e., one row in the provided pd.DataFrame data.

fit

Fit models to predict a column of a data table (label) based on the other columns (features).

fit_summary

Output the training summary information from fit().

get_num_gpus

Get the number of GPUs from config.

list_supported_models

List supported models for each problem type.

load

Load a predictor object from a directory specified by path.

optimize_for_inference

Optimize the predictor's model for inference.

predict

Predict the label column values for new data.

predict_proba

Predict class probabilities rather than class labels.

save

Save this predictor to file in directory specified by path.

set_num_gpus

Set the number of GPUs in config.

set_verbosity

Set the verbosity level of the log.

Attributes

class_labels

The original name of the class labels.

classes

Object classes for the object detection problem type.

column_types

Column types in the pd.DataFrame.

eval_metric

What metric is used to evaluate predictive performance.

label

Name of one pd.DataFrame column that contains the target variable to predict.

match_label

The label class that indicates the <query, response> pair is counted as "match" in the semantic matching tasks.

model_size

Returns the model size in Megabyte.

path

Path to directory where the model and related artifacts are stored.

positive_class

Name of the class label that will be mapped to 1.

problem_property

Property of the problem, storing the problem type and its related properties.

problem_type

What type of prediction problem this predictor has been trained for.

query

Name of one pd.DataFrame column that has the query data in semantic matching tasks.

response

Name of one pd.DataFrame column that contains the response data in semantic matching tasks.

total_parameters

The number of model parameters.

trainable_parameters

The number of trainable model parameters, usually those with requires_grad=True.

validation_metric

Validation metric for selecting the best model and early-stopping during training.

verbosity

Verbosity levels range from 0 to 4 and control how much information is printed.