Source code for thuner.attribute.utils

"""General utilities for object attributes."""

from pydantic import ValidationError, ConfigDict
import yaml
from pathlib import Path
import pandas as pd
import dask.dataframe as dd
import xarray as xr
from pydantic import BaseModel, model_validator
import numpy as np
from thuner.option.attribute import Attribute, AttributeGroup, AttributeType, Attributes
from thuner.log import setup_logger

logger = setup_logger(__name__)


__all__ = ["read_attribute_csv", "AttributesRecord", "time_offset"]


def get_ids(object_tracks, matched, member_object):
    """Get object ids from the match record to avoid recalculating."""
    current_attributes = object_tracks.current_attributes
    if matched:
        id_name = "universal_id"
    else:
        id_name = "id"
    if member_object is not None:
        ids = current_attributes.member_attributes[member_object]["core"][id_name]
    else:
        ids = current_attributes.attribute_types["core"][id_name]
    return ids


def get_nearest_points(
    stacked_mask: xr.DataArray | xr.Dataset,
    id_number: int,
    ds: xr.Dataset | xr.DataArray,
):
    """
    Get the nearest points in a tagging dataset to a given object id.

    Parameters
    ----------
    stacked_mask : xarray.DataArray
        mask containing object ids with latitude and longitude stacked into a new
        dimension called 'points'.
    id_number : int
        Object ID number to get from stacked_mask.
    ds : xarray.Dataset | xarray.DataArray
        Tagging dataset containing latitude and longitude.

    Returns
    -------
    list[tuple]
        List of tuples containing the latitude and longitude of the nearest points.
    """
    points = stacked_mask.where(stacked_mask == id_number, drop=True).points.values
    lats, lons = zip(*points)
    lats_da = xr.DataArray(list(lats), dims="points")
    lons_da = xr.DataArray(list(lons), dims="points") % 360
    ds_grid = ds[["latitude", "longitude"]]
    ds_points = ds_grid.sel(latitude=lats_da, longitude=lons_da, method="nearest")
    ds_lats = ds_points.latitude.values.tolist()
    ds_lons = ds_points["longitude"].values.tolist()
    return list(set(zip(ds_lats, ds_lons)))


def _init_attr_type(attribute_type: AttributeType):
    """Initialize attributes lists for a given attribute type."""
    attributes = {}
    for attr in attribute_type.attributes:
        if isinstance(attr, AttributeGroup):
            for attr_attr in attr.attributes:
                attributes[attr_attr.name] = []
        elif isinstance(attr, Attribute):
            attributes[attr.name] = []
        else:
            raise ValueError(f"Unknown type {attr.type}.")
    return attributes


