跳转至

Python_infer(Python 推理) 模块

deploy.python_infer

Predictor

Initializes the inference engine with the given parameters.

Parameters:

Name Type Description Default
pdmodel_path Optional[str]

Path to the PaddlePaddle model file. Defaults to None.

None
pdiparams_path Optional[str]

Path to the PaddlePaddle model parameters file. Defaults to None.

None
device Literal['gpu', 'cpu', 'npu', 'xpu']

Device to use for inference. Defaults to "cpu".

'cpu'
engine Literal['native', 'tensorrt', 'onnx', 'mkldnn']

Inference engine to use. Defaults to "native".

'native'
precision Literal['fp32', 'fp16', 'int8']

Precision to use for inference. Defaults to "fp32".

'fp32'
onnx_path Optional[str]

Path to the ONNX model file. Defaults to None.

None
ir_optim bool

Whether to use IR optimization. Defaults to True.

True
min_subgraph_size int

Minimum subgraph size for IR optimization. Defaults to 15.

15
gpu_mem int

Initial size of GPU memory pool(MB). Defaults to 500(MB).

500
gpu_id int

GPU ID to use. Defaults to 0.

0
num_cpu_threads int

Number of CPU threads to use. Defaults to 1.

10
Source code in deploy/python_infer/base.py
class Predictor:
    """
    Initializes the inference engine with the given parameters.

    Args:
        pdmodel_path (Optional[str]): Path to the PaddlePaddle model file. Defaults to None.
        pdiparams_path (Optional[str]): Path to the PaddlePaddle model parameters file. Defaults to None.
        device (Literal["gpu", "cpu", "npu", "xpu"], optional): Device to use for inference. Defaults to "cpu".
        engine (Literal["native", "tensorrt", "onnx", "mkldnn"], optional): Inference engine to use. Defaults to "native".
        precision (Literal["fp32", "fp16", "int8"], optional): Precision to use for inference. Defaults to "fp32".
        onnx_path (Optional[str], optional): Path to the ONNX model file. Defaults to None.
        ir_optim (bool, optional): Whether to use IR optimization. Defaults to True.
        min_subgraph_size (int, optional): Minimum subgraph size for IR optimization. Defaults to 15.
        gpu_mem (int, optional): Initial size of GPU memory pool(MB). Defaults to 500(MB).
        gpu_id (int, optional): GPU ID to use. Defaults to 0.
        num_cpu_threads (int, optional): Number of CPU threads to use. Defaults to 1.
    """

    def __init__(
        self,
        pdmodel_path: Optional[str] = None,
        pdiparams_path: Optional[str] = None,
        *,
        device: Literal["gpu", "cpu", "npu", "xpu"] = "cpu",
        engine: Literal["native", "tensorrt", "onnx", "mkldnn"] = "native",
        precision: Literal["fp32", "fp16", "int8"] = "fp32",
        onnx_path: Optional[str] = None,
        ir_optim: bool = True,
        min_subgraph_size: int = 15,
        gpu_mem: int = 500,
        gpu_id: int = 0,
        max_batch_size: int = 10,
        num_cpu_threads: int = 10,
    ):
        self.pdmodel_path = pdmodel_path
        self.pdiparams_path = pdiparams_path

        self._check_device(device)
        self.device = device
        self._check_engine(engine)
        self.engine = engine
        self._check_precision(precision)
        self.precision = precision
        self._compatibility_check()

        self.onnx_path = onnx_path
        self.ir_optim = ir_optim
        self.min_subgraph_size = min_subgraph_size
        self.gpu_mem = gpu_mem
        self.gpu_id = gpu_id
        self.max_batch_size = max_batch_size
        self.num_cpu_threads = num_cpu_threads

        if self.engine == "onnx":
            self.predictor, self.config = self._create_onnx_predictor()
        else:
            self.predictor, self.config = self._create_paddle_predictor()

        logger.message(
            f"Inference with engine: {self.engine}, precision: {self.precision}, "
            f"device: {self.device}."
        )

    def predict(self, input_dict):
        raise NotImplementedError

    def _create_paddle_predictor(
        self,
    ) -> Tuple[paddle_inference.Predictor, paddle_inference.Config]:
        if not osp.exists(self.pdmodel_path):
            raise FileNotFoundError(
                f"Given 'pdmodel_path': {self.pdmodel_path} does not exist. "
                "Please check if it is correct."
            )
        if not osp.exists(self.pdiparams_path):
            raise FileNotFoundError(
                f"Given 'pdiparams_path': {self.pdiparams_path} does not exist. "
                "Please check if it is correct."
            )

        config = paddle_inference.Config(self.pdmodel_path, self.pdiparams_path)
        if self.device == "gpu":
            config.enable_use_gpu(self.gpu_mem, self.gpu_id)
            if self.engine == "tensorrt":
                if self.precision == "fp16":
                    precision = paddle_inference.Config.Precision.Half
                elif self.precision == "int8":
                    precision = paddle_inference.Config.Precision.Int8
                else:
                    precision = paddle_inference.Config.Precision.Float32
                config.enable_tensorrt_engine(
                    workspace_size=1 << 30,
                    precision_mode=precision,
                    max_batch_size=self.max_batch_size,
                    min_subgraph_size=self.min_subgraph_size,
                    use_calib_mode=False,
                )
                # collect shape
                pdmodel_dir = osp.dirname(self.pdmodel_path)
                trt_shape_path = osp.join(pdmodel_dir, "trt_dynamic_shape.txt")

                if not osp.exists(trt_shape_path):
                    config.collect_shape_range_info(trt_shape_path)
                    logger.message(
                        f"Save collected dynamic shape info to: {trt_shape_path}"
                    )
                try:
                    config.enable_tuned_tensorrt_dynamic_shape(trt_shape_path, True)
                except Exception as e:
                    logger.warning(e)
                    logger.warning(
                        "TRT dynamic shape is disabled for your paddlepaddle < 2.3.0"
                    )

        elif self.device == "npu":
            config.enable_custom_device("npu")
        elif self.device == "xpu":
            config.enable_xpu(10 * 1024 * 1024)
        else:
            config.disable_gpu()
            if self.engine == "mkldnn":
                # 'set_mkldnn_cache_capatity' is not available on macOS
                if platform.system() != "Darwin":
                    ...
                    # cache 10 different shapes for mkldnn to avoid memory leak
                    # config.set_mkldnn_cache_capacity(10)
                config.enable_mkldnn()

                if self.precision == "fp16":
                    config.enable_mkldnn_bfloat16()

                config.set_cpu_math_library_num_threads(self.num_cpu_threads)

        # enable memory optim
        config.enable_memory_optim()
        # config.disable_glog_info()
        # enable zero copy
        config.switch_use_feed_fetch_ops(False)
        config.switch_ir_optim(self.ir_optim)

        predictor = paddle_inference.create_predictor(config)
        return predictor, config

    def _create_onnx_predictor(
        self,
    ) -> Tuple["onnxruntime.InferenceSession", "onnxruntime.SessionOptions"]:
        if not osp.exists(self.onnx_path):
            raise FileNotFoundError(
                f"Given 'onnx_path' {self.onnx_path} does not exist. "
                "Please check if it is correct."
            )

        try:
            import onnxruntime as ort
        except ModuleNotFoundError:
            raise ModuleNotFoundError(
                "Please install onnxruntime with `pip install onnxruntime`."
            )

        # set config for onnx predictor
        config = ort.SessionOptions()
        config.intra_op_num_threads = self.num_cpu_threads
        if self.ir_optim:
            config.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

        # instantiate onnx predictor
        providers = (
            ["CUDAExecutionProvider", "CPUExecutionProvider"]
            if self.device != "cpu"
            else ["CPUExecutionProvider"]
        )
        predictor = ort.InferenceSession(
            self.onnx_path, sess_options=config, providers=providers
        )
        return predictor, config

    def _check_device(self, device: str):
        if device not in ["gpu", "cpu", "npu", "xpu"]:
            raise ValueError(
                "Inference only supports 'gpu', 'cpu', 'npu' and 'xpu' devices, "
                f"but got {device}."
            )

    def _check_engine(self, engine: str):
        if engine not in ["native", "tensorrt", "onnx", "mkldnn"]:
            raise ValueError(
                "Inference only supports 'native', 'tensorrt', 'onnx' and 'mkldnn' "
                f"engines, but got {engine}."
            )

    def _check_precision(self, precision: str):
        if precision not in ["fp32", "fp16", "int8"]:
            raise ValueError(
                "Inference only supports 'fp32', 'fp16' and 'int8' "
                f"precision, but got {precision}."
            )

    def _compatibility_check(self):
        if self.engine == "onnx":
            if not (
                importlib.util.find_spec("onnxruntime")
                or importlib.util.find_spec("onnxruntime-gpu")
            ):
                raise ModuleNotFoundError(
                    "\nPlease install onnxruntime first when engine is 'onnx'\n"
                    "* For CPU inference, use `pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple`\n"
                    "* For GPU inference, use `pip install onnxruntime-gpu -i https://pypi.tuna.tsinghua.edu.cn/simple`"
                )
            import onnxruntime as ort

            if self.device == "gpu" and ort.get_device() != "GPU":
                raise RuntimeError(
                    "Please install onnxruntime-gpu with `pip install onnxruntime-gpu`"
                    " when device is set to 'gpu'\n"
                )

