"""Implementation of the multimodal predictor"""
from __future__ import annotations
import json
import logging
import os
import warnings
from typing import Dict, List, Optional, Union
import pandas as pd
import transformers
from autogluon.common.utils.log_utils import set_logger_verbosity, verbosity2loglevel
from autogluon.core.metrics import Scorer
from .constants import AUTOMM_TUTORIAL_MODE, FEW_SHOT_CLASSIFICATION, NER, OBJECT_DETECTION, SEMANTIC_SEGMENTATION
from .learners import (
BaseLearner,
FewShotSVMLearner,
MultiModalMatcher,
NERLearner,
ObjectDetectionLearner,
SemanticSegmentationLearner,
)
from .problem_types import PROBLEM_TYPES_REG
from .utils import get_dir_ckpt_paths
pl_logger = logging.getLogger("lightning")
pl_logger.propagate = False # https://github.com/Lightning-AI/lightning/issues/4621
logger = logging.getLogger(__name__)
[docs]
class MultiModalPredictor:
"""
AutoMM is designed to simplify the fine-tuning of foundation models
for downstream applications with just three lines of code.
AutoMM seamlessly integrates with popular model zoos such as
`HuggingFace Transformers <https://github.com/huggingface/transformers>`_,
`TIMM <https://github.com/huggingface/pytorch-image-models>`_,
and `MMDetection <https://github.com/open-mmlab/mmdetection>`_,
accommodating a diverse range of data modalities,
including image, text, tabular, and document data, whether used individually or in combination.
It offers support for an array of tasks, encompassing classification, regression,
object detection, named entity recognition, semantic matching, and image segmentation.
"""
[docs]
def __init__(
self,
label: Optional[str] = None,
problem_type: Optional[str] = None,
query: Optional[Union[str, List[str]]] = None,
response: Optional[Union[str, List[str]]] = None,
match_label: Optional[Union[int, str]] = None,
presets: Optional[str] = None,
eval_metric: Optional[Union[str, Scorer]] = None,
hyperparameters: Optional[dict] = None,
path: Optional[str] = None,
verbosity: Optional[int] = 2,
num_classes: Optional[int] = None, # TODO: can we infer this from data?
classes: Optional[list] = None,
warn_if_exist: Optional[bool] = True,
enable_progress_bar: Optional[bool] = None,
pretrained: Optional[bool] = True,
validation_metric: Optional[str] = None,
sample_data_path: Optional[str] = None,
):
"""
Parameters
----------
label
Name of one pd.DataFrame column that contains the target variable to predict.
problem_type
Type of problem. We support standard problems like
- 'binary': Binary classification
- 'multiclass': Multi-class classification
- 'regression': Regression
- 'classification': Classification problems include 'binary' and 'multiclass' classification.
In addition, we support advanced problems such as
- 'object_detection': Object detection
- 'ner' or 'named_entity_recognition': Named entity extraction
- 'text_similarity': Text-text semantic matching
- 'image_similarity': Image-image semantic matching
- 'image_text_similarity': Text-image semantic matching
- 'feature_extraction': Extracting feature (only support inference)
- 'zero_shot_image_classification': Zero-shot image classification (only support inference)
- 'few_shot_classification': Few-shot classification for image or text data.
- 'semantic_segmentation': Semantic segmentation with Segment Anything Model.
For certain problem types, the default behavior is to load a pretrained model based on
the presets / hyperparameters and the predictor can do zero-shot inference
(running inference without .fit()). Those include the following
problem types:
- 'object_detection'
- 'text_similarity'
- 'image_similarity'
- 'image_text_similarity'
- 'feature_extraction'
- 'zero_shot_image_classification'
query
Name of one pd.DataFrame column that has the query data in semantic matching tasks.
response
Name of one pd.DataFrame column that contains the response data in semantic matching tasks.
If no label column is provided, the query and response pairs in
one pd.DataFrame row are assumed to be positive pairs.
match_label
The label class that indicates the <query, response> pair is counted as a "match".
This is used when the task belongs to semantic matching, and the labels are binary.
For example, the label column can contain ["duplicate", "not duplicate"] in a duplicate detection task.
The match_label should be "duplicate" since it means that two items match.
presets
Presets regarding model quality, e.g., 'best_quality', 'high_quality' (default), and 'medium_quality'.
Each quality has its corresponding HPO presets: 'best_quality_hpo', 'high_quality_hpo', and 'medium_quality_hpo'.
eval_metric
Evaluation metric name. If `eval_metric = None`, it is automatically chosen based on `problem_type`.
Defaults to 'accuracy' for multiclass classification, `roc_auc` for binary classification,
and 'root_mean_squared_error' for regression.
hyperparameters
This is to override some default configurations.
For example, changing the text and image backbones can be done by formatting:
a string
hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
or a list of strings
hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
or a dictionary
hyperparameters = {
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
"model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
}
path
Path to directory where models and related artifacts should be saved.
If unspecified, a time-stamped folder called "AutogluonAutoMM/ag-[TIMESTAMP]"
will be created in the working directory.
Note: To call `fit()` twice and save all results of each fit,
you must specify different `path` locations or don't specify `path` at all.
verbosity
Verbosity levels range from 0 to 4, controlling how much logging information is printed.
Higher levels correspond to more detailed print statements.
You can set verbosity = 0 to suppress warnings.
num_classes
Number of classes (used for object detection).
If this is specified and is different from the pretrained model's output shape,
the model's head will be changed to have <num_classes> output.
classes
All the classes (used for object detection).
warn_if_exist
Whether to raise warning if the specified path already exists (Default True).
enable_progress_bar
Whether to show progress bar (default True). It would be
disabled if the environment variable os.environ["AUTOMM_DISABLE_PROGRESS_BAR"] is set.
pretrained
Whether to initialize the model with pretrained weights (default True).
If False, it creates a model with random initialization.
validation_metric
Validation metric for selecting the best model and early-stopping during training.
If not provided, it would be automatically chosen based on the problem type.
sample_data_path
The path to sample data from which we can infer num_classes or classes used for object detection.
"""
if problem_type is not None:
problem_type = problem_type.lower()
assert problem_type in PROBLEM_TYPES_REG, (
f"problem_type='{problem_type}' is not supported yet. You may pick a problem type from"
f" {PROBLEM_TYPES_REG.list_keys()}."
)
problem_property = PROBLEM_TYPES_REG.get(problem_type)
if problem_property.experimental:
warnings.warn(
f"problem_type='{problem_type}' is currently experimental.",
UserWarning,
)
problem_type = problem_property.name
else:
problem_property = None
if os.environ.get(AUTOMM_TUTORIAL_MODE):
enable_progress_bar = False
# Also disable progress bar of transformers package
transformers.logging.disable_progress_bar()
if verbosity is not None:
set_logger_verbosity(verbosity)
self._verbosity = verbosity
if problem_property and problem_property.is_matching:
learner_class = MultiModalMatcher
elif problem_type == OBJECT_DETECTION:
learner_class = ObjectDetectionLearner
elif problem_type == NER:
learner_class = NERLearner
elif problem_type == FEW_SHOT_CLASSIFICATION:
learner_class = FewShotSVMLearner
elif problem_type == SEMANTIC_SEGMENTATION:
learner_class = SemanticSegmentationLearner
else:
learner_class = BaseLearner
self._learner = learner_class(
label=label,
problem_type=problem_type,
presets=presets,
eval_metric=eval_metric,
hyperparameters=hyperparameters,
path=path,
verbosity=verbosity,
num_classes=num_classes,
classes=classes,
warn_if_exist=warn_if_exist,
enable_progress_bar=enable_progress_bar,
pretrained=pretrained,
sample_data_path=sample_data_path,
validation_metric=validation_metric,
query=query,
response=response,
match_label=match_label,
)
@property
def path(self):
"""
Path to directory where the model and related artifacts are stored.
"""
return self._learner.path
@property
def label(self):
"""
Name of one pd.DataFrame column that contains the target variable to predict.
"""
return self._learner.label
@property
def query(self):
"""
Name of one pd.DataFrame column that has the query data in semantic matching tasks.
"""
return self._learner.query
@property
def response(self):
"""
Name of one pd.DataFrame column that contains the response data in semantic matching tasks.
"""
return self._learner.response
@property
def match_label(self):
"""
The label class that indicates the <query, response> pair is counted as "match" in the semantic matching tasks.
"""
return self._learner.match_label
@property
def problem_type(self):
"""
What type of prediction problem this predictor has been trained for.
"""
return self._learner.problem_type
@property
def problem_property(self):
"""
Property of the problem, storing the problem type and its related properties.
"""
return self._learner.problem_property
@property
def column_types(self):
"""
Column types in the pd.DataFrame.
"""
return self._learner.column_types
@property
def eval_metric(self):
"""
What metric is used to evaluate predictive performance.
"""
return self._learner.eval_metric
@property
def validation_metric(self):
"""
Validation metric for selecting the best model and early-stopping during training.
Note that the validation metric may be different from the evaluation metric.
"""
return self._learner.validation_metric
@property
def verbosity(self):
"""
Verbosity levels range from 0 to 4 and control how much information is printed.
Higher levels correspond to more detailed print statements.
"""
return self._verbosity
@property
def total_parameters(self) -> int:
"""
The number of model parameters.
"""
return self._learner.total_parameters
@property
def trainable_parameters(self) -> int:
"""
The number of trainable model parameters, usually those with requires_grad=True.
"""
return self._learner.trainable_parameters
@property
def model_size(self) -> float:
"""
Returns the model size in Megabyte.
"""
return self._learner.model_size
@property
def classes(self):
"""
Object classes for the object detection problem type.
"""
return self._learner.classes
@property
def class_labels(self):
"""
The original name of the class labels.
For example, the tabular data may contain classes equal to
"entailment", "contradiction", "neutral". Internally, these will be converted to
0, 1, 2, ...
This function returns the original names of these raw labels.
Returns
-------
List that contain the class names. It will be None if it's not a classification problem.
"""
return self._learner.class_labels
@property
def positive_class(self):
"""
Name of the class label that will be mapped to 1.
This is only meaningful for binary classification problems.
It is useful for computing metrics such as F1 which require a positive and negative class.
You may refer to https://en.wikipedia.org/wiki/F-score for more details.
In binary classification, :class:`MultiModalPredictor.predict_proba(as_multiclass=False)`
returns the estimated probability that each row belongs to the positive class.
Will print a warning and return None if called when `predictor.problem_type != 'binary'`.
Returns
-------
The positive class name in binary classification or None if the problem is not binary classification.
"""
return self._learner.positive_class
# This func is required by the abstract trainer of TabularPredictor.
[docs]
def set_verbosity(self, verbosity: int):
"""Set the verbosity level of the log.
Parameters
----------
verbosity
The verbosity level.
0 --> only errors
1 --> only warnings and critical print statements
2 --> key print statements which should be shown by default
3 --> more-detailed printing
4 --> everything
"""
self._verbosity = verbosity
set_logger_verbosity(verbosity)
# TODO: align verbosity2loglevel with https://huggingface.co/docs/transformers/main_classes/logging#transformers.utils.logging.get_verbosity
[docs]
def set_num_gpus(self, num_gpus):
"""
Set the number of GPUs in config.
"""
self._learner.set_num_gpus(num_gpus)
[docs]
def get_num_gpus(self):
"""
Get the number of GPUs from config.
"""
self._learner.get_num_gpus()
[docs]
def fit(
self,
train_data: Union[pd.DataFrame, str],
presets: Optional[str] = None,
tuning_data: Optional[Union[pd.DataFrame, str]] = None,
max_num_tuning_data: Optional[int] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
time_limit: Optional[int] = None,
save_path: Optional[str] = None,
hyperparameters: Optional[Union[str, Dict, List[str]]] = None,
column_types: Optional[dict] = None,
holdout_frac: Optional[float] = None,
teacher_predictor: Union[str, MultiModalPredictor] = None,
seed: Optional[int] = 0,
standalone: Optional[bool] = True,
hyperparameter_tune_kwargs: Optional[dict] = None,
clean_ckpts: Optional[bool] = True,
):
"""
Fit models to predict a column of a data table (label) based on the other columns (features).
Parameters
----------
train_data
A pd.DataFrame containing training data.
presets
Presets regarding model quality, e.g., best_quality, high_quality, and medium_quality.
Each quality has its corresponding HPO presets: 'best_quality_hpo', 'high_quality_hpo', and 'medium_quality_hpo'.
tuning_data
A pd.DataFrame containing validation data, which should have the same columns as the train_data.
If `tuning_data = None`, `fit()` will automatically hold out some random validation data from `train_data`.
max_num_tuning_data
The maximum number of tuning samples (used for object detection).
id_mappings
Id-to-content mappings (used for semantic matching). The contents can be text, image, etc.
This is used when the pd.DataFrame contains the query/response identifiers instead of their contents.
time_limit
How long `fit()` should run for (wall clock time in seconds).
If not specified, `fit()` will run until the model has completed training.
save_path
Path to directory where models and artifacts should be saved.
hyperparameters
This is to override some default configurations.
For example, changing the text and image backbones can be done by formatting:
a string
hyperparameters = "model.hf_text.checkpoint_name=google/electra-small-discriminator model.timm_image.checkpoint_name=swin_small_patch4_window7_224"
or a list of strings
hyperparameters = ["model.hf_text.checkpoint_name=google/electra-small-discriminator", "model.timm_image.checkpoint_name=swin_small_patch4_window7_224"]
or a dictionary
hyperparameters = {
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
"model.timm_image.checkpoint_name": "swin_small_patch4_window7_224",
}
column_types
A dictionary that maps column names to their data types.
For example: `column_types = {"item_name": "text", "image": "image_path",
"product_description": "text", "height": "numerical"}`
may be used for a table with columns: "item_name", "brand", "product_description", and "height".
If None, column_types will be automatically inferred from the data.
The current supported types are:
- "image_path": each row in this column is one image path.
- "text": each row in this column contains text (sentence, paragraph, etc.).
- "numerical": each row in this column contains a number.
- "categorical": each row in this column belongs to one of K categories.
holdout_frac
Fraction of train_data to holdout as tuning_data for optimizing hyperparameters or
early stopping (ignored unless `tuning_data = None`).
Default value (if None) is selected based on the number of rows in the training data
and whether hyperparameter optimization is utilized.
teacher_predictor
The pre-trained teacher predictor or its saved path. If provided, `fit()` can distill its
knowledge to a student predictor, i.e., the current predictor.
seed
The random seed to be used for training (default 0).
standalone
Whether to save the entire model for offline deployment.
hyperparameter_tune_kwargs
Hyperparameter tuning strategy and kwargs (for example, how many HPO trials to run).
If None, then hyperparameter tuning will not be performed.
num_trials: int
How many HPO trials to run. Either `num_trials` or `time_limit` to `fit` needs to be specified.
scheduler: Union[str, ray.tune.schedulers.TrialScheduler]
If str is passed, AutoGluon will create the scheduler for you with some default parameters.
If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.
scheduler_init_args: Optional[dict] = None
If provided str to `scheduler`, you can optionally provide custom init_args to the scheduler
searcher: Union[str, ray.tune.search.SearchAlgorithm, ray.tune.search.Searcher]
If str is passed, AutoGluon will create the searcher for you with some default parameters.
If ray.tune.schedulers.TrialScheduler object is passed, you are responsible for initializing the object.
You don't need to worry about `metric` and `mode` of the searcher object. AutoGluon will figure it out by itself.
scheduler_init_args: Optional[dict] = None
If provided str to `searcher`, you can optionally provide custom init_args to the searcher
You don't need to worry about `metric` and `mode`. AutoGluon will figure it out by itself.
clean_ckpts
Whether to clean the intermediate checkpoints after training.
Returns
-------
An "MultiModalPredictor" object (itself).
"""
if teacher_predictor is None:
teacher_learner = None
elif isinstance(teacher_predictor, str):
teacher_learner = teacher_predictor
else:
teacher_learner = teacher_predictor._learner
self._learner.fit(
train_data=train_data,
presets=presets,
tuning_data=tuning_data,
max_num_tuning_data=max_num_tuning_data,
time_limit=time_limit,
save_path=save_path,
hyperparameters=hyperparameters,
column_types=column_types,
holdout_frac=holdout_frac,
teacher_learner=teacher_learner,
seed=seed,
standalone=standalone,
hyperparameter_tune_kwargs=hyperparameter_tune_kwargs,
clean_ckpts=clean_ckpts,
id_mappings=id_mappings,
)
return self
[docs]
def evaluate(
self,
data: Union[pd.DataFrame, dict, list, str],
query_data: Optional[list] = None,
response_data: Optional[list] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
metrics: Optional[Union[str, List[str]]] = None,
chunk_size: Optional[int] = 1024,
similarity_type: Optional[str] = "cosine",
cutoffs: Optional[List[int]] = [1, 5, 10],
label: Optional[str] = None,
return_pred: Optional[bool] = False,
realtime: Optional[bool] = False,
eval_tool: Optional[str] = None,
):
"""
Evaluate the model on a given dataset.
Parameters
----------
data
A pd.DataFrame, containing the same columns as the training data.
Or a str, that is a path of the annotation file for detection.
query_data
Query data used for ranking.
response_data
Response data used for ranking.
id_mappings
Id-to-content mappings. The contents can be text, image, etc.
This is used when data/query_data/response_data contain the query/response identifiers instead of their contents.
metrics
A list of metric names to report.
If None, we only return the score for the stored `_eval_metric_name`.
chunk_size
Scan the response data by chunk_size each time. Increasing the value increases the speed, but requires more memory.
similarity_type
Use what function (cosine/dot_prod) to score the similarity (default: cosine).
cutoffs
A list of cutoff values to evaluate ranking.
label
The label column name in data. Some tasks, e.g., image<-->text matching, have no label column in training data,
but the label column may be still required in evaluation.
return_pred
Whether to return the prediction result of each row.
realtime
Whether to do realtime inference, which is efficient for small data (default False).
If provided None, we would infer it on based on the data modalities
and sample number.
eval_tool
The eval_tool for object detection. Could be "pycocotools" or "torchmetrics".
Returns
-------
A dictionary with the metric names and their corresponding scores.
Optionally return a pd.DataFrame of prediction results.
"""
return self._learner.evaluate(
data=data,
metrics=metrics,
return_pred=return_pred,
realtime=realtime,
eval_tool=eval_tool,
query_data=query_data,
response_data=response_data,
id_mappings=id_mappings,
chunk_size=chunk_size,
similarity_type=similarity_type,
cutoffs=cutoffs,
label=label,
)
[docs]
def predict(
self,
data: Union[pd.DataFrame, dict, list, str],
candidate_data: Optional[Union[pd.DataFrame, dict, list]] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
as_pandas: Optional[bool] = None,
realtime: Optional[bool] = False,
save_results: Optional[bool] = None,
**kwargs,
):
"""
Predict the label column values for new data.
Parameters
----------
data
The data to make predictions for. Should contain same column names as training data and
follow same format (except for the `label` column).
candidate_data
The candidate data from which to search the query data's matches.
id_mappings
Id-to-content mappings. The contents can be text, image, etc.
This is used when data contain the query/response identifiers instead of their contents.
as_pandas
Whether to return the output as a pandas DataFrame(Series) (True) or numpy array (False).
realtime
Whether to do realtime inference, which is efficient for small data (default False).
If provided None, we would infer it on based on the data modalities
and sample number.
save_results
Whether to save the prediction results (only works for detection now)
**kwargs
Additional keyword arguments to pass to the underlying learner's predict method.
For example, `as_coco` for object detection tasks.
Returns
-------
Array of predictions, one corresponding to each row in given dataset.
Format depends on the specific learner and provided arguments.
"""
return self._learner.predict(
data=data,
candidate_data=candidate_data,
as_pandas=as_pandas,
realtime=realtime,
save_results=save_results,
id_mappings=id_mappings,
**kwargs,
)
[docs]
def predict_proba(
self,
data: Union[pd.DataFrame, dict, list],
candidate_data: Optional[Union[pd.DataFrame, dict, list]] = None,
id_mappings: Optional[Union[Dict[str, Dict], Dict[str, pd.Series]]] = None,
as_pandas: Optional[bool] = None,
as_multiclass: Optional[bool] = True,
realtime: Optional[bool] = False,
):
"""
Predict class probabilities rather than class labels.
Note that this is only for the classification tasks.
Calling it for a regression task will throw an exception.
Parameters
----------
data
The data to make predictions for. Should contain same column names as training data and
follow same format (except for the `label` column).
candidate_data
The candidate data from which to search the query data's matches.
id_mappings
Id-to-content mappings. The contents can be text, image, etc.
This is used when data contain the query/response identifiers instead of their contents.
as_pandas
Whether to return the output as a pandas DataFrame(Series) (True) or numpy array (False).
as_multiclass
Whether to return the probability of all labels or
just return the probability of the positive class for binary classification problems.
realtime
Whether to do realtime inference, which is efficient for small data (default False).
If provided None, we would infer it on based on the data modalities
and sample number.
Returns
-------
Array of predicted class-probabilities, corresponding to each row in the given data.
When as_multiclass is True, the output will always have shape (#samples, #classes).
Otherwise, the output will have shape (#samples,)
"""
return self._learner.predict_proba(
data=data,
candidate_data=candidate_data,
as_pandas=as_pandas,
as_multiclass=as_multiclass,
realtime=realtime,
id_mappings=id_mappings,
)
[docs]
def save(self, path: str, standalone: Optional[bool] = True):
"""
Save this predictor to file in directory specified by `path`.
Parameters
----------
path
The directory to save this predictor.
standalone
Whether to save the downloaded model for offline deployment.
When standalone = True, save the transformers.CLIPModel and transformers.AutoModel to os.path.join(path,model_name),
and reset the associate model.model_name.checkpoint_name start with `local://` in config.yaml.
When standalone = False, the saved artifact may require an online environment to process in load().
"""
self._learner.save(path=path, standalone=standalone)
[docs]
@classmethod
def load(
cls,
path: str,
resume: Optional[bool] = False,
verbosity: Optional[int] = 3,
):
"""
Load a predictor object from a directory specified by `path`. The to-be-loaded predictor
can be completely or partially trained by .fit(). If a previous training has completed,
it will load the checkpoint `model.ckpt`. Otherwise, if a previous training accidentally
collapses in the middle, it can load the `last.ckpt` checkpoint by setting `resume=True`.
It also supports loading one specific checkpoint given its path.
.. warning::
:meth:`autogluon.multimodal.MultiModalPredictor.load` uses `pickle` module implicitly, which is known to
be insecure. It is possible to construct malicious pickle data which will execute arbitrary code during
unpickling. Never load data that could have come from an untrusted source, or that could have been tampered
with. **Only load data you trust.**
Parameters
----------
path
The directory to load the predictor object.
resume
Whether to resume training from `last.ckpt`. This is useful when a training was accidentally
broken during the middle, and we want to resume the training from the last saved checkpoint.
verbosity
Verbosity levels range from 0 to 4 and control how much information is printed.
Higher levels correspond to more detailed print statements.
You can set verbosity = 0 to suppress warnings.
Returns
-------
The loaded predictor object.
"""
dir_path, ckpt_path = get_dir_ckpt_paths(path=path)
assert os.path.isdir(dir_path), f"'{dir_path}' must be an existing directory."
predictor = cls(label="dummy_label")
with open(os.path.join(dir_path, "assets.json"), "r") as fp:
assets = json.load(fp)
if "class_name" in assets and assets["class_name"] == "MultiModalMatcher":
learner_class = MultiModalMatcher
elif assets["problem_type"] == OBJECT_DETECTION:
learner_class = ObjectDetectionLearner
elif assets["problem_type"] == NER:
learner_class = NERLearner
elif assets["problem_type"] == FEW_SHOT_CLASSIFICATION:
learner_class = FewShotSVMLearner
elif assets["problem_type"] == SEMANTIC_SEGMENTATION:
learner_class = SemanticSegmentationLearner
else:
learner_class = BaseLearner
predictor._learner = learner_class.load(path=path, resume=resume, verbosity=verbosity)
return predictor
[docs]
def dump_model(self, save_path: Optional[str] = None):
"""
Save model weights and config to a local directory.
Model weights are saved in the file `pytorch_model.bin` (for `timm_image` or `hf_text`)
or '<ckpt_name>.pth' (for `mmdet_image`).
Configs are saved in the file `config.json` (for `timm_image` or `hf_text`)
or '<ckpt_name>.py' (for `mmdet_image`).
Parameters
----------
save_path : str
Path to directory where models and configs should be saved.
"""
return self._learner.dump_model(save_path=save_path)
[docs]
def export_onnx(
self,
data: Union[dict, pd.DataFrame],
path: Optional[str] = None,
batch_size: Optional[int] = None,
verbose: Optional[bool] = False,
opset_version: Optional[int] = 16,
truncate_long_and_double: Optional[bool] = False,
):
"""
Export this predictor's model to an ONNX file.
When `path` argument is not provided, the method would not save the model into disk.
Instead, it would export the onnx model into BytesIO and return its binary as bytes.
Parameters
----------
data
Raw data used to trace and export the model.
If this is None, will check if a processed batch is provided.
path : str, default=None
The export path of onnx model. If path is not provided, the method would export model to memory.
batch_size
The batch_size of export model's input.
Normally the batch_size is a dynamic axis, so we could use a small value for faster export.
verbose
verbose flag in torch.onnx.export.
opset_version
opset_version flag in torch.onnx.export.
truncate_long_and_double: bool, default False
Truncate weights provided in int64 or double (float64) to int32 and float32
Returns
-------
onnx_path : str or bytes
A string that indicates location of the exported onnx model, if `path` argument is provided.
Otherwise, would return the onnx model as bytes.
"""
# Make sure _model is initialized
self._learner.on_predict_start()
return self._learner.export_onnx(
data=data,
path=path,
batch_size=batch_size,
verbose=verbose,
opset_version=opset_version,
truncate_long_and_double=truncate_long_and_double,
)
[docs]
def optimize_for_inference(
self,
providers: Optional[Union[dict, List[str]]] = None,
):
"""
Optimize the predictor's model for inference.
Under the hood, the implementation would convert the PyTorch module into an ONNX module, so that
we can leverage efficient execution providers in onnxruntime for faster inference.
Parameters
----------
providers : dict or str, default=None
A list of execution providers for model prediction in onnxruntime.
By default, the providers argument is None. The method would generate an ONNX module that
would perform model inference with TensorrtExecutionProvider in onnxruntime, if tensorrt
package is properly installed. Otherwise, the onnxruntime would fallback to use CUDA or CPU
execution providers instead.
Returns
-------
onnx_module : OnnxModule
The onnx-based module that can be used to replace predictor._model for model inference.
"""
return self._learner.optimize_for_inference(providers=providers)
[docs]
def fit_summary(self, verbosity=0, show_plot=False):
"""
Output the training summary information from `fit()`.
Parameters
----------
verbosity : int, default = 2
Verbosity levels range from 0 to 4 and control how much information is printed.
verbosity = 0 for no output printing.
TODO: Higher levels correspond to more detailed print statements
show_plot : bool, default = False
If True, shows the model summary plot in browser when verbosity > 1.
Returns
-------
Dict containing various detailed information.
We do not recommend directly printing this dict as it may be very large.
"""
return self._learner.fit_summary(verbosity=verbosity, show_plot=show_plot)
[docs]
def list_supported_models(self, pretrained=True):
"""
List supported models for each problem type.
Parameters
----------
pretrained : bool, default = True
If True, only return the models with pretrained weights.
If False, return all the models as long as there is model definition.
Returns
-------
a list of model names
"""
return self._learner.list_supported_models(pretrained=pretrained)