Skip to content

Describe a trained machine learning model

This Notebook leverages the capabilities of STAC to provide a comprehensive and standardized description of a trained ML model. This is done with a STAC Item file that encapsulates the relevant metadata (e.g. model name and version, description of the model architecture and training process, specifications of input and output data formats, etc.).

This Notebook can be used for the following requirements:

  • Import Libraries (e.g. pystac, boto3)
  • Option to either create a STAC Item with pystac, or to upload an existing STAC Item into the Notebook. The STAC Item will contain all related ML model specific properties, related STAC extensions and hyperparameter.
  • Create interlinked STAC Item, Catalog and Collection, and the STAC folder structure

Objective: By the end of this Notebook, the user will have published a STAC Item, Collection and Catalog into the STAC endpoint, and tested its search functionalities via query parameters.

Table of Content:

1) Import Libraries 2) Create STAC Item, Catalog and Collection

1) Import Libraries

from datetime import datetime
import os
import json
import requests
from pathlib import Path
import concurrent.futures
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from tqdm import tqdm

import pystac
from pystac import read_file
from pystac.extensions.version import ItemVersionExtension
from pystac.extensions.eo import EOExtension
from pystac_client import Client
from pystac.stac_io import DefaultStacIO, StacIO
from pystac.extensions.eo import Band, EOExtension
from pystac.extensions.file import FileExtension


from loguru import logger
from urllib.parse import urljoin, urlparse

import boto3
import botocore

from utils import (
    UserSettings,
    StarsCopyWrapper,
    read_url,
    ingest_items,
    get_headers,
    CustomStacIO,
    getTemporalExtent,
    getGeom,
)

2) Create STAC Item, Catalog and Collection

# Create folder structure
CATALOG_DIR = "ML_Catalog"
COLLECTION_NAME = "ML-Models_EO"
ITEM_ID = "Tile-based-ML-Models"
SUB_DIR = os.path.join(CATALOG_DIR, COLLECTION_NAME)

STAC Item

NOTE: Please execute either section 2.1) Create STAC Item or section 2.2) Upload STAC Item below according to the following: * execute section 2.1) Create STAC Item if you want to create a STAC Item from scratch using pystac within this Notebook; or * execute section 2.2) Upload STAC Item if you have already created a STAC Item (i.e. a .json/.geojson file) and want to upload it into this Notebook

2.1) Create STAC Item

# Define BBOX of the Item
bbox = [-121.87680832296513, 36.93063805399626, -120.06532070709298, 38.84330548198025]

item = pystac.Item(
    id=ITEM_ID,
    bbox=bbox,
    geometry=getGeom(bbox),
    datetime=datetime.now(),
    properties={},
)
item

Adding Item properties

In the following section, the user will provide the Item's properties for creation of STAC Item.

# Add standard properties
item.properties["start_datetime"] = "2023-06-13T00:00:00Z"
item.properties["end_datetime"] = "2023-06-18T23:59:59Z"
item.properties["description"] = (
    """Tile based classifier using CNNs for land cover classification. 
    The model is trained on the Sentinel-2 dataset and is capable of classifying 
    land cover types such as water, forest, urban, and agriculture. 
    The model is designed to work with Sentinel-2 imagery and can be used for """
)

item
# Add "ml-model" properties
item.properties["ml-model:type"] = "ml-model"
item.properties["ml-model:learning_approach"] = "supervised"
item.properties["ml-model:prediction_type"] = "classification"
item.properties["ml-model:architecture"] = "ResNet-18"
item.properties["ml-model:training-processor-type"] = "cpu"
item.properties["ml-model:training-os"] = "linux"
# Add "mlm-ext" properties
item.properties["mlm:name"] = "Tile-Based Classifier"
item.properties["mlm:architecture"] = "RandomForestClassifier"
item.properties["mlm:framework"] = "tensorflow"
item.properties["mlm:framework_version"] = "1.4.2"
item.properties["mlm:tasks"] = ["classification"]
item.properties["mlm:compiled"] = False
item.properties["mlm:accelerator"] = "amd64"
item.properties["mlm:accelerator_constrained"] = False

# Add hyperparameters
item.properties["mlm:hyperparameters"] = {
    "learning_rate": 0.001,  # Example value
    "batch_size": 32,  # Example value
    "number_of_epochs": 50,  # Example value
    "optimizer": "adam",  # Example value
    "momentum": 0.9,  # Example value
    "dropout_rate": 0.5,  # Example value
    "number_of_convolutional_layers": 3,  # Example value
    "filter_size": "3x3",  # Example value
    "number_of_filters": 64,  # Example value
    "activation_function": "relu",  # Example value
    "pooling_layers": "max",  # Example value
    "learning_rate_scheduler": "step_decay",  # Example value
    "l2_regularization": 1e-4,  # Example value
}

