Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7,947 changes: 5,295 additions & 2,652 deletions notebooks/uptake_model.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"pandas>=2.2.2",
"xarray>=2024.7.0",
"pydantic>=2.10.6",
"blackjax>=1.2.5",
]

[project.optional-dependencies]
Expand Down
102 changes: 100 additions & 2 deletions src/vaxflux/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,21 @@

import re
from collections.abc import Callable
from typing import Annotated, Any, overload
from typing import Annotated, Any, Final, overload

import pandas as pd
from pandas.api.types import is_datetime64_any_dtype
from pydantic import BeforeValidator

_CLEAN_TEXT_REGEX = re.compile(r"[^a-zA-Z0-9]")
_CLEAN_TEXT_REGEX: Final = re.compile(r"[^a-zA-Z0-9]")
_OBSERVATION_TYPE_CATEGORIES: Final = ("incidence", "prevalence")
_REQUIRED_OBSERVATION_COLUMNS: Final = {
"season",
"start_date",
"end_date",
"type",
"value",
}


def _clean_text(text: str) -> str:
Expand Down Expand Up @@ -260,3 +270,91 @@ def _coord_index(
)
except ValueError:
return None


@overload
def _validate_and_format_observations(observations: pd.DataFrame) -> pd.DataFrame: ...


@overload
def _validate_and_format_observations(observations: None) -> None: ...


def _validate_and_format_observations(
observations: pd.DataFrame | None,
) -> pd.DataFrame | None:
"""
Validate and format user provided observations DataFrames.

Args:
observations: The observations DataFrame to validate and format or `None`.

Returns:
The validated and formatted observations DataFrame or `None` if given `None`.

Raises:
NotImplementedError: If the observations DataFrame contains differing report
dates, nowcasting is not yet supported.
ValueError: If the 'type' column contains values other than 'incidence', other
values are not yet supported.
ValueError: If the observations DataFrame is empty.
ValueError: If the observations DataFrame is missing required columns: 'season',
'start_date', 'end_date', 'type', 'value'.
ValueError: If the observations DataFrame contains invalid values in the 'value'
column, must be numeric.
ValueError: If the observations DataFrame contains negative values in the
'value' column.
ValueError: If the observations DataFrame contains invalid values in the 'type'
column, must be one of 'incidence', 'prevalence'.
"""
if observations is None:
return None
if not len(observations):
raise ValueError("No observations provided.")
observation_columns = set(observations.columns)
if missing_columns := _REQUIRED_OBSERVATION_COLUMNS - observation_columns:
raise ValueError(
"The observations DataFrame is missing "
f"required columns: {missing_columns}."
)
observations = observations.copy()
observations["season"] = observations["season"].astype(str)
observations["value"] = pd.to_numeric(observations["value"])
if observations["value"].isna().any():
raise ValueError(
"The observations DataFrame contains invalid values in the 'value' column."
)
if observations["value"].lt(0).any():
raise ValueError(
"The observations DataFrame contains negative values in the 'value' column."
)
observations["type"] = pd.Categorical(
observations["type"].astype(str), categories=_OBSERVATION_TYPE_CATEGORIES
)
if observations["type"].isna().any():
raise ValueError(
"The observations DataFrame contains invalid values in the "
f"'type' column, must be one of {_OBSERVATION_TYPE_CATEGORIES}."
)
if {"incidence"} != set(observations["type"].unique().tolist()):
raise NotImplementedError(
"Only 'incidence' data is supported, 'prevalence' and count equivalents "
"are planned."
)
for col in {"start_date", "end_date", "report_date"}.intersection(
observation_columns
):
if not is_datetime64_any_dtype(observations[col]):
observations[col] = pd.to_datetime(observations[col])
if "report_date" not in observations.columns:
observations["report_date"] = observations["end_date"].copy()
unique_start_end = observations.drop_duplicates(["start_date", "end_date"])
unique_start_end_report = observations.drop_duplicates(
["start_date", "end_date", "report_date"]
)
if len(unique_start_end) != len(unique_start_end_report):
raise NotImplementedError(
"Observations with differing report dates were provided, "
"nowcasting is not currently supported but planned."
)
return observations
2 changes: 1 addition & 1 deletion src/vaxflux/covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _infer_covariate_categories_from_observations(
for covariate in covariates
if covariate.covariate is not None
}
if observations:
if observations is not None:
# Only observations
if not covariate_categories:
if missing_columns := covariate_names - set(observations.columns):
Expand Down
6 changes: 3 additions & 3 deletions src/vaxflux/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ def _infer_ranges_from_observations(
ValueError: If the observed season ranges are not consistent with the explicit
season ranges, only applicable for the season mode.
"""
if not observations and not ranges:
if observations is None and not ranges:
raise ValueError("At least one of `observations` or `ranges` is required.")
cls = DateRange if mode == "date" else SeasonRange
columns = {
"date": {"season", "start_date", "end_date", "report_date"},
"season": {"season", "season_start_date", "season_end_date"},
}[mode]
if observations:
if observations is not None:
# Only observations
if not ranges:
if missing_columns := columns - set(observations.columns):
Expand All @@ -168,7 +168,7 @@ def _infer_ranges_from_observations(
.sort_values(list(columns), ignore_index=True)
)
return [
cls(row._asdict()) # type: ignore
cls.model_validate(row._asdict()) # type: ignore
for row in observations_ranges.itertuples(
index=False, name=f"Observation{mode.capitalize()}Row"
)
Expand Down
Loading