Description
Summary
We currently use the Torchvision Weights API which allows for a metadata dict to be attached to a model's weights to couple some necessary fields to the model e.g. in_channels, num_classes, license, model type, etc.
LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/resnet18_landsat_tm_toa_simclr-d2d38ace.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 7,
'model': 'resnet18',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/microsoft/torchgeo',
'ssl_method': 'simclr',
'bands': _landsat_tm_toa_bands,
},
)
However, we currently have no way to enforce that a minimum set of fields are filled out. It would be ideal to start doing this so that when anyone lists all model metadata, the naming and fields would all follow the same structure and typing.
The stac_model package has pydantic classes that can be used to validate this. For example given a model's metadata we can validate that the necessary minimal required fields are filled out (additional fields are fine to add and do not break the schema).
Note: I'm using metadata as a yaml but this can be a dict or a json file (doesn't matter as long as it's passed through the pydantic class for validation)
name: Unet
architecture: Unet
artifact_type: torch.export.save
framework: torch
framework_version: 2.7.0
accelerator: cuda
total_parameters: 1234567
tasks: [semantic-segmentation]
input:
- name: Imagery
bands: [red, blue, green]
input:
shape: [-1, 3, 512, 512]
dim_order: [batch, channel, height, width]
data_type: float32
crs: EPSG:4326
res: 0.1
pre_processing_function:
format: pt2-exported-program # indicates that expression is interpreted as a key referencing exported pytorch file that can be loaded to run the function
expression: transforms
output:
- name: segmentation-output
tasks: ["semantic-segmentation"]
result:
shape: [-1, 6, 512, 512]
dim_order: [batch, classes, height, width]
data_type: float32
classification:classes:
- value: 0
name: Class 0
- value: 1
name: Class 1
- value: 2
name: Class 2
- value: 3
name: Class 3
- value: 4
name: Class 4
- value: 5
name: Class 5
import yaml
from stac_model.schema import MLModelProperties
with open("path/to/metadata/yaml", "r") as f:
metadata = yaml.safe_load(f)
MLModelProperties(**metadata) # this raises errors if required fields are missing or do not have the correct types
# or
MLModelProperties.model_validate(metadata)
Some benefits of standardization to a schema:
- Adopting a standard that is gaining traction in the community: several organizations already utilize the MLM STAC Extension for their models (e.g. Regrow, Taco Foundation, etc.)
- Inference automation: If we can expect that some minimal set of fields should come packaged with each model or in a model's checkpoint file -- inference code can be better automated to load a model and use it for some downstream task
Suggestion (I can take on these tasks btw):
- We should add
stac_model
as a dependency for performing this metadata validation - We should backfill and validate all current model metadata
- We should enforce when a new model is added to TorchGeo that the minimal set of fields are filled out
- We should add a feature for packaging model metadata to be exported with model checkpoints trained by our existing trainers
CC: @rbavery @ljstrnadiii @jiayuasu @fmigneault @adamjstewart @calebrob6