GeneralPredictor

Bases: PINNPredictor

Use PINNPredictor as GeneralPredictor.

Source code in deploy/python_infer/__init__.py
class GeneralPredictor(PINNPredictor):
    """Use PINNPredictor as GeneralPredictor."""

    pass

PINNPredictor

Bases: Predictor

General predictor for PINN-based models.

Parameters:

Name Type Description Default
cfg DictConfig

Running configuration.

required

Examples:

>>> import numpy as np
>>> import paddle
>>> from omegaconf import DictConfig
>>> from paddle.static import InputSpec
>>> import ppsci
>>> from deploy.python_infer import pinn_predictor
>>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16)
>>> static_model = paddle.jit.to_static(
...     model,
...     input_spec=[
...         {
...             key: InputSpec([None, 1], "float32", name=key)
...             for key in model.input_keys
...         },
...     ],
... )
>>> paddle.jit.save(static_model, "./inference")
>>> cfg = DictConfig(
...     {
...         "log_freq": 10,
...         "INFER": {
...             "pdmodel_path": "./inference.pdmodel",
...             "pdiparams_path": "./inference.pdiparams",
...             "device": "cpu",
...             "engine": "native",
...             "precision": "fp32",
...             "onnx_path": None,
...             "ir_optim": True,
...             "min_subgraph_size": 15,
...             "gpu_mem": 500,
...             "gpu_id": 0,
...             "max_batch_size": 10,
...             "num_cpu_threads": 10,
...         }
...     }
... )
>>> predictor = pinn_predictor.PINNPredictor(cfg)
>>> pred = predictor.predict(
...     {
...         "x": np.random.randn(4, 1).astype("float32"),
...         "y": np.random.randn(4, 1).astype("float32"),
...     },
...     batch_size=2,
... )
>>> for k, v in pred.items():
...     print(k, v.shape)
save_infer_model/scale_0.tmp_0 (4, 1)
save_infer_model/scale_1.tmp_0 (4, 1)
save_infer_model/scale_2.tmp_0 (4, 1)
Source code in deploy/python_infer/pinn_predictor.py
class PINNPredictor(base.Predictor):
    """General predictor for PINN-based models.

    Args:
        cfg (DictConfig): Running configuration.

    Examples:
        >>> import numpy as np
        >>> import paddle
        >>> from omegaconf import DictConfig
        >>> from paddle.static import InputSpec
        >>> import ppsci
        >>> from deploy.python_infer import pinn_predictor
        >>> model = ppsci.arch.MLP(("x", "y"), ("u", "v", "p"), 3, 16)
        >>> static_model = paddle.jit.to_static(
        ...     model,
        ...     input_spec=[
        ...         {
        ...             key: InputSpec([None, 1], "float32", name=key)
        ...             for key in model.input_keys
        ...         },
        ...     ],
        ... )
        >>> paddle.jit.save(static_model, "./inference")
        >>> cfg = DictConfig(
        ...     {
        ...         "log_freq": 10,
        ...         "INFER": {
        ...             "pdmodel_path": "./inference.pdmodel",
        ...             "pdiparams_path": "./inference.pdiparams",
        ...             "device": "cpu",
        ...             "engine": "native",
        ...             "precision": "fp32",
        ...             "onnx_path": None,
        ...             "ir_optim": True,
        ...             "min_subgraph_size": 15,
        ...             "gpu_mem": 500,
        ...             "gpu_id": 0,
        ...             "max_batch_size": 10,
        ...             "num_cpu_threads": 10,
        ...         }
        ...     }
        ... )
        >>> predictor = pinn_predictor.PINNPredictor(cfg) # doctest: +SKIP
        >>> pred = predictor.predict(
        ...     {
        ...         "x": np.random.randn(4, 1).astype("float32"),
        ...         "y": np.random.randn(4, 1).astype("float32"),
        ...     },
        ...     batch_size=2,
        ... ) # doctest: +SKIP
        >>> for k, v in pred.items(): # doctest: +SKIP
        ...     print(k, v.shape) # doctest: +SKIP
        save_infer_model/scale_0.tmp_0 (4, 1)
        save_infer_model/scale_1.tmp_0 (4, 1)
        save_infer_model/scale_2.tmp_0 (4, 1)
    """

    def __init__(
        self,
        cfg: DictConfig,
    ):
        super().__init__(
            cfg.INFER.pdmodel_path,
            cfg.INFER.pdiparams_path,
            device=cfg.INFER.device,
            engine=cfg.INFER.engine,
            precision=cfg.INFER.precision,
            onnx_path=cfg.INFER.onnx_path,
            ir_optim=cfg.INFER.ir_optim,
            min_subgraph_size=cfg.INFER.min_subgraph_size,
            gpu_mem=cfg.INFER.gpu_mem,
            gpu_id=cfg.INFER.gpu_id,
            max_batch_size=cfg.INFER.max_batch_size,
            num_cpu_threads=cfg.INFER.num_cpu_threads,
        )
        self.log_freq = cfg.log_freq

    def predict(
        self,
        input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
        batch_size: int = 64,
    ) -> Dict[str, np.ndarray]:
        """
        Predicts the output of the model for the given input.

        Args:
            input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]):
                A dictionary containing the input data.
            batch_size (int, optional): The batch size to use for prediction.
                Defaults to 64.

        Returns:
            Dict[str, np.ndarray]: A dictionary containing the predicted output.
        """
        if batch_size > self.max_batch_size:
            logger.warning(
                f"batch_size({batch_size}) is larger than "
                f"max_batch_size({self.max_batch_size}), which may occur error."
            )

        if self.engine != "onnx":
            # prepare input handle(s)
            input_handles = {
                name: self.predictor.get_input_handle(name) for name in input_dict
            }
            # prepare output handle(s)
            output_handles = {
                name: self.predictor.get_output_handle(name)
                for name in self.predictor.get_output_names()
            }
        else:
            # input_names = [node_arg.name for node_arg in self.predictor.get_inputs()]
            output_names: List[str] = [
                node_arg.name for node_arg in self.predictor.get_outputs()
            ]

        num_samples = len(next(iter(input_dict.values())))
        batch_num = (num_samples + (batch_size - 1)) // batch_size
        pred_dict = misc.Prettydefaultdict(list)

        # inference by batch
        for batch_id in range(1, batch_num + 1):
            if batch_id == 1 or batch_id % self.log_freq == 0 or batch_id == batch_num:
                logger.info(f"Predicting batch {batch_id}/{batch_num}")

            # prepare batch input dict
            st = (batch_id - 1) * batch_size
            ed = min(num_samples, batch_id * batch_size)
            batch_input_dict = {key: input_dict[key][st:ed] for key in input_dict}

            # send batch input data to input handle(s)
            if self.engine != "onnx":
                for name, handle in input_handles.items():
                    handle.copy_from_cpu(batch_input_dict[name])

            # run predictor
            if self.engine != "onnx":
                self.predictor.run()
                # receive batch output data from output handle(s)
                batch_output_dict = {
                    name: output_handles[name].copy_to_cpu() for name in output_handles
                }
            else:
                batch_outputs = self.predictor.run(
                    output_names=output_names,
                    input_feed=batch_input_dict,
                )
                batch_output_dict = {
                    name: output for (name, output) in zip(output_names, batch_outputs)
                }

            # collect batch output data
            for key, batch_output in batch_output_dict.items():
                pred_dict[key].append(batch_output)

        # concatenate local predictions
        pred_dict = {key: np.concatenate(value) for key, value in pred_dict.items()}

        return pred_dict
