Skip to content

Schema for Model Metadata #2880

Open
Open
@isaaccorley

Description

@isaaccorley

Summary

We currently use the Torchvision Weights API which allows for a metadata dict to be attached to a model's weights to couple some necessary fields to the model e.g. in_channels, num_classes, license, model type, etc.

    LANDSAT_TM_TOA_SIMCLR = Weights(
        url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth',
        transforms=_ssl4eo_l_transforms,
        meta={
            'dataset': 'SSL4EO-L',
            'in_chans': 7,
            'model': 'resnet18',
            'publication': 'https://arxiv.org/abs/2306.09424',
            'repo': 'https://github.com/microsoft/torchgeo',
            'ssl_method': 'simclr',
            'bands': _landsat_tm_toa_bands,
        },
    )

However, we currently have no way to enforce that a minimum set of fields are filled out. It would be ideal to start doing this so that when anyone lists all model metadata, the naming and fields would all follow the same structure and typing.

The stac_model package has pydantic classes that can be used to validate this. For example given a model's metadata we can validate that the necessary minimal required fields are filled out (additional fields are fine to add and do not break the schema).

Note: I'm using metadata as a yaml but this can be a dict or a json file (doesn't matter as long as it's passed through the pydantic class for validation)

name: Unet
architecture: Unet
artifact_type: torch.export.save
framework: torch
framework_version: 2.7.0
accelerator: cuda
total_parameters: 1234567
tasks: [semantic-segmentation]
input:
  - name: Imagery
    bands: [red, blue, green]
    input:
      shape: [-1, 3, 512, 512]
      dim_order: [batch, channel, height, width]
      data_type: float32
      crs: EPSG:4326
      res: 0.1
    pre_processing_function:
      format: pt2-exported-program # indicates that expression is interpreted as a key referencing exported pytorch file that can be loaded to run the function
      expression: transforms
output:
  - name: segmentation-output
    tasks: ["semantic-segmentation"]
    result:
      shape: [-1, 6, 512, 512]
      dim_order: [batch, classes, height, width]
      data_type: float32
    classification:classes:
      - value: 0
        name: Class 0
      - value: 1
        name: Class 1
      - value: 2
        name: Class 2
      - value: 3
        name: Class 3
      - value: 4
        name: Class 4
      - value: 5
        name: Class 5
import yaml
from stac_model.schema import MLModelProperties

with open("path/to/metadata/yaml", "r") as f:
    metadata = yaml.safe_load(f)

MLModelProperties(**metadata)  # this raises errors if required fields are missing or do not have the correct types
# or
MLModelProperties.model_validate(metadata)

Some benefits of standardization to a schema:

  • Adopting a standard that is gaining traction in the community: several organizations already utilize the MLM STAC Extension for their models (e.g. Regrow, Taco Foundation, etc.)
  • Inference automation: If we can expect that some minimal set of fields should come packaged with each model or in a model's checkpoint file -- inference code can be better automated to load a model and use it for some downstream task

Suggestion (I can take on these tasks btw):

  • We should add stac_model as a dependency for performing this metadata validation
  • We should backfill and validate all current model metadata
  • We should enforce when a new model is added to TorchGeo that the minimal set of fields are filled out
  • We should add a feature for packaging model metadata to be exported with model checkpoints trained by our existing trainers

CC: @rbavery @ljstrnadiii @jiayuasu @fmigneault @adamjstewart @calebrob6

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions