查看模型能让你知道一些关键的信息、比如模型的构成、还有更重要的一点是拿到 input_name 来塑造我们的 shape_dict,就比如用到的 SR 模型,输入的 name 是‘1’,那 shape_dict 就应该是{‘1’: x.shape}
接下来看frontend.fromonnx,可以看到 relay 前端提供了很多框架模型的接口
1 2 3 4 5 6 7 8 9 10 11
from .mxnet import from_mxnet from .mxnet_qnn_op_utils import quantize_conv_bias_mkldnn_from_var from .keras import from_keras from .onnx import from_onnx from .tflite import from_tflite from .coreml import from_coreml from .caffe2 import from_caffe2 from .tensorflow import from_tensorflow from .darknet import from_darknet from .pytorch import from_pytorch from .caffe import from_caffe
这里插个题外话,科普一下 Python 小知识。在框架源代码中经常使用 from . import A 或者 from .A import B 的操作是什么意思? 首先,.的意思是当前目录,..的意思是上级目录。 当碰到 from . import A,python 回去找当前目录下的 __init__.py文件,从里面去找 A,如果是..就是上级文件夹。 如果当前目录下没有__init__.py,则需要 from .A import B,回到当前目录下的A.py里去寻找 B,如果是..就是上级文件夹。
deffrom_onnx(model, shape=None, dtype="float32", opset=None): try: import onnx ifhasattr(onnx.checker, 'check_model'): # try use onnx's own model checker before converting any model try: onnx.checker.check_model(model) except onnx.onnx_cpp2py_export.checker.ValidationError as e: import warnings # the checker is a bit violent about errors, so simply print warnings here warnings.warn(str(e)) except ImportError: pass g = GraphProto(shape, dtype) graph = model.graph if opset isNone: try: opset = model.opset_import[0].version if model.opset_import else1 except AttributeError: opset = 1 mod, params = g.from_onnx(graph, opset) g = None return mod, params
for i in graph.input: # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' i_name = self._parse_value_proto(i) d_type = self._parse_dtype(i, "float32") if i_name in self._params: # i is a param instead of input self._num_param += 1 self._params[i_name] = self._params.pop(i_name) self._nodes[i_name] = new_var( i_name, shape=self._params[i_name].shape, dtype=self._params[i_name].dtype ) else: self._num_input += 1 if i_name in self._shape: tshape = self._shape[i_name] else: raise ValueError("Must provide an input shape for `{0}`.".format(i_name)) ifisinstance(self._dtype, dict): dtype = self._dtype[i_name] if i_name in self._dtype else d_type else: dtype = d_type self._nodes[i_name] = new_var(i_name, shape=tshape, dtype=dtype) self._inputs[i_name] = self._nodes[i_name]
for node in graph.node: op_name = node.op_type if ( op_name notin convert_map and op_name != "Constant" and op_name notin _identity_list ): unsupported_ops.add(op_name)
Comments