autogluon.multimodal.MultiModalPredictor#
- class autogluon.multimodal.MultiModalPredictor(label: Optional[str] = None, problem_type: Optional[str] = None, query: Optional[Union[str, List[str]]] = None, response: Optional[Union[str, List[str]]] = None, match_label: Optional[Union[int, str]] = None, pipeline: Optional[str] = None, presets: Optional[str] = None, eval_metric: Optional[str] = None, hyperparameters: Optional[dict] = None, path: Optional[str] = None, verbosity: Optional[int] = 2, num_classes: Optional[int] = None, classes: Optional[list] = None, warn_if_exist: Optional[bool] = True, enable_progress_bar: Optional[bool] = None, init_scratch: Optional[bool] = False, sample_data_path: Optional[str] = None)[source]#
MultiModalPredictor is a deep learning “model zoo” of model zoos. It can automatically build deep learning models that are suitable for multimodal datasets. You will only need to preprocess the data in the multimodal dataframe format and the MultiModalPredictor can predict the values of one column conditioned on the features from the other columns.
The prediction can be either classification or regression. The feature columns can contain image paths, text, numerical, and categorical values.
- __init__(label: Optional[str] = None, problem_type: Optional[str] = None, query: Optional[Union[str, List[str]]] = None, response: Optional[Union[str, List[str]]] = None, match_label: Optional[Union[int, str]] = None, pipeline: Optional[str] = None, presets: Optional[str] = None, eval_metric: Optional[str] = None, hyperparameters: Optional[dict] = None, path: Optional[str] = None, verbosity: Optional[int] = 2, num_classes: Optional[int] = None, classes: Optional[list] = None, warn_if_exist: Optional[bool] = True, enable_progress_bar: Optional[bool] = None, init_scratch: Optional[bool] = False, sample_data_path: Optional[str] = None)[source]#
- Parameters
label – Name of the column that contains the target variable to predict.
problem_type –
Type of the prediction problem. We support standard problems like
’binary’: Binary classification
’multiclass’: Multi-class classification
’regression’: Regression
’classification’: Classification problems include ‘binary’ and ‘multiclass’ classification.
In addition, we support advanced problems such as
’object_detection’: Object detection
’ner’ or ‘named_entity_recognition’: Named entity extraction
’text_similarity’: Text-text similarity problem
’image_similarity’: Image-image similarity problem
’image_text_similarity’: Text-image similarity problem
’feature_extraction’: Extracting feature (only support inference)
’zero_shot_image_classification’: Zero-shot image classification (only support inference)
’few_shot_text_classification’: (experimental) Few-shot text classification
’ocr_text_detection’: (experimental) Extract OCR text
’ocr_text_recognition’: (experimental) Recognize OCR text
For certain problem types, the default behavior is to load a pretrained model based on the presets / hyperparameters and the predictor will support zero-shot inference (running inference without .fit()). This includes the following problem types:
’object_detection’
’text_similarity’
’image_similarity’
’image_text_similarity’
’feature_extraction’
’zero_shot_image_classification’
’few_shot_text_classification’ (experimental)
’ocr_text_detection’ (experimental)
’ocr_text_recognition’ (experimental)
query – Column names of query data (used for matching).
response – Column names of response data (used for matching). If no label column is provided, query and response columns form positive pairs.
match_label – The label class that indicates the <query, response> pair is counted as “match”. This is used when the problem_type is one of the matching problem types, and when the labels are binary. For example, the label column can contain [“duplicate”, “not duplicate”]. And match_label can be “duplicate”. If match_label is not provided, every sample is assumed to have a unique label.
pipeline – Pipeline has been deprecated and merged in problem_type.
presets – Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
eval_metric – Evaluation metric name. If eval_metric = None, it is automatically chosen based on problem_type. Defaults to ‘accuracy’ for binary and multiclass classification, ‘root_mean_squared_error’ for regression.
hyperparameters –
This is to override some default configurations. For example, changing the text and image backbones can be done by formatting:
a string hyperparameters = “model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224”
or a list of strings hyperparameters = [“model.hf_text.checkpoint_name=google/electra-small-discriminator”, “model.timm_image.checkpoint_name=swin_small_patch4_window7_224”]
or a dictionary hyperparameters = {
”model.hf_text.checkpoint_name”: “google/electra-small-discriminator”, “model.timm_image.checkpoint_name”: “swin_small_patch4_window7_224”,
}
path – Path to directory where models and intermediate outputs should be saved. If unspecified, a time-stamped folder called “AutogluonAutoMM/ag-[TIMESTAMP]” will be created in the working directory to store all models. Note: To call fit() twice and save all results of each fit, you must specify different path locations or don’t specify path at all. Otherwise files from first fit() will be overwritten by second fit().
verbosity – Verbosity levels range from 0 to 4 and control how much information is printed. Higher levels correspond to more detailed print statements (you can set verbosity = 0 to suppress warnings). If using logging, you can alternatively control amount of information printed via logger.setLevel(L), where L ranges from 0 to 50 (Note: higher values of L correspond to fewer print statements, opposite of verbosity levels)
num_classes – Number of classes. Used in classification task. If this is specified and is different from the pretrained model’s output, the model’s head will be changed to have <num_classes> output.
classes – All classes in this dataset.
warn_if_exist – Whether to raise warning if the specified path already exists.
enable_progress_bar – Whether to show progress bar. It will be True by default and will also be disabled if the environment variable os.environ[“AUTOMM_DISABLE_PROGRESS_BAR”] is set.
init_scratch – Whether to init model from scratch. It’s useful when we want to load a checkpoints without its weights.
sample_data_path – This is used for automatically inference num_classes, classes, or label.
Methods
Evaluate model on a test dataset.
Extract features for each sample, i.e., one row in the provided dataframe data.
Fit MultiModalPredictor predict label column of a dataframe based on the other columns, which may contain image path, text, numeric, or categorical features.
Output summary of information about models produced during fit().
returns the classes of the detection (only works for detection)
List supported models for each problem_type to let users know options of checkpoint name to choose during fit().
Load a predictor object from a directory specified by path.
Predict values for the label column of new data.
Predict probabilities class probabilities rather than class labels.
Save this predictor to file in directory specified by path.
Set the verbosity level of the log.
Attributes
class_labels
The original name of the class labels.
column_types
label
match_label
path
positive_class
Name of the class label that will be mapped to 1.
problem_property
problem_type
query
response