[docs] class AttributesRecord(BaseModel): """ Class for storing attributes recorded during the tracking process """ # Allow arbitrary types in the class. model_config = ConfigDict(arbitrary_types_allowed=True) attribute_options: Attributes name: str = None attribute_types: dict | None = None member_attributes: dict | None = None @model_validator(mode="after") def _check_name(cls, values): if values.name is None: values.name = values.attribute_options.name elif values.name != values.attribute_options.name: raise ValueError("Name must match attribute_options name.") return values @model_validator(mode="after") def _initialize_attributes(cls, values): options = values.attribute_options if options is None: return values values.attribute_types = {} for attr_type in options.attribute_types: values.attribute_types[attr_type.name] = _init_attr_type(attr_type) if options.member_attributes is not None: values.member_attributes = {} for obj, obj_attributes in options.member_attributes.items(): obj_attr = {} for attr_type in obj_attributes.attribute_types: obj_attr[attr_type.name] = _init_attr_type(attr_type) values.member_attributes[obj] = obj_attr return values
# Mapping of string representations to actual data types string_to_data_type = { "float": float, "int": int, "datetime64[s]": "datetime64[s]", "bool": bool, "str": str, }
[docs] def time_offset(): """Convenience function to build a TimeOffset attribute.""" kwargs = {"name": "time_offset", "data_type": int, "units": "min"} _desc = "Time offset in minutes from object detection time." kwargs.update({"description": _desc}) return Attribute(**kwargs)
# class TimeOffset(Attribute): # """ # Attribute describing the time offsets to use when tagging objects using other datasets. # For instance, we may wish to tag storms with the ERA5 ambient winds 1-hour before the # storm detection time to provide an assessment of the pre-storm environment. # """ # name: str = "time_offset" # data_type: type = int # units: str = "min" # description: str = "Time offset in minutes from object detection time." def setup_interp( attribute_group: AttributeGroup, input_records, object_tracks, dataset: str, member_object: str = None, ): name = object_tracks.name excluded = ["time", "id", "universal_id", "latitude", "longitude", "altitude"] excluded += ["time_offset"] attributes = attribute_group.attributes names = [attr.name for attr in attributes if attr.name not in excluded] tag_input_records = input_records.tag current_time = object_tracks.times[-1] # Get object centers if member_object is None: core_attributes = object_tracks.current_attributes.attribute_types["core"] else: core_attributes = object_tracks.current_attributes.member_attributes core_attributes = core_attributes[member_object]["core"] ds = tag_input_records[dataset].dataset ds["longitude"] = ds["longitude"] % 360 return name, names, ds, core_attributes, current_time def get_current_mask(object_tracks, matched=False): """Get the appropriate previous mask.""" if matched: mask_type = "matched_masks" else: mask_type = "masks" mask = getattr(object_tracks, mask_type)[-1] return mask def attribute_from_core(attribute, object_tracks, member_object): """Get attribute from core object properties.""" # Check if grouped object current_attributes = object_tracks.current_attributes if member_object is not None and member_object is not object_tracks.name: member_attr = current_attributes.member_attributes attr = member_attr[member_object]["core"][attribute.name] else: core_attr = current_attributes.attribute_types["core"] attr = core_attr[attribute.name] return {attribute.name: attr} def attributes_dataframe(recorded_attributes, attribute_type): """Create a pandas DataFrame from object attributes dictionary.""" data_types = get_data_type_dict(attribute_type) data_types.pop("time") try: df = pd.DataFrame(recorded_attributes).astype(data_types) except: pass multi_index = ["time"] if "time_offset" in recorded_attributes.keys(): multi_index.append("time_offset") if "universal_id" in recorded_attributes.keys(): id_index = "universal_id" else: id_index = "id" multi_index.append(id_index) if "altitude" in recorded_attributes.keys(): multi_index.append("altitude") df.set_index(multi_index, inplace=True) df.sort_index(inplace=True) return df def read_metadata_yml(filepath): """Read metadata from a yml file.""" with open(filepath, "r") as file: kwargs = yaml.safe_load(file) try: attribute_type = AttributeType(**kwargs) except ValidationError: logger.warning("Invalid metadata file found for %s.", filepath) attribute_type = None return attribute_type def get_indexes(attribute_type: AttributeType): """Get the indexes for the attribute DataFrame.""" all_indexes = ["time", "time_offset", "event_start", "universal_id", "id"] all_indexes += ["altitude"] indexes = [] for attribute in attribute_type.attributes: if isinstance(attribute, AttributeGroup): for attr in attribute.attributes: if attr.name in all_indexes: indexes.append(attr.name) else: if attribute.name in all_indexes: indexes.append(attribute.name) return indexes
[docs] def read_attribute_csv( filepath, attribute_type=None, columns=None, times=None, dask=False ): """ Read a CSV file and return a DataFrame. Parameters ---------- filepath : str Filepath to the CSV file. Returns ------- pd.DataFrame DataFrame containing the CSV data. """ filepath = Path(filepath) data_types = None if attribute_type is None: try: meta_path = filepath.with_suffix(".yml") attribute_type = read_metadata_yml(meta_path) data_types = get_data_type_dict(attribute_type) except FileNotFoundError: logger.warning("No metadata file found for %s.", filepath) except ValidationError: logger.warning("Invalid metadata file found for %s.", filepath) except AttributeError: logger.warning("Invalid metadata file found for %s.", filepath) if attribute_type is None: message = "No metadata; loading entire dataframe and data types not enforced." logger.warning(message) kwargs = {"na_values": ["", "NA"], "keep_default_na": True} if dask: df = dd.read_csv(filepath, **kwargs) else: df = pd.read_csv(filepath, **kwargs) return df # Get attributes with np.datetime64 data type time_attrs = [] for attribute in attribute_type.attributes: if isinstance(attribute, AttributeGroup): for attr in attribute.attributes: if attr.data_type is np.datetime64: time_attrs.append(attr.name) else: if attribute.data_type is np.datetime64: time_attrs.append(attribute.name) indexes = get_indexes(attribute_type) if columns is None: columns = get_names(attribute_type) all_columns = indexes + [col for col in columns if col not in indexes] data_types = get_data_type_dict(attribute_type) # Remove time columns as pd handles these separately for name in time_attrs: data_types.pop(name, None) if times is not None and not dask: kwargs = {"usecols": ["time"], "parse_dates": time_attrs} kwargs.update({"na_values": ["", "NA"], "keep_default_na": True}) index_df = pd.read_csv(filepath, **kwargs) row_numbers = index_df[~index_df["time"].isin(times)].index.tolist() # Increment row numbers by 1 to account for header row_numbers = [i + 1 for i in row_numbers] else: if dask: logger.warning("Row skipping not yet implemented with dask dataframes.") row_numbers = None if dask: kwargs = {"dtype": data_types, "parse_dates": time_attrs} kwargs.update({"na_values": ["", "NA"], "keep_default_na": True}) df = dd.read_csv(filepath, **kwargs) message = "Index not set for dask dataframe." logger.warning(message) else: kwargs = {"usecols": all_columns, "dtype": data_types} kwargs.update({"parse_dates": time_attrs, "skiprows": row_numbers}) kwargs.update({"na_values": ["", "NA"], "keep_default_na": True}) df = pd.read_csv(filepath, **kwargs) df = df.set_index(indexes) return df
def get_names(attribute_type: AttributeType): """Get the names of the attributes in the attribute type.""" names = [] for attribute in attribute_type.attributes: if isinstance(attribute, AttributeGroup): for attr in attribute.attributes: names.append(attr.name) else: names.append(attribute.name) return names def get_precision_dict(attribute_type: AttributeType): """Get precision dictionary for attribute options.""" precision_dict = {} for attribute in attribute_type.attributes: if isinstance(attribute, AttributeGroup): for attr in attribute.attributes: if attr.data_type == float: precision_dict[attr.name] = attr.precision else: if attribute.data_type == float: precision_dict[attribute.name] = attribute.precision return precision_dict def get_data_type_dict(attribute_type: AttributeType): """Get precision dictionary for attribute options.""" data_type_dict = {} for attribute in attribute_type.attributes: if isinstance(attribute, AttributeGroup): # If the attribute is a group, get data type for each attribute in group for attr in attribute.attributes: data_type_dict[attr.name] = attr.data_type else: data_type_dict[attribute.name] = attribute.data_type return data_type_dict