AutoMM for Chinese Named Entity Recognition¶
In this tutorial, we will demonstrate how to use AutoMM for Chinese Named Entity Recognition using an e-commerce dataset extracted from one of the most popular online marketplaces, TaoBao.com. The dataset is collected and labelled by Jie et al. and the text column mainly consists of product descriptions. The following figure shows an example of Taobao product description.

Fig. 1 Taobao product description. A rabbit toy for lunar new year decoration.¶
Load the Data¶
We have preprocessed the dataset to make it ready-to-use with AutoMM.
import autogluon.multimodal
from autogluon.core.utils.loaders import load_pd
from autogluon.multimodal.utils import visualize_ner
train_data = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/ner/taobao-ner/chinese_ner_train.csv')
dev_data = load_pd.load('https://automl-mm-bench.s3.amazonaws.com/ner/taobao-ner/chinese_ner_dev.csv')
train_data.head(5)
text_snippet | entity_annotations | |
---|---|---|
0 | 雄争霸点卡/七雄争霸元宝/七雄争霸100元1000元宝直充,自动充值 | [{"entity_group": "HCCX", "start": 3, "end": 5... |
1 | 简约韩版粗跟艾熙百思图亲子鞋冬季百搭街头母女圆头翻边绒面厚底 | [{"entity_group": "HPPX", "start": 6, "end": 8... |
2 | 羚跑商务背包双肩包男士防盗多功能出差韩版休闲15.6寸电脑包皮潮 | [{"entity_group": "HPPX", "start": 0, "end": 2... |
3 | 热水袋防爆充电暖宝卡通毛绒萌萌可爱注水暖宫暖手宝暖水袋 | [{"entity_group": "HCCX", "start": 0, "end": 3... |
4 | 童装11周岁13儿童夏装男童套装2017新款10中大童15男孩12秋季5潮7 | [{"entity_group": "HCCX", "start": 0, "end": 2... |
HPPX, HCCX, XH, and MISC stand for brand, product, pattern, and Miscellaneous information (e.g., product Specification), respectively. Let’s visualize one of the examples, which is about online games top up services.
visualize_ner(train_data["text_snippet"].iloc[0], train_data["entity_annotations"].iloc[0])
Training¶
With AutoMM, the process of Chinese entity recognition is the same as
English entity recognition. All you need to do is to select a suitable
foundation model checkpoint that are pretrained on Chinese or
multilingual documents. Here we use the 'hfl/chinese-lert-small'
backbone for demonstration purpose.
Now, let’s create a predictor for named entity recognition by setting the problem_type to ner and specifying the label column. Afterwards, we call predictor.fit() to train the model for a few minutes.
from autogluon.multimodal import MultiModalPredictor
import uuid
label_col = "entity_annotations"
model_path = f"./tmp/{uuid.uuid4().hex}-automm_ner" # You can rename it to the model path you like
predictor = MultiModalPredictor(problem_type="ner", label=label_col, path=model_path)
predictor.fit(
train_data=train_data,
hyperparameters={'model.ner_text.checkpoint_name':'hfl/chinese-lert-small'},
time_limit=300, #second
)
Global seed set to 123 Auto select gpus: [0] Using 16bit native Automatic Mixed Precision (AMP) GPU available: True (cuda), used: True TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: False, using: 0 HPUs LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] | Name | Type | Params -------------------------------------------------------- 0 | model | HFAutoModelForNER | 15.1 M 1 | validation_metric | F1Score | 0 2 | loss_func | CrossEntropyLoss | 0 -------------------------------------------------------- 15.1 M Trainable params 0 Non-trainable params 15.1 M Total params 30.173 Total estimated model params size (MB) Epoch 0, global step 21: 'val_ner_token_f1' reached 0.25100 (best 0.25100), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=0-step=21.ckpt' as top 3 /home/ci/opt/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py:33: LightningDeprecationWarning: pytorch_lightning.utilities.cloud_io.get_filesystem has been deprecated in v1.8.0 and will be removed in v1.10.0. Please use lightning_lite.utilities.cloud_io.get_filesystem instead. rank_zero_deprecation( /home/ci/opt/venv/lib/python3.8/site-packages/pytorch_lightning/utilities/cloud_io.py:25: LightningDeprecationWarning: pytorch_lightning.utilities.cloud_io.atomic_save has been deprecated in v1.8.0 and will be removed in v1.10.0. This function is internal but you can copy over its implementation. rank_zero_deprecation( Epoch 0, global step 42: 'val_ner_token_f1' reached 0.65349 (best 0.65349), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=0-step=42.ckpt' as top 3 Epoch 1, global step 64: 'val_ner_token_f1' reached 0.73087 (best 0.73087), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=1-step=64.ckpt' as top 3 Epoch 1, global step 85: 'val_ner_token_f1' reached 0.75412 (best 0.75412), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=1-step=85.ckpt' as top 3 Epoch 2, global step 107: 'val_ner_token_f1' reached 0.78199 (best 0.78199), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=2-step=107.ckpt' as top 3 Epoch 2, global step 128: 'val_ner_token_f1' reached 0.79217 (best 0.79217), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=2-step=128.ckpt' as top 3 Epoch 3, global step 150: 'val_ner_token_f1' reached 0.79538 (best 0.79538), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=3-step=150.ckpt' as top 3 Epoch 3, global step 171: 'val_ner_token_f1' reached 0.81451 (best 0.81451), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=3-step=171.ckpt' as top 3 Epoch 4, global step 193: 'val_ner_token_f1' reached 0.82358 (best 0.82358), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=4-step=193.ckpt' as top 3 Time limit reached. Elapsed time is 0:05:02. Signaling Trainer to stop. Epoch 4, global step 193: 'val_ner_token_f1' reached 0.82358 (best 0.82358), saving model to '/home/ci/autogluon/docs/_build/eval/tutorials/multimodal/text_prediction/tmp/0626dec24710486f92dd9359c27fa896-automm_ner/epoch=4-step=193-v1.ckpt' as top 3 Global seed set to 123
Downloading builder script: 0%| | 0.00/6.34k [00:00<?, ?B/s]
Global seed set to 123
Global seed set to 123
<autogluon.multimodal.predictor.MultiModalPredictor at 0x7f99050de970>
Evaluation¶
To check the model performance on the test dataset, all you need to do
is to call predictor.evaluate(...)
.
predictor.evaluate(dev_data)
Global seed set to 123
{'hccx': {'precision': 0.7382003395585739,
'recall': 0.8522148177185418,
'f1': 0.7911208151382825,
'number': 2551},
'hppx': {'precision': 0.5314685314685315,
'recall': 0.5467625899280576,
'f1': 0.5390070921985817,
'number': 278},
'misc': {'precision': 0.5620915032679739,
'recall': 0.6825396825396826,
'f1': 0.6164874551971327,
'number': 504},
'xh': {'precision': 0.5677655677655677,
'recall': 0.6512605042016807,
'f1': 0.6066536203522505,
'number': 238},
'overall_precision': 0.6863459669582118,
'overall_recall': 0.791094931391767,
'overall_f1': 0.7350071549369064,
'overall_accuracy': 0.8567792529407718}
Prediction and Visualization¶
You can easily obtain the predictions given an input sentence by by
calling predictor.predict(...)
.
output = predictor.predict(dev_data)
visualize_ner(dev_data["text_snippet"].iloc[0], output[0])
Global seed set to 123
Now, let’s make predictions on the rabbit toy example.
sentence = "2023年兔年挂件新年装饰品小挂饰乔迁之喜门挂小兔子"
predictions = predictor.predict({'text_snippet': [sentence]})
visualize_ner(sentence, predictions[0])
Global seed set to 123
Other Examples¶
You may go to AutoMM Examples to explore other examples about AutoMM.
Customization¶
To learn how to customize AutoMM, please refer to Customize AutoMM.