diff --git a/_doc/api/light_api.rst b/_doc/api/light_api.rst index 889a70b..544b35f 100644 --- a/_doc/api/light_api.rst +++ b/_doc/api/light_api.rst @@ -19,10 +19,15 @@ translate Classes for the Light API ========================= -ProtoType -+++++++++ +domain +++++++ -.. autoclass:: onnx_array_api.light_api.model.ProtoType +..autofunction:: onnx_array_api.light_api.domain + +BaseVar ++++++++ + +.. autoclass:: onnx_array_api.light_api.var.BaseVar :members: OnnxGraph @@ -31,10 +36,16 @@ OnnxGraph .. autoclass:: onnx_array_api.light_api.OnnxGraph :members: -BaseVar -+++++++ +ProtoType ++++++++++ -.. autoclass:: onnx_array_api.light_api.var.BaseVar +.. autoclass:: onnx_array_api.light_api.model.ProtoType + :members: + +SubDomain ++++++++++ + +.. autoclass:: onnx_array_api.light_api.var.SubDomain :members: Var diff --git a/_doc/tutorial/light_api.rst b/_doc/tutorial/light_api.rst index 4e18793..35474fa 100644 --- a/_doc/tutorial/light_api.rst +++ b/_doc/tutorial/light_api.rst @@ -76,3 +76,32 @@ operator `+` to be available as well and that the case. They are defined in class :class:`Var ` or :class:`Vars ` depending on the number of inputs they require. Their name starts with a lower letter. + +Other domains +============= + +The following example uses operator *Normalizer* from domain +*ai.onnx.ml*. The operator name is called with the syntax +`.`. The domain may have dots in its name +but it must follow the python definition of a variable. +The operator *Normalizer* becomes `ai.onnx.ml.Normalizer`. + +.. runpython:: + :showcode: + + import numpy as np + from onnx_array_api.light_api import start + from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + model = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + + print(onnx_simple_text_plot(model)) diff --git a/_unittests/ut_light_api/test_light_api.py b/_unittests/ut_light_api/test_light_api.py index 98dd64d..f6ae051 100644 --- a/_unittests/ut_light_api/test_light_api.py +++ b/_unittests/ut_light_api/test_light_api.py @@ -1,3 +1,4 @@ +import inspect import unittest from typing import Callable, Optional import numpy as np @@ -12,6 +13,7 @@ from onnx.reference import ReferenceEvaluator from onnx_array_api.ext_test_case import ExtTestCase, skipif_ci_windows from onnx_array_api.light_api import start, OnnxGraph, Var, g +from onnx_array_api.light_api.var import SubDomain from onnx_array_api.light_api._op_var import OpsVar from onnx_array_api.light_api._op_vars import OpsVars @@ -472,7 +474,43 @@ def test_if(self): got = ref.run(None, {"X": -x}) self.assertEqualArray(np.array([0], dtype=np.int64), got[0]) + def test_domain(self): + onx = start(opsets={"ai.onnx.ml": 3}).vin("X").reshape((-1, 1)).rename("USE") + + class A: + def g(self): + return True + + def ah(self): + return True + + setattr(A, "h", ah) + + self.assertTrue(A().h()) + self.assertIn("(self)", str(inspect.signature(A.h))) + self.assertTrue(issubclass(onx._ai, SubDomain)) + self.assertIsInstance(onx.ai, SubDomain) + self.assertIsInstance(onx.ai.parent, Var) + self.assertTrue(issubclass(onx._ai._onnx, SubDomain)) + self.assertIsInstance(onx.ai.onnx, SubDomain) + self.assertIsInstance(onx.ai.onnx.parent, Var) + self.assertTrue(issubclass(onx._ai._onnx._ml, SubDomain)) + self.assertIsInstance(onx.ai.onnx.ml, SubDomain) + self.assertIsInstance(onx.ai.onnx.ml.parent, Var) + self.assertIn("(self,", str(inspect.signature(onx._ai._onnx._ml.Normalizer))) + onx = onx.ai.onnx.ml.Normalizer(norm="MAX") + onx = onx.rename("Y").vout().to_onnx() + self.assertIsInstance(onx, ModelProto) + self.assertIn("Normalizer", str(onx)) + self.assertIn('domain: "ai.onnx.ml"', str(onx)) + self.assertIn('input: "USE"', str(onx)) + ref = ReferenceEvaluator(onx) + a = np.arange(10).astype(np.float32) + got = ref.run(None, {"X": a})[0] + expected = (a > 0).astype(int).astype(np.float32).reshape((-1, 1)) + self.assertEqualArray(expected, got) + if __name__ == "__main__": - TestLightApi().test_if() + TestLightApi().test_domain() unittest.main(verbosity=2) diff --git a/_unittests/ut_light_api/test_translate.py b/_unittests/ut_light_api/test_translate.py index 794839f..c2b2c70 100644 --- a/_unittests/ut_light_api/test_translate.py +++ b/_unittests/ut_light_api/test_translate.py @@ -185,6 +185,39 @@ def test_export_if(self): self.maxDiff = None self.assertEqual(expected, code) + def test_aionnxml(self): + onx = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx) + expected = dedent( + """ + ( + start(opset=19, opsets={'ai.onnx.ml': 3}) + .cst(np.array([-1, 1], dtype=np.int64)) + .rename('r') + .vin('X', elem_type=TensorProto.FLOAT) + .bring('X', 'r') + .Reshape() + .rename('USE') + .bring('USE') + .ai.onnx.ml.Normalizer(norm='MAX') + .rename('Y') + .bring('Y') + .vout(elem_type=TensorProto.FLOAT) + .to_onnx() + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + if __name__ == "__main__": TestTranslate().test_export_if() diff --git a/_unittests/ut_light_api/test_translate_classic.py b/_unittests/ut_light_api/test_translate_classic.py index afdee8d..cb7d6a4 100644 --- a/_unittests/ut_light_api/test_translate_classic.py +++ b/_unittests/ut_light_api/test_translate_classic.py @@ -252,6 +252,72 @@ def test_fft(self): ) raise AssertionError(f"ERROR {e}\n{new_code}") + def test_aionnxml(self): + onx = ( + start(opset=19, opsets={"ai.onnx.ml": 3}) + .vin("X") + .reshape((-1, 1)) + .rename("USE") + .ai.onnx.ml.Normalizer(norm="MAX") + .rename("Y") + .vout() + .to_onnx() + ) + code = translate(onx, api="onnx") + print(code) + expected = dedent( + """ + opset_imports = [ + make_opsetid('', 19), + make_opsetid('ai.onnx.ml', 3), + ] + inputs = [] + outputs = [] + nodes = [] + initializers = [] + sparse_initializers = [] + functions = [] + initializers.append( + from_array( + np.array([-1, 1], dtype=np.int64), + name='r' + ) + ) + inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[])) + nodes.append( + make_node( + 'Reshape', + ['X', 'r'], + ['USE'] + ) + ) + nodes.append( + make_node( + 'Normalizer', + ['USE'], + ['Y'], + domain='ai.onnx.ml', + norm='MAX' + ) + ) + outputs.append(make_tensor_value_info('Y', TensorProto.FLOAT, shape=[])) + graph = make_graph( + nodes, + 'light_api', + inputs, + outputs, + initializers, + sparse_initializer=sparse_initializers, + ) + model = make_model( + graph, + functions=functions, + opset_imports=opset_imports + )""" + ).strip("\n") + self.maxDiff = None + self.assertEqual(expected, code) + if __name__ == "__main__": # TestLightApi().test_topk() diff --git a/onnx_array_api/light_api/__init__.py b/onnx_array_api/light_api/__init__.py index 3ebb413..be6e9dd 100644 --- a/onnx_array_api/light_api/__init__.py +++ b/onnx_array_api/light_api/__init__.py @@ -1,5 +1,6 @@ from typing import Dict, Optional from onnx import ModelProto +from .annotations import domain from .model import OnnxGraph, ProtoType from .translate import Translater from .var import Var, Vars diff --git a/onnx_array_api/light_api/_op_var.py b/onnx_array_api/light_api/_op_var.py index c685437..8a995b3 100644 --- a/onnx_array_api/light_api/_op_var.py +++ b/onnx_array_api/light_api/_op_var.py @@ -1,4 +1,5 @@ from typing import List, Optional, Union +from .annotations import AI_ONNX_ML, domain class OpsVar: @@ -319,6 +320,10 @@ def Transpose(self, perm: Optional[List[int]] = None) -> "Var": perm = perm or [] return self.make_node("Transpose", self, perm=perm) + @domain(AI_ONNX_ML) + def Normalizer(self, norm: str = "MAX"): + return self.make_node("Normalizer", self, norm=norm, domain=AI_ONNX_ML) + def _complete(): ops_to_add = [ diff --git a/onnx_array_api/light_api/annotations.py b/onnx_array_api/light_api/annotations.py index c975dab..3fe7973 100644 --- a/onnx_array_api/light_api/annotations.py +++ b/onnx_array_api/light_api/annotations.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np from onnx import FunctionProto, GraphProto, ModelProto, TensorProto, TensorShapeProto from onnx.helper import np_dtype_to_tensor_dtype @@ -9,12 +9,47 @@ VAR_CONSTANT_TYPE = Union["Var", TensorProto, np.ndarray] GRAPH_PROTO = Union[FunctionProto, GraphProto, ModelProto] +AI_ONNX_ML = "ai.onnx.ml" + ELEMENT_TYPE_NAME = { getattr(TensorProto, k): k for k in dir(TensorProto) if isinstance(getattr(TensorProto, k), int) and "_" not in k } + +class SubDomain: + pass + + +def domain(domain: str, op_type: Optional[str] = None) -> Callable: + """ + Registers one operator into a sub domain. It should be used as a + decorator. One example: + + .. code-block:: python + + @domain("ai.onnx.ml") + def Normalizer(self, norm: str = "MAX"): + return self.make_node("Normalizer", self, norm=norm, domain="ai.onnx.ml") + """ + names = [op_type] + + def decorate(op_method: Callable) -> Callable: + if names[0] is None: + names[0] = op_method.__name__ + + def wrapper(self, *args: List[Any], **kwargs: Dict[str, Any]) -> Any: + return op_method(self.parent, *args, **kwargs) + + wrapper.__qual__name__ = f"[{domain}]{names[0]}" + wrapper.__name__ = f"[{domain}]{names[0]}" + wrapper.__domain__ = domain + return wrapper + + return decorate + + _type_numpy = { np.float32: TensorProto.FLOAT, np.float64: TensorProto.DOUBLE, diff --git a/onnx_array_api/light_api/emitter.py b/onnx_array_api/light_api/emitter.py index c52acfc..a1b0e40 100644 --- a/onnx_array_api/light_api/emitter.py +++ b/onnx_array_api/light_api/emitter.py @@ -241,7 +241,7 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] - raise NotImplementedError(f"domain={domain!r} not supported yet.") + op_type = f"{domain}.{op_type}" atts = kwargs.get("atts", {}) args = [] for k, v in atts.items(): diff --git a/onnx_array_api/light_api/inner_emitter.py b/onnx_array_api/light_api/inner_emitter.py index a2173e0..f5d5e4d 100644 --- a/onnx_array_api/light_api/inner_emitter.py +++ b/onnx_array_api/light_api/inner_emitter.py @@ -120,7 +120,6 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]: outputs = kwargs["outputs"] if kwargs.get("domain", "") != "": domain = kwargs["domain"] - raise NotImplementedError(f"domain={domain!r} not supported yet.") before_lines = [] lines = [ diff --git a/onnx_array_api/light_api/model.py b/onnx_array_api/light_api/model.py index 7391e0b..67fc18e 100644 --- a/onnx_array_api/light_api/model.py +++ b/onnx_array_api/light_api/model.py @@ -248,6 +248,9 @@ def make_node( node = make_node(op_type, input_names, output_names, domain=domain, **kwargs) self.nodes.append(node) + if domain != "": + if not self.opsets or domain not in self.opsets: + raise RuntimeError(f"No opset value was given for domain {domain!r}.") return node def cst(self, value: np.ndarray, name: Optional[str] = None) -> "Var": diff --git a/onnx_array_api/light_api/var.py b/onnx_array_api/light_api/var.py index ddcc7f5..882dcb7 100644 --- a/onnx_array_api/light_api/var.py +++ b/onnx_array_api/light_api/var.py @@ -1,3 +1,4 @@ +import inspect from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from onnx import TensorProto @@ -16,6 +17,26 @@ from ._op_vars import OpsVars +class SubDomain: + """ + Declares a domain or a piece of it (if it contains '.' in its name). + """ + + def __init__(self, var: "BaseVar"): + if not isinstance(var, BaseVar): + raise TypeError(f"Unexpected type {type(var)}.") + self.parent = var + + +def _getclassattr_(self, name): + if not hasattr(self.__class__, name): + raise TypeError( + f"Unable to find {name!r} in class {self.__class__.__name__!r}, " + f"available {dir(self.__class__)}." + ) + return getattr(self.__class__, name) + + class BaseVar: """ Represents an input, an initializer, a node, an output, @@ -24,6 +45,88 @@ class BaseVar: :param parent: the graph containing the Variable """ + def __new__(cls, *args, **kwargs): + """ + If called for the first instantiation of a BaseVar, it process + all methods declared with decorator :func:`onnx_array_api.light_api.domain` + so that it can be called with a syntax `v..`. + """ + res = super().__new__(cls) + res.__init__(*args, **kwargs) + if getattr(cls, "__incomplete", True): + for k in dir(cls): + att = getattr(cls, k, None) + if not att: + continue + name = getattr(att, "__name__", None) + if not name or name[0] != "[": + continue + + # A function with a domain name + if not inspect.isfunction(att): + raise RuntimeError(f"{cls.__name__}.{k} is not a function.") + domain, op_type = name[1:].split("]") + if "." in domain: + spl = domain.split(".", maxsplit=1) + dname = f"_{spl[0]}" + if not hasattr(cls, dname): + d = type( + f"{cls.__name__}{dname}", (SubDomain,), {"name": dname[1:]} + ) + setattr(cls, dname, d) + setattr( + cls, + spl[0], + property( + lambda self, _name_=dname: _getclassattr_(self, _name_)( + self + ) + ), + ) + else: + d = getattr(cls, dname) + suffix = spl[0] + for p in spl[1].split("."): + dname = f"_{p}" + suffix += dname + if not hasattr(d, dname): + sd = type( + f"{cls.__name__}_{suffix}", + (SubDomain,), + {"name": suffix}, + ) + setattr(d, dname, sd) + setattr( + d, + p, + property( + lambda self, _name_=dname: _getclassattr_( + self, _name_ + )(self.parent) + ), + ) + d = sd + else: + d = getattr(d, dname) + elif not hasattr(cls, domain): + dname = f"_{domain}" + d = type(f"{cls.__name__}{dname}", (SubDomain,), {"name": domain}) + setattr(cls, dname, d) + setattr( + cls, + domain, + property( + lambda self, _name_=dname: _getclassattr_(self, _name_)( + self + ) + ), + ) + + setattr(d, op_type, att) + setattr(cls, "__incomplete", False) + + return res + def __init__( self, parent: OnnxGraph,