Skip to content

Export the Best Model to ONNX Format

This notebook provides a step-by-step tutorial for exporting a selected model from the MLflow model registry to ONNX format. The converted model is saved within the inference Python module to support the development of a new Python application and the creation of an inference Docker image, which is then published to the designated container registry.

Note: This process has already been completed. However, users may need to repeat it with their own candidate models.

Install dependencies

pip install tf2onnx onnxmltools onnxruntime onnx mlflow tensorflow

Import dependencies

import json
import os
import mlflow
import tensorflow as tf
import tf2onnx
import keras

Save Model in ONNX Format

In the cells below, the user will download the best model artifact from the MLflow model registry and then save it in the ONNX format.

Note: You may need to decrease the desired_test_accuracy to find active runs in the MLflow model registry.

params = {
    "MLFLOW_TRACKING_URI": "http://localhost:5000/",
    "experiment_id": "EuroSAT_classification",

}
desired_test_accuracy = 0.85
# Search for best run
active_runs = (
    mlflow.search_runs(
        experiment_names=[params["experiment_id"]],
        filter_string=f"metrics.test_accuracy > {desired_test_accuracy}",
        search_all_experiments=True,
    )
    .sort_values(by=["metrics.test_accuracy", "metrics.test_precision"], ascending=False)
    .reset_index()
    .loc[0]
)
run_id = active_runs["run_id"]
print(f"Selected run_id: {run_id}")

# Download just the .keras file
model_uri = f"runs:/{run_id}/model/model.keras/data/model.keras"
keras_path = mlflow.artifacts.download_artifacts(artifact_uri=model_uri)
print(f"Downloaded Keras file path: {keras_path}")

# Load the Keras v3 model
keras_model = keras.models.load_model(keras_path)

# Define input signature
input_signature = [tf.TensorSpec([None, 64, 64, 12], tf.float32, name="input")]

@tf.function(input_signature=input_signature)
def model_func(x):
    return keras_model(x)

# Convert to ONNX
onnx_model, _ = tf2onnx.convert.from_function(
    model_func,
    input_signature=input_signature,
    opset=13,
    output_path="model.onnx"
)

print("✅ Successfully saved model.onnx")