Writing Custom Providers¶
Create custom diagnostics providers by subclassing DiagnosticsProvider.
Example¶
import onnx_ir as ir
import onnx_doctor
from onnx_doctor import Rule
MY_RULE = Rule(
code="CUSTOM001",
name="large-model",
message="Model has more than 1000 nodes.",
default_severity="warning",
category="spec",
target_type="graph",
suggestion="Consider optimizing or splitting the model.",
)
class MyProvider(onnx_doctor.DiagnosticsProvider):
def diagnose(self, model: ir.Model) -> onnx_doctor.DiagnosticsMessageIterator:
"""Analyze the model and yield diagnostics messages."""
# Check each graph in the model
yield from self._check_graph(model.graph)
# Optionally check functions
for func in model.functions.values():
for node in func:
# ... check nodes in functions
pass
def _check_graph(self, graph: ir.Graph) -> onnx_doctor.DiagnosticsMessageIterator:
if len(graph) > 1000:
yield onnx_doctor.DiagnosticsMessage(
target_type="graph",
target=graph,
message=f"Graph has {len(graph)} nodes.",
severity=MY_RULE.default_severity,
producer="MyProvider",
error_code=MY_RULE.code,
rule=MY_RULE,
)
# Recursively check subgraphs in node attributes
for node in graph:
for attr in node.attributes.values():
if attr.type == ir.AttributeType.GRAPH:
yield from self._check_graph(attr.value)
elif attr.type == ir.AttributeType.GRAPHS:
for subgraph in attr.value:
yield from self._check_graph(subgraph)
# Use it
model = ir.load("model.onnx")
messages = onnx_doctor.diagnose(model, [MyProvider()])
The diagnose Method¶
Each provider implements a single method:
def diagnose(self, model: ir.Model) -> onnx_doctor.DiagnosticsMessageIterator:
...
The provider is responsible for walking the model structure as needed. This gives providers full control over their traversal strategy. Common patterns:
Use
ir.traversal.RecursiveGraphIterator(graph)to iterate all nodes including subgraphs, or usegraph.all_nodes()orfunc.all_nodes().Manually iterate
model.graph,model.functions, node attributes, etc.
Location Inference¶
You don’t need to set location on messages. The driver automatically infers
a human-readable location path from the target object:
Node:
graph:node[0](Relu, "my_node")Value:
graph:input[0](X)orgraph:node[0](Relu):output[0](Y)Graph:
graphorgraph:node[0](If):then_branchFunction:
function(domain:name)
Valid Target Types¶
The target field must be one of:
ir.Modelir.Graphir.Nodeir.Valueir.Function