"""Classes for managing tracking related options."""
from typing import List, Annotated, Literal, Callable
from pydantic import Field, model_validator, ValidationError
from thuner.log import setup_logger
from thuner.option.attribute import Attributes
from thuner.utils import BaseOptions, Retrieval
from thuner.detect.preprocess import vertical_max
__all__ = [
"TintOptions",
"MintOptions",
"MaskOptions",
"BaseObjectOptions",
"DetectionOptions",
"DetectedObjectOptions",
"GroupingOptions",
"GroupedObjectOptions",
"LevelOptions",
"TrackOptions",
]
logger = setup_logger(__name__)
[docs]
class TintOptions(BaseOptions):
"""
Options for the TINT tracking algorithm. See the following publications
"""
name: str = "tint"
_desc = "Margin in km for object matching. Does not affect flow vectors."
search_margin: float = Field(10.0, description=_desc, gt=0)
_desc = "Margin in km around object for phase correlation."
local_flow_margin: float = Field(10.0, description=_desc, gt=0)
_desc = "Margin in km around object for global flow vectors."
global_flow_margin: float = Field(150.0, description=_desc, gt=0)
_desc = "If True, create unique global flow vectors for each object."
unique_global_flow: bool = Field(True, description=_desc)
_desc = "Maximum allowable matching cost. Units of km."
max_cost: float = Field(2e2, description=_desc, gt=0, lt=1e3)
_desc = "Maximum allowable shift velocity magnitude. Units of m/s."
max_velocity_mag: float = Field(60.0, description=_desc, gt=0)
_desc = "Maximum allowable shift difference. Units of m/s."
max_velocity_diff: float = Field(60.0, description=_desc, gt=0)
_desc = "Name of object used for matching/tracking."
matched_object: str | None = Field(None, description=_desc)
[docs]
class MintOptions(TintOptions):
"""
Options for the MINT tracking algorithm.
"""
name: str = "mint"
_desc = "Margin in km for object matching. Does not affect flow vectors."
search_margin: float = Field(25.0, description=_desc, gt=0)
_desc = "Margin in km around object for phase correlation."
local_flow_margin: float = Field(35.0, description=_desc, gt=0)
_desc = "Alternative max shift difference used by MINT."
max_velocity_diff_alt: float = Field(25.0, description=_desc, gt=0)
[docs]
class MaskOptions(BaseOptions):
"""
Options for saving and loading masks. Note thuner uses .zarr format for saving
masks, which is great for sparse, chunked arrays.
"""
save: bool = Field(True, description="If True, save masks as .zarr files.")
load: bool = Field(False, description="If True, load masks from .zarr files.")
[docs]
class BaseObjectOptions(BaseOptions):
"""Base class for object options."""
name: str = Field(..., description="Name of the object.")
_desc = "Level of the object in the hierachy. Higher level objects may depend on "
_desc += "lower level objects."
hierarchy_level: int = Field(0, description=_desc, ge=0)
_desc = "Method used to obtain the object, i.e. detect or group."
method: Literal["detect", "group"] = Field("detect", description=_desc)
_desc = "Name of the dataset used for detection if applicable."
dataset: str = Field(..., description=_desc, examples=["cpol", "gridrad"])
_desc = "Length of the deque used for tracking."
deque_length: int = Field(2, description=_desc, gt=0, lt=10)
_desc = "Options for saving and loading masks."
mask_options: MaskOptions = Field(MaskOptions(), description=_desc)
_desc = "Interval in hours for writing objects to disk."
write_interval: int = Field(1, description=_desc, gt=0, lt=24 * 60)
_desc = "Allowed gap in minutes between consecutive times when tracking."
allowed_gap: int = Field(30, description=_desc, gt=0, lt=6 * 60)
_desc = "Options for object attributes."
attributes: Attributes | None = Field(None, description=_desc)
[docs]
class DetectionOptions(BaseOptions):
"""Options for object detection."""
_desc = "Method used to detect the object."
method: Literal["steiner", "threshold"] = Field(..., description=_desc)
_desc = "Altitudes over which to detect objects."
altitudes: List[int] = Field([], description=_desc)
_desc = "Method used to flatten the grid before detection if relevant."
flatten_method: Retrieval | None = Field(
Retrieval(function=vertical_max), description=_desc
)
_desc = "Minimum area of the object in km squared."
min_area: int = Field(10, description=_desc)
_desc = "Threshold used for detection if required."
threshold: int | None = Field(None, description=_desc)
_desc = "Threshold type, i.e. a minima or maxima threshold."
threshold_type: Literal["minima", "maxima"] = Field("minima", description=_desc)
@model_validator(mode="after")
def _check_threshold(cls, values):
"""Check threshold value is provided if applicable."""
if values.method == "detect" and values.threshold is None:
raise ValueError("Threshold not provided for detection method.")
return values
def _check_mask_values(values):
"""Check if masks saved if tracking options provided."""
if values.tracking is not None and not values.mask_options.save:
message = "Masks must be saved when objects are being tracked."
raise ValueError(message)
return values
AnyTrackingOptions = TintOptions | MintOptions
[docs]
class DetectedObjectOptions(BaseObjectOptions):
"""Options for detected objects."""
object_type: Literal["detected"] = Field("detected", description="Type of object.")
_desc = "Variable to use for detection."
variable: str = Field("reflectivity", description=_desc)
_desc = "Method used to detect the object."
detection: DetectionOptions = Field(
DetectionOptions(method="steiner"), description=_desc
)
_desc = "Options for tracking the object."
tracking: AnyTrackingOptions | None = Field(TintOptions(), description=_desc)
@model_validator(mode="after")
def _check_mask(cls, values):
"""Check if masks saved if tracking options provided."""
return _check_mask_values(values)
# Define a custom type with constraints
PositiveFloat = Annotated[float, Field(gt=0)]
NonNegativeInt = Annotated[int, Field(ge=0)]
[docs]
class GroupingOptions(BaseOptions):
"""Options class for grouping lower level objects into higher level objects."""
method: str = Field("graph", description="Method used to group objects.")
member_objects: List[str] = Field([], description="Names of objects to group")
_desc = "Hierarchy levels of objects to group."
member_levels: List[NonNegativeInt] = Field([], description=_desc)
_desc = "Minimum area of each member object in km squared."
member_min_areas: List[PositiveFloat] = Field([], description=_desc)
# Check lists are the same length.
@model_validator(mode="after")
def _check_list_length(cls, values):
"""Check list lengths are consistent."""
member_objects = values.member_objects
member_levels = values.member_levels
member_min_areas = values.member_min_areas
lengths = [len(member_objects), len(member_levels), len(member_min_areas)]
if len(set(lengths)) != 1:
message = "Member objects, levels, and areas must have the same length."
raise ValueError(message)
return values
[docs]
class GroupedObjectOptions(BaseObjectOptions):
"""Options for grouped objects."""
object_type: Literal["grouped"] = Field("grouped", description="Type of object.")
_desc = "Options for grouping objects."
grouping: GroupingOptions = Field(GroupingOptions(), description=_desc)
_desc = "Options for tracking the object."
tracking: AnyTrackingOptions | None = Field(MintOptions(), description=_desc)
@model_validator(mode="after")
def _check_mask(cls, values):
"""Check if masks saved if tracking options provided."""
return _check_mask_values(values)
# Unclear why an additional discriminator is needed here. Perhaps due to the list.
AnyObjectOptions = Annotated[
DetectedObjectOptions | GroupedObjectOptions, Field(discriminator="object_type")
]
[docs]
class LevelOptions(BaseOptions):
"""
Options for a tracking hierachy level. Objects identified at lower levels are
used to define objects at higher levels.
"""
_desc = "Options for each object in the level."
objects: List[AnyObjectOptions] = Field([], description=_desc)
_object_lookup = {}
_desc = "Names of the objects comprising this tracking level."
object_names: List[str] = Field([], description=_desc)
[docs]
@model_validator(mode="after")
def initialize_object_lookup(cls, values):
"""Initialize object lookup dictionary."""
values._object_lookup = {obj.name: obj for obj in values.objects}
values.object_names = [obj.name for obj in values.objects]
if len(values.object_names) != len(set(values.object_names)):
message = "Object names must be unique to facilitate name based lookup."
raise ValueError(message)
return values
[docs]
def object_by_name(self, obj_name: str) -> BaseObjectOptions:
"""Return object options by name."""
return self._object_lookup.get(obj_name)
def _check_grouped_object(object_options, object_levels):
"""
Helper function to check a grouped object lists member object heierachy level
correctly.
"""
for i, member_name in enumerate(object_options.grouping.member_objects):
if member_name not in object_levels:
message = f"Grouped object '{object_options.name}' references member "
message += f"object '{member_name}' which doesn't exist in track options."
raise ValidationError(message)
member_level = object_levels[member_name]
expected_level = object_options.grouping.member_levels[i]
if member_level != expected_level:
message = f"Grouped object '{object_options.name}' expects member "
message += f"'{member_name}' at level {expected_level}, "
message += f"but it's actually at level {member_level}."
raise ValidationError(message)
if member_level >= object_options.hierarchy_level:
message = f"Grouped object '{object_options.name}' at level "
message += f"{object_options.hierarchy_level} cannot reference member "
message += f"'{member_name}' at level {member_level}. "
message += f"Member objects must be at lower hierarchy levels."
raise ValidationError(message)
[docs]
class TrackOptions(BaseOptions):
"""
Options for the levels of a tracking hierarchy.
"""
levels: List[LevelOptions] = Field([], description="Hierachy levels.")
_object_lookup = {}
object_names: List[str] = Field([], description="Names of the objects.")
[docs]
@model_validator(mode="after")
def initialize_object_lookup(cls, values):
"""Initialize object lookup dictionary."""
object_names = []
lookup_dicts = []
for level in values.levels:
lookup_dicts.append(level._object_lookup)
object_names += level._object_lookup.keys()
if len(object_names) != len(set(object_names)):
message = "Object names must be unique to facilitate name based lookup."
raise ValueError(message)
for lookup_dict in lookup_dicts:
values._object_lookup.update(lookup_dict)
values.object_names = object_names
return values
[docs]
def object_by_name(self, obj_name: str) -> BaseObjectOptions:
"""Return object options by name."""
try:
return self._object_lookup.get(obj_name)
except KeyError:
message = f"Object {obj_name} not found in object lookup."
raise KeyError(message)
@model_validator(mode="after")
def _validate_grouped_objects(cls, values):
"""
Validate that the tracking hierarchy is consistent - grouped objects should
only reference objects from lower hierarchy levels.
"""
# Build a mapping of object names to their hierarchy levels
object_levels = {}
for level in values.levels:
for obj in level.objects:
object_levels[obj.name] = obj.hierarchy_level
# Check grouped objects reference appropriate member objects
for level in values.levels:
for obj in level.objects:
if hasattr(obj, "grouping") and obj.grouping:
_check_grouped_object(obj, object_levels)
return values