Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions climanet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,4 @@
from climanet.dataset import STDataset
from climanet.utils import regrid_to_boundary_centered_grid

__all__ = [
"STDataset",
"st_encoder_decoder",
"regrid_to_boundary_centered_grid"
]
__all__ = ["STDataset", "st_encoder_decoder", "regrid_to_boundary_centered_grid"]
60 changes: 52 additions & 8 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(
raise ValueError(f"Spatial dimension '{dim}' not found in input data")

if (
patch_size[0] > daily_da.sizes[spatial_dims[0]] or patch_size[1] > daily_da.sizes[spatial_dims[1]]
patch_size[0] > daily_da.sizes[spatial_dims[0]]
or patch_size[1] > daily_da.sizes[spatial_dims[1]]
):
raise ValueError(
f"Patch size {patch_size} is larger than data dimensions {daily_da.sizes[spatial_dims]}"
Expand All @@ -54,9 +55,6 @@ def __init__(
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
self.lon_coords = daily_da[spatial_dims[1]].to_numpy().copy()

# Store the stats of the daily data before filling NaNs
self.daily_mean, self.daily_std = calc_stats(self.daily_np)

if land_mask is not None:
lm = land_mask.to_numpy().copy()
if lm.ndim == 3:
Expand All @@ -69,8 +67,11 @@ def __init__(
# daily_mask: True where NaN (i.e. missing ocean data, not land)
self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W)

# Fill NaNs with 0 in-place
np.nan_to_num(self.daily_np, copy=False, nan=0.0)
# Stats will be set later via set_stats() and NaNs will be filled with 0 in-place
self.daily_mean = None
self.daily_std = None
self._nans_filled = False
self._warned = False

# Precompute padded_days_mask as a tensor (same for all patches)
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()
Expand All @@ -96,7 +97,7 @@ def _compute_patch_indices(self, H: int, W: int) -> list:
f"Patch size {self.patch_size} does not evenly divide image dimensions (H={H}, W={W}). "
f"Uncovered pixels: {remainder_h} in height, {remainder_w} in width. "
f"Consider adjusting patch_size or image dimensions for full coverage.",
UserWarning
UserWarning,
)

# Generate non-overlapping patch indices
Expand All @@ -105,12 +106,18 @@ def _compute_patch_indices(self, H: int, W: int) -> list:

return [(i, j) for i in i_starts for j in j_starts]


def __len__(self):
return len(self.patch_indices)

def __getitem__(self, idx):
"""Get a spatiotemporal patch sample based on the index."""
if not self._nans_filled and not self._warned:
warnings.warn(
"NaNs have not been replaced. Call fill_nans_with_zero() before using the dataset.",
UserWarning,
)
self._warned = True

if idx < 0 or idx >= len(self.patch_indices):
raise IndexError("Index out of range")

Expand Down Expand Up @@ -159,3 +166,40 @@ def __getitem__(self, idx):
"lat_patch": lat_patch, # (H,)
"lon_patch": lon_patch, # (W,)
}

def compute_stats(self, indices: list = None) -> Tuple[np.ndarray, np.ndarray]:
"""Compute mean and std from specified indices (or all data if None).

Args:
indices: List of patch indices to compute stats from. If None, use all.

Returns:
Tuple of (mean, std) arrays
"""
if indices is None:
data = self.daily_np # (M, T, H, W)
else:
# Stack selected spatial patches
ph, pw = self.patch_size
patches = []
for idx in indices:
i, j = self.patch_indices[idx]
patch = self.daily_np[:, :, i : i + ph, j : j + pw]
patches.append(patch)
data = np.concatenate(patches, axis=-1)

mean, std = calc_stats(data) # (M,)

self.daily_mean = mean
self.daily_std = std

# Fill NaNs with 0 in-place after stats are computed
self.fill_nans_with_zero()

return mean, std

def fill_nans_with_zero(self):
"""Fill NaN values in daily_np with zero in-place."""
if not self._nans_filled:
np.nan_to_num(self.daily_np, copy=False, nan=0.0)
self._nans_filled = True
59 changes: 41 additions & 18 deletions climanet/predict.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
from pathlib import Path

import numpy as np
from torch.utils.data import Dataset
from climanet.st_encoder_decoder import SpatioTemporalModel
import xarray as xr
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter


def _setup_logging(log_dir: str) -> SummaryWriter:
"""Set up TensorBoard logging directory and writer."""
Path(log_dir).mkdir(parents=True, exist_ok=True)
return SummaryWriter(log_dir)
from climanet.utils import setup_logging, compute_masked_loss


def _save_netcdf(predictions: np.ndarray, dataset: Dataset, save_dir: str):
"""Helper function to convert predictions to xarray and save as netCDF."""
B, M, H, W = predictions.shape

lats = dataset.monthly_da.coords["lat"].values
lons = dataset.monthly_da.coords["lon"].values
times = dataset.monthly_da.coords["time"].values
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset
indices = dataset.indices if hasattr(dataset, "indices") else range(len(dataset))

full_predictions = np.empty((M, len(lats), len(lons)), dtype=predictions.dtype)
for i, (lat_start, lon_start) in enumerate(dataset.patch_indices):
lats = base_dataset.monthly_da.coords["lat"].values
lons = base_dataset.monthly_da.coords["lon"].values
times = base_dataset.monthly_da.coords["time"].values

full_predictions = np.full(
(M, len(lats), len(lons)), np.nan, dtype=predictions.dtype
)
for i, patch_idx in enumerate(indices):
lat_start, lon_start = base_dataset.patch_indices[patch_idx]
full_predictions[:, lat_start : lat_start + H, lon_start : lon_start + W] = (
predictions[i]
)
Expand Down Expand Up @@ -61,6 +59,7 @@ def predict_monthly_var(
batch_size: int = 2,
return_numpy: bool = True,
save_predictions: bool = True,
return_loss: bool = False,
device: str = "cpu",
run_dir: str = ".",
verbose: bool = True,
Expand All @@ -76,11 +75,13 @@ def predict_monthly_var(
Otherwise, returns a PyTorch tensor.
save_predictions: If True, convert the predictions to xarray and
save to disk as netCDF files and return the xarray Dataset.
return_loss: If True, also return the average loss over the dataset.
device: The device to run the predictions on (e.g., 'cpu' or 'cuda').
run_dir: Directory to save log files and predictions.
verbose: If True, prints progress information during prediction.
Returns:
A NumPy array, PyTorch tensor, or xarray Dataset containing the predicted values.
If return_loss is True, it also returns the average loss over the dataset.
"""
# Load the model if a path is provided
if isinstance(model, str):
Expand All @@ -95,15 +96,19 @@ def predict_monthly_var(
)

# Initialize an empty list to store predictions
M = dataset.monthly_np.shape[0]
H, W = dataset.patch_size
base_dataset = dataset.dataset if hasattr(dataset, "dataset") else dataset
base_dataset.fill_nans_with_zero() # Ensure NaNs are filled before prediction

M = base_dataset.monthly_np.shape[0]
H, W = base_dataset.patch_size
all_predictions = torch.empty(len(dataset), M, H, W)

# Set up logging
writer = _setup_logging(run_dir)
writer = setup_logging(run_dir)

with torch.no_grad():
idx = 0
average_loss = 0.0
for i, batch in enumerate(dataloader):
# Move batch to the appropriate device
predictions = model(
Expand All @@ -112,14 +117,29 @@ def predict_monthly_var(
batch["land_mask_patch"].to(device, non_blocking=use_cuda),
batch["padded_days_mask"].to(device, non_blocking=use_cuda),
)

# Compute masked loss
loss = compute_masked_loss(
predictions, batch["monthly_patch"], batch["land_mask_patch"]
)
average_loss += loss.item()

all_predictions[idx : idx + predictions.size(0)] = predictions.cpu()
idx += predictions.size(0)

if verbose:
print(f"Processed batch {i + 1}/{len(dataloader)}")
print(
f"Processed batch {i + 1}/{len(dataloader)}, with loss: {loss.item():.4f}"
)

writer.add_scalar("Progress/Batch", i + 1, idx)

average_loss = average_loss / len(dataloader)

if verbose:
print(f"Average loss over all batches: {average_loss:.4f}")
writer.add_scalar("Loss/Average", average_loss)

if return_numpy:
all_predictions = all_predictions.numpy()

Expand All @@ -136,4 +156,7 @@ def predict_monthly_var(
# Close the writer when done
writer.close()

if return_loss:
all_predictions = (all_predictions, average_loss)

return all_predictions
Loading