Source code for thuner.track.track

"""Track storm objects in a dataset."""

import shutil
import copy
from typing import Iterable
import numpy as np
from pathlib import Path
from thuner.log import setup_logger
import thuner.data._update as _update
import thuner.detect.detect as detect
import thuner.group.group as group
import thuner.visualize.runtime as runtime
import thuner.visualize.visualize as visualize
import thuner.match.match as match
from thuner.config import get_outputs_directory
import thuner.utils as utils
import thuner.write as write
import thuner.attribute.attribute as attribute
import thuner.option as option
from thuner.track._utils import InputRecords, Tracks


logger = setup_logger(__name__)


__all__ = ["track"]


def consolidate_options(data_options, grid_options, track_options, visualize_options):
    """Consolidate the options for a given run."""
    options = {"data_options": data_options, "grid_options": grid_options}
    options.update({"track_options": track_options})
    options.update({"visualize_options": visualize_options})
    return options


[docs] def track( times: Iterable[np.datetime64], data_options: option.data.DataOptions, grid_options: option.grid.GridOptions, track_options: option.track.TrackOptions, visualize_options: option.visualize.VisualizeOptions = None, output_directory: str | Path = None, ): """ Track objects described in track_options, in the datasets described in data_options, using the grid described in grid_options. Parameters ---------- times : Iterable[np.datetime64] The times to track the objects. data_options : DataOptions The data options. grid_options : GridOptions The grid options. track_options : TrackOptions The track options. visualize_options : VisualizeOptions, optional The runtime visualization options for visualizing the tracking process. Defaults to None. output_directory : str | Path, optional The directory in which to save the output. If None, use the output directory specified in the THUNER config file. See thuner.config.get_outputs_directory. Defaults to None. """ logger.info("Beginning thuner tracking. Saving output to %s.", output_directory) tracks = Tracks(track_options=track_options) input_records = InputRecords(data_options=data_options) consolidated_options = consolidate_options( track_options, data_options, grid_options, visualize_options ) # Clear masks, attributes and records directories to prevent overwriting if (output_directory / "masks").exists(): shutil.rmtree(output_directory / "masks") if (output_directory / "attributes").exists(): shutil.rmtree(output_directory / "attributes") if (output_directory / "records").exists(): shutil.rmtree(output_directory / "records") # Initialize the paths to save xesmf regridder weights for dataset_options in data_options.datasets: if dataset_options.reuse_regridder: if dataset_options.weights_filepath is None: filepath = output_directory / "records/regridder_weights" filepath = filepath / f"{dataset_options.name}.nc" dataset_options.weights_filepath = filepath current_time = None for next_time in times: if output_directory is None: consolidated_options["start_time"] = str(next_time) hash_str = utils.hash_dictionary(consolidated_options) output_directory = get_outputs_directory() / f"runs/" output_directory = output_directory / f"{utils.now_str()}_{hash_str[:8]}" logger.info(f"Processing {utils.format_time(next_time, filename_safe=False)}.") args = [next_time, input_records.track, track_options, data_options] args += [grid_options, output_directory] _update.update_track_input_records(*args) # Record track input filepaths for name in input_records.track.keys(): input_record = input_records.track[name] args = [next_time, input_record, input_record] if write.utils.write_interval_reached(*args): write.filepath.write(input_record, output_directory) args = [current_time, input_records.tag, track_options, data_options] args += [grid_options] _update.update_tag_input_records(*args) # loop over levels for level_index in range(len(track_options.levels)): logger.info("Processing hierarchy level %s.", level_index) track_level_args = [next_time, level_index, tracks, input_records] track_level_args += [data_options, grid_options, track_options] track_level_args += [visualize_options, output_directory] track_level(*track_level_args) current_time = next_time # Write final data to file # write.mask.write_final(tracks, track_options, output_directory) write.attribute.write_final(tracks, track_options, output_directory) write.filepath.write_final(input_records.track, output_directory) # Aggregate files previously written to file # write.mask.aggregate(track_options, output_directory) write.attribute.aggregate(track_options, output_directory) write.filepath.aggregate(input_records.track, output_directory) # Animate the relevant figures visualize.animate_all(visualize_options, output_directory)
def track_level( next_time, level_index, tracks, input_records, data_options: option.data.DataOptions, grid_options, track_options: option.track.TrackOptions, visualize_options, output_directory, ): """Track a hierarchy level.""" level_tracks = tracks.levels[level_index] level_options = track_options.levels[level_index] def get_track_object_args(obj, level_options): logger.info("Tracking %s.", obj) object_options = level_options.object_by_name(obj) if "dataset" not in object_options.__class__.model_fields: dataset_options = None else: dataset_options = data_options.dataset_by_name(object_options.dataset) track_object_args = [next_time, level_index, obj, tracks, input_records] track_object_args += [dataset_options, grid_options, track_options] track_object_args += [visualize_options, output_directory] return track_object_args for obj in level_tracks.objects.keys(): track_object_args = get_track_object_args(obj, level_options) track_object(*track_object_args) return level_tracks def track_object( next_time, level_index, obj, tracks, input_records, dataset_options, grid_options, track_options, visualize_options, output_directory, ): """Track the given object.""" # Get the object options object_options = track_options.levels[level_index].object_by_name(obj) object_tracks = tracks.levels[level_index].objects[obj] track_input_records = input_records.track # Update current and previous next_time if object_tracks.next_time is not None: current_time = copy.deepcopy(object_tracks.next_time) object_tracks.times.append(current_time) object_tracks.next_time = next_time if object_options.mask_options.save: # Write masks to zarr file write.mask.write(object_tracks, object_options, output_directory) # Write existing data to file if necessary if write.utils.write_interval_reached(next_time, object_tracks, object_options): write.attribute.write(object_tracks, object_options, output_directory) object_tracks._last_write_time = next_time # Detect objects at next_time if "grouping" in object_options.__class__.model_fields: get_objects = group.group elif "detection" in object_options.__class__.model_fields: get_objects = detect.detect else: raise ValueError("No known method for obtaining objects provided.") get_objects_args = [track_input_records, tracks, level_index, obj, dataset_options] get_objects_args += [object_options, grid_options] get_objects(*get_objects_args) match.match(object_tracks, object_options, grid_options) # Visualize the operation of the algorithm visualize_args = [track_input_records, tracks, level_index, obj, track_options] visualize_args += [grid_options, visualize_options, output_directory] runtime.visualize(*visualize_args) # Update the lists used to periodically write data to file if object_tracks.times[-1] is not None: args = [input_records, tracks, object_options, grid_options] attribute.record(*args) get_objects_dispatcher = { "detect": detect.detect, "group": group.group, }