autogluon.multimodal.MultiModalPredictor¶
- class autogluon.multimodal.MultiModalPredictor(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]¶
AutoMM is designed to simplify the fine-tuning of foundation models for downstream applications with just three lines of code. AutoMM seamlessly integrates with popular model zoos such as HuggingFace Transformers, TIMM, and MMDetection, accommodating a diverse range of data modalities, including image, text, tabular, and document data, whether used individually or in combination. It offers support for an array of tasks, encompassing classification, regression, object detection, named entity recognition, semantic matching, and image segmentation.
- __init__(label: str | None = None, problem_type: str | None = None, query: str | List[str] | None = None, response: str | List[str] | None = None, match_label: int | str | None = None, presets: str | None = None, eval_metric: str | Scorer | None = None, hyperparameters: dict | None = None, path: str | None = None, verbosity: int | None = 2, num_classes: int | None = None, classes: list | None = None, warn_if_exist: bool | None = True, enable_progress_bar: bool | None = None, pretrained: bool | None = True, validation_metric: str | None = None, sample_data_path: str | None = None, use_ensemble: bool | None = False, ensemble_size: int | None = 2, ensemble_mode: str | None = 'one_shot')[source]¶
- Parameters:
label – Name of one pd.DataFrame column that contains the target variable to predict.
problem_type –
Type of problem. We support standard problems like
’binary’: Binary classification
’multiclass’: Multi-class classification
’regression’: Regression
’classification’: Classification problems include ‘binary’ and ‘multiclass’ classification.
In addition, we support advanced problems such as
’object_detection’: Object detection
’ner’ or ‘named_entity_recognition’: Named entity extraction
’text_similarity’: Text-text semantic matching
’image_similarity’: Image-image semantic matching
’image_text_similarity’: Text-image semantic matching
’feature_extraction’: Extracting feature (only support inference)
’zero_shot_image_classification’: Zero-shot image classification (only support inference)
’few_shot_classification’: Few-shot classification for image or text data.
’semantic_segmentation’: Semantic segmentation with Segment Anything Model.
For certain problem types, the default behavior is to load a pretrained model based on the presets / hyperparameters and the predictor can do zero-shot inference (running inference without .fit()). Those include the following problem types:
’object_detection’
’text_similarity’
’image_similarity’
’image_text_similarity’
’feature_extraction’
’zero_shot_image_classification’
query – Name of one pd.DataFrame column that has the query data in semantic matching tasks.
response – Name of one pd.DataFrame column that contains the response data in semantic matching tasks. If no label column is provided, the query and response pairs in one pd.DataFrame row are assumed to be positive pairs.
match_label – The label class that indicates the <query, response> pair is counted as a “match”. This is used when the task belongs to semantic matching, and the labels are binary. For example, the label column can contain [“duplicate”, “not duplicate”] in a duplicate detection task. The match_label should be “duplicate” since it means that two items match.
presets – Presets regarding model quality, e.g., ‘best_quality’, ‘high_quality’ (default), and ‘medium_quality’. Each quality has its corresponding HPO presets: ‘best_quality_hpo’, ‘high_quality_hpo’, and ‘medium_quality_hpo’.
eval_metric – Evaluation metric name. If eval_metric = None, it is automatically chosen based on problem_type. Defaults to ‘accuracy’ for multiclass classification, roc_auc for binary classification, and ‘root_mean_squared_error’ for regression.
hyperparameters –
This is to override some default configurations. For example, changing the text and image backbones can be done by formatting:
a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”
or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]
or a dictionary hyperparameters = {
”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,
}
path – Path to directory where models and related artifacts should be saved. If unspecified, a time-stamped folder called “AutogluonAutoMM/ag-[TIMESTAMP]” will be created in the working directory. Note: To call fit() twice and save all results of each fit, you must specify different path locations or don’t specify path at all.
verbosity – Verbosity levels range from 0 to 4, controlling how much logging information is printed. Higher levels correspond to more detailed print statements. You can set verbosity = 0 to suppress warnings.
num_classes – Number of classes (used for object detection). If this is specified and is different from the pretrained model’s output shape, the model’s head will be changed to have <num_classes> output.
classes – All the classes (used for object detection).
warn_if_exist – Whether to raise warning if the specified path already exists (Default True).
enable_progress_bar – Whether to show progress bar (default True). It would be disabled if the environment variable os.environ[“AUTOMM_DISABLE_PROGRESS_BAR”] is set.
pretrained – Whether to initialize the model with pretrained weights (default True). If False, it creates a model with random initialization.
validation_metric – Validation metric for selecting the best model and early-stopping during training. If not provided, it would be automatically chosen based on the problem type.
sample_data_path – The path to sample data from which we can infer num_classes or classes used for object detection.
use_ensemble – Whether to use ensembling when fitting the predictor (Default False). Currently, it works only on multimodal data (image+text, image+tabular, text+tabular, image+text+tabular) with classification or regression tasks.
ensemble_size – A multiple of number of models in the ensembling pool (Default 2). The actual ensemble size = ensemble_size * the model number
ensemble_mode – The mode of conducting ensembling: - one_shot: the classic ensemble selection - sequential: iteratively calling the classic ensemble selection with each time growing the model zoo by the best next model.
Methods
Save model weights and config to a local directory.
Evaluate the model on a given dataset.
Export this predictor's model to an ONNX file.
Extract features for each sample, i.e., one row in the provided pd.DataFrame data.
Fit models to predict a column of a data table (label) based on the other columns (features).
Output the training summary information from fit().
Get the number of GPUs from config.
List supported models for each problem type.
Load a predictor object from a directory specified by path.
Optimize the predictor's model for inference.
Predict the label column values for new data.
Predict class probabilities rather than class labels.
Save this predictor to file in directory specified by path.
Set the number of GPUs in config.
Set the verbosity level of the log.
Attributes
class_labels
The original name of the class labels.
classes
Object classes for the object detection problem type.
column_types
Column types in the pd.DataFrame.
eval_metric
What metric is used to evaluate predictive performance.
label
Name of one pd.DataFrame column that contains the target variable to predict.
match_label
The label class that indicates the <query, response> pair is counted as "match" in the semantic matching tasks.
model_size
Returns the model size in Megabyte.
path
Path to directory where the model and related artifacts are stored.
positive_class
Name of the class label that will be mapped to 1.
problem_property
Property of the problem, storing the problem type and its related properties.
problem_type
What type of prediction problem this predictor has been trained for.
query
Name of one pd.DataFrame column that has the query data in semantic matching tasks.
response
Name of one pd.DataFrame column that contains the response data in semantic matching tasks.
total_parameters
The number of model parameters.
trainable_parameters
The number of trainable model parameters, usually those with requires_grad=True.
validation_metric
Validation metric for selecting the best model and early-stopping during training.
verbosity
Verbosity levels range from 0 to 4 and control how much information is printed.