"""Data processing utilities."""
import os
# Check if system is unix-like, as fcntl not supported on Windows
if os.name == "posix":
import fcntl
else:
message = "Note fcntl is not available on Windows. If you need to download data "
message = "do so before running thuner, or use a unix based system."
print(message)
# Check if system is unix-like, as xESMF is not supported on Windows
if os.name == "posix":
import xesmf as xe
else:
message = "Warning: Windows systems cannot run xESMF for regridding."
message += "If you need regridding, consider using a Linux or MacOS system."
print(message)
import multiprocessing
import threading
import subprocess
import zipfile
import time
from pathlib import Path
import requests
import cdsapi
import cv2
import numpy as np
import xarray as xr
from skimage.morphology import remove_small_objects, remove_small_holes
from scipy.ndimage import binary_dilation, binary_erosion
import thuner.log as log
import thuner.utils as utils
from thuner.config import get_outputs_directory
logger = log.setup_logger(__name__, level="DEBUG")
# Set the number of cv2 threads to 0 to avoid crashes.
# See https://github.com/opencv/opencv/issues/5150#issuecomment-675019390
cv2.setNumThreads(0)
[docs]
def get_demo_data(output_parent=None, remote_directory=None):
"""
Download the demo data from the AWS s3 thuner-storage bucket.
"""
if output_parent is None:
output_parent = get_outputs_directory()
if remote_directory is None:
remote_directory = "s3://thuner-storage/THUNER_output"
if not Path(output_parent).exists():
Path(output_parent).mkdir(parents=True)
if not Path(output_parent).is_dir():
raise ValueError(f"{output_parent} is not a directory.")
if not Path(output_parent).is_absolute():
raise ValueError(f"{output_parent} must be an absolute path.")
# Remove "s3://thuner-storage/" from the remote directory and append result to output parent
base_url = "s3://thuner-storage/THUNER_output/"
directory_structure = remote_directory.replace(base_url, "")
output_directory = output_parent / directory_structure
command = f"aws s3 sync {remote_directory} {output_directory}"
logger.info("Syncing directory %s. Please wait.", output_directory)
subprocess.run(command, shell=True, check=True)
class DownloadState(utils.SingletonBase):
"""
Singleton class to manage download state across multiple processes.
"""
def _initialize(self):
# Use both process and thread locks to prevent excess simultaneous downloads
self.process_lock = multiprocessing.Lock()
self.thread_lock = threading.Lock()
# Also use a lock file to store time since last download request, to prevent
# excess requests from all instances of thuner running on a given machine
with self.process_lock, self.thread_lock:
lock_directory = get_outputs_directory() / ".locks"
utils.create_hidden_directory(lock_directory)
self.lock_filepath = lock_directory / "download_lock"
if not Path(self.lock_filepath).exists():
Path(self.lock_filepath).touch()
# Time of the last download request
self.last_request_time = 0.0
# Impose a 1 second wait time between download requests
self.wait_time = 1
def wait_for_lockfile(self):
"""Wait for turn to download using filelock."""
with open(self.lock_filepath, "r+") as lock_file:
fcntl.flock(lock_file, fcntl.LOCK_EX)
try:
if Path(self.lock_filepath).exists():
self._handle_existing_lockfile(lock_file)
self._update_lockfile_timestamp(lock_file)
finally:
fcntl.flock(lock_file, fcntl.LOCK_UN)
def _handle_existing_lockfile(self, lock_file):
"""Handle the case where the lock file already exists."""
lock_file.seek(0)
last_request_time = float(lock_file.read() or 0)
next_time = time.time()
elapsed_time = next_time - last_request_time
if elapsed_time < self.wait_time:
logger.info(f"Recent download Request. Waiting {self.wait_time} seconds.")
time.sleep(self.wait_time - elapsed_time)
def _update_lockfile_timestamp(self, lock_file):
"""Update the lock file with the current timestamp."""
lock_file.seek(0)
lock_file.truncate()
lock_file.write(str(time.time()))
def url_to_filepath(url, parent_remote, parent_local):
"""Convert remote URL to local file path."""
if not isinstance(url, str):
raise TypeError("url must be a string")
parent_remote = parent_remote.rstrip("/")
parent_local = parent_local.rstrip("/")
return url.replace(parent_remote, parent_local)
def handle_response(response, already_downloaded, filepath):
"""Handle the response from a HTTP request."""
if response.status_code != 200 and response.status_code != 206:
message = f"Failed to download file to {filepath}. "
message += f"HTTP status code: {response.status_code}."
raise ValueError(message)
partial_filepath = filepath.with_suffix(".part")
checkpoint_size = 10 * 1024**2 # 10 MB
mb_downloaded = already_downloaded / 1024**2
logger.info(f"Downloaded {mb_downloaded:.1f} MB of {filepath.name}.")
last_checkpoint = already_downloaded
with open(partial_filepath, "ab") as f:
block_size = 1024 # 1 KB
for data in response.iter_content(block_size):
f.write(data)
already_downloaded += block_size
if already_downloaded - last_checkpoint > checkpoint_size:
mb_downloaded = already_downloaded / 1024**2
logger.info(f"Downloaded {mb_downloaded:.1f} MB of {filepath.name}.")
last_checkpoint = already_downloaded
partial_filepath.rename(filepath)
logger.info(f"Completed download of {filepath.name}.")
def get_header(url, filepath):
"""Get the header for a HTTP request."""
partial_filepath = filepath.with_suffix(".part")
if partial_filepath.exists():
logger.info("Resuming download of %s", url)
already_downloaded = partial_filepath.stat().st_size
resume_header = {"Range": f"bytes={partial_filepath.stat().st_size}-"}
else:
logger.info("Initiating download of %s", url)
already_downloaded = 0
resume_header = {}
return already_downloaded, resume_header
def download(url, parent_remote, parent_local, max_retries=10, retry_delay=2):
"""
Downloads a file from the given URL and saves it to the specified directory,
preserving the subdirectory structure of the remote filesystem.
"""
filepath = Path(url_to_filepath(url, parent_remote, parent_local))
if not filepath.parent.exists():
filepath.parent.mkdir(parents=True)
if filepath.exists():
logger.info("%s already exists.", filepath)
return str(filepath)
download_state = DownloadState()
download_state.wait_for_lockfile()
for attempt in range(1, max_retries + 1):
try:
already_downloaded, resume_header = get_header(url, filepath)
logger.info("Sending HTTP request to %s.", url)
response = requests.get(url, headers=resume_header, stream=True, timeout=10)
handle_response(response, already_downloaded, filepath)
return str(filepath)
except Exception as e:
logger.error(f"Download attempt {attempt} failed: {e}")
if attempt < max_retries:
logger.info(f"Retrying in {retry_delay} seconds...")
time.sleep(retry_delay)
else:
message = "Max retries reached. Download failed."
raise requests.exceptions.RequestException(message)
def unzip_file(filepath, directory=None):
"""
Downloads a .zip file from a URL, extracts the contents of the .zip file into
directory.
Parameters
----------
url : str
The URL of the .zip file.
directory : str
The path to the directory where the .zip file contents will be extracted.
Returns
-------
extracted_filepaths : list
A list of paths to the extracted files.
dir_size : int
The total size of the extracted files in bytes.
"""
if directory is None:
directory = Path(filepath).parent
filename = filepath.split("/")[-1]
out_directory = directory / Path(filename).stem
out_directory.mkdir(exist_ok=True)
# Open the .zip file
with zipfile.ZipFile(Path(directory) / filename, "r") as zip_ref:
# Extract all the contents of the .zip file in the temporary directory
zip_ref.extractall(Path(out_directory))
# Get the list of extracted files
extracted_filepaths = list(Path(out_directory).rglob("*"))
extracted_filepaths = [str(file) for file in extracted_filepaths]
dir_size = get_directory_size(directory)
return sorted(extracted_filepaths), dir_size
def check_valid_url(url):
if not isinstance(url, str):
raise TypeError("url must be a string")
# Send a HTTP request to the URL
logger.info("Sending HTTP request to %s.", url)
try:
response = requests.head(url, timeout=10)
# Check if the request is successful
if response.status_code == 200:
return True
else:
return False
except requests.RequestException:
return False
def get_directory_size(directory):
"""
Get the size of a directory in a human-readable format.
Parameters
----------
directory : str
The path to the directory.
Returns
-------
size : str
The size of the directory in a human-readable format.
"""
total = 0
for p in Path(directory).rglob("*"):
if p.is_file():
total += p.stat().st_size
# Convert size to a human-readable format
for unit in ["B", "KB", "MB", "GB", "TB"]:
if total < 1024.0:
return f"{total:.1f} {unit}"
total /= 1024.0
def consolidate_netcdf(filepaths, fields=None, concat_dim="time"):
"""
Consolidate multiple netCDF files into a single xarray dataset.
Parameters
----------
filepaths : list of str
List of filepaths to the netCDF files that need to be consolidated.
fields : list of str, optional
List of variable names to include in the consolidated dataset. If not provided,
all variables in the first file will be included.
concat_dim : str, optional
Dimension along which the datasets will be concatenated. Default is "time".
Returns
-------
dataset : xarray.Dataset
The consolidated xarray dataset containing the selected variables
from the input files.
"""
datasets = []
if fields is None:
fields = xr.open_dataset(filepaths[0]).data_vars.keys()
for filepath in filepaths:
dataset = xr.open_dataset(filepath)
dataset = dataset[fields]
datasets.append(dataset)
logger.info("Concatenating datasets along %s.", concat_dim)
dataset = xr.concat(datasets, dim=concat_dim)
return dataset
def get_pyart_grid_shape(grid_options):
"""
Get the grid shape for pyart grid.
Parameters
----------
grid_options : dict
Dictionary containing the grid options.
Returns
-------
tuple
The grid shape as a tuple of (nz, ny, nx).
"""
z_min = grid_options.start_z
z_max = grid_options.end_z
y_min = grid_options.start_y
y_max = grid_options.end_y
x_min = grid_options.start_x
x_max = grid_options.end_x
z_count = (z_max - z_min) / grid_options.grid_spacing[0]
y_count = (y_max - y_min) / grid_options.grid_spacing[1]
x_count = (x_max - x_min) / grid_options.grid_spacing[2]
if z_count.is_integer() and y_count.is_integer() and x_count.is_integer():
z_count = int(z_count)
y_count = int(y_count)
x_count = int(x_count)
else:
raise ValueError("Grid spacings must divide domain lengths.")
return (z_count, y_count, x_count)
def get_pyart_grid_limits(grid_options):
"""
Get the grid limits for pyart grid.
Parameters
----------
grid_options : dict
Dictionary containing the grid options.
Returns
-------
tuple
The grid limits as a tuple of ((z_min, z_max), (y_min, y_max), (x_min, x_max)).
"""
z_min = grid_options.start_z
z_max = grid_options.end_z
y_min = grid_options.start_y
y_max = grid_options.end_y
x_min = grid_options.start_x
x_max = grid_options.end_x
return ((z_min, z_max), (y_min, y_max), (x_min, x_max))
def cdsapi_retrieval(cds_name, request, local_path):
"""
Perform a CDS API retrieval.
Parameters
----------
cds_name : str
The name argument for the cdsapi retrieval.
request : dict
A dictionary containing the cdsapi retrieval options.
local_path : str
The local file path where the retrieved data will be saved.
Returns
-------
None
"""
if Path(local_path).exists():
logger.info("%s already exists.", local_path)
return
if not Path(local_path).parent.exists():
Path(local_path).parent.mkdir(parents=True)
cdsc = cdsapi.Client()
cdsc.retrieve(cds_name, request, local_path)
def log_dataset_update(local_logger, name, time):
time_str = utils.format_time(time, filename_safe=False)
local_logger.info(f"Updating {name} dataset for {time_str}.")
def log_convert(local_logger, name, filepath):
local_logger.info("Converting %s data from %s", name, Path(filepath).name)
def call_ncks(input_filepath, output_filepath, start, end, lat_range, lon_range):
"""Call ncks to subset a large netcdf file."""
# Read metadata using xr with lazy loading.
ds = xr.open_dataset(input_filepath, chunks={})
# Check if time variable "time" or "valid_time". If "valid_time" convert to "time".
if "valid_time" in ds:
time_var = "valid_time"
else:
time_var = "time"
# Ensure start and end times are within the dataset time range.
time = ds[time_var].values
if start < time[0]:
start = time[0]
if end > time[-1]:
end = time[-1]
lon_range = [(lon + 180) % 360 - 180 for lon in lon_range]
command = (
f"ncks -d {time_var},{start},{end} "
f"-d latitude,{lat_range[0]},{lat_range[1]} "
f"-d longitude,{lon_range[0]},{lon_range[1]} "
f"{input_filepath} {output_filepath}"
)
result = subprocess.run(command, shell=True, check=True)
if result.returncode != 0:
logger.error("ncks failed with return code %d.", result.returncode)
logger.error("Standard output: %s", result.stdout)
logger.error("Standard error: %s", result.stderr)
raise subprocess.CalledProcessError(result.returncode, command)
def apply_mask(ds, grid_options):
"""Apply a domain mask to an xr dataset."""
domain_mask = ds["domain_mask"]
if grid_options.name == "cartesian":
dims = ["y", "x"]
elif grid_options.name == "geographic":
dims = ["latitude", "longitude"]
else:
raise ValueError("Grid name must be 'cartesian' or 'geographic'.")
for var in ds.data_vars.keys() - ["gridcell_area", "domain_mask", "boundary_mask"]:
# Check if the variable has horizontal dimensions
if not set(dims).issubset(set(ds[var].dims)):
continue
# Otherwise apply the mask
broadcasted_mask = domain_mask.broadcast_like(ds[var])
# Apply the mask, setting unmasked values to NaN or 0 as appropriate
dtype = ds[var].dtype
float_types = [np.floating, np.complexfloating]
int_types = [np.integer, np.bool_]
if any(np.issubdtype(dtype, parent_type) for parent_type in float_types):
ds[var] = ds[var].where(broadcasted_mask)
elif any(np.issubdtype(dtype, parent_type) for parent_type in int_types):
ds[var] = ds[var].where(broadcasted_mask, 0)
else:
message = f"Cannot apply domain mask to {var}. Unknown data type."
raise ValueError(message)
return ds
def mask_from_input_record(
track_input_records, dataset_options, object_options, grid_options
):
"""
Get a domain mask from the input record. This function is used if a single domain
mask applies to all objects/times in the dataset.
"""
input_record = track_input_records[dataset_options.name]
domain_mask = input_record.domain_mask
boundary_coords = input_record.boundary_coordinates
return domain_mask, boundary_coords
def mask_from_observations(dataset, dataset_options, object_options=None):
"""Create domain mask based on number of observations in each cell."""
altitudes = object_options.detection.altitudes
if altitudes == [] or altitudes is None:
altitudes = [dataset.altitude.values.min(), dataset.altitude.values.max()]
else:
altitudes = object_options.detection.altitudes
num_obs = dataset["number_of_observations"].sel(altitude=slice(*altitudes))
mask = num_obs > dataset_options.obs_thresh
mask = mask.any(dim="altitude")
return mask.astype(bool)
def smooth_mask(mask):
"""Smooth a binary mask using morphological operations."""
# Remove objects smaller than a 150 km radius region.
# 0.02 lat/lon per pixel and ~100 km per lat/lon gives min area of np.pi*750**2
# or ~1.8e4 pixels.
mask_values = remove_small_objects(mask.values, min_size=1.8e4)
# Fill holes smaller 100 pixels
mask_values = remove_small_holes(mask_values, area_threshold=2e2)
# Pad the mask before dilation/erosion to avoid edge effects
pad_width = 3
mask_values = np.pad(mask_values, pad_width, mode="edge")
# Erode and dilate with large element to remove small objects and fill holes
mask_values = binary_erosion(mask_values, structure=np.ones((20, 20)))
mask_values = binary_dilation(mask_values, structure=np.ones((20, 20)))
# Repeat but with a smaller element, and applying dilation first to close lines
mask_values = binary_dilation(mask_values, structure=np.ones((5, 5)))
# Erode one more pixel than dilated to ensure objects don't artificially avoid
# touching the boundary
mask_values = binary_erosion(mask_values, structure=np.ones((6, 6)))
mask_values = mask_values[pad_width:-pad_width, pad_width:-pad_width]
# Another pass at hole filling
mask_values = remove_small_objects(mask_values, min_size=1.8e4)
mask_values = remove_small_holes(mask_values, area_threshold=2e2)
mask.values = mask_values
return mask
def mask_from_range(dataset, dataset_options, grid_options):
"""Create domain mask for gridcells greater than range from central point."""
if grid_options.name == "cartesian":
X, Y = np.meshgrid(grid_options.x, grid_options.y)
distances = np.sqrt(X**2 + Y**2)
coords = {"y": dataset.y, "x": dataset.x}
dims = {"y": dataset.y, "x": dataset.x}
elif grid_options.name == "geographic":
lons = grid_options.longitude
lats = grid_options.latitude
origin_longitude = float(dataset.attrs["origin_longitude"])
origin_latitude = float(dataset.attrs["origin_latitude"])
LON, LAT = np.meshgrid(lons, lats)
distances = utils.haversine(LAT, LON, origin_latitude, origin_longitude)
coords = {"latitude": dataset.latitude, "longitude": dataset.longitude}
dims = {"latitude": dataset.latitude, "longitude": dataset.longitude}
else:
raise ValueError("Grid name must be 'cartesian' or 'geographic'.")
units_dict = {"m": 1, "km": 1e3}
range = dataset_options.range * units_dict[dataset_options.range_units]
mask = distances <= range
mask = xr.DataArray(mask.astype(bool), coords=coords, dims=dims)
return mask
def get_encoding(ds):
"""Get encoding for writing masks to file."""
encoding = {}
for var in ds.variables:
encoding[var] = {"zlib": True, "complevel": 5}
return encoding
def get_geographic_regridder(
dataset, grid_options, dataset_options, latitude=None, longitude=None
):
"""Load an xesmf using stored weights if present."""
# Can probably abstract this part
weights_filepath = dataset_options.weights_filepath
if latitude is None or longitude is None:
latitude, longitude = grid_options.latitude, grid_options.longitude
dims_dict = {"latitude": latitude, "longitude": longitude}
dims = ["latitude", "longitude"]
ds = xr.Dataset({dim: ([dim], dims_dict[dim]) for dim in dims})
regrid_options = {"periodic": False, "extrap_method": None}
if not Path(weights_filepath).exists():
logger.info("Building regridder; this can take a while for large grids.")
regridder = xe.Regridder(dataset, ds, "bilinear", **regrid_options)
if dataset_options.reuse_regridder:
Path(weights_filepath).parent.mkdir(parents=True, exist_ok=True)
regridder.to_netcdf(weights_filepath)
# The filepath now exists, so the else case called next time
else:
logger.info("Loading regridder from file.")
regrid_options["weights"] = weights_filepath
regridder = xe.Regridder(dataset, ds, "bilinear", **regrid_options)
return regridder
def copy_attributes(ds, old_ds):
"""Copy attributes from one xarray dataset to another."""
for var in ds.data_vars:
if var in old_ds.data_vars:
ds[var].attrs = old_ds[var].attrs
for coord in ds.coords:
ds[coord].attrs = old_ds[coord].attrs
ds.attrs.update(old_ds.attrs)
ds.attrs["history"] += f", regridded using xesmf on " f"{np.datetime64('now')}"
return ds