diff --git a/configs/cls/mobilenetv3/cls_mv3.yaml b/configs/cls/mobilenetv3/cls_mv3.yaml index 2d5d036f0..4ca3b2eb2 100644 --- a/configs/cls/mobilenetv3/cls_mv3.yaml +++ b/configs/cls/mobilenetv3/cls_mv3.yaml @@ -142,3 +142,47 @@ eval: drop_remainder: False max_rowsize: 12 num_workers: 8 + +predict: + backend: MindSpore + deive_target: Ascend + device_id: 1 + max_device_memory: 8GB + amp_level: O0 + mode: 0 + ckpt_load_path: /root/.mindspore/models/cls_mobilenetv3-92db9c58.ckpt + dataset_sink_mode: False + dataset: + type: RecDataset + dataset_root: dir/to/dataset + data_dir: all_images + label_file: val_cls_gt.txt + sample_ratio: 1.0 + shuffle: False + transform_pipeline: + - DecodeImage: + img_mode: BGR + to_float32: False + - Rotate90IfVertical: + threshold: 2.0 + direction: counterclockwise + - RecResizeImg: + image_shape: [48, 192] # H, W + padding: False # aspect ratio will be preserved if true. + - NormalizeImage: + bgr_to_rgb: True + is_hwc: True + mean : [127.0, 127.0, 127.0] + std : [127.0, 127.0, 127.0] + - ToCHWImage: + # the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize + output_columns: ['image', 'label'] # TODO return text string padding w/ fixed length, and a scaler to indicate the length + net_input_column_index: [0] # input indices for network forward func in output_columns + label_column_index: [1] # input indices marked as label + + loader: + shuffle: False + batch_size: 8 + drop_remainder: False + max_rowsize: 12 + num_workers: 8 \ No newline at end of file diff --git a/configs/det/dbnet/db_r50_icdar15.yaml b/configs/det/dbnet/db_r50_icdar15.yaml index f6eb56b3b..adf9cbcb3 100644 --- a/configs/det/dbnet/db_r50_icdar15.yaml +++ b/configs/det/dbnet/db_r50_icdar15.yaml @@ -157,7 +157,13 @@ eval: num_workers: 2 predict: - ckpt_load_path: tmp_det/best.ckpt + backend: MindSpore + deive_target: Ascend + device_id: 0 + max_device_memory: 8GB + amp_level: O0 + mode: 0 + ckpt_load_path: /root/.mindspore/models/dbnet_resnet50-c3a4aa24.ckpt output_save_dir: ./output dataset_sink_mode: False dataset: diff --git a/configs/layout/yolov8/yolov8n.yaml b/configs/layout/yolov8/yolov8n.yaml index c3e1696e0..b2279939b 100644 --- a/configs/layout/yolov8/yolov8n.yaml +++ b/configs/layout/yolov8/yolov8n.yaml @@ -151,3 +151,37 @@ eval: drop_remainder: False max_rowsize: 12 num_workers: 8 + +predict: + backend: MindSpore + deive_target: Ascend + device_id: 3 + max_device_memory: 8GB + amp_level: O0 + mode: 0 + ckpt_load_path: /root/.mindspore/models/yolov8n-4b9e8004.ckpt + dataset_sink_mode: False + dataset: + type: PublayNetDataset + dataset_path: publaynet/val.txt + annotations_path: *annotations_path + img_size: 800 + transform_pipeline: + - func_name: letterbox + scaleup: False + - func_name: image_norm + scale: 255. + - func_name: image_transpose + bgr2rgb: True + hwc2chw: True + batch_size: *refine_batch_size + stride: 64 + output_columns: ['image', 'labels', 'image_ids', 'hw_ori', 'hw_scale', 'pad'] + net_input_column_index: [ 0 ] # input indices for network forward func in output_columns + meta_data_column_index: [ 2, 3, 4, 5 ] # input indices marked as label + loader: + shuffle: False + batch_size: *refine_batch_size + drop_remainder: False + max_rowsize: 12 + num_workers: 8 diff --git a/configs/rec/crnn/crnn_resnet34.yaml b/configs/rec/crnn/crnn_resnet34.yaml index 893e481f5..df8ec0b4e 100644 --- a/configs/rec/crnn/crnn_resnet34.yaml +++ b/configs/rec/crnn/crnn_resnet34.yaml @@ -150,7 +150,13 @@ eval: num_workers: 8 predict: - ckpt_load_path: ./tmp_rec/best.ckpt + backend: MindSpore + deive_target: Ascend + device_id: 2 + max_device_memory: 8GB + amp_level: O3 + mode: 0 + ckpt_load_path: /root/.mindspore/models/crnn_resnet34-83f37f07.ckpt vis_font_path: tools/utils/simfang.ttf dataset_sink_mode: False dataset: diff --git a/deploy/py_infer/example/dataset/layout/example1.png b/deploy/py_infer/example/dataset/layout/example1.png new file mode 100644 index 000000000..1678f5b87 Binary files /dev/null and b/deploy/py_infer/example/dataset/layout/example1.png differ diff --git a/deploy/py_infer/example/dataset/layout/example2.png b/deploy/py_infer/example/dataset/layout/example2.png new file mode 100644 index 000000000..eaba31a12 Binary files /dev/null and b/deploy/py_infer/example/dataset/layout/example2.png differ diff --git a/deploy/py_infer/example/dataset/layout/example3.png b/deploy/py_infer/example/dataset/layout/example3.png new file mode 100644 index 000000000..6b44c7b66 Binary files /dev/null and b/deploy/py_infer/example/dataset/layout/example3.png differ diff --git a/deploy/py_infer/src/core/model/model.py b/deploy/py_infer/src/core/model/model.py index 9a809533f..fa23d8da1 100644 --- a/deploy/py_infer/src/core/model/model.py +++ b/deploy/py_infer/src/core/model/model.py @@ -106,8 +106,8 @@ def warmup(self): height, width = hw_list[0] warmup_shape = [(*other_shape, height, width)] # Only single input - dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)] - self.model.infer(dummy_tensor) + # dummy_tensor = [np.random.randn(*shape).astype(dtype) for shape, dtype in zip(warmup_shape, self.input_dtype)] + # self.model.infer(dummy_tensor) def __del__(self): if hasattr(self, "model") and self.model: diff --git a/deploy/py_infer/src/data_process/postprocess/builder.py b/deploy/py_infer/src/data_process/postprocess/builder.py index 092f415af..ddb7892a0 100644 --- a/deploy/py_infer/src/data_process/postprocess/builder.py +++ b/deploy/py_infer/src/data_process/postprocess/builder.py @@ -44,6 +44,7 @@ def get_device_status(): def _get_status(): nonlocal status try: + ms.set_context(max_device_memory="0.01GB") status = ms.Tensor([0])[0:].asnumpy()[0] except RuntimeError: status = 1 diff --git a/deploy/py_infer/src/infer_args.py b/deploy/py_infer/src/infer_args.py index fc7285939..fbc55db16 100644 --- a/deploy/py_infer/src/infer_args.py +++ b/deploy/py_infer/src/infer_args.py @@ -119,6 +119,9 @@ def get_args(): "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." ) parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) args = parser.parse_args() setup_logger(args) diff --git a/deploy/py_infer/src/parallel/module/detection/det_post_node.py b/deploy/py_infer/src/parallel/module/detection/det_post_node.py index fba6a5abc..18d4ce5e0 100644 --- a/deploy/py_infer/src/parallel/module/detection/det_post_node.py +++ b/deploy/py_infer/src/parallel/module/detection/det_post_node.py @@ -1,3 +1,4 @@ +import cv2 import numpy as np from ....data_process.utils import cv_utils @@ -10,12 +11,35 @@ def __init__(self, args, msg_queue): super(DetPostNode, self).__init__(args, msg_queue) self.text_detector = None self.task_type = self.args.task_type + self.is_concat = self.args.is_concat def init_self_args(self): self.text_detector = TextDetector(self.args) self.text_detector.init(preprocess=False, model=False, postprocess=True) super().init_self_args() + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops_concated = np.concatenate(resized_crops, axis=1) + return crops_concated + def process(self, input_data): if input_data.skip: self.send_to_next_module(input_data) @@ -23,6 +47,8 @@ def process(self, input_data): data = input_data.data boxes = self.text_detector.postprocess(data["pred"], data["shape_list"]) + if self.is_concat: + boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0])) infer_res_list = [] for box in boxes: @@ -39,6 +65,8 @@ def process(self, input_data): for box in infer_res_list: sub_image = cv_utils.crop_box_from_image(image, np.array(box)) sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)] input_data.sub_image_list = sub_image_list input_data.data = None diff --git a/mindocr/data/transforms/layout_transform.py b/mindocr/data/transforms/layout_transform.py new file mode 100644 index 000000000..1cfeb7ab3 --- /dev/null +++ b/mindocr/data/transforms/layout_transform.py @@ -0,0 +1,90 @@ +import cv2 +import numpy as np + +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../../"))) + +from mindocr.data.layout_dataset import xyxy2xywh + +def letterbox(scaleup): + def func(data): + image = data["image"] + hw_ori = data["raw_img_shape"] + new_shape = data["target_size"] + color = (114, 114, 114) + # Resize and pad image while meeting stride-multiple constraints + shape = image.shape[:2] # current shape [height, width] + h, w = shape[:] + # h0, w0 = hw_ori + h0, w0 = new_shape + # hw_scale = np.array([h / h0, w / w0]) + hw_scale = np.array([h0 / h, w0 / w]) + if isinstance(new_shape, int): + new_shape = (new_shape, new_shape) + + # Scale ratio (new / old) + r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) + if not scaleup: # only scale down, do not scale up (for better test mAP) + r = min(r, 1.0) + + # Compute padding + new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) + dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding + + dw, dh = dw / 2, dh / 2 # divide padding into 2 sides + hw_pad = np.array([dh, dw]) + + if shape[::-1] != new_unpad: # resize + image = cv2.resize(image, new_unpad, interpolation=cv2.INTER_LINEAR) + top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) + left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) + image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border + + data["image"] = image + data["image_ids"] = 0 + data["hw_ori"] = hw_ori + data["hw_scale"] = hw_scale + data["pad"] = hw_pad + return data + + return func + + +def image_norm(scale=255.0): + def func(data): + image = data["image"] + image = image.astype(np.float32, copy=False) + image /= scale + data["image"] = image + return data + + return func + + +def image_transpose(bgr2rgb=True, hwc2chw=True): + def func(data): + image = data["image"] + if bgr2rgb: + image = image[:, :, ::-1] + if hwc2chw: + image = image.transpose(2, 0, 1) + data["image"] = image + return data + + return func + +def label_norm(labels, xyxy2xywh_=True): + def func(data): + if len(labels) == 0: + return data, labels + + if xyxy2xywh_: + labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) # convert xyxy to xywh + + labels[:, [2, 4]] /= data.shape[0] # normalized height 0-1 + labels[:, [1, 3]] /= data.shape[1] # normalized width 0-1 + + return data, labels + return func \ No newline at end of file diff --git a/mindocr/data/transforms/transforms_factory.py b/mindocr/data/transforms/transforms_factory.py index e040267be..e1f763d56 100644 --- a/mindocr/data/transforms/transforms_factory.py +++ b/mindocr/data/transforms/transforms_factory.py @@ -15,6 +15,7 @@ from .rec_transforms import * from .svtr_transform import * from .table_transform import * +from .layout_transform import * __all__ = ["create_transforms", "run_transforms", "transforms_dbnet_icdar15"] _logger = logging.getLogger(__name__) diff --git a/mindocr/infer/classification/__init__.py b/mindocr/infer/classification/__init__.py new file mode 100644 index 000000000..cb61c6ab6 --- /dev/null +++ b/mindocr/infer/classification/__init__.py @@ -0,0 +1,3 @@ +from .cls_infer_node import ClsInferNode +from .cls_post_node import ClsPostNode +from .cls_pre_node import ClsPreNode diff --git a/mindocr/infer/classification/classification.py b/mindocr/infer/classification/classification.py new file mode 100644 index 000000000..8ed1a1eb3 --- /dev/null +++ b/mindocr/infer/classification/classification.py @@ -0,0 +1,62 @@ +import logging +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict +from typing import List + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.data.transforms import create_transforms, run_transforms +from mindocr.postprocess import build_postprocess +from mindocr.infer.utils.model import MSModel, LiteModel + + +algo_to_model_name = { + "MV3": "cls_mobilenet_v3_small_100_model", +} +logger = logging.getLogger("mindocr") + +class ClsPreprocess(object): + def __init__(self, args): + self.args = args + with open(args.cls_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) + + def __call__(self, img): + data = {"image": img} + data = run_transforms(data, self.transforms[1:]) + return data + + +class ClsModelMS(MSModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.cls_algorithm] + self.config_path = args.cls_config_path + self._init_model(self.model_name, self.config_path) + + +class ClsModelLite(LiteModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.cls_algorithm] + self.config_path = args.cls_config_path + self._init_model(self.model_name, self.config_path) + +INFER_CLS_MAP = {"MindSporeLite": ClsModelLite, "MindSpore": ClsModelMS} + +class ClsPostprocess(object): + def __init__(self, args): + self.args = args + with open(args.cls_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.postprocessor = build_postprocess(self.yaml_cfg.postprocess) + + def __call__(self, pred): + return self.postprocessor(pred) \ No newline at end of file diff --git a/mindocr/infer/classification/cls_infer_node.py b/mindocr/infer/classification/cls_infer_node.py new file mode 100644 index 000000000..13d1f8c88 --- /dev/null +++ b/mindocr/infer/classification/cls_infer_node.py @@ -0,0 +1,58 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .classification import INFER_CLS_MAP + + +class ClsInferNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(ClsInferNode, self).__init__(args, msg_queue, tqdm_info) + self.args = args + self.cls_model = None + self.task_type = self.args.task_type + + def init_self_args(self): + with open(self.args.cls_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.batch_size = self.yaml_cfg.predict.loader.batch_size + ClsModel = INFER_CLS_MAP[self.yaml_cfg.predict.backend] + self.cls_model = ClsModel(self.args) + super().init_self_args() + + def process(self, input_data): + """ + Input: + - input_data.data: [np.ndarray], shape:[3,w,h], e.g. [3,48,192] + Output: + - input_data.data: [np.ndarray], shape:[?,2] + """ + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data["cls_pre_res"] + data = [np.expand_dims(d, 0) for d in data if len(d.shape) == 3] + data = np.concatenate(data, axis=0) + + preds = [] + for batch_i in range(data.shape[0] // self.batch_size + 1): + start_i = batch_i * self.batch_size + end_i = (batch_i + 1) * self.batch_size + d = data[start_i:end_i] + if d.shape[0] == 0: + continue + pred = self.cls_model([d]) + preds.append(pred[0]) + preds = np.concatenate(preds, axis=0) + # input_data.data = {"pred": preds} + input_data.data["cls_infer_res"] = {"pred": preds} + self.send_to_next_module(input_data) diff --git a/mindocr/infer/classification/cls_post_node.py b/mindocr/infer/classification/cls_post_node.py new file mode 100644 index 000000000..8c95fccfc --- /dev/null +++ b/mindocr/infer/classification/cls_post_node.py @@ -0,0 +1,62 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +import cv2 +import numpy as np + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .classification import ClsPostprocess +from tools.infer.text.utils import crop_text_region +from pipeline.data_process.utils.cv_utils import crop_box_from_image + + +class ClsPostNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(ClsPostNode, self).__init__(args, msg_queue, tqdm_info) + self.cls_postprocess = ClsPostprocess(args) + self.task_type = self.args.task_type + self.cls_thresh = 0.9 + + def init_self_args(self): + super().init_self_args() + + def process(self, input_data): + """ + Input: + - input_data.data: [np.ndarray], shape:[?,2] + Output: + - input_data.sub_image_list: [np.ndarray], shape:[1,3,-1,-1], e.g. [1,3,48,192] + - input_data.data = None + or + - input_data.infer_result = [(str, float)] + """ + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data["cls_infer_res"] + pred = data["pred"] + output = self.cls_postprocess(pred) + angles = output["angles"] + scores = np.array(output["scores"]).tolist() + + batch = input_data.sub_image_size + if self.task_type.value in (TaskType.DET_CLS_REC.value, TaskType.LAYOUT_DET_CLS_REC.value): + sub_images = input_data.sub_image_list + for i in range(batch): + angle, score = angles[i], scores[i] + if "180" == angle and score > self.cls_thresh: + sub_images[i] = cv2.rotate(sub_images[i], cv2.ROTATE_180) + input_data.sub_image_list = sub_images + else: + input_data.infer_result = [(angle, score) for angle, score in zip(angles, scores)] + + self.send_to_next_module(input_data) diff --git a/mindocr/infer/classification/cls_pre_node.py b/mindocr/infer/classification/cls_pre_node.py new file mode 100644 index 000000000..f00b8059b --- /dev/null +++ b/mindocr/infer/classification/cls_pre_node.py @@ -0,0 +1,40 @@ +import argparse +import os +import time +import sys +import numpy as np + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .classification import ClsPreprocess + + +class ClsPreNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(ClsPreNode, self).__init__(args, msg_queue, tqdm_info) + self.cls_preprocesser = ClsPreprocess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + return {"batch_size": 1} + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + if self.task_type.value == TaskType.REC.value: + image = input_data.frame[0] + data = [self.cls_preprocesser(image)["image"]] + input_data.sub_image_size = 1 + input_data.data = data + self.send_to_next_module(input_data) + else: + sub_image_list = input_data.sub_image_list + data = [self.cls_preprocesser(split_image)["image"] for split_image in sub_image_list] + input_data.data["cls_pre_res"] = data + self.send_to_next_module(input_data) diff --git a/mindocr/infer/common/__init__.py b/mindocr/infer/common/__init__.py new file mode 100644 index 000000000..98af74e6b --- /dev/null +++ b/mindocr/infer/common/__init__.py @@ -0,0 +1,3 @@ +from .collect_node2 import CollectNode +from .decode_node import DecodeNode +from .handout_node import HandoutNode diff --git a/mindocr/infer/common/collect_node.py b/mindocr/infer/common/collect_node.py new file mode 100644 index 000000000..7f536e18e --- /dev/null +++ b/mindocr/infer/common/collect_node.py @@ -0,0 +1,176 @@ +import os +from collections import defaultdict +from ctypes import c_uint64 +from multiprocessing import Manager + +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.data_process.utils import cv_utils +from pipeline.tasks import TaskType +from pipeline.utils import log, safe_list_writer, visual_utils +from pipeline.datatype import ProcessData, ProfilingData, StopData +from pipeline.framework.module_base import ModuleBase + +RESULTS_SAVE_FILENAME = { + TaskType.DET: "det_results.txt", + TaskType.CLS: "cls_results.txt", + TaskType.REC: "rec_results.txt", + TaskType.DET_REC: "pipeline_results.txt", + TaskType.DET_CLS_REC: "pipeline_results.txt", + TaskType.LAYOUT: "layout_results.txt", + TaskType.LAYOUT_DET_REC: "pipeline_results.txt", + +} + + +class CollectNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super().__init__(args, msg_queue, tqdm_info) + self.image_sub_remaining = defaultdict(defaultdict) + self.image_pipeline_res = defaultdict(defaultdict) + self.infer_size = defaultdict(int) + self.image_total = Manager().Value(c_uint64, 0) + self.task_type = args.task_type + self.res_save_dir = args.res_save_dir + self.save_filename = RESULTS_SAVE_FILENAME[TaskType(self.task_type.value)] + + def init_self_args(self): + super().init_self_args() + + def _collect_stop(self, input_data): + self.image_total.value = input_data.image_total + + def _vis_results(self, image_name, image, taskid, data_type): + if self.args.crop_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)): + basename = os.path.basename(image_name) + filename = os.path.join(self.args.crop_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + crop_list = visual_utils.vis_crop(image, box_list) + for i, crop in enumerate(crop_list): + cv_utils.img_write(filename + "_crop_" + str(i) + ".jpg", crop) + + if self.args.vis_pipeline_save_dir: + basename = os.path.basename(image_name) + filename = os.path.join(self.args.vis_pipeline_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + text_list = [x["transcription"] for x in self.image_pipeline_res[taskid][image_name]] + box_text = visual_utils.vis_bbox_text(image, box_list, text_list, font_path=self.args.vis_font_path) + cv_utils.img_write(filename + ".jpg", box_text) + + if self.args.vis_det_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)): + basename = os.path.basename(image_name) + filename = os.path.join(self.args.vis_det_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + box_line = visual_utils.vis_bbox(image, box_list, [255, 255, 0], 2) + cv_utils.img_write(filename + ".jpg", box_line) + + # log.info(f"{image_name} is finished.") + + def final_text_save(self): + rst_dict = dict() + for rst in self.image_pipeline_res.values(): + rst_dict.update(rst) + save_filename = os.path.join(self.res_save_dir, self.save_filename) + safe_list_writer(rst_dict, save_filename) + # log.info(f"save infer result to {save_filename} successfully") + + def _collect_results(self, input_data: ProcessData): + taskid = input_data.taskid + if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): + image_path = input_data.image_path[0] # bs=1 + for result in input_data.infer_result: + if result[-1] > 0.5: + if self.args.result_contain_score: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": result[-2], "points": result[:-2], "score": str(result[-1])} + ) + else: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": result[-2], "points": result[:-2]} + ) + if not input_data.infer_result: + self.image_pipeline_res[taskid][image_path] = [] + elif self.task_type.value == TaskType.DET.value: + image_path = input_data.image_path[0] # bs=1 + self.image_pipeline_res[taskid][image_path] = input_data.infer_result + elif self.task_type.value in (TaskType.REC.value, TaskType.CLS.value): + for image_path, infer_result in zip(input_data.image_path, input_data.infer_result): + self.image_pipeline_res[taskid][image_path] = infer_result + elif self.task_type.value == TaskType.LAYOUT.value: + for infer_result in input_data.infer_result: + image_path = infer_result.pop("image_id") + if image_path in self.image_pipeline_res[taskid]: + self.image_pipeline_res[taskid][image_path].append(infer_result) + else: + self.image_pipeline_res[taskid][image_path] = [infer_result] + else: + raise NotImplementedError("Task type do not support.") + + self._update_remaining(input_data) + + def _update_remaining(self, input_data: ProcessData): + taskid = input_data.taskid + data_type = input_data.data_type + if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): # with sub image + for idx, image_path in enumerate(input_data.image_path): + if image_path in self.image_sub_remaining[taskid]: + self.image_sub_remaining[taskid][image_path] -= input_data.sub_image_size + if not self.image_sub_remaining[taskid][image_path]: + self.image_sub_remaining[taskid].pop(image_path) + self.infer_size[taskid] += 1 + self._vis_results( + image_path, input_data.frame[idx], taskid, data_type + ) if input_data.frame else ... + else: + remaining = input_data.sub_image_total - input_data.sub_image_size + if remaining: + self.image_sub_remaining[taskid][image_path] = remaining + else: + self.infer_size[taskid] += 1 + self._vis_results( + image_path, input_data.frame[idx], taskid, data_type + ) if input_data.frame else ... + else: # without sub image + for idx, image_path in enumerate(input_data.image_path): + self.infer_size[taskid] += 1 + self._vis_results(image_path, input_data.frame[idx], taskid, data_type) if input_data.frame else ... + + def process(self, input_data): + if isinstance(input_data, ProcessData): + taskid = input_data.taskid + if input_data.taskid not in self.image_sub_remaining.keys(): + self.image_sub_remaining[input_data.taskid] = defaultdict(int) + if input_data.taskid not in self.image_pipeline_res.keys(): + self.image_pipeline_res[input_data.taskid] = defaultdict(list) + self._collect_results(input_data) + if self.infer_size[taskid] == input_data.task_images_num: + self.send_to_next_module({taskid: self.image_pipeline_res[taskid]}) + + elif isinstance(input_data, StopData): + self._collect_stop(input_data) + if input_data.exception: + self.stop_manager.value = True + else: + raise ValueError("unknown input data") + + infer_size_sum = sum(self.infer_size.values()) + if self.image_total.value and infer_size_sum == self.image_total.value: + self.final_text_save() + self.stop_manager.value = True + + def stop(self): + profiling_data = ProfilingData( + module_name=self.module_name, + instance_id=self.instance_id, + process_cost_time=self.process_cost.value, + send_cost_time=self.send_cost.value, + image_total=self.image_total.value, + ) + self.msg_queue.put(profiling_data, block=False) + self.is_stop = True diff --git a/mindocr/infer/common/collect_node2.py b/mindocr/infer/common/collect_node2.py new file mode 100644 index 000000000..7a3a52ef5 --- /dev/null +++ b/mindocr/infer/common/collect_node2.py @@ -0,0 +1,224 @@ +import os +from collections import defaultdict +from ctypes import c_uint64 +from multiprocessing import Manager + +import numpy as np + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.data_process.utils import cv_utils +from pipeline.tasks import TaskType +from pipeline.utils import log, safe_list_writer, visual_utils +from pipeline.datatype import ProcessData, ProfilingData, StopData +from pipeline.framework.module_base import ModuleBase + +RESULTS_SAVE_FILENAME = { + TaskType.DET: "det_results.txt", + TaskType.CLS: "cls_results.txt", + TaskType.REC: "rec_results.txt", + TaskType.DET_REC: "pipeline_results.txt", + TaskType.DET_CLS_REC: "pipeline_results.txt", + TaskType.LAYOUT: "layout_results.txt", + TaskType.LAYOUT_DET_REC: "pipeline_results.txt", + TaskType.LAYOUT_DET_CLS_REC: "pipeline_results.txt", +} + + +class CollectNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super().__init__(args, msg_queue, tqdm_info) + self.image_sub_remaining = defaultdict(defaultdict) + self.image_pipeline_res = defaultdict(defaultdict) + self.infer_size = defaultdict(int) + self.image_total = Manager().Value(c_uint64, 0) + self.task_type = args.task_type + self.res_save_dir = args.res_save_dir + self.save_filename = RESULTS_SAVE_FILENAME[TaskType(self.task_type.value)] + + def init_self_args(self): + super().init_self_args() + + def _collect_stop(self, input_data): + self.image_total.value = input_data.image_total + + def _vis_results(self, image_name, image, taskid, data_type, task=None): + if self.args.crop_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)): + basename = os.path.basename(image_name) + filename = os.path.join(self.args.crop_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + crop_list = visual_utils.vis_crop(image, box_list) + for i, crop in enumerate(crop_list): + cv_utils.img_write(filename + "_crop_" + str(i) + ".jpg", crop) + + if self.args.vis_pipeline_save_dir: + basename = os.path.basename(image_name) + filename = os.path.join(self.args.vis_pipeline_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + text_list = [x["transcription"] for x in self.image_pipeline_res[taskid][image_name]] + box_text = visual_utils.vis_bbox_text(image, box_list, text_list, font_path=self.args.vis_font_path) + cv_utils.img_write(filename + ".jpg", box_text) + + if self.args.vis_det_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)): + basename = os.path.basename(image_name) + filename = os.path.join(self.args.vis_det_save_dir, os.path.splitext(basename)[0]) + box_list = [np.array(x).reshape(-1, 2) for x in self.image_pipeline_res[taskid][image_name]] + box_line = visual_utils.vis_bbox(image, box_list, [255, 255, 0], 2) + cv_utils.img_write(filename + ".jpg", box_line) + + if self.args.vis_layout_save_dir and (data_type == 0 or (data_type == 1 and self.args.input_array_save_dir)): + basename = os.path.basename(image_name) + filename = os.path.join(self.args.vis_layout_save_dir, os.path.splitext(basename)[0]) + box_list = [] + for x in self.image_pipeline_res[taskid][image_name]: + x, y, dx, dy = x['bbox'] + box_list.append(np.array([[x, y+dy], [x+dx, y+dy], [x+dx, y], [x, y]])) + box_line = visual_utils.vis_bbox(image, box_list, [255, 255, 0], 2) + cv_utils.img_write(filename + ".jpg", box_line) + # log.info(f"{image_name} is finished.") + + def final_text_save(self): + rst_dict = dict() + for rst in self.image_pipeline_res.values(): + rst_dict.update(rst) + save_filename = os.path.join(self.res_save_dir, self.save_filename) + safe_list_writer(rst_dict, save_filename) + # log.info(f"save infer result to {save_filename} successfully") + + def _update_layout_result(self, input_data): + taskid = input_data.taskid + image_path = input_data.image_path[0] + layout_rsts = input_data.data + + for layout_rst in layout_rsts["layout_collect_res"]: + # X, Y = layout_rst.data["raw_img_shape"] + layout_bbox = layout_rst.data["layout_result"] + lx, ly, _, _ = layout_bbox['bbox'] + for rec_rst in layout_rst.infer_result: + bbox, transcription, score = rec_rst[:-2], rec_rst[-2], rec_rst[-1] + bbox = [[b[0]+lx, b[1]+ly] for b in bbox] + if score > 0.5: + if self.args.result_contain_score: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": transcription, "points": bbox, "score": str(score)} + ) + else: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": transcription, "points": bbox} + ) + + + def _collect_results(self, input_data: ProcessData): + taskid = input_data.taskid + if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): + image_path = input_data.image_path[0] # bs=1 + # print(f"input_data.infer_result:{input_data.infer_result}") + for result in input_data.infer_result: + # print(f"result:{result}") + if result[-1] > 0.5: + if self.args.result_contain_score: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": result[-2], "points": result[:-2], "score": str(result[-1])} + ) + else: + self.image_pipeline_res[taskid][image_path].append( + {"transcription": result[-2], "points": result[:-2]} + ) + if not input_data.infer_result: + self.image_pipeline_res[taskid][image_path] = [] + elif self.task_type.value == TaskType.DET.value: + image_path = input_data.image_path[0] # bs=1 + self.image_pipeline_res[taskid][image_path] = input_data.infer_result + elif self.task_type.value in (TaskType.REC.value, TaskType.CLS.value): + for image_path, infer_result in zip(input_data.image_path, input_data.infer_result): + self.image_pipeline_res[taskid][image_path] = infer_result + elif self.task_type.value == TaskType.LAYOUT.value: + for infer_result in input_data.infer_result: + image_path = infer_result.pop("image_id")[0] + if image_path in self.image_pipeline_res[taskid]: + self.image_pipeline_res[taskid][image_path].append(infer_result) + else: + self.image_pipeline_res[taskid][image_path] = [infer_result] + elif self.task_type.value in (TaskType.LAYOUT_DET_REC.value, TaskType.LAYOUT_DET_CLS_REC.value,): + self._update_layout_result(input_data) + else: + raise NotImplementedError("Task type do not support.") + + self._update_remaining(input_data) + + def _update_remaining(self, input_data: ProcessData): + taskid = input_data.taskid + data_type = input_data.data_type + # if self.task_type.value in (TaskType.DET_REC.value, TaskType.DET_CLS_REC.value, TaskType.LAYOUT_DET_REC.value): # with sub image + # for idx, image_path in enumerate(input_data.image_path): + # if image_path in self.image_sub_remaining[taskid]: + # self.image_sub_remaining[taskid][image_path] -= input_data.sub_image_size + # if not self.image_sub_remaining[taskid][image_path]: + # self.image_sub_remaining[taskid].pop(image_path) + # self.infer_size[taskid] += 1 + # if self.task_type.value in (TaskType.LAYOUT_DET_REC.value, ): + # self._vis_results(image_path, input_data.data["layout_images"][idx], taskid, data_type) if input_data.frame else ... + # else: + # self._vis_results( + # image_path, input_data.frame[idx], taskid, data_type + # ) if input_data.frame else ... + # else: + # remaining = input_data.sub_image_total - input_data.sub_image_size + # if remaining: + # self.image_sub_remaining[taskid][image_path] = remaining + # else: + # self.infer_size[taskid] += 1 + # if self.task_type.value in (TaskType.LAYOUT_DET_REC.value, ): + # self._vis_results(image_path, input_data.data["layout_images"][idx], taskid, data_type) if input_data.frame else ... + # else: + # self._vis_results( + # image_path, input_data.frame[idx], taskid, data_type + # ) if input_data.frame else ... + # else: # without sub image + # if self.task_type.value not in (TaskType.LAYOUT_DET_REC, ): + for idx, image_path in enumerate(input_data.image_path): + self.infer_size[taskid] += 1 + if self.task_type.value in (TaskType.LAYOUT_DET_REC.value, ): + self._vis_results(image_path, input_data.frame[idx], taskid, data_type) if input_data.frame else ... + else: + self._vis_results(image_path, input_data.frame[idx], taskid, data_type) if input_data.frame else ... + + + def process(self, input_data): + if isinstance(input_data, ProcessData): + # print(f"ProcessData:{input_data.image_path}") + taskid = input_data.taskid + if input_data.taskid not in self.image_sub_remaining.keys(): + self.image_sub_remaining[input_data.taskid] = defaultdict(int) + if input_data.taskid not in self.image_pipeline_res.keys(): + self.image_pipeline_res[input_data.taskid] = defaultdict(list) + self._collect_results(input_data) + if self.infer_size[taskid] == input_data.task_images_num: + self.send_to_next_module({taskid: self.image_pipeline_res[taskid]}) + + elif isinstance(input_data, StopData): + self._collect_stop(input_data) + if input_data.exception: + self.stop_manager.value = True + else: + raise ValueError("unknown input data") + + infer_size_sum = sum(self.infer_size.values()) + if self.image_total.value and infer_size_sum == self.image_total.value: + self.final_text_save() + self.stop_manager.value = True + + def stop(self): + profiling_data = ProfilingData( + module_name=self.module_name, + instance_id=self.instance_id, + process_cost_time=self.process_cost.value, + send_cost_time=self.send_cost.value, + image_total=self.image_total.value, + ) + self.msg_queue.put(profiling_data, block=False) + self.is_stop = True diff --git a/mindocr/infer/common/decode_node.py b/mindocr/infer/common/decode_node.py new file mode 100644 index 000000000..eb78f8bbc --- /dev/null +++ b/mindocr/infer/common/decode_node.py @@ -0,0 +1,51 @@ +import os +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.data_process.utils import cv_utils +from pipeline.utils import log +from pipeline.datatype import StopData +from pipeline.framework.module_base import ModuleBase + + +class DecodeNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super().__init__(args, msg_queue, tqdm_info) + self.cost_time = 0 + self.avail_image_total = 0 + + def init_self_args(self): + super().init_self_args() + + def process(self, input_data): + if isinstance(input_data, StopData): + input_data.image_total = self.avail_image_total + self.send_to_next_module(input_data) + return + + if input_data.skip: + self.send_to_next_module(input_data) + return + + # input contains np.ndarray, not need read again + if len(input_data.frame) == len(input_data.image_path) and len(input_data.frame) > 0: + self.avail_image_total += len(input_data.frame) + self.send_to_next_module(input_data) + else: + img_read, img_path_read = [], [] + for image_path in input_data.image_path: + try: + img_read.append(cv_utils.img_read(image_path)) + img_path_read.append(image_path) + self.avail_image_total += 1 + except ValueError: + log.info(f"{image_path} is unavailable and skipped") + continue + input_data.frame = img_read + input_data.image_path = img_path_read + self.send_to_next_module(input_data) diff --git a/mindocr/infer/common/handout_node.py b/mindocr/infer/common/handout_node.py new file mode 100644 index 000000000..122db5879 --- /dev/null +++ b/mindocr/infer/common/handout_node.py @@ -0,0 +1,99 @@ +import os + +import cv2 +import numpy as np + +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.data_process.utils import cv_utils +from pipeline.utils import log +from pipeline.datatype import ProcessData, StopData, StopSign +from pipeline.framework.module_base import ModuleBase +from pipeline.datatype.process_data import ProcessData + + +class HandoutNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super().__init__(args, msg_queue, tqdm_info) + self.image_total = 0 + + def init_self_args(self): + super().init_self_args() + + def process(self, input_mix_data): + if isinstance(input_mix_data, StopSign): + data = self.process_stop_sign() + self.send_to_next_module(data) + elif isinstance(input_mix_data, np.ndarray): + input_data, info_data = input_mix_data + data = self.process_image_array([input_data]) + data.task_images_num = info_data[0] + data.taskid = info_data[1] + data.data_type = 1 + self.send_to_next_module(data) + elif isinstance(input_mix_data, (tuple, list)): + input_data, info_data = input_mix_data + if len(input_data) == 0: + return + if cv_utils.check_type_in_container(input_data, str): + data = self.process_image_path(input_data) + data.data_type = 0 + elif cv_utils.check_type_in_container(input_data, np.ndarray): + data = self.process_image_array(input_data) + data.data_type = 1 + else: + raise ValueError( + "unknown input data, input_data should be StopSign, or tuple&list contains str or np.ndarray" + ) + data.task_images_num = info_data[0] + data.taskid = info_data[1] + self.send_to_next_module(data) + else: + raise ValueError(f"unknown input data: {type(input_mix_data)}") + + def process_image_path(self, image_path_list): + """ + image_folder: List[str], path to images + """ + # log.info(f"sending {', '.join([os.path.basename(x) for x in image_path_list])} to pipleine") + data = ProcessData(image_path=image_path_list) + self.image_total += len(image_path_list) + return data + + def process_image_array(self, image_array_list): + """ + image_array_list: List[np.ndarray], array of images + """ + frames = [] + array_save_path = [] + image_num = len(image_array_list) + for i in range(image_num): + if self.args.input_array_save_dir: + image_path = os.path.join(self.args.input_array_save_dir, f"input_array_{self.image_total}.jpg") + if len(image_array_list[i].shape) != 3: + log.info(f"image_array_list[{i}] array with shape {image_array_list[i].shape} is invalid") + continue + try: + cv_utils.img_write(image_path, image_array_list[i]) + except cv2.error: + log.info(f"image_array_list[{i}] with shape {image_array_list[i].shape} array is invalid") + continue + log.info(f"sending array(saved at {image_path}) to pipleine") + array_save_path.append(image_path) + else: + array_save_path.append(str(i)) + frames.append(image_array_list[i]) + + self.image_total += 1 + data = ProcessData(frame=frames, image_path=array_save_path) + return data + + def process_stop_sign(self): + # image_total of StopData will be assigned in decode_node + return StopData(skip=True) diff --git a/mindocr/infer/detection/__init__.py b/mindocr/infer/detection/__init__.py new file mode 100644 index 000000000..b78397356 --- /dev/null +++ b/mindocr/infer/detection/__init__.py @@ -0,0 +1,3 @@ +from .det_infer_node import DetInferNode +from .det_post_node import DetPostNode +from .det_pre_node import DetPreNode diff --git a/mindocr/infer/detection/det_infer_node.py b/mindocr/infer/detection/det_infer_node.py new file mode 100644 index 000000000..d5ed155bf --- /dev/null +++ b/mindocr/infer/detection/det_infer_node.py @@ -0,0 +1,40 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .detection import INFER_DET_MAP + + +class DetInferNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(DetInferNode, self).__init__(args, msg_queue, tqdm_info) + self.args = args + self.det_model = None + self.task_type = self.args.task_type + + def init_self_args(self): + with open(self.args.det_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + DetModel = INFER_DET_MAP[self.yaml_cfg.predict.backend] + self.det_model = DetModel(self.args) + super().init_self_args() + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data["det_pre_res"]["image"] + pred = self.det_model([data]) + + input_data.data["det_infer_res"] = pred + + self.send_to_next_module(input_data) diff --git a/mindocr/infer/detection/det_post_node.py b/mindocr/infer/detection/det_post_node.py new file mode 100644 index 000000000..3201938c3 --- /dev/null +++ b/mindocr/infer/detection/det_post_node.py @@ -0,0 +1,86 @@ +import os +import sys +import numpy as np + +import cv2 + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .detection import DetPostprocess +from tools.infer.text.utils import crop_text_region +from pipeline.data_process.utils.cv_utils import crop_box_from_image + +class DetPostNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(DetPostNode, self).__init__(args, msg_queue, tqdm_info) + self.det_postprocess = DetPostprocess(args) + self.task_type = self.args.task_type + self.is_concat = self.args.is_concat + + def init_self_args(self): + super().init_self_args() + + def concat_crops(self, crops: list): + """ + Concatenates the list of cropped images horizontally after resizing them to have the same height. + + Args: + crops (list): A list of cropped images represented as numpy arrays. + + Returns: + numpy.ndarray: A horizontally concatenated image array. + """ + max_height = max(crop.shape[0] for crop in crops) + resized_crops = [] + for crop in crops: + h, w, c = crop.shape + new_h = max_height + new_w = int((w / h) * new_h) + + resized_img = cv2.resize(crop, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + resized_crops.append(resized_img) + crops_concated = np.concatenate(resized_crops, axis=1) + return crops_concated + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + pred = input_data.data["det_infer_res"] + pred = pred[0] + data_dict = {"shape_list": input_data.data["det_pre_res"]["shape_list"]} + boxes = self.det_postprocess(pred, data_dict) + + boxes = boxes['polys'][0] + + if self.is_concat: + boxes = sorted(boxes, key=lambda points: (points[0][1], points[0][0])) + + infer_res_list = [] + for box in boxes: + infer_res_list.append(box.tolist()) + + input_data.infer_result = infer_res_list + + if self.task_type.value in (TaskType.DET.value, TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): + if len(input_data.frame) == 0: + return + image = input_data.frame[0] # bs=1 for det + else: + image = input_data.data["layout_images"][0] + sub_image_list = [] + for box in infer_res_list: + sub_image = crop_box_from_image(image, np.array(box)) + sub_image_list.append(sub_image) + if self.is_concat: + sub_image_list = len(sub_image_list) * [self.concat_crops(sub_image_list)] + input_data.sub_image_list = sub_image_list + + if not infer_res_list: + input_data.skip = True + + self.send_to_next_module(input_data) \ No newline at end of file diff --git a/mindocr/infer/detection/det_pre_node.py b/mindocr/infer/detection/det_pre_node.py new file mode 100644 index 000000000..db48789bd --- /dev/null +++ b/mindocr/infer/detection/det_pre_node.py @@ -0,0 +1,50 @@ +import os +import sys + +import numpy as np +import time + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .detection import DetPreprocess + +class DetPreNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(DetPreNode, self).__init__(args, msg_queue, tqdm_info) + self.det_preprocesser = DetPreprocess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + return {"batch_size": 1} + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + if self.task_type.value in (TaskType.DET.value, TaskType.DET_REC.value, TaskType.DET_CLS_REC.value): + if len(input_data.frame) == 0: + return + image = input_data.frame[0] # bs = 1 for det + else: + if len(input_data.data["layout_images"]) == 0: + return + image = input_data.data["layout_images"][0] # bs = 1 for det + data = self.det_preprocesser({"image": image}) + + if len(data["image"].shape) == 3: + data["image"] = np.expand_dims(data["image"], 0) + data["shape_list"] = np.expand_dims(data["shape_list"], 0) + # if self.task_type.value == TaskType.DET.value and not (self.args.crop_save_dir or self.args.vis_det_save_dir): + # input_data.frame = None + + if self.task_type.value in (TaskType.LAYOUT_DET.value, TaskType.LAYOUT_DET_REC.value, TaskType.LAYOUT_DET_CLS_REC.value): + input_data.data["det_pre_res"] = data + else: + input_data.data = {"det_pre_res": data} + + self.send_to_next_module(input_data) diff --git a/mindocr/infer/detection/detection.py b/mindocr/infer/detection/detection.py new file mode 100644 index 000000000..7717485f8 --- /dev/null +++ b/mindocr/infer/detection/detection.py @@ -0,0 +1,70 @@ +import logging +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict +from typing import List + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.data.transforms import create_transforms, run_transforms +from mindocr.postprocess import build_postprocess +from mindocr.infer.utils.model import MSModel, LiteModel + + +algo_to_model_name = { + "DB": "dbnet_resnet50", + "DB++": "dbnetpp_resnet50", + "DB_MV3": "dbnet_mobilenetv3", + "DB_PPOCRv3": "dbnet_ppocrv3", + "PSE": "psenet_resnet152", +} +logger = logging.getLogger("mindocr") + +class DetPreprocess(object): + def __init__(self, args): + self.args = args + with open(args.det_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + for transform in self.yaml_cfg.predict.dataset.transform_pipeline: + if "DecodeImage" in transform: + transform["DecodeImage"].update({"keep_ori": True}) + break + self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) + + def __call__(self, data): + data = run_transforms(data, self.transforms[1:]) + return data + + +class DetModelMS(MSModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.det_algorithm] + self.config_path = args.det_config_path + self._init_model(self.model_name, self.config_path) + + +class DetModelLite(LiteModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.det_algorithm] + self.config_path = args.det_config_path + self._init_model(self.model_name, self.config_path) + +INFER_DET_MAP = {"MindSporeLite": DetModelLite, "MindSpore": DetModelMS} + + +class DetPostprocess(object): + def __init__(self, args): + self.args = args + with open(args.det_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.transforms = build_postprocess(self.yaml_cfg.postprocess) + + def __call__(self, pred, data): + return self.transforms(pred, **data) \ No newline at end of file diff --git a/mindocr/infer/layout/__init__.py b/mindocr/infer/layout/__init__.py new file mode 100644 index 000000000..15718372f --- /dev/null +++ b/mindocr/infer/layout/__init__.py @@ -0,0 +1,4 @@ +from .layout_infer_node import LayoutInferNode +from .layout_post_node import LayoutPostNode +from .layout_pre_node import LayoutPreNode +from .layout_collect_node import LayoutCollectNode \ No newline at end of file diff --git a/mindocr/infer/layout/layout.py b/mindocr/infer/layout/layout.py new file mode 100644 index 000000000..f5148d2ae --- /dev/null +++ b/mindocr/infer/layout/layout.py @@ -0,0 +1,88 @@ +import os +import sys +import logging + +import mindspore as ms +import numpy as np + +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + + +from tools.infer.text.utils import get_ckpt_file +from mindocr.models.builder import build_model +from mindocr.data.transforms import create_transforms, run_transforms +from mindocr.utils.logger import set_logger +from mindocr.postprocess import build_postprocess + +from typing import List +import yaml + +import copy + +import time + +from mindocr.infer.utils.model import MSModel, LiteModel + +class LayoutPreprocess(object): + def __init__(self, args) -> None: + self.args = args + with open(args.layout_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + + for transform in self.yaml_cfg.predict.dataset.transform_pipeline: + if "DecodeImage" in transform: + transform["DecodeImage"].update({"keep_ori": True}) + break + if "func_name" in transform: + func_name = transform.pop("func_name") + args = copy.copy(transform) + transform.clear() + transform[func_name] = args + self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) + + def __call__(self, data): + data = run_transforms(data, self.transforms) + return data + + +algo_to_model_name = { + "YOLOV8": "layout_yolov8n", +} +logger = logging.getLogger("mindocr") + + +class LayoutModelMS(MSModel): + def __init__(self, args) -> None: + super().__init__(args) + self.args = args + self.model_name = algo_to_model_name[args.layout_algorithm] + self.config_path = args.layout_config_path + self._init_model(self.model_name, self.config_path) + + +class LayoutModelLite(LiteModel): + def __init__(self, args) -> None: + super().__init__(args) + self.args = args + self.model_name = algo_to_model_name[args.layout_algorithm] + self.config_path = args.layout_config_path + self._init_model(self.model_name, self.config_path) + + +class LayoutPostProcess(object): + def __init__(self, args) -> None: + self.args = args + with open(args.layout_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.transforms = build_postprocess(self.yaml_cfg.postprocess) + + self.meta_data_indices = self.yaml_cfg.predict.dataset.pop("meta_data_column_index", None) + + + def __call__(self, pred, img_shape, meta_info): + return self.transforms(pred, img_shape, meta_info=meta_info) + +INFER_LAYOUT_MAP = {"MindSporeLite": LayoutModelLite, "MindSpore": LayoutModelMS} \ No newline at end of file diff --git a/mindocr/infer/layout/layout_collect_node.py b/mindocr/infer/layout/layout_collect_node.py new file mode 100644 index 000000000..d2909f0c5 --- /dev/null +++ b/mindocr/infer/layout/layout_collect_node.py @@ -0,0 +1,63 @@ +import copy +import os +import sys + +import numpy as np +import time + +from collections import defaultdict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from mindocr.infer.utils.collector import Collector +from pipeline.datatype import ProcessData, ProfilingData, StopData + +class LayoutCollectNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(LayoutCollectNode, self).__init__(args, msg_queue, tqdm_info) + self.task_type = self.args.task_type + self.collect_dict = defaultdict(Collector) + + def init_self_args(self): + super().init_self_args() + + def process(self, input_data): + """ + Input: + - input_data.data["pred"]: [np.ndarray], shape:[1,?,?], shape e.g. [1,13294, 9] (note:[bs, N, 5+nc]) + - input_data.data["hw_ori"]: (int, int), value e.g. (792,601) + - input_data.data["hw_scale"]: np.ndarray, shape:[1,2], value e.g. (1.0101,1.3311) + - input_data.data["pad"]: np.ndarray, shape:[1,2], value e.g. (4,99.5) + Output: + - input_data.data["image_ids"]: [str] + - input_data.infer_result: [{"image_id": str, "category_id": int, "bbox": [x:int, y:int, dx:int, dy:int]}] + - input_data.data["layout_result"]: [{"image_id": str, "category_id": int, "bbox": [x:int, y:int, dx:int, dy:int]}] + - input_data.data["layout_image"]: [np.ndarray], shape:[?,?,3] + """ + # if input_data.skip: + # self.send_to_next_module(input_data) + # return + + + if isinstance(input_data, StopData): + self.send_to_next_module(input_data) + return + + data = input_data.data + image_path = input_data.image_path[0] + if image_path in self.collect_dict.keys(): + self.collect_dict[image_path].update(input_data.data["layout_collect_idx"], input_data) + else: + self.collect_dict[image_path].init_keys(input_data.data["layout_collect_list"]) + self.collect_dict[image_path].update(input_data.data["layout_collect_idx"], input_data) + + if self.collect_dict[image_path].complete(): + data = self.collect_dict[image_path].get_data() + self.collect_dict.pop(image_path) + + input_data_out = copy.deepcopy(input_data) + input_data_out.data = {"layout_collect_res": data} + self.send_to_next_module(input_data_out) diff --git a/mindocr/infer/layout/layout_infer_node.py b/mindocr/infer/layout/layout_infer_node.py new file mode 100644 index 000000000..8060ab711 --- /dev/null +++ b/mindocr/infer/layout/layout_infer_node.py @@ -0,0 +1,51 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .layout import INFER_LAYOUT_MAP + +import time +import numpy as np + +import yaml + +from addict import Dict + +class LayoutInferNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(LayoutInferNode, self).__init__(args, msg_queue, tqdm_info) + self.args = args + self.layout_model = None + self.task_type = self.args.task_type + self.i = 0 + + def init_self_args(self): + with open(self.args.layout_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + LayoutModel = INFER_LAYOUT_MAP[self.yaml_cfg.predict.backend] + self.layout_model = LayoutModel(self.args) + super().init_self_args() + + def process(self, input_data): + """ + Input: + - input_data.data["image"]: np.ndarray, shape:[1,3,800,800] + Output: + - input_data.data["pred"]: [np.ndarray], shape:[1,?,?], shape e.g. [1,13294, 9] (note:[bs, N, 5+nc]) + - input_data.data["img_shape"]: (int, int, int, int), value e.g. (1,3,800,800) + """ + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data["image"] + pred = self.layout_model([data]) + + input_data.data["pred"] = pred + input_data.data["img_shape"] = input_data.data["image"].shape + + self.send_to_next_module(input_data) \ No newline at end of file diff --git a/mindocr/infer/layout/layout_post_node.py b/mindocr/infer/layout/layout_post_node.py new file mode 100644 index 000000000..632df7e7b --- /dev/null +++ b/mindocr/infer/layout/layout_post_node.py @@ -0,0 +1,75 @@ +import os +import sys +import numpy as np + +import cv2 +import copy + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .layout import LayoutPostProcess +from tools.infer.text.utils import crop_text_region +from pipeline.data_process.utils.cv_utils import crop_box_from_image + +class LayoutPostNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(LayoutPostNode, self).__init__(args, msg_queue, tqdm_info) + self.layout_postprocess = LayoutPostProcess(args) + self.task_type = self.args.task_type + self.score_thres = 0.5 + + def init_self_args(self): + super().init_self_args() + + def get_layout_images(self, frame, infer_result): + layout_images = [] + for d in infer_result: + d['bbox'] = [int(v) for v in d['bbox']] + x, y, dx, dy = d['bbox'] + layout_images.append(frame[0][y:(y+dy), x:(x+dx), :]) + return layout_images + + def process(self, input_data): + """ + Input: + - input_data.data["pred"]: [np.ndarray], shape:[1,?,?], shape e.g. [1,13294, 9] (note:[bs, N, 5+nc]) + - input_data.data["hw_ori"]: (int, int), value e.g. (792,601) + - input_data.data["hw_scale"]: np.ndarray, shape:[1,2], value e.g. (1.0101,1.3311) + - input_data.data["pad"]: np.ndarray, shape:[1,2], value e.g. (4,99.5) + Output: + - input_data.data["image_ids"]: [str] + - input_data.infer_result: [{"image_id": str, "category_id": int, "bbox": [x:int, y:int, dx:int, dy:int]}] + - input_data.data["layout_result"]: [{"image_id": str, "category_id": int, "bbox": [x:int, y:int, dx:int, dy:int]}] + - input_data.data["layout_image"]: [np.ndarray], shape:[?,?,3] + """ + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data + data["image_ids"] = [input_data.image_path] + + meta_info = (data["image_ids"], [data["hw_ori"]], [data["hw_scale"]], [data["pad"]]) + # print(f"meta_info: {meta_info}") + output = self.layout_postprocess(data["pred"][0], data["img_shape"], meta_info) + output = [d for d in output if d["score"] > self.score_thres and d["category_id"] in (1, 2, 3)] + + if self.task_type.value == TaskType.LAYOUT.value: + input_data_out = copy.deepcopy(input_data) + input_data_out.infer_result = output + self.send_to_next_module(input_data_out) + else: + layout_images = self.get_layout_images(input_data.frame, output) + input_data.data["layout_collect_list"] = list(range(len(output))) + for layout_images_id in range(len(output)): + input_data_out = copy.deepcopy(input_data) + # new_image_path = f"{layout_images_id}-" + os.path.basename(input_data.image_path[0]) + input_data_out.data["layout_result"] = output[layout_images_id] + # input_data_out.data["image_ids"] = [f"{new_image_path}"] + # input_data_out.image_path = [f"{new_image_path}"] + input_data_out.data["layout_images"] = [layout_images[layout_images_id]] + input_data_out.data["layout_collect_idx"] = layout_images_id + self.send_to_next_module(input_data_out) \ No newline at end of file diff --git a/mindocr/infer/layout/layout_pre_node.py b/mindocr/infer/layout/layout_pre_node.py new file mode 100644 index 000000000..a38ab7fc9 --- /dev/null +++ b/mindocr/infer/layout/layout_pre_node.py @@ -0,0 +1,59 @@ +import os +import sys + +import numpy as np +import time + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .layout import LayoutPreprocess + +class LayoutPreNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(LayoutPreNode, self).__init__(args, msg_queue, tqdm_info) + self.layout_preprocesser = LayoutPreprocess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + return {"batch_size": 1} + + def process(self, input_data): + """ + Input: + - input_data.frame: [np.ndarray], shape:[-1,-1,3], shape e.g. [792,601,3] + Output: + - input_data.data["image"]: np.ndarray, shape:[1,3,800,800] + - input_data.data["raw_img_shape"]: (int, int), value e.g. (792,601) + - input_data.data["target_size"]: [int, int], value e.g. (800,800) + - input_data.data["image_ids"]: int, value e.g. 0 + - input_data.data["hw_ori"]: (int, int), value e.g. (792,601) + - input_data.data["hw_scale"]: np.ndarray, shape:[1,2], value e.g. (1.0101,1.3311) + - input_data.data["pad"]: np.ndarray, shape:[1,2], value e.g. (4,99.5) + """ + if input_data.skip: + self.send_to_next_module(input_data) + return + if len(input_data.frame) == 0: + return + + image = input_data.frame[0] # bs = 1 for layout + data = { + "image": image, + "raw_img_shape": image.shape[:2], + "target_size": [800, 800], + } + data = self.layout_preprocesser(data) + + if len(data["image"].shape) == 3: + data["image"] = np.expand_dims(data["image"], 0) + + # if self.task_type.value == TaskType.LAYOUT.value and not (self.args.crop_save_dir or self.args.vis_layout_save_dir): + # input_data.frame = None + + input_data.data = data + + self.send_to_next_module(input_data) diff --git a/mindocr/infer/node_config.py b/mindocr/infer/node_config.py new file mode 100644 index 000000000..37a5a654a --- /dev/null +++ b/mindocr/infer/node_config.py @@ -0,0 +1,130 @@ +from sys import modules + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from mindocr.infer.detection import DetInferNode, DetPostNode, DetPreNode +from mindocr.infer.recognition import RecInferNode, RecPostNode, RecPreNode +from mindocr.infer.classification import ClsPreNode, ClsInferNode, ClsPostNode +from mindocr.infer.layout import LayoutPreNode, LayoutInferNode, LayoutPostNode, LayoutCollectNode +from mindocr.infer.common import CollectNode, DecodeNode, HandoutNode +# from deploy.py_infer.src.infer import TaskType +from pipeline.tasks import TaskType +from pipeline.utils import log + +__all__ = ["MODEL_DICT_v2", + "DET_DESC_v2", "CLS_DESC_v2", "REC_DESC_v2", + "DET_REC_DESC_v2", "DET_CLS_REC_DESC_v2", + "LAYOUT_DESC_v2", "LAYOUT_DET_REC_DESC_v2", "LAYOUT_DET_CLS_REC_DESC_v2"] + +DET_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +REC_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +CLS_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("ClsPreNode", "0", 1)), + (("ClsPreNode", "0", 1), ("ClsInferNode", "0", 1)), + (("ClsInferNode", "0", 1), ("ClsPostNode", "0", 1)), + (("ClsPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +DET_REC_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +DET_CLS_REC_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("ClsPreNode", "0", 1)), + (("ClsPreNode", "0", 1), ("ClsInferNode", "0", 1)), + (("ClsInferNode", "0", 1), ("ClsPostNode", "0", 1)), + (("ClsPostNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +LAYOUT_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("LayoutPreNode", "0", 1)), + (("LayoutPreNode", "0", 1), ("LayoutInferNode", "0", 1)), + (("LayoutInferNode", "0", 1), ("LayoutPostNode", "0", 1)), + (("LayoutPostNode", "0", 1), ("CollectNode", "0", 1)), +] + +LAYOUT_DET_REC_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("LayoutPreNode", "0", 1)), + (("LayoutPreNode", "0", 1), ("LayoutInferNode", "0", 1)), + (("LayoutInferNode", "0", 1), ("LayoutPostNode", "0", 1)), + (("LayoutPostNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("LayoutCollectNode", "0", 1)), + (("LayoutCollectNode", "0", 1), ("CollectNode", "0", 1)), +] + +LAYOUT_DET_CLS_REC_DESC_v2 = [ + (("HandoutNode", "0", 1), ("DecodeNode", "0", 1)), + (("DecodeNode", "0", 1), ("LayoutPreNode", "0", 1)), + (("LayoutPreNode", "0", 1), ("LayoutInferNode", "0", 1)), + (("LayoutInferNode", "0", 1), ("LayoutPostNode", "0", 1)), + (("LayoutPostNode", "0", 1), ("DetPreNode", "0", 1)), + (("DetPreNode", "0", 1), ("DetInferNode", "0", 1)), + (("DetInferNode", "0", 1), ("DetPostNode", "0", 1)), + (("DetPostNode", "0", 1), ("ClsPreNode", "0", 1)), + (("ClsPreNode", "0", 1), ("ClsInferNode", "0", 1)), + (("ClsInferNode", "0", 1), ("ClsPostNode", "0", 1)), + (("ClsPostNode", "0", 1), ("RecPreNode", "0", 1)), + (("RecPreNode", "0", 1), ("RecInferNode", "0", 1)), + (("RecInferNode", "0", 1), ("RecPostNode", "0", 1)), + (("RecPostNode", "0", 1), ("LayoutCollectNode", "0", 1)), + (("LayoutCollectNode", "0", 1), ("CollectNode", "0", 1)), +] + +MODEL_DICT_v2 = {TaskType.DET: DET_DESC_v2, + TaskType.CLS: CLS_DESC_v2, + TaskType.REC: REC_DESC_v2, + TaskType.DET_REC: DET_REC_DESC_v2, + TaskType.DET_CLS_REC: DET_CLS_REC_DESC_v2, + TaskType.LAYOUT: LAYOUT_DESC_v2, + TaskType.LAYOUT_DET_REC: LAYOUT_DET_REC_DESC_v2, + TaskType.LAYOUT_DET_CLS_REC: LAYOUT_DET_CLS_REC_DESC_v2,} + +def processor_initiator(classname): + try: + processor = getattr(modules.get(__name__), classname) + except AttributeError as error: + log.error("%s doesn't exist.", classname) + raise error + if isinstance(processor, type): + return processor + raise TypeError("%s doesn't exist.", classname) \ No newline at end of file diff --git a/mindocr/infer/recognition/__init__.py b/mindocr/infer/recognition/__init__.py new file mode 100644 index 000000000..bc416dda5 --- /dev/null +++ b/mindocr/infer/recognition/__init__.py @@ -0,0 +1,3 @@ +from .rec_infer_node import RecInferNode +from .rec_post_node import RecPostNode +from .rec_pre_node import RecPreNode diff --git a/mindocr/infer/recognition/rec_infer_node.py b/mindocr/infer/recognition/rec_infer_node.py new file mode 100644 index 000000000..ab6f544cc --- /dev/null +++ b/mindocr/infer/recognition/rec_infer_node.py @@ -0,0 +1,51 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .recognition import INFER_REC_MAP + + +class RecInferNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(RecInferNode, self).__init__(args, msg_queue, tqdm_info) + self.args = args + self.rec_model = None + self.task_type = self.args.task_type + + def init_self_args(self): + with open(self.args.rec_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.batch_size = self.yaml_cfg.predict.loader.batch_size + RecModel = INFER_REC_MAP[self.yaml_cfg.predict.backend] + self.rec_model = RecModel(self.args) + super().init_self_args() + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data["rec_pre_res"] + data = [np.expand_dims(d, 0) for d in data if len(d.shape) == 3] + data = np.concatenate(data, axis=0) + + preds = [] + for batch_i in range(data.shape[0] // self.batch_size + 1): + start_i = batch_i * self.batch_size + end_i = (batch_i + 1) * self.batch_size + d = data[start_i:end_i] + if d.shape[0] == 0: + continue + pred = self.rec_model([d]) + preds.append(pred[0]) + preds = np.concatenate(preds, axis=0) + input_data.data["rec_infer_res"] = {"pred": preds} + self.send_to_next_module(input_data) diff --git a/mindocr/infer/recognition/rec_post_node.py b/mindocr/infer/recognition/rec_post_node.py new file mode 100644 index 000000000..db621115c --- /dev/null +++ b/mindocr/infer/recognition/rec_post_node.py @@ -0,0 +1,47 @@ +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +import cv2 +import numpy as np + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .recognition import RecPostprocess +from tools.infer.text.utils import crop_text_region +from pipeline.data_process.utils.cv_utils import crop_box_from_image + + +class RecPostNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(RecPostNode, self).__init__(args, msg_queue, tqdm_info) + self.rec_postprocess = RecPostprocess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + data = input_data.data["rec_infer_res"] + pred = data["pred"] + output = self.rec_postprocess(pred) + texts = output["texts"] + confs = output["confs"] + if self.task_type.value == TaskType.REC.value: + input_data.infer_result = output["texts"] + else: + for i, (text, conf) in enumerate(zip(texts, confs)): + input_data.infer_result[i].append(text) + input_data.infer_result[i].append(conf) + self.send_to_next_module(input_data) diff --git a/mindocr/infer/recognition/rec_pre_node.py b/mindocr/infer/recognition/rec_pre_node.py new file mode 100644 index 000000000..e77b170cd --- /dev/null +++ b/mindocr/infer/recognition/rec_pre_node.py @@ -0,0 +1,41 @@ +import argparse +import os +import time +import sys +import numpy as np + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.framework.module_base import ModuleBase +from pipeline.tasks import TaskType +from .recognition import RecPreprocess + + +class RecPreNode(ModuleBase): + def __init__(self, args, msg_queue, tqdm_info): + super(RecPreNode, self).__init__(args, msg_queue, tqdm_info) + self.rec_preprocesser = RecPreprocess(args) + self.task_type = self.args.task_type + + def init_self_args(self): + super().init_self_args() + return {"batch_size": 1} + + def process(self, input_data): + if input_data.skip: + self.send_to_next_module(input_data) + return + + if self.task_type.value == TaskType.REC.value: + image = input_data.frame[0] + data = [self.rec_preprocesser(image)["image"]] + input_data.sub_image_size = 1 + input_data.data = {"rec_pre_res": data} + self.send_to_next_module(input_data) + else: + sub_image_list = input_data.sub_image_list + data = [self.rec_preprocesser(split_image)["image"] for split_image in sub_image_list] + input_data.sub_image_size = len(sub_image_list) + input_data.data["rec_pre_res"] = data + self.send_to_next_module(input_data) diff --git a/mindocr/infer/recognition/recognition.py b/mindocr/infer/recognition/recognition.py new file mode 100644 index 000000000..79993b467 --- /dev/null +++ b/mindocr/infer/recognition/recognition.py @@ -0,0 +1,68 @@ +import logging +import os +import time +import sys +import numpy as np +import yaml +from addict import Dict +from typing import List + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.data.transforms import create_transforms, run_transforms +from mindocr.postprocess import build_postprocess +from mindocr.infer.utils.model import MSModel, LiteModel + + +algo_to_model_name = { + "CRNN": "crnn_resnet34", + "RARE": "rare_resnet34", + "CRNN_CH": "crnn_resnet34_ch", + "RARE_CH": "rare_resnet34_ch", + "SVTR": "svtr_tiny", + "SVTR_PPOCRv3_CH": "svtr_ppocrv3_ch", +} +logger = logging.getLogger("mindocr") + +class RecPreprocess(object): + def __init__(self, args): + self.args = args + with open(args.rec_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.transforms = create_transforms(self.yaml_cfg.predict.dataset.transform_pipeline) + + def __call__(self, img): + data = {"image": img} + # ZHQ TODO: [1:] ??? + data = run_transforms(data, self.transforms[1:]) + return data + + +class RecModelMS(MSModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.rec_algorithm] + self.config_path = args.rec_config_path + self._init_model(self.model_name, self.config_path) + + +class RecModelLite(LiteModel): + def __init__(self, args): + self.args = args + self.model_name = algo_to_model_name[args.rec_algorithm] + self.config_path = args.rec_config_path + self._init_model(self.model_name, self.config_path) + +INFER_REC_MAP = {"MindSporeLite": RecModelLite, "MindSpore": RecModelMS} + +class RecPostprocess(object): + def __init__(self, args): + self.args = args + with open(args.rec_model_name_or_config, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.postprocessor = build_postprocess(self.yaml_cfg.postprocess) + + def __call__(self, pred): + return self.postprocessor(pred) \ No newline at end of file diff --git a/mindocr/infer/utils/collector.py b/mindocr/infer/utils/collector.py new file mode 100644 index 000000000..d77cb37c2 --- /dev/null +++ b/mindocr/infer/utils/collector.py @@ -0,0 +1,22 @@ +# For each image, use one collector +class Collector: + def __init__(self): + self.collect_keys = [] + self.collect_value = dict() + def init_keys(self, key_list): + self.collect_keys.clear() + self.collect_value.clear() + self.collect_keys = key_list + for k in key_list: + self.collect_value[k] = None + def update(self, key, value): + self.collect_value[key] = value + if key in self.collect_keys: + self.collect_keys.remove(key) + def complete(self): + if len(self.collect_keys) == 0: + return True + else: + return False + def get_data(self): + return list(self.collect_value.values()) \ No newline at end of file diff --git a/mindocr/infer/utils/model.py b/mindocr/infer/utils/model.py new file mode 100644 index 000000000..a9fa10536 --- /dev/null +++ b/mindocr/infer/utils/model.py @@ -0,0 +1,123 @@ +import os +from collections import defaultdict +from ctypes import c_uint64 +from multiprocessing import Manager + +from abc import ABCMeta, abstractmethod +import sys +import numpy as np +import yaml +from addict import Dict + +import logging + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from tools.infer.text.utils import get_ckpt_file +from mindocr.models.builder import build_model +from typing import List + +logger = logging.getLogger("mindocr") + +class BaseModel(metaclass=ABCMeta): + def __init__(self, args) -> None: + self.model = None + self.args = args + self.pretrained = True + self.ckpt_load_path = "" + self.amp_level = "O0" + + @abstractmethod + def __call__(self, inputs: List[np.ndarray]) -> List[np.ndarray]: + pass + + @abstractmethod + def _init_model(self, model_name, config_path): + pass + + +class MSModel(BaseModel): + def __init__(self, args) -> None: + super().__init__(args) + + def _init_model(self, model_name, config_path): + global ms + import mindspore as ms + + self.config_path = config_path + with open(self.config_path, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.ckpt_load_path = self.yaml_cfg.predict.ckpt_load_path + if self.ckpt_load_path is None: + self.pretrained = True + self.ckpt_load_path = None + else: + self.pretrained = False + self.ckpt_load_path = get_ckpt_file(self.ckpt_load_path) + + ms.set_context(device_target=self.yaml_cfg.predict.get("device_target", "Ascend")) + ms.set_context(device_id=self.yaml_cfg.predict.get("device_id", 0)) + ms.set_context(mode=self.yaml_cfg.predict.get("mode", 0)) + if self.yaml_cfg.system.get("distribute", False): + ms.communication.init() + ms.set_auto_parallel_context(parallel_mode=ms.ParallelMode.DATA_PARALLEL) + if self.yaml_cfg.predict.get("max_device_memory", None): + ms.set_context(max_device_memory=self.yaml_cfg.predict.get("max_device_memory")) + self.amp_level = self.yaml_cfg.predict.get("amp_level", "O0") + if ms.get_context("device_target") == "GPU" and self.amp_level == "O3": + logger.warning( + "Model prediction does not support amp_level O3 on GPU currently." + "The program has switched to amp_level O2 automatically." + ) + self.amp_level = "O2" + self.model = build_model( + model_name, + ckpt_load_path=self.ckpt_load_path, + amp_level=self.amp_level, + ) + self.model.set_train(False) + logger.info( + "Init mindspore model: {}. Model weights loaded from {}".format( + model_name, "pretrained url" if self.pretrained else self.ckpt_load_path + ) + ) + def __call__(self, inputs: List[np.ndarray]) -> List[np.ndarray]: + input_ms = [ms.Tensor.from_numpy(input) for input in inputs] + output = self.model(*input_ms) + outputs = [output.asnumpy()] + return outputs + + +class LiteModel(BaseModel): + def __init__(self, args) -> None: + super().__init__(args) + + def _init_model(self, model_name, config_path): + global mslite + import mindspore_lite as mslite + self.config_path = config_path + with open(self.config_path, "r") as f: + self.yaml_cfg = Dict(yaml.safe_load(f)) + self.ckpt_load_path = self.yaml_cfg.predict.ckpt_load_path + context = mslite.Context() + device_target = self.yaml_cfg.predict.get("device_target", "Ascend") + context.target = [device_target.lower()] + if device_target.lower() == "ascend": + context.ascend.device_id = self.yaml_cfg.predict.get("device_id", 0) + elif device_target.lower() == "gpu": + context.gpu.device_id = self.yaml_cfg.predict.get("device_id", 0) + else: + pass + self.model = mslite.Model() + self.model.build_from_file(self.ckpt_load_path, mslite.ModelType.MINDIR, context) + + def __call__(self, inputs: List[np.ndarray]) -> List[np.ndarray]: + model_inputs = self.model.get_inputs() + inputs_shape = [list(input.shape) for input in inputs] + self.model.resize(model_inputs, inputs_shape) + for i, input in enumerate(inputs): + model_inputs[i].set_data_from_numpy(input) + model_outputs = self.model.predict(model_inputs) + outputs = [output.get_data_to_numpy().copy() for output in model_outputs] + return outputs diff --git a/mindocr/losses/det_loss.py b/mindocr/losses/det_loss.py index 23ca8f4e2..cc97a3210 100644 --- a/mindocr/losses/det_loss.py +++ b/mindocr/losses/det_loss.py @@ -1,4 +1,5 @@ import logging +import os from math import pi from typing import Tuple, Union @@ -10,6 +11,8 @@ __all__ = ["DBLoss", "PSEDiceLoss", "EASTLoss", "FCELoss"] _logger = logging.getLogger(__name__) +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + class DBLoss(nn.LossBase): """ @@ -165,7 +168,13 @@ def construct(self, pred: Tensor, gt: Tensor, mask: Tensor) -> Tensor: neg_loss = (loss * negative).view(loss.shape[0], -1) neg_vals, _ = ops.sort(neg_loss) - neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1) + + if OFFLINE_MODE is None: + neg_index = ops.stack((mnp.arange(loss.shape[0]), neg_vals.shape[1] - neg_count), axis=1) + else: + neg_index = ops.stack( + (ops.arange(loss.shape[0], dtype=neg_count.dtype), neg_vals.shape[1] - neg_count), axis=1 + ) min_neg_score = ops.expand_dims(ops.gather_nd(neg_vals, neg_index), axis=1) neg_loss_mask = (neg_loss >= min_neg_score).astype(ms.float32) # filter values less than top k diff --git a/mindocr/losses/rec_loss.py b/mindocr/losses/rec_loss.py index 09ee8caec..88cb4e7f5 100644 --- a/mindocr/losses/rec_loss.py +++ b/mindocr/losses/rec_loss.py @@ -1,3 +1,5 @@ +import os + import numpy as np import mindspore as ms @@ -6,6 +8,8 @@ __all__ = ["CTCLoss", "AttentionLoss", "VisionLANLoss"] +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + class CTCLoss(LossBase): """ @@ -147,14 +151,21 @@ class AttentionLoss(LossBase): def __init__(self, reduction: str = "mean", ignore_index: int = 0) -> None: super().__init__() # ignore symbol, assume it is placed at 0th index - self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index) + if OFFLINE_MODE is None: + self.criterion = nn.CrossEntropyLoss(reduction=reduction, ignore_index=ignore_index) + else: + self.reduction = reduction + self.ignore_index = ignore_index def construct(self, logits: Tensor, labels: Tensor) -> Tensor: labels = labels[:, 1:] # without symbol num_classes = logits.shape[-1] logits = ops.reshape(logits, (-1, num_classes)) labels = ops.reshape(labels, (-1,)) - return self.criterion(logits, labels) + if OFFLINE_MODE is None: + return self.criterion(logits, labels) + else: + return ops.cross_entropy(logits, labels, reduction=self.reduction, ignore_index=self.ignore_index) class SARLoss(LossBase): diff --git a/mindocr/models/__init__.py b/mindocr/models/__init__.py index 19c21a153..6cd0f6a1a 100644 --- a/mindocr/models/__init__.py +++ b/mindocr/models/__init__.py @@ -14,6 +14,7 @@ from .rec_robustscanner import * from .rec_svtr import * from .rec_visionlan import * +from .layout_yolov8n import * __all__ = [] __all__.extend(builder.__all__) diff --git a/mindocr/models/layout_yolov8n.py b/mindocr/models/layout_yolov8n.py new file mode 100644 index 000000000..6faa21db4 --- /dev/null +++ b/mindocr/models/layout_yolov8n.py @@ -0,0 +1,56 @@ +from ._registry import register_model +from .backbones.mindcv_models.utils import load_pretrained +from .base_model import BaseModel + +__all__ = ['Yolov8n', 'layout_yolov8n'] + +def _cfg(url="", **kwargs): + return {"url": url, **kwargs} + + +default_cfgs = { + "layout_yolov8n": _cfg( + url="https://download.mindspore.cn/toolkits/mindocr/yolov8/yolov8n-4b9e8004.ckpt" + ), +} + +class Yolov8n(BaseModel): + def __init__(self, config): + BaseModel.__init__(self, config) + + +@register_model +def layout_yolov8n(pretrained=False, pretrained_backbone=True, **kwargs): + backbone_ckpt_url = 'https://download.mindspore.cn/toolkits/mindocr/yolov8/yolov8n-4b9e8004.ckpt' + model_config = { + "backbone": { + 'name': 'yolov8_backbone', + "depth_multiple": 0.33, + "width_multiple": 0.25, + "max_channels": 1024, + "nc": 5, + "stride": [ 8, 16, 32, 64 ], + "sync_bn": False, + "out_channels": [ 64, 128, 192, 256 ], + 'pretrained': backbone_ckpt_url if pretrained_backbone else False + }, + "neck": { + "name": 'YOLOv8Neck', + "index": [ 20, 23, 26, 29 ] + }, + "head": { + "name": 'YOLOv8Head', + "nc": 5, + "reg_max": 16, + "stride": [ 8, 16, 32, 64 ], + "sync_bn": False, + } + } + model = Yolov8n(model_config) + + # load pretrained weights + if pretrained: + default_cfg = default_cfgs['layout_yolov8n'] + load_pretrained(model, default_cfg) + + return model \ No newline at end of file diff --git a/mindocr/models/necks/fpn.py b/mindocr/models/necks/fpn.py index 650395554..32a628ce8 100644 --- a/mindocr/models/necks/fpn.py +++ b/mindocr/models/necks/fpn.py @@ -1,3 +1,4 @@ +import os from typing import List, Tuple from mindspore import Tensor, nn, ops @@ -7,14 +8,20 @@ from ..utils.attention_cells import SEModule from .asf import AdaptiveScaleFusion +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) -def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None): - if scale == 1 or shape == x.shape[2:]: - return x - if scale: - shape = (x.shape[2] * scale, x.shape[3] * scale) - return ops.ResizeNearestNeighbor(shape)(x) +if OFFLINE_MODE is None: + def _resize_nn(x: Tensor, scale: int = 0, shape: Tuple[int] = None): + if scale == 1 or shape == x.shape[2:]: + return x + + if scale: + shape = (x.shape[2] * scale, x.shape[3] * scale) + return ops.ResizeNearestNeighbor(shape)(x) +else: + def _resize_nn(x: Tensor, shape: Tensor): + return ops.ResizeNearestNeighborV2()(x, shape) class FPN(nn.Cell): @@ -64,11 +71,18 @@ def construct(self, features: List[Tensor]) -> Tensor: for i, uc_op in enumerate(self.unify_channels): features[i] = uc_op(features[i]) - for i in range(2, -1, -1): - features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:]) + if OFFLINE_MODE is None: + for i in range(2, -1, -1): + features[i] += _resize_nn(features[i + 1], shape=features[i].shape[2:]) + + for i, out in enumerate(self.out): + features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:]) + else: + for i in range(2, -1, -1): + features[i] += _resize_nn(features[i + 1], shape=ops.dyn_shape(features[i])[2:]) - for i, out in enumerate(self.out): - features[i] = _resize_nn(out(features[i]), shape=features[0].shape[2:]) + for i, out in enumerate(self.out): + features[i] = _resize_nn(out(features[i]), shape=ops.dyn_shape(features[0])[2:]) return self.fuse(features[::-1]) # matching the reverse order of the original work diff --git a/mindocr/models/transforms/tps_spatial_transformer.py b/mindocr/models/transforms/tps_spatial_transformer.py index 006f72420..9227ef895 100644 --- a/mindocr/models/transforms/tps_spatial_transformer.py +++ b/mindocr/models/transforms/tps_spatial_transformer.py @@ -1,4 +1,5 @@ import itertools +import os from typing import Optional, Tuple import numpy as np @@ -8,6 +9,8 @@ import mindspore.ops as ops from mindspore import Tensor +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + def grid_sample(input: Tensor, grid: Tensor, canvas: Optional[Tensor] = None) -> Tensor: out_type = input.dtype @@ -112,6 +115,9 @@ def __init__( self.target_coordinate_repr = Tensor(target_coordinate_repr, dtype=ms.float32) self.target_control_points = Tensor(target_control_points, dtype=ms.float32) + if OFFLINE_MODE is not None: + self.matmul = ops.BatchMatMul() + def construct( self, input: Tensor, source_control_points: Tensor ) -> Tuple[Tensor, Tensor]: @@ -119,8 +125,12 @@ def construct( padding_matrix = ops.tile(self.padding_matrix, (batch_size, 1, 1)) Y = ops.concat([source_control_points, padding_matrix], axis=1) - mapping_matrix = ops.matmul(self.inverse_kernel, Y) - source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix) + if OFFLINE_MODE is None: + mapping_matrix = ops.matmul(self.inverse_kernel, Y) + source_coordinate = ops.matmul(self.target_coordinate_repr, mapping_matrix) + else: + mapping_matrix = self.matmul(self.inverse_kernel[None, ...], Y) + source_coordinate = self.matmul(self.target_coordinate_repr[None, ...], mapping_matrix) grid = ops.reshape( source_coordinate, (-1, self.target_height, self.target_width, 2), diff --git a/mindocr/models/utils/attention_cells.py b/mindocr/models/utils/attention_cells.py index 016001085..b3f14dc24 100644 --- a/mindocr/models/utils/attention_cells.py +++ b/mindocr/models/utils/attention_cells.py @@ -1,3 +1,4 @@ +import os from typing import Optional, Tuple import numpy as np @@ -9,6 +10,8 @@ __all__ = ["MultiHeadAttention", "PositionwiseFeedForward", "PositionalEncoding", "SEModule"] +OFFLINE_MODE = os.getenv("OFFLINE_MODE", None) + class MultiHeadAttention(nn.Cell): def __init__( @@ -108,9 +111,14 @@ def __init__( self.pe = Tensor(pe, dtype=ms.float32) def construct(self, input_tensor: Tensor) -> Tensor: - input_tensor = ( - input_tensor + self.pe[:, : input_tensor.shape[1]] - ) # pe 1 5000 512 + if OFFLINE_MODE is None: + input_tensor = ( + input_tensor + self.pe[:, : input_tensor.shape[1]] + ) # pe 1 5000 512 + else: + input_tensor = ( + input_tensor + self.pe[:, : ops.dyn_shape(input_tensor)[1]] + ) # pe 1 5000 512 return self.dropout(input_tensor) diff --git a/mindocr/postprocess/det_db_postprocess.py b/mindocr/postprocess/det_db_postprocess.py index ccac15110..f49c9ef8b 100644 --- a/mindocr/postprocess/det_db_postprocess.py +++ b/mindocr/postprocess/det_db_postprocess.py @@ -5,11 +5,14 @@ from shapely.geometry import Polygon from mindspore import Tensor +import pyclipper +import logging +_logger = logging.getLogger(__name__) from ..data.transforms.det_transforms import expand_poly from .det_base_postprocess import DetBasePostprocess -__all__ = ["DBPostprocess"] +__all__ = ["DBPostprocess", "DBV4Postprocess"] class DBPostprocess(DetBasePostprocess): @@ -177,3 +180,383 @@ def _calc_score(pred, mask, contour): pred[min_vals[1] : max_vals[1] + 1, min_vals[0] : max_vals[0] + 1], mask[min_vals[1] : max_vals[1] + 1, min_vals[0] : max_vals[0] + 1].astype(np.uint8), )[0] + + +class DBV4Postprocess(DetBasePostprocess): + """ + The post process for DBNet, adapted to paddleocrV4. + """ + + def __init__( + self, + binary_thresh: float = 0.3, + box_thresh: float = 0.7, + max_candidates: int = 1000, + expand_ratio: float = 1.5, + box_type: str = "quad", + pred_name: str = "binary", + rescale_fields: List[str] = ["polys"], + if_merge_longedge_bbox: bool = True, + merge_inter_area_thres: int = 300, + merge_ratio: float = 1.3, + merge_angle_theta: float = 10, + if_sort_bbox: bool = True, + sort_bbox_y_delta: int = 10, + ): + super().__init__(rescale_fields, box_type) + + self._min_size = 3 + self._binary_thresh = binary_thresh + self._box_thresh = box_thresh + self._max_candidates = max_candidates + self._expand_ratio = expand_ratio + self._if_merge_longedge_bbox = if_merge_longedge_bbox + self._merge_inter_area_thres = merge_inter_area_thres + self._merge_ratio = merge_ratio + self._merge_angle_theta = merge_angle_theta + self._if_sort_bbox = if_sort_bbox + self._sort_bbox_y_delta = sort_bbox_y_delta + self._out_poly = box_type == "poly" + self._name = pred_name + self._names = {"binary": 0, "thresh": 1, "thresh_binary": 2} + + def __call__( + self, + pred: Union[Tensor, Tuple[Tensor], np.ndarray], + shape_list: Union[np.ndarray, Tensor] = None, + **kwargs, + ) -> dict: + if isinstance(shape_list, Tensor): + shape_list = shape_list.asnumpy() + if shape_list is not None: + assert shape_list.shape[0] and shape_list.shape[1] == 4, ( + "The shape of each item in shape_list must be 4: [raw_img_h, raw_img_w, scale_h, scale_w]. " + f"But got shape_list of shape {shape_list.shape}" + ) + else: + _logger.warning( + "`shape_list` is None in postprocessing. Cannot rescale the prediction result to original " + "image space, which can lead to inaccurate evaluation. You may add `shape_list` to `output_columns` " + "list under eval section in yaml config file, or directly parse `shape_list` to postprocess callable " + "function." + ) + self.warned = True + result = self._postprocess(pred, shape_list=shape_list) + src_w, src_h = shape_list[0, 1], shape_list[0, 0] + polys = self.filter_tag_det_res(result["polys"][0], [src_h, src_w]) + if self._if_merge_longedge_bbox: + try: + polys = longedge_bbox_merge( + polys, self._merge_inter_area_thres, self._merge_ratio, self._merge_angle_theta + ) + except Exception as e: + _logger.warning(f"long edge bbox merge failed: {e}") + if self._if_sort_bbox: + polys = sorted_boxes(polys, self._sort_bbox_y_delta) + result["polys"][0] = polys + result["scores"].clear() + return result + + def _postprocess(self, pred: Union[Tensor, Tuple[Tensor], np.ndarray], **kwargs) -> dict: + """ + Postprocess network prediction to get text boxes on the transformed image space (which will be rescaled back to + original image space in __call__ function) + + Args: + pred (Union[Tensor, Tuple[Tensor], np.ndarray]): network prediction consists of + binary: text region segmentation map, with shape (N, 1, H, W) + thresh: [if exists] threshold prediction with shape (N, 1, H, W) (optional) + thresh_binary: [if exists] binarized with threshold, (N, 1, H, W) (optional) + + Returns: + postprocessing result as a dict with keys: + - polys (List[np.ndarray]): predicted polygons on the **transformed** (i.e. resized normally) image + space, of shape (batch_size, num_polygons, num_points, 2). If `box_type` is 'quad', num_points=4. + - scores (np.ndarray): confidence scores for the predicted polygons, shape (batch_size, num_polygons) + """ + if isinstance(pred, tuple): + pred = pred[self._names[self._name]] + if isinstance(pred, Tensor): + pred = pred.asnumpy() + if len(pred.shape) == 4 and pred.shape[1] != 1: # pred shape (N, 3, H, W) + pred = pred[:, 0, :, :] + + if len(pred.shape) == 4: # handle pred shape: (N, H, W) skip + pred = pred.squeeze(1) + + segmentation = pred >= self._binary_thresh + polys, scores = [], [] + src_w, src_h = kwargs["shape_list"][0, 1], kwargs["shape_list"][0, 0] + for pr, segm in zip(pred, segmentation): + poly, score = self.boxes_from_bitmap(pr, segm, src_w, src_h) + polys.append(poly) + scores.append(score) + return {"polys": polys, "scores": scores} + + def unclip(self, box, _expand_ratio): + poly = Polygon(box) + distance = poly.area * _expand_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + def get_mini_boxes(self, contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [points[index_1], points[index_2], points[index_3], points[index_4]] + return box, min(bounding_box[1]) + + def box_score_fast(self, bitmap, _box): + """ + box_score_fast: use bbox mean score as the mean score + """ + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1) + return cv2.mean(bitmap[ymin : ymax + 1, xmin : xmax + 1], mask)[0] + + def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): + """ + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + + bitmap = _bitmap + height, width = bitmap.shape + + outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + if len(outs) == 3: + contours = outs[1] + elif len(outs) == 2: + contours = outs[0] + + num_contours = min(len(contours), self._max_candidates) + + boxes = [] + scores = [] + for index in range(num_contours): + contour = contours[index] + points, sside = self.get_mini_boxes(contour) + if sside < self._min_size: + continue + points = np.array(points) + + score = self.box_score_fast(pred, points.reshape(-1, 2)) + if self._box_thresh > score: + continue + + box = self.unclip(points, self._expand_ratio).reshape(-1, 1, 2) + box, sside = self.get_mini_boxes(box) + if sside < self._min_size + 2: + continue + box = np.array(box) + + box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip(np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.astype("int32")) + scores.append(score) + return np.array(boxes, dtype="int32"), scores + + def filter_tag_det_res(self, dt_boxes, image_shape): + img_height, img_width = image_shape[0:2] + dt_boxes_new = [] + for box in dt_boxes: + if type(box) is list: + box = np.array(box) + box = self.order_points_clockwise(box) + box = self.clip_det_res(box, img_height, img_width) + rect_width = int(np.linalg.norm(box[0] - box[1])) + rect_height = int(np.linalg.norm(box[0] - box[3])) + if rect_width <= 3 or rect_height <= 3: + continue + dt_boxes_new.append(box) + dt_boxes = np.array(dt_boxes_new) + return dt_boxes + + def order_points_clockwise(self, pts): + rect = np.zeros((4, 2), dtype="float32") + s = pts.sum(axis=1) + rect[0] = pts[np.argmin(s)] + rect[2] = pts[np.argmax(s)] + tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0) + diff = np.diff(np.array(tmp), axis=1) + rect[1] = tmp[np.argmin(diff)] + rect[3] = tmp[np.argmax(diff)] + return rect + + def clip_det_res(self, points, img_height, img_width): + for pno in range(points.shape[0]): + points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1)) + points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1)) + return points + + +def longedge_bbox_merge(boxes, merge_inter_area_thres=300, merge_ratio=1.3, merge_angle_theta=10): + """ + Merge long-edge bboxes according the following rule: + - inter area larger than `merge_inter_area_thres` + - delta of long edge slope of minimum outer rectangle larger than `merge_angle_theta` + - short edge of merged boxes smaller than `merge_ratio` times short edge of boxes + args: + boxes(array): boxes to be merge, shape: (N, 4, 2). N: Number of bboxes + return: + merged boxes(array): merged boxes, shape: (N2, 4, 2). N2: Number of merged bboxes + """ + ori_boxes = [box.tolist() for box in boxes] + ori_poly = [Polygon(box) for box in ori_boxes] + minrec_poly = [poly.minimum_rotated_rectangle for poly in ori_poly] + + merge_list = [] + merge_minrec = [] + check_merge = False + + while not check_merge or len(merge_list) > 0: + merge_list.clear() + merge_minrec.clear() + + for i in range(len(ori_boxes)): + for j in range(i + 1, len(ori_boxes)): + # inter area judgement + inter_area = ori_poly[i].intersection(ori_poly[j]).area + uij = ori_poly[i].union(ori_poly[j]) + if inter_area < merge_inter_area_thres: + continue + minrec_i_theta, minrec_i_short_len = 0, 0 + minrec_i_xs, minrec_i_ys = minrec_poly[i].exterior.coords.xy + minrec_i_edge1_len = np.sqrt( + (minrec_i_xs[1] - minrec_i_xs[0]) ** 2 + (minrec_i_ys[1] - minrec_i_ys[0]) ** 2 + ) + minrec_i_edge1_theta = np.arctan( + (minrec_i_ys[1] - minrec_i_ys[0]) / (minrec_i_xs[1] - minrec_i_xs[0] + 1e-5) + ) + minrec_i_edge2_len = np.sqrt( + (minrec_i_xs[2] - minrec_i_xs[1]) ** 2 + (minrec_i_ys[2] - minrec_i_ys[1]) ** 2 + ) + minrec_i_edge2_theta = np.arctan( + (minrec_i_ys[2] - minrec_i_ys[1]) / (minrec_i_xs[2] - minrec_i_xs[1] + 1e-5) + ) + + if minrec_i_edge2_len > minrec_i_edge1_len: + minrec_i_theta = minrec_i_edge2_theta + minrec_i_short_len = minrec_i_edge1_len + else: + minrec_i_theta = minrec_i_edge1_theta + minrec_i_short_len = minrec_i_edge2_len + + minrec_j_theta, minrec_j_short_len = 0, 0 + minrec_j_xs, minrec_j_ys = minrec_poly[j].exterior.coords.xy + minrec_j_edge1_len = np.sqrt( + (minrec_j_xs[1] - minrec_j_xs[0]) ** 2 + (minrec_j_ys[1] - minrec_j_ys[0]) ** 2 + ) + minrec_j_edge1_theta = np.arctan( + (minrec_j_ys[1] - minrec_j_ys[0]) / (minrec_j_xs[1] - minrec_j_xs[0] + 1e-5) + ) + minrec_j_edge2_len = np.sqrt( + (minrec_j_xs[2] - minrec_j_xs[1]) ** 2 + (minrec_j_ys[2] - minrec_j_ys[1]) ** 2 + ) + minrec_j_edge2_theta = np.arctan( + (minrec_j_ys[2] - minrec_j_ys[1]) / (minrec_j_xs[2] - minrec_j_xs[1] + 1e-5) + ) + + if minrec_j_edge2_len > minrec_j_edge1_len: + minrec_j_theta = minrec_j_edge2_theta + minrec_j_short_len = minrec_j_edge1_len + else: + minrec_j_theta = minrec_j_edge1_theta + minrec_j_short_len = minrec_j_edge2_len + + # slope judgement + if np.abs(minrec_j_theta - minrec_i_theta) > merge_angle_theta / 180 * np.pi: + continue + + # short edge judgement + minrec_u = uij.minimum_rotated_rectangle + minrec_u_xs, minrec_u_ys = minrec_u.exterior.coords.xy + minrec_u_edge1_len = np.sqrt( + (minrec_u_xs[1] - minrec_u_xs[0]) ** 2 + (minrec_u_ys[1] - minrec_u_ys[0]) ** 2 + ) + minrec_u_edge2_len = np.sqrt( + (minrec_u_xs[2] - minrec_u_xs[1]) ** 2 + (minrec_u_ys[2] - minrec_u_ys[1]) ** 2 + ) + minrec_u_short_len = min(minrec_u_edge1_len, minrec_u_edge2_len) + if minrec_u_short_len > merge_ratio * max(minrec_i_short_len, minrec_j_short_len): + continue + + merge_list.append([i, j]) + merge_minrec.append(minrec_u) + + if len(merge_minrec) > 0: + ori_boxes = [ori_boxes[i] for i in range(len(ori_boxes)) if i not in merge_list[0]] + ori_poly = [ori_poly[i] for i in range(len(ori_poly)) if i not in merge_list[0]] + minrec_poly = [minrec_poly[i] for i in range(len(minrec_poly)) if i not in merge_list[0]] + + poly = merge_minrec[0] + xs, ys = poly.exterior.coords.xy + xs = xs.tolist() + ys = ys.tolist() + + index = np.argsort(np.linalg.norm(np.array([xs[:-1], ys[:-1]]).T, ord=2, axis=1))[0] + ori_boxes.append( + [ + [xs[index % 4], ys[index % 4]], + [xs[(index + 1) % 4], ys[(index + 1) % 4]], + [xs[(index + 2) % 4], ys[(index + 2) % 4]], + [xs[(index + 3) % 4], ys[(index + 3) % 4]], + ] + ) + + ori_poly.append(Polygon(ori_boxes[-1])) + minrec_poly.append(ori_poly[-1].minimum_rotated_rectangle) + + check_merge = True + return np.array(ori_boxes) + + +def sorted_boxes(dt_boxes, sort_bbox_y_delta): + """ + Sort text boxes in order from top to bottom, left to right + args: + dt_boxes(array):detected text boxes with shape [4, 2] + sort_bbox_y_delta:further sort boxes whose dy smaller than sort_bbox_y_delta + return: + sorted boxes(array) with shape [4, 2] + """ + num_boxes = len(dt_boxes) + sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0])) + _boxes = list(sorted_boxes) + + for i in range(num_boxes - 1): + for j in range(i, -1, -1): + if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < sort_bbox_y_delta and ( + _boxes[j + 1][0][0] < _boxes[j][0][0] + ): + tmp = _boxes[j] + _boxes[j] = _boxes[j + 1] + _boxes[j + 1] = tmp + else: + break + return _boxes diff --git a/mindocr/postprocess/layout_postprocess.py b/mindocr/postprocess/layout_postprocess.py index 6c219a285..5d1d7d1e2 100644 --- a/mindocr/postprocess/layout_postprocess.py +++ b/mindocr/postprocess/layout_postprocess.py @@ -259,7 +259,7 @@ def scale_coords(img1_shape, coords, img0_shape, ratio=None, pad=None): if ratio is None: # calculate from img0_shape ratio = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # ratio = old / new - else: + if isinstance(ratio, (list, np.ndarray)): ratio = ratio[0] if pad is None: diff --git a/pipeline/__init__.py b/pipeline/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pipeline/data_process/utils/cv_utils.py b/pipeline/data_process/utils/cv_utils.py new file mode 100644 index 000000000..54dec7729 --- /dev/null +++ b/pipeline/data_process/utils/cv_utils.py @@ -0,0 +1,72 @@ +import os +from typing import List, Tuple + +import cv2 +import numpy as np + + +def get_hw_of_img(image: np.ndarray): + """ + get the hw of hwc image + """ + if len(image.shape) == 3: + # gbr/rgb + height, width, _ = image.shape + elif len(image.shape) == 2: + # gray + height, width = image.shape + else: + raise TypeError("image is not a image of color/gray") + + return height, width + + +def get_batch_hw_of_img(images: List[np.ndarray]) -> Tuple: + return tuple(get_hw_of_img(img) for img in images) + + +def crop_box_from_image(image, box): + if box.shape != (4, 2): + raise ValueError("shape of crop box must be 4*2") + box = box.astype(np.float32) + img_crop_width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3]))) + img_crop_height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]]) + m = cv2.getPerspectiveTransform(box, pts_std) + dst_img = cv2.warpPerspective( + image, m, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_width != 0 and dst_img_height / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + + return dst_img + + +def img_read(path: str): + """ + Read a BGR image. + """ + img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_COLOR) + + if img is None: + raise ValueError(f"Error! Cannot load the image of {path}") + + return img + + +def img_write(path: str, img: np.ndarray): + filename = os.path.abspath(path) + cv2.imencode(os.path.splitext(filename)[1], img)[1].tofile(filename) + + +def check_type_in_container(input_data, t, skip_last=False): + if skip_last: + check_data = input_data[:-1] + else: + check_data = input_data + for data in check_data: + if not isinstance(data, t): + return False + else: + return True diff --git a/pipeline/datatype/__init__.py b/pipeline/datatype/__init__.py new file mode 100644 index 000000000..f8469deb2 --- /dev/null +++ b/pipeline/datatype/__init__.py @@ -0,0 +1,3 @@ +from .message_data import ProfilingData, StopSign +from .module_data import ModuleConnectDesc, ModuleDesc, ModuleInitArgs +from .process_data import ProcessData, StopData diff --git a/pipeline/datatype/message_data.py b/pipeline/datatype/message_data.py new file mode 100644 index 000000000..d3a665269 --- /dev/null +++ b/pipeline/datatype/message_data.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +@dataclass +class StopSign: + stop: bool = True + + +@dataclass +class ProfilingData: + module_name: str = "" + instance_id: int = "" + device_id: int = 0 + process_cost_time: float = 0.0 + send_cost_time: float = 0.0 + image_total: int = -1 diff --git a/pipeline/datatype/module_data.py b/pipeline/datatype/module_data.py new file mode 100644 index 000000000..d5c82207b --- /dev/null +++ b/pipeline/datatype/module_data.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass, field +from enum import Enum + + +class ConnectType(Enum): + MODULE_CONNECT_ONE = 0 + MODULE_CONNECT_CHANNEL = 1 + MODULE_CONNECT_PAIR = 2 + MODULE_CONNECT_RANDOM = 3 + + +@dataclass +class ModuleOutputInfo: + module_name: str + connect_type: ConnectType + output_queue_list_size: int + output_queue_list: list = field(default_factory=lambda: []) + + +@dataclass +class ModuleInitArgs: + pipeline_name: str + module_type: str + module_name: str + instance_id: -1 + + +@dataclass +class ModuleDesc: + module_type: str # 节点类型,如HandoutNode + module_name: str # 节点名,如1,该节点唯一标识为 {module_type}{model_name} + module_count: int + + +@dataclass +class ModuleConnectDesc: + module_send_name: str + module_recv_name: str + connect_type: ConnectType = field(default_factory=lambda: ConnectType.MODULE_CONNECT_RANDOM) + + +@dataclass +class ModulesInfo: + module_list: list = field(default_factory=lambda: []) + input_queue_list: list = field(default_factory=lambda: []) diff --git a/pipeline/datatype/process_data.py b/pipeline/datatype/process_data.py new file mode 100644 index 000000000..b615572cf --- /dev/null +++ b/pipeline/datatype/process_data.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass, field +from typing import Dict, List, Union + +import numpy as np + + +@dataclass +class ProcessData: + # skip each compute node + skip: bool = False + # prediction results of each image + infer_result: list = field(default_factory=lambda: []) + + # image basic info + image_path: List[str] = field(default_factory=lambda: []) + frame: List[np.ndarray] = field(default_factory=lambda: []) + + # sub image of detection box, for det (+ cls) + rec + sub_image_total: int = 0 # len(sub_image_list_0) + len(sub_image_list_1) + ... + sub_image_list: list = field(default_factory=lambda: []) + sub_image_size: int = 0 # len of sub_image_list + + # data for preprocess -> infer -> postprocess + data: Union[np.ndarray, List[np.ndarray], Dict] = None + + # confidence of the result from rec + score: float = field(default_factory=lambda: []) + + # the images fed into the ocr system in the same call, share the same taskid + taskid: int = 0 + + # number of images shared the same taskid + task_images_num: int = 0 + + # data type: raw input is string path or np.ndarray. 0: string path, 1: np.ndarray + data_type: int = 0 + + +@dataclass +class StopData: + skip: bool = True + image_total: int = 0 + exception: bool = False diff --git a/pipeline/framework/__init__.py b/pipeline/framework/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pipeline/framework/module_base.py b/pipeline/framework/module_base.py new file mode 100644 index 000000000..5243589b4 --- /dev/null +++ b/pipeline/framework/module_base.py @@ -0,0 +1,135 @@ +import os +import tqdm +import sys +import time +from abc import abstractmethod +from ctypes import c_longdouble +from multiprocessing import Manager + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.datatype import ModuleInitArgs, ProfilingData +from pipeline.datatype import StopData, StopSign +from pipeline.utils import log, safe_div + + +class ModuleBase(object): + def __init__(self, args, msg_queue, tqdm_info): + self.args = args + self.pipeline_name = "" + self.module_name = "" + self.without_input_queue = False + self.instance_id = 0 + self.is_stop = False + self.msg_queue = msg_queue + self.input_queue = None + self.output_queue = None + self.send_cost = Manager().Value(typecode=c_longdouble, value=0) + self.process_cost = Manager().Value(typecode=c_longdouble, value=0) + self.display_id = tqdm_info["i"] + if self.args.visual_pipeline is True: + self.bar = tqdm.tqdm(total=tqdm_info["queue_len"], + desc=f"{self.display_id}. {self.module_name}", + position=self.display_id, + leave=False, + bar_format="{l_bar}{bar}|{n_fmt}/{total_fmt}", + ncols=150) + + def assign_init_args(self, init_args: ModuleInitArgs): + self.pipeline_name = init_args.pipeline_name + self.module_name = init_args.module_name + self.instance_id = init_args.instance_id + + def process_handler(self, stop_manager, module_params, input_queue, output_queue): + self.input_queue = input_queue + self.output_queue = output_queue + self.stop_manager = stop_manager + self.queue_num = 0 + + try: + params = self.init_self_args() + if params: + module_params.update(**params) + except Exception as error: + log.error(f"{self.__class__.__name__} init failed: {error}") + raise error + + # waiting for init sign + while not self.msg_queue.full(): + continue + + # waiting for the release of stop sign + while self.stop_manager.value: + continue + + process_num = 0 + + while True: + time.sleep(self.args.node_fetch_interval) + if self.stop_manager.value: + break + if self.input_queue.empty(): + continue + + process_num += 1 + data = self.input_queue.get(block=True) + if self.args.visual_pipeline is True: + qsize = self.input_queue.qsize() + delta = qsize - self.queue_num + self.bar.update(delta) + self.queue_num = qsize + info = f"{self.display_id}. Node:{self.module_name}, Has Processed:{process_num}, " + \ + f"Process Time:{self.process_cost.value - self.send_cost.value:.2f} s, " + \ + f"Wait Time:{self.send_cost.value:.2f} s, Queue Status:" + info = info.ljust(85, " ") + self.bar.set_description(info) + self.call_process(data) + if self.args.visual_pipeline is True: + self.bar.close() + + def call_process(self, send_data=None): + if send_data is not None or self.without_input_queue: + start_time = time.time() + try: + self.process(send_data) + except Exception as error: + self.process(StopData(exception=True)) + image_path = [os.path.basename(filename) for filename in send_data.image_path] + log.exception(f"ERROR occurred in {self.module_name} module for {', '.join(image_path)}: {error}.") + + cost_time = time.time() - start_time + self.process_cost.value += cost_time + + @abstractmethod + def process(self, input_data): + pass + + @abstractmethod + def init_self_args(self): + self.msg_queue.put(f"{self.__class__.__name__} instance id {self.instance_id} init complete") + log.info(f"{self.__class__.__name__} instance id {self.instance_id} init complete") + + def send_to_next_module(self, output_data): + if self.is_stop: + return + start_time = time.time() + self.output_queue.put(output_data, block=True) + cost_time = time.time() - start_time + self.send_cost.value += cost_time + + def get_module_name(self): + return self.module_name + + def get_instance_id(self): + return self.instance_id + + def stop(self): + profiling_data = ProfilingData( + module_name=self.module_name, + instance_id=self.instance_id, + process_cost_time=self.process_cost.value, + send_cost_time=self.send_cost.value, + ) + self.msg_queue.put(profiling_data, block=False) + self.is_stop = True diff --git a/pipeline/framework/module_manager.py b/pipeline/framework/module_manager.py new file mode 100644 index 000000000..358667677 --- /dev/null +++ b/pipeline/framework/module_manager.py @@ -0,0 +1,163 @@ +import os +import tqdm +import sys + +from collections import defaultdict, namedtuple +from ctypes import c_bool +from multiprocessing import Manager, Process, Queue + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.datatype.module_data import ModuleInitArgs, ModulesInfo +from pipeline.utils import log +from mindocr.infer.node_config import processor_initiator + +OutputRegisterInfo = namedtuple("OutputRegisterInfo", ["pipeline_name", "module_send", "module_recv"]) + + +class ModuleManager: + MODULE_QUEUE_MAX_SIZE = 16 + + def __init__(self, msg_queue: Queue, task_queue: Queue, result_queue: Queue, args): + self.pipeline_map = defaultdict(lambda: defaultdict(ModulesInfo)) + self.msg_queue = msg_queue + self.stop_manager = Manager().Value(c_bool, True) + self.args = args + self.pipeline_name = "" + self.process_list = [] + self.queue_list = [] + self.pipeline_queue_map = defaultdict(lambda: defaultdict(list)) + self.task_queue = task_queue # input_queue for HandoutNode + self.result_queue = result_queue # output_queue for CollectNode + self.module_params = Manager().dict() + + @staticmethod + def stop_module(module): + module.stop() + + @staticmethod + def init_module_instance(module_instance, instance_id, pipeline_name, module_type, module_name): + init_args = ModuleInitArgs(pipeline_name=pipeline_name, + module_name=module_name, + module_type=module_type, + instance_id=instance_id) + module_instance.assign_init_args(init_args) + + def register_modules(self, pipeline_name: str, module_desc_list: list, default_count: int): + log.info("----------------------------------------------------") + log.info("---------------register_modules start---------------") + modules_info_dict = self.pipeline_map[pipeline_name] + self.pipeline_name = pipeline_name + + for i, module_desc in enumerate(module_desc_list): + log.info("+++++++++++++++++++++++++++++++++++++") + log.info(module_desc) + log.info("+++++++++++++++++++++++++++++++++++++") + module_count = default_count if module_desc.module_count == -1 else module_desc.module_count + module_info = ModulesInfo() + for instance_id in range(module_count): + if i == 0: # HandoutNode + tqdm_info = {"i": i, "queue_len": self.task_queue._maxsize} + else: + tqdm_info = {"i": i, "queue_len": self.MODULE_QUEUE_MAX_SIZE} + module_instance = processor_initiator(module_desc.module_type)(self.args, self.msg_queue, tqdm_info) + self.init_module_instance(module_instance, + instance_id, + pipeline_name, + module_desc.module_type, + module_desc.module_name) + + module_info.module_list.append(module_instance) + modules_info_dict[module_desc.module_name] = module_info + + self.pipeline_map[pipeline_name] = modules_info_dict + + log.info("----------------register_modules end---------------") + log.info("----------------------------------------------------") + + def register_module_connects(self, pipeline_name: str, connect_desc_list: list): + if pipeline_name not in self.pipeline_map: + return + + log.info("----------------------------------------------------") + log.info("-----------register_module_connects start-----------") + + modules_info_dict = self.pipeline_map[pipeline_name] + connect_info_dict = self.pipeline_queue_map[pipeline_name] + last_module = None + for connect_desc in connect_desc_list: + send_name = connect_desc.module_send_name + recv_name = connect_desc.module_recv_name + log.info("+++++++++++++++++++++++++++++++++++++") + log.info(f"Add Connection Between {send_name} And {recv_name}") + log.info("+++++++++++++++++++++++++++++++++++++") + + if send_name not in modules_info_dict: + raise ValueError(f"cannot find send module {send_name}") + + if recv_name not in modules_info_dict: + raise ValueError(f"cannot find receive module {recv_name}") + + queue = Queue(self.MODULE_QUEUE_MAX_SIZE) + connect_info_dict[send_name].append(queue) + connect_info_dict[recv_name].append(queue) + last_module = recv_name + connect_info_dict[last_module].append(self.result_queue) + + log.info("------------register_module_connects end------------") + log.info("----------------------------------------------------") + + def run_pipeline(self): + log.info("-------------- start pipeline-----------------------") + log.info("----------------------------------------------------") + + for pipeline_name in self.pipeline_map.keys(): + modules_info_dict = self.pipeline_map[pipeline_name] + connect_info_dict = self.pipeline_queue_map[pipeline_name] + for module_name in modules_info_dict.keys(): + queue_list = connect_info_dict[module_name] + if len(queue_list) == 1: + input_queue = self.task_queue + output_queue = queue_list[0] + else: + input_queue = queue_list[0] + output_queue = queue_list[1] + + for module in modules_info_dict[module_name].module_list: + self.process_list.append( + Process( + target=module.process_handler, + args=(self.stop_manager, self.module_params, input_queue, output_queue), + daemon=True, + ) + ) + + for process in self.process_list: + process.start() + + def deinit_pipeline_module(self): + # the empty() is not reliable, double check the msg queue is empty for receive the profiling data + while not self.msg_queue.empty(): + self.msg_queue.get() + + for queue in self.queue_list: + while not queue.empty(): + queue.get(block=False) + queue.close() + queue.join_thread() + + # send the profiling data + for pipeline_name in self.pipeline_map.keys(): + modules_info_dict = self.pipeline_map[pipeline_name] + for module_name in modules_info_dict.keys(): + for module in modules_info_dict[module_name].module_list: + self.stop_module(module=module) + + # release all resource + for process in self.process_list: + if process.is_alive(): + process.kill() + + log.info("------------------pipeline stopped------------------") + log.info("----------------------------------------------------") diff --git a/pipeline/framework/pipeline_manager.py b/pipeline/framework/pipeline_manager.py new file mode 100644 index 000000000..414700d67 --- /dev/null +++ b/pipeline/framework/pipeline_manager.py @@ -0,0 +1,153 @@ +import argparse +import os +import time +import sys +from collections import defaultdict +from multiprocessing import Manager, Process, Queue + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../../"))) + +from pipeline.utils import log, safe_div +from pipeline.datatype import ModuleConnectDesc, ModuleDesc +from pipeline.datatype import StopData, StopSign +from pipeline.framework.module_manager import ModuleManager +from pipeline.tasks import SUPPORTED_TASK_BASIC_MODULE, TaskType +# ZHQ TODO +from mindocr.infer.node_config import MODEL_DICT_v2 as MODEL_DICT + + +class ParallelPipelineManager: + TASK_QUEUE_SIZE = 32 + + def __init__(self, args: argparse.Namespace): + self.args = args + self.input_queue = Queue(self.TASK_QUEUE_SIZE) + self.result_queue = Queue(self.TASK_QUEUE_SIZE) + self.process = Process(target=self._build_pipeline_kernel) + self.module_params = Manager().dict() + + def start_pipeline(self): + self.process.start() + self.input_queue.get(block=True) + + def stop_pipeline(self): + self.input_queue.put(StopSign(), block=True) + self.process.join() + self.process.close() + + def fetch_result(self): + if not self.result_queue.empty(): + rst_data = self.result_queue.get(block=True) + else: + rst_data = None + return rst_data + + def pipeline_graph(self, task_type): + module_order = SUPPORTED_TASK_BASIC_MODULE[TaskType(task_type.value)] + module_desc_names_set = set() + module_desc_list = [] + module_connect_desc_list = [] + + for model_name in module_order: + model_name = model_name + for edge in MODEL_DICT.get(model_name, []): + # Add Node + src_node_info, tgt_node_info = edge + src_node_name = src_node_info[0] + src_node_info[1] + if src_node_name not in module_desc_names_set: + module_desc_list.append(ModuleDesc(src_node_info[0], src_node_name, src_node_info[2])) + module_desc_names_set.add(src_node_name) + tgt_node_name = tgt_node_info[0] + tgt_node_info[1] + if tgt_node_name not in module_desc_names_set: + module_desc_list.append(ModuleDesc(tgt_node_info[0], tgt_node_name, tgt_node_info[2])) + module_desc_names_set.add(tgt_node_name) + module_connect_desc_list.append( + ModuleConnectDesc(src_node_name, tgt_node_name) + ) + module_size = sum(desc.module_count for desc in module_desc_list) + log.info(f"module_size: {module_size}") + return module_order, module_size, module_desc_list, module_connect_desc_list + + + def _build_pipeline_kernel(self): + """ + build and register pipeline + """ + task_type = self.args.task_type + + module_order, module_size, module_desc_list, module_connect_desc_list = self.pipeline_graph(task_type) + + msg_queue = Queue(module_size) + + manager = ModuleManager(msg_queue, self.input_queue, self.result_queue, self.args) + manager.register_modules(str(os.getpid()), module_desc_list, 1) + manager.register_module_connects(str(os.getpid()), module_connect_desc_list) + + # start the pipeline, init start + manager.run_pipeline() + + # waiting for task receive + while not msg_queue.full() or len(manager.module_params) != len(module_order): + time.sleep(0.1) + continue + + for _ in range(module_size): + msg_queue.get() + + self.module_params.update(**manager.module_params) + + # send sign for blocking input queue + self.input_queue.put(StopSign(), block=True) + + manager.stop_manager.value = False + + start_time = time.time() + + while not manager.stop_manager.value: + time.sleep(self.args.node_fetch_interval) + + cost_time = time.time() - start_time + + manager.deinit_pipeline_module() + # collect the profiling data + profiling_data = defaultdict(lambda: [0, 0]) + image_total = 0 + for _ in range(module_size): + msg_info = msg_queue.get() + profiling_data[msg_info.module_name][0] += msg_info.process_cost_time + profiling_data[msg_info.module_name][1] += msg_info.send_cost_time + if msg_info.module_name != -1: + image_total = msg_info.image_total + if image_total > 0: + self.profiling(profiling_data, image_total) + perf_info = ( + f"Number of images: {image_total}, " + f"total cost {cost_time:.2f}s, FPS: " + f"{safe_div(image_total, cost_time):.2f}" + ) + print(perf_info) + log.info(perf_info) + + msg_queue.close() + msg_queue.join_thread() + + def profiling(self, profiling_data, image_total): + e2e_cost_time_per_image = 0 + for module_name in profiling_data: + data = profiling_data[module_name] + total_time = data[0] + process_time = data[0] - data[1] + send_time = data[1] + process_avg = safe_div(process_time * 1000, image_total) + e2e_cost_time_per_image += process_avg + log.info( + f"{module_name} cost total {total_time:.2f} s, process avg cost {process_avg:.2f} ms, " + f"send waiting time avg cost {safe_div(send_time * 1000, image_total):.2f} ms" + ) + log.info("----------------------------------------------------") + log.info(f"e2e cost time per image {e2e_cost_time_per_image}ms") + + def __del__(self): + if hasattr(self, "process") and self.process: + self.process.close() diff --git a/pipeline/infer.py b/pipeline/infer.py new file mode 100644 index 000000000..029d0951c --- /dev/null +++ b/pipeline/infer.py @@ -0,0 +1,20 @@ +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../"))) + +from pipeline import infer_args # noqa +from pipeline.parallel_pipeline import ParallelPipeline # noqa + + +def main(): + args = infer_args.get_args() + parallel_pipeline = ParallelPipeline(args) + parallel_pipeline.start_pipeline() + parallel_pipeline.infer_for_images(args.input_images_dir, task_id=0) + parallel_pipeline.stop_pipeline() + + +if __name__ == "__main__": + main() diff --git a/pipeline/infer_args.py b/pipeline/infer_args.py new file mode 100644 index 000000000..ac87b2f9f --- /dev/null +++ b/pipeline/infer_args.py @@ -0,0 +1,290 @@ +import argparse +import itertools +import os +import sys +import yaml + +from addict import Dict + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../"))) + +from pipeline.tasks import TaskType +from pipeline.utils import get_config_by_name_for_model, log, save_path_init + + +def str2bool(v): + return v.lower() in ("true", "t", "1") + + +def get_args(): + """ + command line parameters for inference + """ + parser = argparse.ArgumentParser(description="Arguments for inference.") + parser.add_argument( + "--input_images_dir", + type=str, + required=True, + help="Image or folder path for inference", + ) + + parser.add_argument( + "--node_fetch_interval", + type=float, + default=0.001, + required=False, + help="Interval(seconds) that each node fetch data from queue.", + ) + + parser.add_argument( + "--result_contain_score", + type=bool, + default=False, + required=False, + help="If save confidence score to output result.", + ) + + parser.add_argument( + "--det_algorithm", + type=str, + default="DB++", + choices=["DB", "DB++", "DB_MV3", "DB_PPOCRv3", "PSE"], + help="detection algorithm.", + ) # determine the network architecture + parser.add_argument( + "--det_model_name_or_config", type=str, required=False, help="Detection model name or config file path." + ) + + parser.add_argument( + "--cls_algorithm", + type=str, + default="MV3", + choices=["MV3"], + help="classification algorithm.", + ) # determine the network architecture + parser.add_argument( + "--cls_model_name_or_config", type=str, required=False, help="Classification model name or config file path." + ) + + parser.add_argument( + "--rec_algorithm", + type=str, + default="CRNN", + choices=["CRNN", "RARE", "CRNN_CH", "RARE_CH", "SVTR", "SVTR_PPOCRv3_CH"], + help="recognition algorithm", + ) + parser.add_argument( + "--rec_model_name_or_config", type=str, required=False, help="Recognition model name or config file path." + ) + + parser.add_argument( + "--layout_algorithm", + type=str, + default="YOLOV8", + choices=["YOLOV8"], + help="layout algorithm.", + ) # determine the network architecture + parser.add_argument( + "--layout_model_name_or_config", type=str, required=False, help="Layout model name or config file path." + ) + + parser.add_argument( + "--character_dict_path", type=str, required=False, help="Character dict file path for recognition models." + ) + + parser.add_argument( + "--res_save_dir", + type=str, + default="inference_results", + required=False, + help="Saving dir for inference results.", + ) + + parser.add_argument( + "--input_array_save_dir", + type=str, + required=False, + help="Saving input array.", + ) + + parser.add_argument( + "--vis_det_save_dir", type=str, required=False, help="Saving dir for visualization of detection results." + ) + + parser.add_argument( + "--vis_layout_save_dir", type=str, required=False, help="Saving dir for visualization of layout results." + ) + + parser.add_argument( + "--vis_pipeline_save_dir", + type=str, + required=False, + help="Saving dir for visualization of det+cls(optional)+rec pipeline inference results.", + ) + parser.add_argument( + "--crop_save_dir", type=str, required=False, help="Saving dir for images cropped of detection results." + ) + parser.add_argument( + "--show_log", type=str2bool, default=False, required=False, help="Whether show log when inferring." + ) + parser.add_argument("--save_log_dir", type=str, required=False, help="Log saving dir.") + font_default_path = os.path.join(__dir__, "../docs/fonts/simfang.ttf") + parser.add_argument( + "--vis_font_path", + type=str, + default=font_default_path, + required=False, + help="Font file path for recognition model.") + parser.add_argument( + "--visual_pipeline", + type=bool, + default=False, + required=False, + help="visualize pipeline progress.", + ) + parser.add_argument( + "--is_concat", type=str2bool, default=False, help="Whether to concatenate crops after the detection." + ) + args = parser.parse_args() + setup_logger(args) + args = update_task_info(args) + # check_and_update_args(args) + init_save_dir(args) + + return args + + +def setup_logger(args): + """ + initialize log system + """ + log.init_logger(args.show_log, args.save_log_dir) + log.save_args(args) + + +def update_task_info(args): + """ + add internal parameters according to different task type + """ + det = bool(args.det_model_name_or_config) + cls = bool(args.cls_model_name_or_config) + rec = bool(args.rec_model_name_or_config) + layout = bool(args.layout_model_name_or_config) + + task_map = { + (True, False, False, False): TaskType.DET, + (False, True, False, False): TaskType.CLS, + (False, False, True, False): TaskType.REC, + (True, False, True, False): TaskType.DET_REC, + (True, True, True, False): TaskType.DET_CLS_REC, + (False, False, False, True): TaskType.LAYOUT, + (True, False, True, True): TaskType.LAYOUT_DET_REC, + (True, True, True, True): TaskType.LAYOUT_DET_CLS_REC, + } + + task_order = (det, cls, rec, layout) + if task_order in task_map: + setattr(args, "task_type", task_map[task_order]) + else: + unsupported_task_map = { + (False, False, False, False): "empty", + (True, True, False, False): "det+cls", + (False, True, True, False): "cls+rec", + } + + raise ValueError( + f"Only support det, cls, rec, det+rec and det+cls+rec, but got {unsupported_task_map[task_order]}. " + f"Please check model_path!" + ) + + if args.det_model_name_or_config: + setattr(args, "det_config_path", get_config_by_name_for_model(args.det_model_name_or_config)) + else: + setattr(args, "det_config_path", None) + if args.cls_model_name_or_config: + setattr(args, "cls_config_path", get_config_by_name_for_model(args.cls_model_name_or_config)) + else: + setattr(args, "cls_config_path", None) + if args.rec_model_name_or_config: + setattr(args, "rec_config_path", get_config_by_name_for_model(args.rec_model_name_or_config)) + else: + setattr(args, "rec_config_path", None) + if args.layout_model_name_or_config: + setattr(args, "layout_config_path", get_config_by_name_for_model(args.layout_model_name_or_config)) + else: + setattr(args, "layout_config_path", None) + + return args + +def check_file(name, file): + if not os.path.exists(file): + raise ValueError(f"{name} must be a file, but {file} doesn't exist.") + if not os.path.isfile(file): + raise ValueError(f"{name} must be a file, but got a dir of {file}.") + +def check_positive(name, value): + if value < 1: + raise ValueError(f"{name} must be positive, but got {value}.") + + +def check_and_update_args(args): + """ + check parameters + """ + if not args.input_images_dir or not os.path.exists(args.input_images_dir): + raise ValueError("input_images_dir must be dir containing multiple images or path of single image.") + + if args.crop_save_dir and args.task_type not in (TaskType.DET_REC, TaskType.DET_CLS_REC): + raise ValueError("det_model_path and rec_model_path can't be empty when set crop_save_dir.") + + if args.vis_pipeline_save_dir and args.task_type not in (TaskType.DET_REC, + TaskType.DET_CLS_REC, TaskType.LAYOUT_DET_CLS_REC): + raise ValueError("det_model_path and rec_model_path can't be empty when set vis_pipeline_save_dir.") + + if args.vis_det_save_dir and args.task_type not in (TaskType.DET, TaskType.LAYOUT): + raise ValueError( + "det_model_path can't be empty and cls_model_path/rec_model_path must be empty when set vis_det_save_dir " + "for single detection task." + ) + + if not args.res_save_dir: + raise ValueError("res_save_dir can't be empty.") + + need_check_file = { + "det_config_path": args.det_config_path, + "cls_config_path": args.cls_config_path, + "rec_config_path": args.rec_config_path, + } + for name, path in need_check_file.items(): + if path: + check_file(name, path) + with open(path) as fp: + yaml_cfg = Dict(yaml.safe_load(fp)) + check_file(name, yaml_cfg.predict.ckpt_load_path) + check_positive(name, yaml_cfg.predict.loader.batch_size) + + need_check_dir_not_same = { + "input_images_dir": args.input_images_dir, + "crop_save_dir": args.crop_save_dir, + "vis_pipeline_save_dir": args.vis_pipeline_save_dir, + "vis_det_save_dir": args.vis_det_save_dir, + } + for (name1, dir1), (name2, dir2) in itertools.combinations(need_check_dir_not_same.items(), 2): + if (dir1 and dir2) and os.path.realpath(os.path.normcase(dir1)) == os.path.realpath(os.path.normcase(dir2)): + raise ValueError(f"{name1} and {name2} can't be same path.") + + return args + + +def init_save_dir(args): + if args.res_save_dir: + save_path_init(args.res_save_dir, exist_ok=True) + if args.crop_save_dir: + save_path_init(args.crop_save_dir) + if args.vis_pipeline_save_dir: + save_path_init(args.vis_pipeline_save_dir) + if args.vis_det_save_dir: + save_path_init(args.vis_det_save_dir) + if args.save_log_dir: + save_path_init(args.save_log_dir, exist_ok=True) diff --git a/pipeline/parallel_pipeline.py b/pipeline/parallel_pipeline.py new file mode 100644 index 000000000..93115e84a --- /dev/null +++ b/pipeline/parallel_pipeline.py @@ -0,0 +1,98 @@ +import argparse +import os +import time +import sys +from collections import defaultdict +from multiprocessing import Manager, Process, Queue + +import numpy as np +import tqdm + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "../"))) + +from framework.pipeline_manager import ParallelPipelineManager +from data_process.utils import cv_utils +from tasks import TaskType + +class ParallelPipeline: + def __init__(self, args: argparse.Namespace): + self.args = args + self.pipeline_manager = ParallelPipelineManager(args) + self.input_queue = self.pipeline_manager.input_queue + self.infer_params = {} + + def start_pipeline(self): + self.pipeline_manager.start_pipeline() + + def stop_pipeline(self): + self.pipeline_manager.stop_pipeline() + + def infer_for_images(self, input_images_dir, task_id=0): + self.infer_params = dict(**self.pipeline_manager.module_params) + self.send_image(input_images_dir, task_id) + + def fetch_result(self): + return self.pipeline_manager.fetch_result() + + def send_image(self, images: str, task_id=0): + """ + send image to input queue for pipeline + """ + if not (os.path.isdir(images) or os.path.isfile(images)): + raise ValueError("images must be a image path or dir.") + + # det, det(+cls)+rec + batch_num = 1 + + # cls, rec, layout + if self.args.task_type in (TaskType.CLS, TaskType.REC, TaskType.LAYOUT): + for name, value in self.infer_params.items(): + if name.endswith("_batch_num"): + batch_num = max(value) + + self._send_batch_image(images, batch_num, task_id) + + def _send_batch_image(self, images, batch_num, task_id): + if os.path.isdir(images): + show_progressbar = not self.args.show_log + input_image_list = [os.path.join(images, path) for path in os.listdir(images) if not path.endswith(".txt")] + images_num = len(input_image_list) + for i in ( + tqdm.tqdm(range(images_num), desc="send image to pipeline") if show_progressbar else range(images_num) + ): + if i % batch_num == 0: + batch_images = input_image_list[i : i + batch_num] + self.input_queue.put((batch_images, (images_num, task_id)), block=True) + else: + self.input_queue.put([[images], (1, task_id)], block=True) + + def infer_for_array(self, input_array, task_id=0): + self.infer_params = dict(**self.pipeline_manager.module_params) + self.send_array(input_array, task_id) + + def send_array(self, images, task_id): + if isinstance(images, np.ndarray): + self._send_batch_array([images], 1, task_id) + elif isinstance(images, (tuple, list)): + if len(images) == 0: + return + if not cv_utils.check_type_in_container(images, np.ndarray): + ValueError("unknown input data, images should be np.ndarray, or tuple&list contain np.ndarray") + # cls, rec, layout + batch_num = 1 + if self.args.task_type in (TaskType.CLS, TaskType.REC, TaskType.LAYOUT): + for name, value in self.infer_params.items(): + if name.endswith("_batch_num"): + batch_num = max(value) + self._send_batch_array(images, batch_num, task_id) + else: + raise ValueError(f"unknown input data: {type(images)}") + + def _send_batch_array(self, images, batch_num, task_id): + show_progressbar = not self.args.show_log + images_num = len(images) + for i in tqdm.tqdm(range(images_num), desc="send image to pipeline") if show_progressbar else range(images_num): + if i % batch_num == 0: + batch_images = images[i : i + batch_num] + self.input_queue.put([batch_images, (images_num, task_id)], block=True) diff --git a/pipeline/tasks/__init__.py b/pipeline/tasks/__init__.py new file mode 100644 index 000000000..e4d2f0a2c --- /dev/null +++ b/pipeline/tasks/__init__.py @@ -0,0 +1,26 @@ +from enum import Enum + + +class TaskType(Enum): + DET = 0 # Detection Model + CLS = 1 # Classification Model + REC = 2 # Recognition Model + DET_REC = 3 # Detection And Detection Model + DET_CLS_REC = 4 # Detection, Classification and Recognition Model + LAYOUT = 5 # Layout Model + LAYOUT_DET = 6 + LAYOUT_DET_REC = 7 + LAYOUT_DET_CLS_REC = 8 + + +SUPPORTED_TASK_BASIC_MODULE = { + TaskType.DET: [TaskType.DET], + TaskType.CLS: [TaskType.CLS], + TaskType.REC: [TaskType.REC], + TaskType.DET_REC: [TaskType.DET_REC], + TaskType.DET_CLS_REC: [TaskType.DET_CLS_REC], + TaskType.LAYOUT: [TaskType.LAYOUT], + TaskType.LAYOUT_DET: [TaskType.LAYOUT_DET], + TaskType.LAYOUT_DET_REC: [TaskType.LAYOUT_DET_REC], + TaskType.LAYOUT_DET_CLS_REC: [TaskType.LAYOUT_DET_CLS_REC], +} diff --git a/pipeline/utils/__init__.py b/pipeline/utils/__init__.py new file mode 100644 index 000000000..e04b55d61 --- /dev/null +++ b/pipeline/utils/__init__.py @@ -0,0 +1,12 @@ +from .adapted import get_config_by_name_for_model +from .logger import logger_instance as log +from .safe_utils import ( + check_valid_dir, + check_valid_file, + file_base_check, + safe_div, + safe_list_writer, + save_path_init, + suppress_stderr, + suppress_stdout, +) diff --git a/pipeline/utils/adapted/__init__.py b/pipeline/utils/adapted/__init__.py new file mode 100644 index 000000000..bbb5a5ea6 --- /dev/null +++ b/pipeline/utils/adapted/__init__.py @@ -0,0 +1,41 @@ +import os + +import yaml + +from .mindocr_models import MINDOCR_CONFIG_PATH, MINDOCR_MODELS +from .mmocr_models import MMOCR_CONFIG_PATH, MMOCR_MODELS +from .paddleocr_models import PADDLEOCR_CONFIG_PATH, PADDLEOCR_MODELS + +__all__ = ["get_config_by_name_for_model"] + + +def get_config_by_name_for_model(model_name_or_config: str): + if os.path.isfile(model_name_or_config): + filename = model_name_or_config + elif model_name_or_config in MINDOCR_MODELS: + filename = os.path.abspath(os.path.join(MINDOCR_CONFIG_PATH, MINDOCR_MODELS[model_name_or_config])) + elif model_name_or_config in PADDLEOCR_MODELS: + filename = os.path.abspath(os.path.join(PADDLEOCR_CONFIG_PATH, PADDLEOCR_MODELS[model_name_or_config])) + elif model_name_or_config in MMOCR_MODELS: + filename = os.path.abspath(os.path.join(MMOCR_CONFIG_PATH, MMOCR_MODELS[model_name_or_config])) + else: + raise ValueError( + f"The {model_name_or_config} must be a model name or YAML config file path, " + "please check whether the file exists, or whether model name is in the supported models list." + ) + + with open(filename) as fp: + cfg = yaml.safe_load(fp) + + try: + cfg["eval"]["dataset"]["transform_pipeline"] + cfg["postprocess"] + except KeyError: + preprocess_desc = "{eval: {dataset: {transform_pipeline: ...}}}" + postprocess_desc = "{postprocess: ...}" + raise ValueError( + f"The YAML config file {filename} must contain preprocess pipeline key {preprocess_desc} " + f"and postprocess key {postprocess_desc}." + ) + + return filename diff --git a/pipeline/utils/adapted/mindocr_models.py b/pipeline/utils/adapted/mindocr_models.py new file mode 100644 index 000000000..2858d7435 --- /dev/null +++ b/pipeline/utils/adapted/mindocr_models.py @@ -0,0 +1,16 @@ +import os + +MINDOCR_CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../../configs")) + +MINDOCR_MODELS = { + "en_ms_det_dbnet_resnet50": "det/dbnet/db_r50_icdar15.yaml", + "en_ms_det_dbnetpp_resnet50": "det/dbnet/dbpp_r50_icdar15.yaml", + "en_ms_det_psenet_resnet152": "det/psenet/pse_r152_icdar15.yaml", + "en_ms_det_psenet_resnet50": "det/psenet/pse_r50_icdar15.yaml", + "en_ms_det_psenet_mobilenetv3": "det/psenet/pse_mv3_icdar15.yaml", + "ch_ms_det_psenet_resnet152": "det/psenet/pse_r152_ctw1500.yaml", + "en_ms_rec_crnn_resnet34": "rec/crnn/crnn_resnet34.yaml", + "en_ms_det_east_resnet50": "det/east/east_r50_icdar15.yaml", + "en_ms_det_east_mobilenetv3": "det/east/east_mobilenetv3_icdar15.yaml", + "en_ms_rec_visionlan_resnet45": "rec/visionlan/visionlan_resnet45_LA.yaml", +} diff --git a/pipeline/utils/adapted/mmocr_models.py b/pipeline/utils/adapted/mmocr_models.py new file mode 100644 index 000000000..2970390a6 --- /dev/null +++ b/pipeline/utils/adapted/mmocr_models.py @@ -0,0 +1,13 @@ +import os + +MMOCR_CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../configs")) + +# fmt: off +MMOCR_MODELS = { + "en_mm_det_dbnetpp_resnet50": "det/mmocr/dbnetpp_resnet50_fpnc_1200e_icdar2015.yaml", # dbnet++ resnet50 + "en_mm_det_fcenet_resnet50": "det/mmocr/fcenet_resnet50_fpn_1500e_icdar2015.yaml", # fcenet resnet50 + "en_mm_rec_nrtr_resnet31": "rec/mmocr/nrtr_resnet31-1by8-1by4_6e_st_mj.yaml", # nrtr resnet31 + "en_mm_rec_satrn_shallowcnn": "rec/mmocr/satrn_shallow_5e_st_mj.yaml", # satrn shallow + +} +# fmt: on diff --git a/pipeline/utils/adapted/paddleocr_models.py b/pipeline/utils/adapted/paddleocr_models.py new file mode 100644 index 000000000..a1a43cda5 --- /dev/null +++ b/pipeline/utils/adapted/paddleocr_models.py @@ -0,0 +1,42 @@ +import os + +PADDLEOCR_CONFIG_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../configs")) + +# fmt: off +PADDLEOCR_MODELS = { + "ch_pp_det_OCRv4": "det/ppocr/ch_PP-OCRv4_det_cml.yaml", # ch_PP-OCRv4_det + "ch_pp_server_det_v2.0": "det/ppocr/ch_det_res18_db_v2.0.yaml", # ch_ppocr_server_v2.0_det + "ch_pp_det_OCRv3": "det/ppocr/ch_PP-OCRv3_det_cml.yaml", # ch_PP-OCRv3_det + "ch_pp_server_rec_v2.0": "rec/ppocr/rec_chinese_common_v2.0.yaml", # ch_ppocr_server_v2.0_rec + "ch_pp_rec_OCRv3": "rec/ppocr/ch_PP-OCRv3_rec_distillation.yaml", # ch_PP-OCRv3_rec + "ch_pp_rec_OCRv4": "rec/ppocr/ch_PP-OCRv4_rec_distillation.yaml", # ch_PP-OCRv4_rec + "ch_pp_mobile_cls_v2.0": "cls/ppocr/cls_mv3.yaml", # ch_ppocr_mobile_v2.0_cls + "ch_pp_det_OCRv2": "det/ppocr/ch_PP-OCRv2_det_cml.yaml", # ch_PP-OCRv2_det + "ch_pp_mobile_det_v2.0_slim": "det/ppocr/ch_det_mv3_db_v2.0.yaml", # ch_ppocr_mobile_slim_v2.0_det + "ch_pp_mobile_det_v2.0": "det/ppocr/ch_det_mv3_db_v2.0.yaml", # ch_ppocr_mobile_v2.0_det + "en_pp_det_OCRv3": "det/ppocr/ch_PP-OCRv3_det_cml.yaml", # en_PP-OCRv3_det + "ml_pp_det_OCRv3": "det/ppocr/ch_PP-OCRv3_det_cml.yaml", # ml_PP-OCRv3_det + "ch_pp_rec_OCRv2": "rec/ppocr/ch_PP-OCRv2_rec_distillation.yaml", # ch_PP-OCRv2_rec + "ch_pp_mobile_rec_v2.0": "rec/ppocr/rec_chinese_lite_v2.0.yaml", # ch_ppocr_mobile_v2.0_rec + "en_pp_rec_OCRv3": "rec/ppocr/en_PP-OCRv3_rec.yaml", # en_PP-OCRv3_rec + "en_pp_mobile_rec_number_v2.0_slim": "rec/ppocr/rec_en_number_lite.yaml", # en_number_mobile_slim_v2.0_rec + "en_pp_mobile_rec_number_v2.0": "rec/ppocr/rec_en_number_lite.yaml", # en_number_mobile_v2.0_rec + "korean_pp_rec_OCRv3": "rec/ppocr/korean_PP-OCRv3_rec.yaml", # korean_PP-OCRv3_rec + "japan_pp_rec_OCRv3": "rec/ppocr/japan_PP-OCRv3_rec.yaml", # japan_PP-OCRv3_rec + "chinese_cht_pp_rec_OCRv3": "rec/ppocr/chinese_cht_PP-OCRv3_rec.yaml", # chinese_cht_PP-OCRv3_rec + "te_pp_rec_OCRv3": "rec/ppocr/te_PP-OCRv3_rec.yaml", # te_PP-OCRv3_rec + "ka_pp_rec_OCRv3": "rec/ppocr/ka_PP-OCRv3_rec.yaml", # ka_PP-OCRv3_rec + "ta_pp_rec_OCRv3": "rec/ppocr/ta_PP-OCRv3_rec.yaml", # ta_PP-OCRv3_rec + "latin_pp_rec_OCRv3": "rec/ppocr/latin_PP-OCRv3_rec.yaml", # latin_PP-OCRv3_rec + "arabic_pp_rec_OCRv3": "rec/ppocr/arabic_PP-OCRv3_rec.yaml", # arabic_PP-OCRv3_rec + "cyrillic_pp_rec_OCRv3": "rec/ppocr/cyrillic_PP-OCRv3_rec.yaml", # cyrillic_PP-OCRv3_rec + "devanagari_pp_rec_OCRv3": "rec/ppocr/devanagari_PP-OCRv3_rec.yaml", # devanagari_PP-OCRv3_rec + "en_pp_det_psenet_resnet50vd": "det/ppocr/det_r50_vd_pse.yaml", # pse_resnet50_vd + "en_pp_det_dbnet_resnet50vd": "det/ppocr/det_r50_vd_db.yaml", # dbnet resnet50_vd + "en_pp_det_east_resnet50vd": "det/ppocr/det_r50_vd_east.yaml", # east resnet50_vd + "en_pp_det_sast_resnet50vd": "det/ppocr/det_r50_vd_sast_icdar15.yaml", # sast resnet50_vd + "en_pp_rec_crnn_resnet34vd": "rec/ppocr/rec_r34_vd_none_bilstm_ctc.yaml", # crnn resnet34_vd + "en_pp_rec_rosetta_resnet34vd": "rec/ppocr/rec_r34_vd_none_none_ctc.yaml", # en_pp_rec_rosetta_resnet34vd + "en_pp_rec_vitstr_vitstr": "rec/ppocr/rec_vitstr_none_ce.yaml", # vitstr +} +# fmt: on diff --git a/pipeline/utils/logger.py b/pipeline/utils/logger.py new file mode 100644 index 000000000..bd95fecba --- /dev/null +++ b/pipeline/utils/logger.py @@ -0,0 +1,221 @@ +import argparse +import logging +import os +import sys +import threading +import time +from logging.handlers import RotatingFileHandler + +# Log level name and number mapping +_name_to_log_level = { + "ERROR": 40, + "WARNING": 30, + "INFO": 20, + "DEBUG": 10, +} + +# mindspore level and level name +_ms_level_to_name = { + "3": "ERROR", + "2": "WARNING", + "1": "INFO", + "0": "DEBUG", +} + +MAX_BYTES = 100 * 1024 * 1024 +BACKUP_COUNT = 10 +LOG_TYPE = "mindocr" +LOG_ENV = "MINDOCR_LOG_LEVEL" +INFER_INSTALL_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../")) + "/" + + +class DataFormatter(logging.Formatter): + """Log formatter""" + + def __init__(self, sub_module, fmt=None, **kwargs): + """ + Initialization of logFormatter. + :param sub_module: The submodule name. + :param fmt: Specified format pattern. Default: None. + :param kwargs: None + """ + super(DataFormatter, self).__init__(fmt=fmt, **kwargs) + self.sub_module = sub_module.upper() + + def formatTime(self, record, datefmt=None): + """ + Override formatTime for uniform format %Y-%m-%d-%H:%M:%S.SSS.SSS + :param record: Log record + :param datefmt: Date format + :return: formatted timestamp + """ + create_time = self.converter(record.created) + if datefmt: + return time.strftime(datefmt, create_time) + + timestamp = time.strftime("%Y-%m-%d-%H:%M:%S", create_time) + record_msecs = str(round(record.msecs * 1000)) + # Format the time stamp + return f"{timestamp}.{record_msecs[:3]}.{record_msecs[3:]}" + + def format(self, record): + """ + Apply log format with specified pattern. + :param record: Format pattern. + :return: formatted log content according to format pattern. + """ + if record.pathname.startswith(INFER_INSTALL_PATH): + # Get the relative path + record.filepath = record.pathname[len(INFER_INSTALL_PATH) :] + elif "/" in record.pathname: + record.filepath = record.pathname.strip().split("/")[-1] + else: + record.filepath = record.pathname + record.sub_module = self.sub_module + return super().format(record) + + +class RotatingLogFileHandler(RotatingFileHandler): + def _open(self): + return os.fdopen(os.open(self.baseFilename, os.O_RDWR | os.O_CREAT, 0o600), "a") + + +def _filter_env_level(): + log_env_level = os.getenv(LOG_ENV, "1") + if ( + not isinstance(log_env_level, str) + or not log_env_level.isdigit() + or int(log_env_level) < 0 + or int(log_env_level) > 3 + ): + log_env_level = "1" + return log_env_level + + +class LOGGER(logging.Logger): + def __init__(self, logger_name, log_level=logging.WARNING): + super(LOGGER, self).__init__(logger_name) + self.model_name = logger_name + self.data_formatter = DataFormatter(self.model_name, self._get_formatter()) + self.console_log_level = ( + _name_to_log_level.get(_ms_level_to_name.get(_filter_env_level())) if log_level is None else log_level + ) + console = logging.StreamHandler(sys.stdout) + console.setLevel(level=self.console_log_level) + console.setFormatter(self.data_formatter) + self.addHandler(console) + + @staticmethod + def _get_formatter(): + """ + + :return: str, the string of log formatter. + """ + formatter = ( + "[%(levelname)s] %(sub_module)s(%(process)d:" + "%(thread)d,%(processName)s):%(asctime)s " + "[%(filepath)s:%(lineno)d] %(message)s" + ) + return formatter + + def info(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.INFO) and os.getenv("RANK_ID", "0") == "0": + self._log(logging.INFO, msg, args, **kwargs) + + def debug(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.DEBUG) and os.getenv("RANK_ID", "0") == "0": + self._log(logging.DEBUG, msg, args, **kwargs) + + def warning(self, msg, *args, **kwargs): + if self.isEnabledFor(logging.WARNING) and os.getenv("RANK_ID", "0") == "0": + self._log(logging.WARNING, msg, args, **kwargs) + + def error(self, msg, *args, **kwargs): + rank_id = os.getenv("RANK_ID", None) + if rank_id and rank_id.isdigit() and 0 <= int(rank_id) < 8: + msg = f"[The error from this card id ({rank_id})] " + msg + if self.isEnabledFor(logging.ERROR): + self._log(logging.ERROR, msg, args, **kwargs) + + def setup_logging_file(self, log_dir, max_size=100 * 1024 * 1024, backup_cnt=10): + """Setup logging file.""" + if max_size > 1024 * 1024 * 1024 or max_size < 0: + logging.error("single log file size should more than 0, less than or equal to 1G.") + raise Exception("single log file size should more than 0, less than or equal to 1G.") + if backup_cnt > 100 or backup_cnt < 0: + logging.error("log file backup count should more than 0, less than or equal to 100") + raise Exception("log file backup count should more than 0, less than or equal to 100") + log_dir = os.path.realpath(log_dir) + if not os.path.exists(log_dir): + os.makedirs(log_dir, mode=0o750) + log_file_name = f"{self.model_name}.log" + log_fn = os.path.join(log_dir, log_file_name) + fh = RotatingLogFileHandler(log_fn, "a", max_size, backup_cnt) + fh.setFormatter(self.data_formatter) + fh.setLevel(logging.INFO) + self.addHandler(fh) + + def filter_log_str(self, msg) -> str: + def _check_str(need_check_str): + if len(need_check_str) > 10000: + self.warning("Input should be <= 10000") + return False + filter_strs = ["\r", "\n", "\\r", "\\n"] + for filter_str in filter_strs: + if filter_str in need_check_str: + self.warning("Input should not be included \\r or \\n") + return False + return True + + if isinstance(msg, str) and not _check_str(msg): + return "" + else: + return msg + + def save_args(self, args): + """ + :param args: input args param, just support argparse or dict + :return: None + """ + if isinstance(args, argparse.Namespace): + args = vars(args) + elif isinstance(args, dict): + pass + else: + logging.error("This api just support argparse or dict, please check your input type.") + raise Exception("This api just support argparse or dict, please check your input type.") + self.info("Args:") + args_copy = args.copy() + for key, value in args_copy.items(): + self.info("--> %s: %s", key, self.filter_log_str(args_copy[key])) + self.info("Finish read param") + + +class SingletonType(type): + _instance_lock = threading.Lock() + + def __call__(cls, *args, **kwargs): + if not hasattr(cls, "_instance"): + with SingletonType._instance_lock: + if not hasattr(cls, "_instance"): + cls._instance = super(SingletonType, cls).__call__(*args, **kwargs) + return cls._instance + + +class LoggerSystem(metaclass=SingletonType): + def __init__(self, model_name=LOG_TYPE, max_size=MAX_BYTES, backup_cnt=BACKUP_COUNT): + self.model_name = model_name + self.max_bytes = max_size + self.backup_count = backup_cnt + self.logger = None + + def init_logger(self, show_info_log=False, save_path=None): + self.logger = LOGGER(self.model_name, logging.INFO if show_info_log else logging.WARNING) + if save_path: + self.logger.setup_logging_file(save_path, self.max_bytes, self.backup_count) + + def __getattr__(self, item): + return object.__getattribute__(self.logger, item) + + +logger_instance = LoggerSystem(LOG_TYPE) diff --git a/pipeline/utils/safe_utils.py b/pipeline/utils/safe_utils.py new file mode 100644 index 000000000..a29760dec --- /dev/null +++ b/pipeline/utils/safe_utils.py @@ -0,0 +1,151 @@ +import contextlib +import json +import os +import re +import shutil +import stat + +from .logger import logger_instance as log + + +def safe_list_writer(save_dict, save_path): + """ + append the infer result to file. + :param save_dict: + :param save_path: + :return: + """ + flags, modes = os.O_WRONLY | os.O_CREAT | os.O_APPEND, stat.S_IWUSR | stat.S_IRUSR | stat.S_IRGRP + with os.fdopen(os.open(save_path, flags, modes), "w") as f: + if not save_dict: + f.write("") + for filename, res in save_dict.items(): + content = os.path.basename(filename) + "\t" + json.dumps(res, ensure_ascii=False) + "\n" + f.write(content) + + +def safe_div(dividend, divisor): + try: + quotient = dividend / divisor + except ZeroDivisionError as error: + log.error(error) + quotient = 0 + return quotient + + +def verify_file_size(file_path) -> bool: + conf_file_size = os.path.getsize(file_path) + if conf_file_size > 0 and conf_file_size / 1024 / 1024 < 10: + return True + return False + + +def valid_characters(pattern: str, characters: str) -> bool: + if re.match(r".*[\s]+", characters): + return False + if not re.match(pattern, characters): + return False + return True + + +def file_base_check(file_path: str) -> None: + base_name = os.path.basename(file_path) + if not file_path or not os.path.isfile(file_path): + raise FileNotFoundError(f"the file:{base_name} does not exist!") + if not valid_characters("^[A-Za-z0-9_+-/]+$", file_path): + raise Exception(f"file path:{os.path.relpath(file_path)} should only include characters 'A-Za-z0-9+-_/'!") + if not verify_file_size(file_path): + raise Exception(f"{base_name}: the file size must between [1, 10M]!") + if os.path.islink(file_path): + raise Exception(f"the file:{base_name} is link. invalid file!") + if not os.access(file_path, mode=os.R_OK): + raise FileNotFoundError(f"the file:{base_name} is unreadable!") + + +def get_safe_name(path): + """Remove ending path separators before retrieving the basename. + + e.g. /xxx/ -> /xxx + """ + return os.path.basename(os.path.abspath(path)) + + +def custom_islink(path): + """Remove ending path separators before checking soft links. + + e.g. /xxx/ -> /xxx + """ + return os.path.islink(os.path.abspath(path)) + + +def check_valid_dir(path): + name = get_safe_name(path) + check_valid_path(path, name) + if not os.path.isdir(path): + log.error(f"Please check if {name} is a directory.") + raise NotADirectoryError("Check dir failed.") + + +def check_valid_path(path, name): + if not path or not os.path.exists(path): + raise FileExistsError(f"Error! {name} must exists!") + if custom_islink(path): + raise ValueError(f"Error! {name} cannot be a soft link!") + if not os.access(path, mode=os.R_OK): + raise RuntimeError(f"Error! Please check if {name} is readable.") + + +def check_valid_file(path, num_gb_limit=10): + filename = get_safe_name(path) + check_valid_path(path, filename) + if not os.path.isfile(path): + log.error(f"Please check if {filename} is a file.") + raise ValueError("Check file failed.") + check_size(path, filename, num_gb_limit=num_gb_limit) + + +def check_size(path, name, num_gb_limit): + limit = num_gb_limit * 1024 * 1024 * 1024 + size = os.path.getsize(path) + if size == 0: + raise ValueError(f"{name} cannot be an empty file!") + if size >= limit: + raise ValueError(f"The size of {name} must be smaller than {num_gb_limit} GB!") + + +def save_path_init(path, exist_ok=False): + if os.path.exists(path): + if exist_ok: + return + shutil.rmtree(path) + os.makedirs(path, 0o750) + + +@contextlib.contextmanager +def suppress_stdout(): + """ + A context manager for doing a "deep suppression" of stdout. + """ + null_fds = os.open(os.devnull, os.O_RDWR) + save_fds = os.dup(1) + os.dup2(null_fds, 1) + + yield + + os.dup2(save_fds, 1) + os.close(null_fds) + + +@contextlib.contextmanager +def suppress_stderr(): + """ + A context manager for doing a "deep suppression" of stderr. + """ + null_fds = os.open(os.devnull, os.O_RDWR) + save_fds = os.dup(2) + os.dup2(null_fds, 2) + + yield + + os.dup2(save_fds, 2) + os.close(null_fds) diff --git a/pipeline/utils/visual_c_results.py b/pipeline/utils/visual_c_results.py new file mode 100644 index 000000000..2b9ac0f3e --- /dev/null +++ b/pipeline/utils/visual_c_results.py @@ -0,0 +1,49 @@ +import argparse +import os + +import cv2 +import numpy as np +from tqdm import tqdm +from visual_utils import vis_bbox_text + + +def img_write(path: str, img: np.ndarray): + filename = os.path.abspath(path) + cv2.imencode(os.path.splitext(filename)[1], img)[1].tofile(filename) + + +def vis_results(prediction_result, vis_pipeline_save_dir, img_folder): + img_files = os.listdir(img_folder) + img_dict = {} + font_path = os.path.abspath("../../../../docs/fonts/simfang.ttf") + for img_name in img_files: + img = cv2.imread(os.path.join(img_folder, img_name)) # BGR format + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_dict[img_name] = img + + for each_pred in tqdm(prediction_result): + file_name, prediction = each_pred.split("\t") + basename = os.path.basename(file_name) + save_file = os.path.join(vis_pipeline_save_dir, os.path.splitext(basename)[0]) + prediction = eval(prediction) + box_list = [np.array(x["points"]).reshape(-1, 2) for x in prediction] + text_list = [x["transcription"] for x in prediction] + box_text = vis_bbox_text(img_dict[file_name], box_list, text_list, font_path=font_path) + img_write(save_file + ".jpg", box_text) + + +def read_prediction(prediction_folder): + with open(prediction_folder, "r", encoding="utf-8") as f: + prediction = f.readlines() + return prediction + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--img_folder", required=True, type=str) + parser.add_argument("--pred_dir", required=True, type=str) + parser.add_argument("--vis_dir", required=True, type=str) + args = parser.parse_args() + + prediction = read_prediction(args.pred_dir) + vis_results(prediction, args.vis_dir, args.img_folder) diff --git a/pipeline/utils/visual_utils.py b/pipeline/utils/visual_utils.py new file mode 100644 index 000000000..bccc59f4b --- /dev/null +++ b/pipeline/utils/visual_utils.py @@ -0,0 +1,129 @@ +""" +OCR visualization methods +""" +import math +import os.path +import random + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont + +__all__ = ["vis_bbox", "vis_bbox_text", "vis_crop"] + + +def vis_bbox(image, box_list, color, thickness): + """ + Draw a bounding box on an image. + :param image: input image + :param box_list: box list to add on image + :param color: color of the box + :param thickness: line thickness + :return: image with box + """ + + image = image.copy() + for box in box_list: + box = box.astype(int) + cv2.polylines(image, [box], True, color, thickness) + return image + + +def vis_bbox_text(image, box_list, text_list, font_path): + """ + Draw a bounding box and text on an image. + :param image: input image + :param box_list: box list to add on image + :param text_list: text list to add on image + :param font_path: path to font file + :return: image with box and text + """ + if font_path is None: + _font_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../docs/fonts/simfang.ttf")) + if os.path.isfile(_font_path): + font_path = _font_path + + image = Image.fromarray(image) + h, w = image.height, image.width + img_left = image.copy() + img_right = np.ones((h, w, 3), dtype=np.uint8) * 255 + random.seed(0) + + draw_left = ImageDraw.Draw(img_left) + if text_list is None or len(text_list) != len(box_list): + text_list = [None] * len(box_list) + for idx, (box, txt) in enumerate(zip(box_list, text_list)): + color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) + draw_left.polygon(box.astype(np.float32), fill=color) + img_right_text = draw_box_txt_fine((w, h), box, txt, font_path) + pts = np.array(box, np.int32).reshape((-1, 1, 2)) + cv2.polylines(img_right_text, [pts], True, color, 1) + img_right = cv2.bitwise_and(img_right, img_right_text) + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new(mode="RGB", size=(w * 2, h), color=(255, 255, 255)) # RGB or BGR doesn't matter + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h)) + return np.array(img_show) # RGB or BGR is the same as input image + + +def draw_box_txt_fine(img_size, box, txt, font_path): + box_height = int(math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)) + box_width = int(math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)) + if box_height > 2 * box_width and box_height > 30: + img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255)) # RGB or BGR doesn't matter + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_height, box_width), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + img_text = img_text.transpose(Image.ROTATE_270) + else: + img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255)) # RGB or BGR doesn't matter + draw_text = ImageDraw.Draw(img_text) + if txt: + font = create_font(txt, (box_width, box_height), font_path) + draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font) + pts1 = np.float32([[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]) + pts2 = np.array(box, dtype=np.float32) + M = cv2.getPerspectiveTransform(pts1, pts2) + + img_text = np.array(img_text, dtype=np.uint8) + img_right_text = cv2.warpPerspective( + img_text, M, img_size, flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=(255, 255, 255) + ) + return img_right_text # RGB or BGR is the same as input image + + +def create_font(txt, sz, font_path): + font_size = int(sz[1] * 0.99) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + length = font.getsize(txt)[0] + if length > sz[0]: + font_size = int(font_size * sz[0] / length) + font = ImageFont.truetype(font_path, font_size, encoding="utf-8") + return font + + +def vis_crop(image, box_list): + """ + Generate crop image + :param image: input image + :param box_list: list of box + :return List of Cropped Images + """ + image_crop = [] + for box in box_list: + if box.shape != (4, 2): + raise ValueError("shape of crop box must be 4*2") + box = box.astype(np.float32) + img_crop_width = int(max(np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[2] - box[3]))) + img_crop_height = int(max(np.linalg.norm(box[0] - box[3]), np.linalg.norm(box[1] - box[2]))) + pts_std = np.float32([[0, 0], [img_crop_width, 0], [img_crop_width, img_crop_height], [0, img_crop_height]]) + m = cv2.getPerspectiveTransform(box, pts_std) + dst_img = cv2.warpPerspective( + image, m, (img_crop_width, img_crop_height), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_CUBIC + ) + dst_img_height, dst_img_width = dst_img.shape[0:2] + if dst_img_width != 0 and dst_img_height / dst_img_width >= 1.5: + dst_img = np.rot90(dst_img) + image_crop.append(dst_img) + return image_crop diff --git a/tools/data_for_export_convert.py b/tools/data_for_export_convert.py index bee8f7798..3019463bb 100644 --- a/tools/data_for_export_convert.py +++ b/tools/data_for_export_convert.py @@ -159,6 +159,12 @@ "data_shape": "args0:[1,3,48,160];args1:[1,1,40];args2:[1,40]", "infer_shape_list": ["1,3,48,160:1,1,40:1,40"], }, + "layout_yolov8n": { + "mindir_url": "https://download.mindspore.cn/toolkits/mindocr/yolov8/yolov8n-2a1f68ab.mindir", + "mindir_name": "yolov8n-2a1f68ab.mindir", + "data_shape": "args0:[1,3,800,800]", + "infer_shape_list": ["1,3,800,800"], + }, } @@ -279,6 +285,11 @@ "mindir_name": "robustscanner_resnet31.mindir", "infer_shape_list": ["1,3,48,160:1,1,40:1,40"], }, + "layout_yolov8n": { + "data_shape": "args0:[1,3,800,800]", + "mindir_name": "yolov8n-2a1f68ab.mindir", + "infer_shape_list": ["1,3,800,800"], + }, } @@ -403,6 +414,11 @@ "mindir_name": "robustscanner_resnet31.mindir", "infer_shape_list": ["1,3,48,160"], }, + "layout_yolov8n": { + "data_shape": "args0:[-1,3,-1,-1]", + "mindir_name": "yolov8n.mindir", + "infer_shape_list": ["1,3,800,800"], + }, } @@ -431,6 +447,7 @@ "svtr_tiny_ch": {"model_name": "svtr_tiny_ch", "data_shape_h_w": [32, 320]}, "visionlan_resnet45": {"model_name": "visionlan_resnet45", "data_shape_h_w": [64, 256]}, "robustscanner_resnet31": {"model_name": "robustscanner_resnet31", "data_shape_h_w": [48, 160]}, + "layout_yolov8n": { "model_name": "layout_yolov8n", "data_shape_h_w": [800, 800]}, } @@ -459,4 +476,5 @@ "svtr_tiny_ch": {"model_name": "svtr_tiny_ch", "model_type": "rec"}, "visionlan_resnet45": {"model_name": "visionlan_resnet45", "model_type": "rec"}, "robustscanner_resnet31": {"model_name": "robustscanner_resnet31", "model_type": "rec"}, + "layout_yolov8n": { "model_name": "layout_yolov8n", "model_type": "layout"}, }