item.properties
{'datetime': '2025-04-04T12:30:39.002998Z',
 'start_datetime': '2023-06-13T00:00:00Z',
 'end_datetime': '2023-06-18T23:59:59Z',
 'description': 'Tile based classifier using CNNs for land cover classification. \n    The model is trained on the Sentinel-2 dataset and is capable of classifying \n    land cover types such as water, forest, urban, and agriculture. \n    The model is designed to work with Sentinel-2 imagery and can be used for ',
 'ml-model:type': 'ml-model',
 'ml-model:learning_approach': 'supervised',
 'ml-model:prediction_type': 'classification',
 'ml-model:architecture': 'ResNet-18',
 'ml-model:training-processor-type': 'cpu',
 'ml-model:training-os': 'linux',
 'mlm:name': 'Tile-Based Classifier',
 'mlm:architecture': 'RandomForestClassifier',
 'mlm:framework': 'tensorflow',
 'mlm:framework_version': '1.4.2',
 'mlm:tasks': ['classification'],
 'mlm:compiled': False,
 'mlm:accelerator': 'amd64',
 'mlm:accelerator_constrained': False,
 'mlm:hyperparameters': {'learning_rate': 0.001,
  'batch_size': 32,
  'number_of_epochs': 50,
  'optimizer': 'adam',
  'momentum': 0.9,
  'dropout_rate': 0.5,
  'number_of_convolutional_layers': 3,
  'filter_size': '3x3',
  'number_of_filters': 64,
  'activation_function': 'relu',
  'pooling_layers': 'max',
  'learning_rate_scheduler': 'step_decay',
  'l2_regularization': 0.0001}}

Model inputs

The properties of model inputs can be populated in the cell below

# Add input and output to the properties
item.properties["mlm:input"] = [
    {
        "name": "EO Data",
        "bands": [
            "B01",
            "B02",
            "B03",
            "B04",
            "B05",
            "B06",
            "B07",
            "B08",
            "B8A",
            "B09",
            "B10",
            "B11",
            "B12",
        ],
        "input": {
            "shape": [-1, 3, 64, 64],
            "dim_order": ["batch", "channel", "height", "width"],
            "data_type": "float32",
        },
        "norm_type": "z-score",
    }
]

Model outputs

class_map = {
    "Annual Crop": 0,
    "Forest": 1,
    "Herbaceous Vegetation": 2,
    "Highway": 3,
    "Industrial Buildings": 4,
    "Pasture": 5,
    "Permanent Crop": 6,
    "Residential Buildings": 7,
    "River": 8,
    "SeaLake": 9,
}

color_map = {
    0: (34, 139, 34, 255),  # AnnualCrop: Forest Green
    1: (0, 100, 0, 255),  # Forest: Dark Green
    2: (144, 238, 144, 255),  # HerbaceousVegetation: Light Green
    3: (128, 128, 128, 255),  # Highway: Gray
    4: (169, 169, 169, 255),  # Industrial: Dark Gray
    5: (85, 107, 47, 255),  # Pasture: Olive Green
    6: (60, 179, 113, 255),  # PermanentCrop: Medium Sea Green
    7: (139, 69, 19, 255),  # Residential: Saddle Brown
    8: (30, 144, 255, 255),  # River: Dodger Blue
    9: (0, 0, 255, 255),  # SeaLake: Blue
}

tmp_dict = []

for class_name, id in class_map.items():
    color = color_map[id]
    # Convert RGB to hex (without the alpha value)
    hex_color = "#{:02X}{:02X}{:02X}".format(color[0], color[1], color[2])

    tmp_dict.append({
        "name": class_name,
        "value": id,
        "description": f"{class_name} tile",
        "color_hint": hex_color.lower()[1:]  # Remove the "#" and convert to lowercase
    })
item.properties["mlm:output"] = [
    {
        "name": "CLASSIFICATION",
        "tasks": ["segmentation", "semantic-segmentation"],
        "result": {
            "shape": [-1, 10980, 10980],
            "dim_order": ["batch", "height", "width"],
            "data_type": "uint8",
        },
        "post_processing_function": None,
        "classification:classes": tmp_dict
    }
]

item

The user will add Raster bands to the Item's properties

# Add "raster:bands" properties
def add_prop_RasterBands(name, cname, nd, dt, bps, res, scale, offset, unit):
    return {
        "name": name,
        "common_name": cname,
        "nodata": nd,
        "data_type": dt,
        "bits_per_sample": bps,
        "spatial_resolution": res,
        "scale": scale,
        "offset": offset,
        "unit": unit,
    }


