Writing Custom Providers¶
Create custom diagnostics providers by subclassing DiagnosticsProvider.
Example¶
import onnx_ir as ir
import onnx_doctor
from onnx_doctor import Rule
MY_RULE = Rule(
code="CUSTOM001",
name="large-model",
message="Model has more than 1000 nodes.",
default_severity="warning",
category="spec",
target_type="graph",
suggestion="Consider optimizing or splitting the model.",
)
class MyProvider(onnx_doctor.DiagnosticsProvider):
def check_graph(self, graph: ir.Graph):
node_count = sum(1 for _ in graph)
if node_count > 1000:
yield onnx_doctor.DiagnosticsMessage(
target_type="graph",
target=graph,
message=f"Graph has {node_count} nodes.",
severity=MY_RULE.default_severity,
producer="MyProvider",
error_code=MY_RULE.code,
rule=MY_RULE,
)
# Use it
model = ir.load("model.onnx")
messages = onnx_doctor.diagnose(model, [MyProvider()])
Available Check Methods¶
Override any of these methods in your provider:
check_model(model)— Called once per model.check_graph(graph)— Called for each graph.check_function(function)— Called for each function.check_node(node)— Called for each node.check_value(value)— Called for each value.check_tensor(tensor)— Called for each tensor.check_attribute(attribute)— Called for each attribute.