predict(input_dict, batch_size=64)

Predicts the output of the model for the given input.

Parameters:

Name Type Description Default
input_dict Dict[str, Union[ndarray, Tensor]]

A dictionary containing the input data.

required
batch_size int

The batch size to use for prediction. Defaults to 64.

64

Returns:

Type Description
Dict[str, ndarray]

Dict[str, np.ndarray]: A dictionary containing the predicted output.

Source code in deploy/python_infer/pinn_predictor.py
def predict(
    self,
    input_dict: Dict[str, Union[np.ndarray, paddle.Tensor]],
    batch_size: int = 64,
) -> Dict[str, np.ndarray]:
    """
    Predicts the output of the model for the given input.

    Args:
        input_dict (Dict[str, Union[np.ndarray, paddle.Tensor]]):
            A dictionary containing the input data.
        batch_size (int, optional): The batch size to use for prediction.
            Defaults to 64.

    Returns:
        Dict[str, np.ndarray]: A dictionary containing the predicted output.
    """
    if batch_size > self.max_batch_size:
        logger.warning(
            f"batch_size({batch_size}) is larger than "
            f"max_batch_size({self.max_batch_size}), which may occur error."
        )

    if self.engine != "onnx":
        # prepare input handle(s)
        input_handles = {
            name: self.predictor.get_input_handle(name) for name in input_dict
        }
        # prepare output handle(s)
        output_handles = {
            name: self.predictor.get_output_handle(name)
            for name in self.predictor.get_output_names()
        }
    else:
        # input_names = [node_arg.name for node_arg in self.predictor.get_inputs()]
        output_names: List[str] = [
            node_arg.name for node_arg in self.predictor.get_outputs()
        ]

    num_samples = len(next(iter(input_dict.values())))
    batch_num = (num_samples + (batch_size - 1)) // batch_size
    pred_dict = misc.Prettydefaultdict(list)

    # inference by batch
    for batch_id in range(1, batch_num + 1):
        if batch_id == 1 or batch_id % self.log_freq == 0 or batch_id == batch_num:
            logger.info(f"Predicting batch {batch_id}/{batch_num}")

        # prepare batch input dict
        st = (batch_id - 1) * batch_size
        ed = min(num_samples, batch_id * batch_size)
        batch_input_dict = {key: input_dict[key][st:ed] for key in input_dict}

        # send batch input data to input handle(s)
        if self.engine != "onnx":
            for name, handle in input_handles.items():
                handle.copy_from_cpu(batch_input_dict[name])

        # run predictor
        if self.engine != "onnx":
            self.predictor.run()
            # receive batch output data from output handle(s)
            batch_output_dict = {
                name: output_handles[name].copy_to_cpu() for name in output_handles
            }
        else:
            batch_outputs = self.predictor.run(
                output_names=output_names,
                input_feed=batch_input_dict,
            )
            batch_output_dict = {
                name: output for (name, output) in zip(output_names, batch_outputs)
            }

        # collect batch output data
        for key, batch_output in batch_output_dict.items():
            pred_dict[key].append(batch_output)

    # concatenate local predictions
    pred_dict = {key: np.concatenate(value) for key, value in pred_dict.items()}

    return pred_dict