Training Models¶
Overview¶
This chapter documents the training stack used by AINPP-PB-LATAM from the point of view of the current codebase. It focuses on what is actually configurable today through Hydra and what that means for scientific benchmark experiments.
The training workflow is built around:
main.pyas the CLI entry point,- Hydra for model, dataset, loss, and runtime composition,
AINPPPBLATAMDatasetfor Zarr-based sequence sampling,run_trainingfor supervised training,run_gan_trainingfor adversarial training,- model definitions under
src/ainpp_pb_latam/models/.
End-to-End Training Flow¶
At runtime the project follows this sequence:
main.pyloadsconf/config.yaml.- Hydra composes the selected
model,dataset,training,loss,evaluation, andinferencegroups. - The dataset object is instantiated from
conf/dataset/gsmap.yaml. - The model is instantiated from the selected file in
conf/model/. - The loss is instantiated from the selected file in
conf/loss/. - The optimizer is created from the
trainingconfig. run_trainingperforms the epoch loop, validation, checkpointing, and early stopping.
Typical command:
Configuration Topology¶
The root training-related defaults are defined in conf/config.yaml:
defaults:
- _self_
- model: unet/direct
- discriminator: patchgan
- training: gan
- dataset: gsmap
- loss: hybrid_mse_ssim
- inference: default
- evaluation: default
Important implications:
- The root file defines shared dimensions such as
input_timesteps,output_timesteps,input_channels,hidden_channels, andkernel_size. - Several model configs interpolate these shared values instead of redefining them.
- You can modify the global temporal configuration once and let multiple configs inherit it.
- Some models expose additional constructor parameters that are not listed in the YAML by default; Hydra still allows overriding them explicitly.
Example:
python main.py task=train \
model=afno/direct \
input_timesteps=12 \
output_timesteps=6 \
+model.embed_dim=384 \
+model.depth=12
Tensor Shapes and Training Contract¶
The dataset returns:
- input tensor:
(B, Tin, C, H, W) - target tensor:
(B, Tout, C, H, W)
With the current default dataset and model assumptions:
Tin = 12Tout = 6C = 1H = patch_heightW = patch_width
For the default GSMaP setup:
This contract is central. If you change input_timesteps, output_timesteps, patch size, or the number of channels, the model and dataset must still agree on the same shape semantics.
Dataset Configuration in Detail¶
The dataset configuration lives in conf/dataset/gsmap.yaml and instantiates:
Core Temporal Parameters¶
input_timesteps: number of historical frames fed into the modeloutput_timesteps: number of future frames the model must predictstride: temporal stride between valid samples for each splitsteps_per_epoch: if defined, enables random sampling mode instead of deterministic traversal
The current split overrides are:
train.group=train,stride=1,steps_per_epoch=500validation.group=validation,stride=6,steps_per_epoch=500
This means:
- training samples are dense in time and randomly drawn,
- validation samples are sparser in time,
- both training and validation are capped by a fixed number of sampled items per epoch instead of scanning the full split.
Spatial Parameters¶
patch_heightpatch_widthpatch_stride_hpatch_stride_w
Current default:
When a stride is null, the dataset uses the patch size itself. That yields non-overlapping tiles unless the domain edge requires a final snapped patch for full coverage.
Data Variables¶
The dataset class supports:
input_var, defaultgsmap_nrttarget_var, defaultgsmap_mvkgroup, defaulttraindtype, defaultfloat32consolidated, defaulttruereturn_metadata, defaultfalse
These fields are useful when experimenting with alternate variables, storage conventions, or metadata-aware debugging.
Changing Input and Target Lengths¶
To modify the historical window and forecast horizon:
Because the dataset and most model configs interpolate these root values, this is the preferred way to change sequence lengths.
Changing Spatial Crop Size¶
To train on full-domain inputs:
python main.py task=train \
model=unet/direct \
dataset.dataset.patch_height=null \
dataset.dataset.patch_width=null
To train on smaller patches:
python main.py task=train \
model=unet/direct \
dataset.dataset.patch_height=256 \
dataset.dataset.patch_width=256
To create overlapping windows:
python main.py task=train \
model=unet/direct \
dataset.dataset.patch_height=320 \
dataset.dataset.patch_width=320 \
dataset.dataset.patch_stride_h=160 \
dataset.dataset.patch_stride_w=160
Changing Sampling Density¶
To use fewer random samples per epoch for fast prototyping:
python main.py task=train \
model=unet/direct \
dataset.overrides.train.steps_per_epoch=100 \
dataset.overrides.validation.steps_per_epoch=50
To force deterministic validation over the full split:
Changing Input and Target Variables¶
If the Zarr store contains alternative variable names:
python main.py task=train \
model=unet/direct \
dataset.dataset.input_var=gsmap_nrt \
dataset.dataset.target_var=gsmap_mvk
Dataloader Parameters¶
The dataset config also controls:
train_loader.batch_sizetrain_loader.num_workerstrain_loader.prefetch_factortrain_loader.pin_memoryval_loader.batch_sizeval_loader.num_workersval_loader.pin_memory
Example:
python main.py task=train \
model=unet/direct \
dataset.train_loader.batch_size=4 \
dataset.train_loader.num_workers=8 \
dataset.val_loader.batch_size=4
Supervised Training Configuration¶
The default supervised profile is conf/training/default.yaml.
Key fields:
mode: supervisedepochs: 50lr: 0.001batch_size: 16scheduler.patiencescheduler.factorcheckpoint.enabledcheckpoint.dircheckpoint.intervalcheckpoint.save_bestearly_stopping.enabledearly_stopping.patienceearly_stopping.deltaearly_stopping.mode
Notes from the current implementation:
- The optimizer is always Adam through
build_optimizer. lris used unless the config only defineslr_g.beta1andbeta2are also read if present in the training config.- The scheduler block exists in YAML, but no scheduler is currently attached inside
run_training.
Example of a slower, more conservative run:
python main.py task=train \
model=unet/direct \
training=default \
training.epochs=100 \
training.lr=0.0003 \
training.early_stopping.patience=20 \
training.checkpoint.interval=10
GAN Training Configuration¶
The adversarial profile is conf/training/gan.yaml.
Key fields:
mode: ganepochs: 100lr_glr_dbeta1beta2lambda_pixel- checkpoint settings
- early stopping settings
The intended logic in run_gan_training is:
- generator predicts future rainfall,
- discriminator sees the concatenated history and future sequence,
- discriminator learns to distinguish real future targets from generated future targets,
- generator balances adversarial realism and pixel-level accuracy.
The paired discriminator config is:
_target_: ainpp_pb_latam.models.gan.discriminator.PatchDiscriminator3D
input_channels: 1
ndf: 64
n_layers: 1
norm_type: "instance"
A typical adversarial run would conceptually look like:
However, note one practical detail: the current main.py path dispatches task=train through run_training. If you want full GAN training behavior, the CLI path still needs to branch into run_gan_training when training.mode=gan.
Model Catalog¶
This section summarizes the models currently wired into Hydra under conf/model/.
UNet Direct¶
Config file:
_target_: ainpp_pb_latam.models.unet.forecaster.UNetMultiHorizon
input_timesteps: ${input_timesteps}
input_channels: 1
output_timesteps: ${output_timesteps}
output_channels: 1
features: [64, 128, 256, 512]
kernel_size: 3
bilinear: true
nonnegativity: "relu"
Training behavior:
- flattens the temporal dimension into channels,
- applies a 2D U-Net to the stacked history,
- predicts all future horizons in one forward pass,
- reshapes the output back to
(B, Tout, C, H, W), - applies a non-negativity constraint at the end.
Best when:
- you want a strong baseline,
- you need explicit control over encoder depth,
- you want predictable behavior on patch-based training.
Important parameters:
features: controls encoder and decoder width at each levelkernel_size: spatial receptive field per blockbilinear: chooses bilinear upsampling instead of transposed convolutionnonnegativity:relu,softplus, ornone
Examples:
python main.py task=train \
model=unet/direct \
model.features=[32,64,128,256] \
model.kernel_size=5 \
model.bilinear=false \
model.nonnegativity=softplus
UNet Autoregressive¶
Config file:
_target_: ainpp_pb_latam.models.unet.forecaster.UNetAutoRegressive
input_timesteps: 12
input_channels: 1
output_timesteps: 6
features: [64, 128, 256, 512]
kernel_size: 3
bilinear: true
nonnegativity: "relu"
Training behavior:
- predicts one future frame at a time,
- appends each prediction back into the context window,
- rolls forward until the requested forecast horizon is reached.
Best when:
- your scientific question values sequential dependency between horizons,
- you want the model to explicitly learn rollout dynamics,
- error propagation across future steps is acceptable or desired to study.
Important caution:
- this file currently hardcodes
input_timesteps=12andoutput_timesteps=6rather than interpolating root values, - if you change the global root dimensions, also override the model fields explicitly for the autoregressive U-Net.
Example:
python main.py task=train \
model=unet/autoregressive \
model.input_timesteps=18 \
model.output_timesteps=12 \
dataset.dataset.input_timesteps=18 \
dataset.dataset.output_timesteps=12
ConvLSTM¶
Config file:
_target_: ainpp_pb_latam.models.convlstm.forecaster.ConvLSTMMultiHorizon
input_channels: ${input_channels}
hidden_channels: ${hidden_channels}
kernel_size: ${kernel_size}
output_timesteps: ${output_timesteps}
Training behavior:
- uses a ConvLSTM encoder-decoder,
- processes the input sequence recurrently,
- decodes future steps autoregressively from latent state,
- maps hidden features into one precipitation channel through a small output head.
Best when:
- temporal recurrence is central to the experiment,
- you want a sequence model without flattening time into channels,
- you want to study the effect of hidden-state depth and recurrent receptive field.
Important parameters:
hidden_channels: number and width of ConvLSTM layerskernel_size: convolution kernel used inside recurrent cellsoutput_timesteps: rollout horizon
Examples:
python main.py task=train \
model=convlstm/direct \
input_channels=1 \
output_timesteps=12 \
dataset.dataset.output_timesteps=12
AFNO¶
Config file:
_target_: ainpp_pb_latam.models.afno.forecaster.AFNO2D
input_timesteps: ${input_timesteps}
output_timesteps: ${output_timesteps}
Additional constructor parameters supported by the code:
img_sizeinput_channelsoutput_channelsembed_dimdepthpatch_sizenum_blocks
Training behavior:
- flattens time into channels,
- embeds the spatial field into non-overlapping patches,
- applies a stack of Fourier blocks in latent space,
- reconstructs the output with a transposed convolution head.
Best when:
- global spatial coupling matters,
- you want a spectral model,
- patch-token processing is more attractive than deep CNN decoding.
Critical caveat:
AFNO2Ddefaults toimg_size=(880, 970)andpatch_size=10,- if your dataset crops are
320 x 320, you should overridemodel.img_sizeaccordingly, patch_sizemust divide both image dimensions used by the model.
Examples:
python main.py task=train \
model=afno/direct \
dataset.dataset.patch_height=320 \
dataset.dataset.patch_width=320 \
+model.img_size=[320,320] \
+model.patch_size=10 \
+model.embed_dim=384 \
+model.depth=12 \
+model.num_blocks=12
python main.py task=train \
model=afno/direct \
dataset.dataset.patch_height=256 \
dataset.dataset.patch_width=256 \
+model.img_size=[256,256] \
+model.patch_size=16
ResNet50¶
Config file:
_target_: ainpp_pb_latam.models.resnet50.forecaster.ResNet50MultiHorizon
input_timesteps: ${input_timesteps}
output_timesteps: ${output_timesteps}
Additional constructor parameter supported by the code:
pretrained
Training behavior:
- collapses time into the channel dimension,
- uses
timmresnet50das a feature extractor, - decodes multi-scale features through U-Net-like upsampling blocks,
- predicts all future steps jointly.
Best when:
- you want an ImageNet-style convolutional encoder,
- transfer learning from pretrained image backbones is acceptable,
- the experiment benefits from a robust CNN feature hierarchy.
Examples:
Because in_chans=input_timesteps, changing input_timesteps changes the first convolution shape in the timm backbone.
InceptionV4¶
Config file:
_target_: ainpp_pb_latam.models.inceptionv4.forecaster.InceptionV4MultiHorizon
input_timesteps: ${input_timesteps}
output_timesteps: ${output_timesteps}
Additional constructor parameter supported by the code:
pretrained
Training behavior:
- uses a
timmInception-V4 encoder withfeatures_only=True, - decodes the multiscale feature pyramid back to full resolution,
- predicts all horizons in one shot,
- enforces non-negative rainfall through
relu.
Best when:
- you want a deeper multi-branch CNN encoder,
- you want strong spatial feature extraction with pretrained weights,
- you are benchmarking classic computer vision backbones against nowcasting-specific designs.
Example:
Xception¶
Config file:
_target_: ainpp_pb_latam.models.xception.forecaster.XceptionMultiHorizon
input_timesteps: ${input_timesteps}
output_timesteps: ${output_timesteps}
Additional constructor parameter supported by the code:
pretrained
Training behavior:
- uses a
timmXception encoder adapted toin_chans=input_timesteps, - decodes through skip-connected upsampling blocks,
- predicts all future frames simultaneously,
- applies
reluto the output before restoring the channel dimension.
Best when:
- you want depthwise-separable convolutional features,
- you want a lighter alternative to some classical heavy backbones,
- you want to compare pretrained encoder transfer against U-Net and ConvLSTM baselines.
Example:
Loss Functions¶
Loss functions are instantiated from conf/loss/.
weighted_mse¶
Config:
Use this when:
- you want to upweight heavy-rain pixels,
- standard MSE underestimates intense precipitation,
- the benchmark should favor amplitude accuracy in high-value regions.
Higher alpha increases the emphasis on rain cores. threshold defines from which target intensity the extra weighting starts.
Example:
python main.py task=train \
model=unet/direct \
loss=weighted_mse \
loss.alpha=10.0 \
loss.threshold=1.0
huber¶
Robust against outliers and often a safer regression baseline than plain MSE on noisy fields.
Example:
logcosh¶
Smooth transition between L2-like and L1-like behavior. Useful when you want stable optimization but less sensitivity to extreme residuals than MSE.
dice¶
Optimizes overlap on a rain mask rather than continuous intensity values.
Important caveat:
- this is best for event extent, not calibrated rainfall amplitude,
- it binarizes the target using the configured threshold.
Example:
focal¶
Binary focal loss over thresholded rainfall occurrence.
Important caveat:
- the implementation expects logits and internally applies
binary_cross_entropy_with_logits, - this is more appropriate for event detection than direct rainfall regression.
Example:
python main.py task=train \
model=unet/direct \
loss=focal \
loss.threshold=0.1 \
loss.alpha=0.25 \
loss.gamma=2.0
spectral¶
Frequency-domain loss to preserve structure and reduce blur.
Example:
torrential¶
Tiered intensity weighting for severe rainfall events.
Example:
python main.py task=train \
model=unet/direct \
loss=torrential \
loss.thresholds=[5.0,20.0,50.0] \
loss.weights=[2.0,5.0,10.0]
hybrid_mse_ssim¶
Current config:
_target_: ainpp_pb_latam.losses.HybridLoss
weights: [1.0, 0.2]
losses:
- _target_: ainpp_pb_latam.losses.WeightedMSELoss
alpha: 2.0
threshold: 0.0
- _target_: ainpp_pb_latam.losses.SSIMLoss
window_size: 11
in_channels: 1
This is a practical default when you want:
- amplitude fidelity,
- structural coherence,
- reduced blur compared with pure pixel losses.
Example:
python main.py task=train \
model=unet/direct \
loss=hybrid_mse_ssim \
loss.weights=[1.0,0.1] \
loss.losses[0].alpha=4.0 \
loss.losses[1].window_size=7
sota¶
Current config combines:
AdvancedTorrentialLossSpectralLossPerceptualLoss
This profile is the most ambitious option in the repository because it combines amplitude weighting, frequency structure, and image-feature realism.
Important caution:
PerceptualLossattempts to load VGG16 pretrained weights,- if weights are unavailable, it falls back to MSE-like behavior,
- this may make runs less reproducible across environments if internet or cached weights differ.
Recommended Override Patterns¶
Change Only the Model Family¶
Change Forecast Horizon¶
python main.py task=train \
model=unet/direct \
input_timesteps=24 \
output_timesteps=12 \
dataset.dataset.input_timesteps=24 \
dataset.dataset.output_timesteps=12
Change Patch Geometry¶
python main.py task=train \
model=resnet50/direct \
dataset.dataset.patch_height=256 \
dataset.dataset.patch_width=384 \
dataset.dataset.patch_stride_h=128 \
dataset.dataset.patch_stride_w=192
Change Rain-Event Emphasis¶
python main.py task=train \
model=unet/direct \
loss=weighted_mse \
loss.alpha=8.0 \
loss.threshold=1.0
Disable Pretrained Weights¶
Increase Recurrent Capacity¶
Practical Constraints and Caveats¶
These are important when designing experiments:
UNetAutoRegressivedoes not interpolate root dimensions in its YAML by default.AFNO2Drequiresimg_sizeandpatch_sizecompatibility with the actual crop size.- The scheduler block exists in config but is not currently consumed by the supervised engine.
- The current
task=trainpath inmain.pydoes not yet branch intorun_gan_trainingautomatically. timmbackbones depend on the installed deep learning extras and pretrained weight availability.- Large patch sizes and deep CNN encoders will increase memory pressure quickly on HPC jobs.
Experiment Design Suggestions¶
For a reliable benchmark progression:
- Start with
unet/directandloss=hybrid_mse_ssim. - Tune patch size and
steps_per_epochuntil I/O and GPU memory are stable. - Sweep rain-emphasis losses such as
weighted_mseandtorrential. - Benchmark recurrent behavior with
convlstm/direct. - Benchmark transfer-learning encoders with
resnet50/direct,inceptionv4/direct, andxception/direct. - Use
afno/directonly after aligningmodel.img_sizewith the actual crop geometry.
Minimal Reproducible Training Recipes¶
Strong Baseline¶
python main.py task=train \
model=unet/direct \
training=default \
loss=hybrid_mse_ssim \
dataset.train_loader.batch_size=4 \
dataset.val_loader.batch_size=4
Heavy-Rain Baseline¶
python main.py task=train \
model=unet/direct \
training=default \
loss=torrential \
loss.thresholds=[10.0,30.0,50.0] \
loss.weights=[2.0,5.0,10.0]
Recurrent Baseline¶
python main.py task=train \
model=convlstm/direct \
training=default \
loss=weighted_mse \
hidden_channels=[32,64,64]