Source code for thuner.option.attribute

"""Classes for object attribute options."""

import importlib
import numpy as np
from pydantic import Field, model_validator
from thuner.utils import BaseOptions, Retrieval

__all__ = [
    "Attribute",
    "AttributeGroup",
    "AttributeType",
    "Attributes",
    "Retrieval",
]


[docs] class Attribute(BaseOptions): """ Base attribute description class. An "attribute" will become a column of a pandas dataframe, csv file, sql table, etc. """ name: str = Field(..., description="Name of the attribute.") _desc = "The function/kwargs used to retrieve the attribute." retrieval: Retrieval | None = Field(None, description=_desc) data_type: type | str = Field(..., description="Data type of the attribute.") _desc = "Number of decimal places for a numerical attribute." precision: int | None = Field(None, description=_desc) description: str | None = Field(None, description="Description of the attribute.") units: str | None = Field(None, description="Units of the attribute.")
[docs] @model_validator(mode="after") def check_data_type(cls, values): """ Check that the data type is valid. """ if isinstance(values.data_type, str): # convert string to type if "." in values.data_type: module_name, type_name = values.data_type.rsplit(".", 1) module = importlib.import_module(module_name) values.data_type = getattr(module, type_name) return values
[docs] class AttributeGroup(BaseOptions): """ A group of related attributes retrieved by the same method, e.g. lat/lon or u/v. """ name: str = Field(..., description="Name of the attribute group.") attributes: list[Attribute] = Field(..., description="Attributes in the group.") _desc = "The function/kwargs used to retrieve the attributes in the group." retrieval: Retrieval | None = Field(None, description=_desc) _desc = "Description of the attribute group." description: str | None = Field(None, description=_desc)
[docs] @model_validator(mode="after") def check_retrieval(cls, values): """ Check that the retrieval method is the same for all attributes in the group. Also check that the shared retrieval method is the same as the group retrieval method if one has been provided. """ retrievals = [] for attribute in values.attributes: try: retrievals.append(attribute.retrieval) except: print("none") if np.all(np.array(retrievals) == None): # If retrieval for all attributes is None, do nothing return values if values.retrieval is None and len(set(retrievals)) > 1: message = "attributes in group must have the same retrieval method." raise ValueError(message) elif values.retrieval is None: # if retrieval is None, set it to the common retrieval method values.retrieval = retrievals[0] return values
AttributeList = list[Attribute | AttributeGroup] AttributeDict = dict[str, Attribute | AttributeGroup]
[docs] class AttributeType(BaseOptions): """ Attribute type options. Each "attribute type" contains attributes and attribute groups, and will form a single pandas dataframe, csv file, sql table, etc. """ name: str = Field(..., description="Name of the attribute type.") _desc = "Description of the attribute type." description: str | None = Field(None, description=_desc) _desc = "Attributes and attribute groups comprising the attribute type." attributes: AttributeList = Field(..., description=_desc) # If the attribute type corresponds to a specific tagging dataset, specify it here _desc = "Dataset for tag attribute types (None if not applicable)." dataset: str | None = Field(None, description=_desc) _desc = "Lookup dictionary for attributes." _attribute_lookup = {}
[docs] @model_validator(mode="after") def initialize_lookup(cls, values): """ Initialize the lookup dictionary for attributes. This is used to quickly access attributes by name. """ values._attribute_lookup = {} for attribute in values.attributes: values._attribute_lookup[attribute.name] = attribute return values
[docs] def attribute_by_name(self, name: str) -> Attribute | AttributeGroup: """ Get an attribute by name. """ try: return self._attribute_lookup[name] except KeyError: message = f"Attribute {name} not found in attribute type {self.name}." raise KeyError(message)
AttributesDict = dict[str, "Attributes"]
[docs] class Attributes(BaseOptions): """ Class for storing all the attribute options for a given object. """ name: str = Field(..., description="Name of the object.", examples=["mcs"]) _desc = "Attribute types of the object." attribute_types: list[AttributeType] = Field(..., description=_desc) _desc = "Lookup dictionary for attribute types." _attribute_type_lookup = {} _desc = "List of object attributes for the member objects." member_attributes: AttributesDict | None = Field(None, description=_desc)
[docs] @model_validator(mode="after") def initialize_lookup(cls, values): """ Initialize the lookup dictionary for attribute types. This is used to quickly access attribute types by name. """ values._attribute_type_lookup = {} for attribute_type in values.attribute_types: values._attribute_type_lookup[attribute_type.name] = attribute_type return values
[docs] def attribute_type_by_name(self, name: str) -> AttributeType: """ Get an attribute type by name. """ try: return self._attribute_type_lookup[name] except KeyError: message = f"Attribute type {name} not found." raise KeyError(message)