Source code for thuner.data.gridrad

"""Process GridRad data."""

import copy
from typing import Union
from pathlib import Path
from functools import reduce
import numpy as np
import pandas as pd
import xarray as xr
from skimage.morphology import remove_small_objects
from pydantic import Field, model_validator
import thuner.data._utils as _utils
from thuner.log import setup_logger
import thuner.grid as grid
import thuner.utils as utils

logger = setup_logger(__name__)

__all__ = [
    "GridRadSevereOptions",
    "get_gridrad_filepaths",
    "open_gridrad",
    "convert_gridrad",
    "filter",
    "remove_speckles",
    "remove_low_level_clutter",
    "remove_clutter_below_anvils",
    "remove_clutter",
]


[docs] class GridRadSevereOptions(utils.BaseDatasetOptions): """Options for GridRad Severe datasets.""" def model_post_init(self, __context): """ If unset by user, change default values inherited from the base class. """ if "name" not in self.model_fields_set: self.name = "gridrad" if "fields" not in self.model_fields_set: self.fields = ["reflectivity"] if "parent_remote" not in self.model_fields_set: self.parent_remote = "https://data.rda.ucar.edu" # Define additional fields for CPOL event_start: str = Field(..., description="Event start date.") dataset_id: str = Field("ds841.6", description="UCAR RDA dataset ID.") version: str = Field("v4_2", description="GridRad version.") obs_thresh: int = Field(2, description="Observation count threshold for filtering.")
[docs] def get_filepaths(self): """ Get the filepaths for the GridRad dataset assuming filenames and directory structure match the remote location. """ return get_gridrad_filepaths(self)
[docs] def convert_dataset(self, time, filepath, track_options, grid_options): """Convert GridRad dataset.""" return convert_gridrad(time, filepath, track_options, self, grid_options)
@model_validator(mode="after") def _check_times(cls, values): """Check start_time isn't before beginning of GridRad record.""" start_time = np.datetime64("2010-01-20T18:00:00") if np.datetime64(values.start) < start_time: raise ValueError(f"start must be {str(start_time)} or later.") return values @model_validator(mode="after") def _check_filepaths(cls, values): """Check filepaths are valid.""" if values.filepaths is None: logger.info("Generating GridRad filepaths.") values.filepaths = get_gridrad_filepaths(values) if values.filepaths is None: raise ValueError("filepaths not provided or badly formed.") return values
gridrad_variables = [ "Reflectivity", "SpectrumWidth", "AzShear", "Divergence", "DifferentialReflectivity", "DifferentialPhase", "CorrelationCoefficient", ] gridrad_names_dict = { "reflectivity": "Reflectivity", "spectrum_width": "SpectrumWidth", "azimuthal_shear": "AzShear", "divergence": "Divergence", "differential_reflectivity": "DifferentialReflectivity", "differential_phase": "DifferentialPhase", "correlation_coefficient": "CorrelationCoefficient", } def get_event_directories(year, base_local=None): if base_local is None: base_local = Path("/scratch/w40/esh563/THUNER_output") year_directory = base_local / f"input_data/raw/d841006/volumes/{year}" event_directories = sorted([p for p in year_directory.iterdir() if p.is_dir()]) return event_directories def get_event_times(event_directory: Union[str, Path]): """ Get start and end times, and event start date from a GridRad severe event directory. """ if isinstance(event_directory, str): event_directory = Path(event_directory) event_start = f"{event_directory.name[:4]}-{event_directory.name[4:6]}" event_start += f"-{event_directory.name[6:]}" files = sorted(event_directory.iterdir()) times = [] for file in files: # Ignore files if "500Z.nc" in str(file): continue time = str(file.name).split("_")[-1].split(".")[0] formatted_time = f"{time[:4]}-{time[4:6]}-{time[6:8]}" formatted_time += f"T{time[9:11]}:{time[11:13]}:{time[13:15]}" times.append(np.datetime64(formatted_time)) # Suprisingly it appears we need to sort again to ensure the times are in order times = sorted(times) start = str(times[0]) end = str(times[-1]) return start, end, event_start
[docs] def get_gridrad_filepaths(options): """ Get the start and end dates for the cases in the GridRad-Severe dataset (doi.org/10.5065/2B46-1A97). """ start = np.datetime64(options.start).astype("datetime64[s]") end = np.datetime64(options.end).astype("datetime64[s]") filepaths = [] base_url = utils.get_parent(options) base_url += f"/{dataset_id_converter[options.dataset_id]}/volumes" times = np.arange(start, end + np.timedelta64(10, "m"), np.timedelta64(10, "m")) times = pd.DatetimeIndex(times) start, end = pd.Timestamp(start), pd.Timestamp(end) # Note gridrad severe directories are organized by the day the event "started" if options.dataset_id == "ds841.6": event_start = pd.Timestamp(options.event_start) base_filepath = f"{base_url}/{event_start.year}/" base_filepath += f"{event_start.year}{event_start.month:02}{event_start.day:02}" for time in times: filepath = f"{base_filepath}/nexrad_3d_{options.version}_" filepath += f"{time.year}{time.month:02}{time.day:02}T" filepath += f"{time.hour:02}{time.minute:02}00Z.nc" # Check if the file exists if Path(filepath).exists(): filepaths.append(filepath) return sorted(filepaths)
[docs] def open_gridrad(path, dataset_options): """ Open a GridRad netcdf file, converting variables with an "Index" dimension back to 3D """ kept_variables = [gridrad_names_dict[f] for f in dataset_options.fields] kept_variables += ["Nradobs", "Nradecho", "wReflectivity", "CorrelationCoefficient"] ds = xr.open_dataset(path) kept_variables = [v for v in kept_variables if v in ds.data_vars] dropped_variables = [v for v in ds.data_vars if v not in kept_variables] for var in kept_variables: if var != "index" and "Index" in ds[var].dims: ds = reshape_variable(ds, var) ds = ds.drop_vars(dropped_variables + ["index"]) return ds
def reshape_variable(ds, variable): """ Reshape a variable in a GridRad dataset to a 3D grid. Adapted from code provided by Stacey Hitchcock. """ values = ds[variable].values attrs = ds[variable].attrs alt, lat, lon = ds["Altitude"], ds["Latitude"], ds["Longitude"] new_values = np.zeros(len(alt) * len(lat) * len(lon)) new_values[ds.index.values] = values new_values = new_values.astype(ds[variable].dtype) new_shape = (len(alt), len(lat), len(lon)) new_dims = ["Altitude", "Latitude", "Longitude"] new_coords = {"Altitude": alt, "Latitude": lat, "Longitude": lon} ds[variable] = xr.DataArray( new_values.reshape(new_shape), dims=new_dims, coords=new_coords ) ds[variable].attrs = attrs return ds
[docs] def filter( ds, weight_thresh=1.2, echo_frac_thresh=0.3, refl_thresh=0, obs_thresh=2, variables=None, ): """ Filter a GridRad dataset. Based on code from the GridRad website https://gridrad.org/software.html and edits by Stacey Hitchcock. Parameters ---------- ds : xarray.Dataset The GridRad dataset. weight_thresh : float, optional The bin weight threshold. Default is 1.5. echo_frac_thresh : float, optional The echo fraction threshold. Default is 0.6. refl_thresh : float, optional The reflectivity threshold. Default is 0. obs_thresh : int, optional The number of observations. Default is 3. Returns ------- ds : xarray.Dataset The filtered GridRad dataset """ logger.debug("Filtering GridRad data") if variables is None: variables = [v for v in gridrad_variables if v in ds.variables] # Calcualate echo fraction efficiently using where args = [ds["Nradobs"] > 0, ds["Nradecho"] / ds["Nradobs"], 0.0] kwargs = {"keep_attrs": True} echo_fraction = xr.where(*args, **kwargs) echo_fraction = echo_fraction.astype(np.float32) # Get indices to filter weight_cond = xr.where(ds["wReflectivity"] < weight_thresh, True, False, **kwargs) refl_cond = xr.where(ds["Reflectivity"] <= refl_thresh, True, False, **kwargs) frac_cond = xr.where(echo_fraction < echo_frac_thresh, True, False, **kwargs) obs_cond = xr.where(ds["Nradobs"] <= obs_thresh, True, False, **kwargs) # Filter cells below weight and reflectivity thresholds cond_refl = xr.where(weight_cond & refl_cond, True, False, **kwargs) # Filter cells containing at < obs_thresh observations. If at least obs_thresh # observations, filter cells with echoes in less than echo_fraction_thresh of the # total observations cond_frac = xr.where(obs_cond | frac_cond, True, False, **kwargs) # Retain values not filtered preserved = xr.where(~cond_refl & ~cond_frac, True, False, **kwargs) for var in variables: ds[var] = ds[var].where(preserved) return ds
[docs] def remove_speckles(ds, window_size=5, coverage_thresh=0.32, variables=None): """ Remove speckles in GridRad data. Based on code from the GridRad website https://gridrad.org/software.html and edits by Stacey Hitchcock. Modified from the original to use xr.rolling instead of np.roll to correctly handle edges and corners. """ logger.debug("Removing speckles from the GridRad data") if variables is None: variables = [v for v in gridrad_variables if v in ds.variables] # refl_exists = np.isfinite(ds["Reflectivity"]).astype(float) refl_exists = xr.where(~np.isnan(ds["Reflectivity"]), True, False) min_size = window_size**3 * coverage_thresh speckle_mask = remove_small_objects(refl_exists.values > 0, min_size=min_size) for var in variables: ds[var] = ds[var].where(speckle_mask) return ds
[docs] def remove_low_level_clutter(ds, variables=None): """ Remove low level clutter from GridRad data. Based on code from the GridRad website https://gridrad.org/software.html and edits by Stacey Hitchcock. """ logger.debug("Removing low level clutter from the GridRad data") # Determine max heights of non-nan reflectivity values. If entire column is nan, # set max altitude to zero. refl_max = ds.Reflectivity.max(dim="Altitude", skipna=True) refl_0_alts = ds.Altitude.where(ds.Reflectivity > 0.0, 0.0) refl_0_max_alt = refl_0_alts.max(dim="Altitude") refl_0_min_alt = refl_0_alts.min(dim="Altitude") refl_5_max_alt = ds.Altitude.where(ds.Reflectivity > 5.0, 0.0).max(dim="Altitude") refl_15_max_alt = ds.Altitude.where(ds.Reflectivity > 15.0, 0.0).max(dim="Altitude") # Check for very weak echos below 4 km cond_1 = (refl_max < 20.0) & (refl_0_max_alt <= 4.0) & (refl_0_min_alt <= 3.0) # Check for very weak echos below 5 km cond_2 = (refl_max < 10.0) & (refl_0_max_alt <= 5.0) & (refl_0_min_alt <= 3.0) # Check for weak echos below 5 km. Note the > 0.0 ensures values actually exist cond_3 = (refl_5_max_alt <= 5.0) & (refl_5_max_alt > 0.0) & (refl_15_max_alt <= 3.0) # Check for weak echos below 2 km cond_4 = (refl_15_max_alt < 2.0) & (refl_15_max_alt > 0.0) cond = np.logical_not(cond_1 | cond_2 | cond_3 | cond_4) for var in variables: ds[var] = ds[var].where(cond) return ds
[docs] def remove_clutter_below_anvils(ds, variables=None): """ Remove clutter below anvils in GridRad data. Based on code from the GridRad website https://gridrad.org/software.html and edits by Stacey Hitchcock. """ logger.debug("Removing clutter below anvils from the GridRad data") # Check if reflectivity exists at, above and below 4 km exists = np.isfinite(ds.Reflectivity) exists_above_4 = exists.where(ds.Altitude >= 4.0, drop=True) exists_4 = exists_above_4.isel(Altitude=0) exists_above_4 = exists_above_4.sum(dim="Altitude") > 0 exists_below_4 = exists.where(ds.Altitude < 4.0, drop=True).sum(dim="Altitude") > 0 cond = exists_4 | ~exists_above_4 | ~exists_below_4 for var in variables: ds[var] = ds[var].where(cond) return ds
[docs] def remove_clutter(ds, variables=None, low_level=True, below_anvil=False): """ Remove clutter from GridRad data. Based on code from the GridRad website https://gridrad.org/software.html and edits by Stacey Hitchcock. Parameters ---------- ds : xarray.Dataset The GridRad dataset. variables : list, optional The variables to remove clutter from. Default is ["Reflectivity"]. Returns ------- ds : xarray.Dataset The GridRad dataset with clutter removed. """ logger.debug("Removing clutter from the GridRad data") if variables is None: variables = [v for v in gridrad_variables if v in ds.variables] # Remove low reflectivity low level clutter cond = (ds.Reflectivity >= 10.0) | (ds.Altitude > 4.0) for var in variables: ds[var] = ds[var].where(cond) # Attempt correlation based clutter removal if relevant variables exist correlation_var_list = ["DifferentialReflectivity", "CorrelationCoefficient"] if all(corr_var in ds.variables for corr_var in correlation_var_list): # Require either high correlation or reflectivity cond1 = ds["Reflectivity"] >= 40.0 | ds["r_HV"] >= 0.9 # Require moderate reflectivity or high correlation or low altitude cond2 = ds["Reflectivity"] >= 25.0 | ds["CorrelationCoefficient"] >= 0.95 cond2 = cond2 | ds["Altitude"] < 10.0 # Require both conditions above be met for var in variables: ds[var] = ds[var].where(cond1 & cond2) # First pass at speckle removal ds = remove_speckles(ds, variables=variables) if low_level: # Remove low level clutter. Note this can remove some low level cloud/drizzle ds = remove_low_level_clutter(ds, variables=variables) if below_anvil: # Remove clutter below anvils ds = remove_clutter_below_anvils(ds, variables=variables) # Second pass at speckle removal ds = remove_speckles(ds, variables=variables) return ds
[docs] def convert_gridrad(time, filepath, track_options, dataset_options, grid_options): """Convert gridrad data to the standard format.""" logger.debug(f"Converting GridRad dataset at time {time}.") # Open the dataset and perform preliminary filtering and decluttering ds = open_gridrad(filepath, dataset_options) ds = filter(ds, obs_thresh=dataset_options.obs_thresh) ds = remove_clutter(ds) # Ensure the intended time is in the dataset if time not in ds.time.values: raise ValueError(f"{time} not in {filepath}") # Restructure the dataset names_dict = {"Latitude": "latitude", "Longitude": "longitude"} names_dict.update({"Altitude": "altitude", "Reflectivity": "reflectivity"}) names_dict.update({"Nradobs": "number_of_observations"}) names_dict.update({"Nradecho": "number_of_echoes"}) ds = ds.rename(names_dict) for dim in ["latitude", "longitude", "altitude"]: ds[dim].attrs["standard_name"] = dim ds[dim].attrs["long_name"] = dim ds["altitude"] = ds["altitude"] * 1000 # Convert to meters kept_fields = dataset_options.fields + ["number_of_observations"] kept_fields += ["number_of_echoes"] dropped_fields = [f for f in ds.data_vars if f not in kept_fields] ds = ds.drop_vars(dropped_fields) for field in dataset_options.fields: ds[field] = ds[field].expand_dims("time") ds[field].attrs["long_name"] = field ds["longitude"] = ds["longitude"] % 360 utils.infer_grid_options(ds, grid_options) # Get the domain mask associated with the given object # Note the relevant domain mask is a function of how the object is detected, e.g. # which levels! domain_mask = get_domain_mask(ds, track_options, dataset_options) all_coords = utils.get_mask_boundary(domain_mask, grid_options) boundary_coords, simple_boundary_coords, boundary_mask = all_coords ds["domain_mask"] = domain_mask ds["boundary_mask"] = boundary_mask # Don't mask the gridcell areas cell_areas = grid.get_cell_areas(grid_options) ds["gridcell_area"] = (["latitude", "longitude"], cell_areas) area_attrs = {"units": "km^2", "standard_name": "area", "valid_min": 0} ds["gridcell_area"].attrs.update(area_attrs) # Apply the domain mask to the current grid ds = _utils.apply_mask(ds, grid_options) ds = ds.drop_vars(["number_of_observations", "number_of_echoes"]) return ds, boundary_coords, simple_boundary_coords
def get_domain_mask(ds, track_options, dataset_options): """ Get a domain mask for a GridRad dataset. """ domain_masks = [] dataset_name = dataset_options.name for level_options in track_options.levels: for object_options in level_options.objects: detected = "detection" in object_options.__class__.model_fields uses_dataset = dataset_name == object_options.dataset if detected and uses_dataset: args = [ds, dataset_options, object_options] mask = _utils.mask_from_observations(*args) domain_masks.append(mask) # Combine the masks if len(domain_masks) == 0: message = f"{dataset_name} not used for object detection. Check track_options." logger.debug(message) raise ValueError(message) domain_mask = reduce(lambda x, y: x * y, domain_masks) domain_mask = _utils.smooth_mask(domain_mask) logger.debug(f"Got domain mask for {dataset_name}.") return domain_mask dataset_id_converter = {"ds841.6": "d841006"}