def add_prop_RasterBands_Expression(name, cname, nd, dt, exp):
    return {
        "name": name,
        "common_name": cname,
        "nodata": nd,
        "data_type": dt,
        "processing:expression": exp,
    }


item.properties["raster:bands"] = [
    add_prop_RasterBands(
        name="B01",
        cname="coastal",
        nd=0,
        dt="float32",
        bps=15,
        res=60,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B02",
        cname="blue",
        nd=0,
        dt="float32",
        bps=15,
        res=10,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B03",
        cname="green",
        nd=0,
        dt="float32",
        bps=15,
        res=10,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B04",
        cname="red",
        nd=0,
        dt="float32",
        bps=15,
        res=10,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B08",
        cname="nir",
        nd=0,
        dt="float32",
        bps=15,
        res=10,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B8A",
        cname="nir08",
        nd=0,
        dt="float32",
        bps=15,
        res=20,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B09",
        cname="nir09",
        nd=0,
        dt="float32",
        bps=15,
        res=60,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B11",
        cname="swir16",
        nd=0,
        dt="float32",
        bps=15,
        res=20,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
    add_prop_RasterBands(
        name="B12",
        cname="swir22",
        nd=0,
        dt="float32",
        bps=15,
        res=20,
        scale=0.0001,
        offset=0,
        unit="m",
    ),
]


# Display
item.properties["raster:bands"]
[{'name': 'B01',
  'common_name': 'coastal',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 60,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B02',
  'common_name': 'blue',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 10,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B03',
  'common_name': 'green',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 10,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B04',
  'common_name': 'red',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 10,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B08',
  'common_name': 'nir',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 10,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B8A',
  'common_name': 'nir08',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 20,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B09',
  'common_name': 'nir09',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 60,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B11',
  'common_name': 'swir16',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 20,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'},
 {'name': 'B12',
  'common_name': 'swir22',
  'nodata': 0,
  'data_type': 'float32',
  'bits_per_sample': 15,
  'spatial_resolution': 20,
  'scale': 0.0001,
  'offset': 0,
  'unit': 'm'}]
# Add Assets - ML Training
app_version = "0.0.2"
asset = pystac.Asset(
    title="Workflow for tile-based training",
    href=f"https://github.com/parham-membari-terradue/machine-learning-process/releases/download/{app_version}/tile-sat-training.{app_version}.cwl",
    media_type="application/cwl+yaml",
    roles=["ml-model:training-runtime", "runtime", "mlm:training-runtime"],
)
item.add_asset("tile-based-training", asset)

# Add Assets - Inference
asset = pystac.Asset(
    title="Workflow for tile-based inference",
    href=f"https://github.com/parham-membari-terradue/machine-learning-process/releases/download/{app_version}/tile-sat-inference.{app_version}.cwl",
    media_type="application/cwl+yaml",
    roles=["ml-model:inference-runtime", "runtime", "mlm:inference-runtime"],
)
item.add_asset("tile-based-inference", asset)

# Add Asset - ML model
asset = pystac.Asset(
    title="ONNX Model",
    href="https://github.com/parham-membari-terradue/machine-learning-process/blob/main/inference/make-inference/src/make_inference/model/model.onnx",
    media_type="application/octet-stream; framework=onnx; profile=onnx",
    roles=["mlm:model"],
)
item.add_asset("model", asset)

item.assets
{'tile-based-training': <Asset href=https://github.com/parham-membari-terradue/machine-learning-process/releases/download/0.0.2/tile-sat-training.0.0.2.cwl>,
 'tile-based-inference': <Asset href=https://github.com/parham-membari-terradue/machine-learning-process/releases/download/0.0.2/tile-sat-inference.0.0.2.cwl>,
 'model': <Asset href=https://github.com/parham-membari-terradue/machine-learning-process/blob/main/inference/make-inference/src/make_inference/model/model.onnx>}
# Add links
rel_path = f"./{SUB_DIR}/{item.id}/{item.id}.json"
item.set_self_href(rel_path)
item.links
[<Link rel=self target=/home/t2/Desktop/p/argo/machine-learning-process/MLM/ML_Catalog/ML-Models_EO/Tile-based-ML-Models/Tile-based-ML-Models.json>]

In addition to the EO STAC Extension, the user can add the "ML Model" STAC Extension (ml-model) in the STAC Item.

There is an upcoming extension for ML models that is under development, which will allow to store more details and information related to the ML model: "Machine Learning Model" STAC Extension (mlm).

# Add Extensions
EOExtension.ext(item, add_if_missing=True)

# Add the extension to the item and set the schema URL
if not any("ml-model" in url for url in item.stac_extensions):
    item.stac_extensions.append(
        "https://stac-extensions.github.io/ml-model/v1.0.0/schema.json"
    )
if not any("mlm-extension" in url for url in item.stac_extensions):
    item.stac_extensions.append(
        "https://crim-ca.github.io/mlm-extension/v1.2.0/schema.json"
    )
if not any("raster" in url for url in item.stac_extensions):
    item.stac_extensions.append(
        "https://stac-extensions.github.io/raster/v1.1.0/schema.json"
    )
if not any("file" in url for url in item.stac_extensions):
    item.stac_extensions.append(
        "https://stac-extensions.github.io/file/v2.1.0/schema.json"
    )
item.stac_extensions
['https://stac-extensions.github.io/eo/v1.1.0/schema.json',
 'https://stac-extensions.github.io/ml-model/v1.0.0/schema.json',
 'https://crim-ca.github.io/mlm-extension/v1.2.0/schema.json',
 'https://stac-extensions.github.io/raster/v1.1.0/schema.json',
 'https://stac-extensions.github.io/file/v2.1.0/schema.json']
item
# Validate STAC Item
item.validate()
['https://schemas.stacspec.org/v1.1.0/item-spec/json-schema/item.json',
 'https://stac-extensions.github.io/eo/v1.1.0/schema.json',
 'https://stac-extensions.github.io/ml-model/v1.0.0/schema.json',
 'https://crim-ca.github.io/mlm-extension/v1.2.0/schema.json',
 'https://stac-extensions.github.io/raster/v1.1.0/schema.json',
 'https://stac-extensions.github.io/file/v2.1.0/schema.json']

2.3) STAC Objects

STAC Catalog

Check if a catalog.json exists already. If not, create it, otherwise read existing catalog and add STAC Item to it.

cat_path = os.path.join(CATALOG_DIR, "catalog.json")

if not os.path.exists(cat_path):
    # Catalog does not exist - create it
    print("Catalog does not exist. Creating it")

    catalog = pystac.Catalog(
        id="ML-Model_EO", description="A catalog to describe ML models", title="ML Models"
    )
else:
    # Read Catalog and add the STAC Item to it
    print("Catalog exists already. Reading it")

    catalog = pystac.read_file(cat_path)
catalog.validate()
catalog
Catalog does not exist. Creating it

STAC Collection

Check if a collection.json exists already. If not, create it, otherwise read existing collection and add STAC Item to it.

coll_path = os.path.join(CATALOG_DIR, COLLECTION_NAME, "collection.json")

if not os.path.exists(coll_path):
    # Collection does not exist - create it
    print("Collection does not exist. Creating it")

    # Spatial extent
    bbox_world = [-180, -90, 180, 90]
    # Define temporal extent
    start_date = "2015-06-27T00:00:01.000000+00:00"
    end_date = None  # "2024-04-29T13:23:32.741484+00:00"

    collection = pystac.Collection(
        id=COLLECTION_NAME,
        description="A collection for ML Models",
        extent=pystac.Extent(
            spatial=pystac.SpatialExtent(bbox_world),
            temporal=getTemporalExtent(start_date, end_date),
        ),
        title=COLLECTION_NAME,
        license="proprietary",
        keywords=[],
        providers=[
            pystac.Provider(
                name="AI-Extensions Project",
                roles=["producer"],
                url="https://ai-extensions.github.io/docs",
            )
        ],
    )

else:
    # Read Collection and add the STAC Item to it
    print("Collection exists already. Reading it")

    collection = read_file(coll_path)

collection
Collection does not exist. Creating it

Note: In order to add a STAC Item to the Collection, ensure that a STAC Item is not already present in the Collection.

If that's the case, firstly open the collection.json file and delete the Item from it, then open the catalog.json file and delete the Collection from it.

This will ensure that both collection.json and catalog.json files are updated correctly.

# Add STAC Item to the Collection. Note: this works only if there are no items in the collection
if not any(item.id in link.href for link in collection.links if link.rel == "item"):
    # Add item
    print("Adding item")
    collection.add_item(item=item)
collection
Adding item
# Add Collection to the Catalog.
print("Adding Collection")
collection.set_parent(catalog)
catalog.add_child(collection)
catalog
Adding Collection

Normalise the catalog to save the files under a specific folder name

catalog.normalize_and_save(
    root_href=CATALOG_DIR, catalog_type=pystac.CatalogType.SELF_CONTAINED
)
catalog.validate()
catalog
# Check that collection and item have been included in the catalog
catalog.describe()
* <Catalog id=ML-Model_EO>
    * <Collection id=ML-Models_EO>
      * <Item id=Tile-based-ML-Models>