Skip to content

Commit b914827

Browse files
authored
【开源实习】MarkupLM模型迁移 (#1404)
1 parent cfdb7ff commit b914827

11 files changed

+4458
-0
lines changed

mindnlp/transformers/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@
119119
lxmert,
120120
mamba,
121121
marian,
122+
markuplm,
122123
m2m_100,
123124
mask2former,
124125
mbart,
@@ -320,6 +321,7 @@
320321
from .lxmert import *
321322
from .mamba import *
322323
from .marian import *
324+
from .markuplm import *
323325
from .mask2former import *
324326
from .mbart import *
325327
from .mbart50 import *
@@ -520,6 +522,7 @@
520522
__all__.extend(lxmert.__all__)
521523
__all__.extend(mamba.__all__)
522524
__all__.extend(marian.__all__)
525+
__all__.extend(markuplm.__all__)
523526
__all__.extend(mask2former.__all__)
524527
__all__.extend(mbart.__all__)
525528
__all__.extend(mbart50.__all__)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2023 Huawei Technologies Co., Ltd
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
"""
16+
MarkupLM Model.
17+
"""
18+
from . import configuration_markuplm, modeling_markuplm, tokenization_markuplm, tokenization_markuplm_fast
19+
from . import feature_extraction_markuplm,processing_markuplm
20+
from .modeling_markuplm import *
21+
from .configuration_markuplm import *
22+
from .tokenization_markuplm import *
23+
from .tokenization_markuplm_fast import *
24+
from .feature_extraction_markuplm import *
25+
from .processing_markuplm import *
26+
27+
__all__ = []
28+
__all__.extend(modeling_markuplm.__all__)
29+
__all__.extend(configuration_markuplm.__all__)
30+
__all__.extend(tokenization_markuplm.__all__)
31+
__all__.extend(tokenization_markuplm_fast.__all__)
32+
__all__.extend(feature_extraction_markuplm.__all__)
33+
__all__.extend(processing_markuplm.__all__)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# coding=utf-8
2+
# Copyright 2021, The Microsoft Research Asia MarkupLM Team authors
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""MarkupLM model configuration"""
16+
17+
from mindnlp.utils import logging
18+
from ...configuration_utils import PretrainedConfig
19+
20+
21+
logger = logging.get_logger(__name__)
22+
23+
24+
class MarkupLMConfig(PretrainedConfig):
25+
r"""
26+
This is the configuration class to store the configuration of a [`MarkupLMModel`]. It is used to instantiate a
27+
MarkupLM model according to the specified arguments, defining the model architecture. Instantiating a configuration
28+
with the defaults will yield a similar configuration to that of the MarkupLM
29+
[microsoft/markuplm-base](https://huggingface.co/microsoft/markuplm-base) architecture.
30+
31+
Configuration objects inherit from [`BertConfig`] and can be used to control the model outputs. Read the
32+
documentation from [`BertConfig`] for more information.
33+
34+
Args:
35+
vocab_size (`int`, *optional*, defaults to 30522):
36+
Vocabulary size of the MarkupLM model. Defines the different tokens that can be represented by the
37+
*inputs_ids* passed to the forward method of [`MarkupLMModel`].
38+
hidden_size (`int`, *optional*, defaults to 768):
39+
Dimensionality of the encoder layers and the pooler layer.
40+
num_hidden_layers (`int`, *optional*, defaults to 12):
41+
Number of hidden layers in the Transformer encoder.
42+
num_attention_heads (`int`, *optional*, defaults to 12):
43+
Number of attention heads for each attention layer in the Transformer encoder.
44+
intermediate_size (`int`, *optional*, defaults to 3072):
45+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
46+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
47+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
48+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
49+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
50+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
52+
The dropout ratio for the attention probabilities.
53+
max_position_embeddings (`int`, *optional*, defaults to 512):
54+
The maximum sequence length that this model might ever be used with. Typically set this to something large
55+
just in case (e.g., 512 or 1024 or 2048).
56+
type_vocab_size (`int`, *optional*, defaults to 2):
57+
The vocabulary size of the `token_type_ids` passed into [`MarkupLMModel`].
58+
initializer_range (`float`, *optional*, defaults to 0.02):
59+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
60+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
61+
The epsilon used by the layer normalization layers.
62+
max_tree_id_unit_embeddings (`int`, *optional*, defaults to 1024):
63+
The maximum value that the tree id unit embedding might ever use. Typically set this to something large
64+
just in case (e.g., 1024).
65+
max_xpath_tag_unit_embeddings (`int`, *optional*, defaults to 256):
66+
The maximum value that the xpath tag unit embedding might ever use. Typically set this to something large
67+
just in case (e.g., 256).
68+
max_xpath_subs_unit_embeddings (`int`, *optional*, defaults to 1024):
69+
The maximum value that the xpath subscript unit embedding might ever use. Typically set this to something
70+
large just in case (e.g., 1024).
71+
tag_pad_id (`int`, *optional*, defaults to 216):
72+
The id of the padding token in the xpath tags.
73+
subs_pad_id (`int`, *optional*, defaults to 1001):
74+
The id of the padding token in the xpath subscripts.
75+
xpath_tag_unit_hidden_size (`int`, *optional*, defaults to 32):
76+
The hidden size of each tree id unit. One complete tree index will have
77+
(50*xpath_tag_unit_hidden_size)-dim.
78+
max_depth (`int`, *optional*, defaults to 50):
79+
The maximum depth in xpath.
80+
81+
Examples:
82+
83+
```python
84+
>>> from transformers import MarkupLMModel, MarkupLMConfig
85+
86+
>>> # Initializing a MarkupLM microsoft/markuplm-base style configuration
87+
>>> configuration = MarkupLMConfig()
88+
89+
>>> # Initializing a model from the microsoft/markuplm-base style configuration
90+
>>> model = MarkupLMModel(configuration)
91+
92+
>>> # Accessing the model configuration
93+
>>> configuration = model.config
94+
```"""
95+
96+
model_type = "markuplm"
97+
98+
def __init__(
99+
self,
100+
vocab_size=30522,
101+
hidden_size=768,
102+
num_hidden_layers=12,
103+
num_attention_heads=12,
104+
intermediate_size=3072,
105+
hidden_act="gelu",
106+
hidden_dropout_prob=0.1,
107+
attention_probs_dropout_prob=0.1,
108+
max_position_embeddings=512,
109+
type_vocab_size=2,
110+
initializer_range=0.02,
111+
layer_norm_eps=1e-12,
112+
pad_token_id=0,
113+
bos_token_id=0,
114+
eos_token_id=2,
115+
max_xpath_tag_unit_embeddings=256,
116+
max_xpath_subs_unit_embeddings=1024,
117+
tag_pad_id=216,
118+
subs_pad_id=1001,
119+
xpath_unit_hidden_size=32,
120+
max_depth=50,
121+
position_embedding_type="absolute",
122+
use_cache=True,
123+
classifier_dropout=None,
124+
**kwargs,
125+
):
126+
super().__init__(
127+
pad_token_id=pad_token_id,
128+
bos_token_id=bos_token_id,
129+
eos_token_id=eos_token_id,
130+
**kwargs,
131+
)
132+
self.vocab_size = vocab_size
133+
self.hidden_size = hidden_size
134+
self.num_hidden_layers = num_hidden_layers
135+
self.num_attention_heads = num_attention_heads
136+
self.hidden_act = hidden_act
137+
self.intermediate_size = intermediate_size
138+
self.hidden_dropout_prob = hidden_dropout_prob
139+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
140+
self.max_position_embeddings = max_position_embeddings
141+
self.type_vocab_size = type_vocab_size
142+
self.initializer_range = initializer_range
143+
self.layer_norm_eps = layer_norm_eps
144+
self.position_embedding_type = position_embedding_type
145+
self.use_cache = use_cache
146+
self.classifier_dropout = classifier_dropout
147+
# additional properties
148+
self.max_depth = max_depth
149+
self.max_xpath_tag_unit_embeddings = max_xpath_tag_unit_embeddings
150+
self.max_xpath_subs_unit_embeddings = max_xpath_subs_unit_embeddings
151+
self.tag_pad_id = tag_pad_id
152+
self.subs_pad_id = subs_pad_id
153+
self.xpath_unit_hidden_size = xpath_unit_hidden_size
154+
155+
__all__ = ['MarkupLMConfig']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
# coding=utf-8
2+
# Copyright 2022 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
Feature extractor class for MarkupLM.
17+
"""
18+
19+
import html
20+
21+
import bs4
22+
from bs4 import BeautifulSoup
23+
from mindnlp.utils import logging
24+
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
25+
# from mindnlp.utils import is_bs4_available, logging, requires_backends
26+
# if is_bs4_available():
27+
# import bs4
28+
# from bs4 import BeautifulSoup
29+
30+
logger = logging.get_logger(__name__)
31+
32+
33+
class MarkupLMFeatureExtractor(FeatureExtractionMixin):
34+
r"""
35+
Constructs a MarkupLM feature extractor. This can be used to get a list of nodes and corresponding xpaths from HTML
36+
strings.
37+
38+
This feature extractor inherits from [`~feature_extraction_utils.PreTrainedFeatureExtractor`] which contains most
39+
of the main methods. Users should refer to this superclass for more information regarding those methods.
40+
41+
"""
42+
43+
# def __init__(self, **kwargs):
44+
#requires_backends(self, ["bs4"])
45+
# super().__init__(**kwargs)
46+
def xpath_soup(self, element):
47+
xpath_tags = []
48+
xpath_subscripts = []
49+
child = element if element.name else element.parent
50+
for parent in child.parents: # type: bs4.element.Tag
51+
siblings = parent.find_all(child.name, recursive=False)
52+
xpath_tags.append(child.name)
53+
xpath_subscripts.append(
54+
0 if 1 == len(siblings) else next(i for i, s in enumerate(siblings, 1) if s is child)
55+
)
56+
child = parent
57+
xpath_tags.reverse()
58+
xpath_subscripts.reverse()
59+
return xpath_tags, xpath_subscripts
60+
61+
def get_three_from_single(self, html_string):
62+
html_code = BeautifulSoup(html_string, "html.parser")
63+
64+
all_doc_strings = []
65+
string2xtag_seq = []
66+
string2xsubs_seq = []
67+
68+
for element in html_code.descendants:
69+
if isinstance(element, bs4.element.NavigableString):
70+
if type(element.parent) != bs4.element.Tag:
71+
continue
72+
73+
text_in_this_tag = html.unescape(element).strip()
74+
if not text_in_this_tag:
75+
continue
76+
77+
all_doc_strings.append(text_in_this_tag)
78+
79+
xpath_tags, xpath_subscripts = self.xpath_soup(element)
80+
string2xtag_seq.append(xpath_tags)
81+
string2xsubs_seq.append(xpath_subscripts)
82+
83+
if len(all_doc_strings) != len(string2xtag_seq):
84+
raise ValueError("Number of doc strings and xtags does not correspond")
85+
if len(all_doc_strings) != len(string2xsubs_seq):
86+
raise ValueError("Number of doc strings and xsubs does not correspond")
87+
88+
return all_doc_strings, string2xtag_seq, string2xsubs_seq
89+
90+
def construct_xpath(self, xpath_tags, xpath_subscripts):
91+
xpath = ""
92+
for tagname, subs in zip(xpath_tags, xpath_subscripts):
93+
xpath += f"/{tagname}"
94+
if subs != 0:
95+
xpath += f"[{subs}]"
96+
return xpath
97+
98+
def __call__(self, html_strings) -> BatchFeature:
99+
"""
100+
Main method to prepare for the model one or several HTML strings.
101+
102+
Args:
103+
html_strings (`str`, `List[str]`):
104+
The HTML string or batch of HTML strings from which to extract nodes and corresponding xpaths.
105+
106+
Returns:
107+
[`BatchFeature`]: A [`BatchFeature`] with the following fields:
108+
109+
- **nodes** -- Nodes.
110+
- **xpaths** -- Corresponding xpaths.
111+
112+
Examples:
113+
114+
```python
115+
>>> from transformers import MarkupLMFeatureExtractor
116+
117+
>>> page_name_1 = "page1.html"
118+
>>> page_name_2 = "page2.html"
119+
>>> page_name_3 = "page3.html"
120+
121+
>>> with open(page_name_1) as f:
122+
... single_html_string = f.read()
123+
124+
>>> feature_extractor = MarkupLMFeatureExtractor()
125+
126+
>>> # single example
127+
>>> encoding = feature_extractor(single_html_string)
128+
>>> print(encoding.keys())
129+
>>> # dict_keys(['nodes', 'xpaths'])
130+
131+
>>> # batched example
132+
133+
>>> multi_html_strings = []
134+
135+
>>> with open(page_name_2) as f:
136+
... multi_html_strings.append(f.read())
137+
>>> with open(page_name_3) as f:
138+
... multi_html_strings.append(f.read())
139+
140+
>>> encoding = feature_extractor(multi_html_strings)
141+
>>> print(encoding.keys())
142+
>>> # dict_keys(['nodes', 'xpaths'])
143+
```"""
144+
145+
# Input type checking for clearer error
146+
valid_strings = False
147+
148+
# Check that strings has a valid type
149+
if isinstance(html_strings, str):
150+
valid_strings = True
151+
elif isinstance(html_strings, (list, tuple)):
152+
if len(html_strings) == 0 or isinstance(html_strings[0], str):
153+
valid_strings = True
154+
155+
if not valid_strings:
156+
raise ValueError(
157+
"HTML strings must of type `str`, `List[str]` (batch of examples), "
158+
f"but is of type {type(html_strings)}."
159+
)
160+
161+
is_batched = bool(isinstance(html_strings, (list, tuple)) and (isinstance(html_strings[0], str)))
162+
163+
if not is_batched:
164+
html_strings = [html_strings]
165+
166+
# Get nodes + xpaths
167+
nodes = []
168+
xpaths = []
169+
for html_string in html_strings:
170+
all_doc_strings, string2xtag_seq, string2xsubs_seq = self.get_three_from_single(html_string)
171+
nodes.append(all_doc_strings)
172+
xpath_strings = []
173+
for node, tag_list, sub_list in zip(all_doc_strings, string2xtag_seq, string2xsubs_seq):
174+
xpath_string = self.construct_xpath(tag_list, sub_list)
175+
xpath_strings.append(xpath_string)
176+
xpaths.append(xpath_strings)
177+
178+
# return as Dict
179+
data = {"nodes": nodes, "xpaths": xpaths}
180+
encoded_inputs = BatchFeature(data=data, tensor_type=None)
181+
182+
return encoded_inputs
183+
184+
__all__ = ["MarkupLMFeatureExtractor"]

0 commit comments

Comments
 (0)