Skip to content

lbruand-db/distributed-transolver3

Repository files navigation

Mesh-Sharded Distributed Transolver-3

CI Ruff ty pytest Python 3.12 Built with Claude Code

Scaling Transformer Solvers to Industrial-Scale Geometries (100M+ cells).

Based on the Transolver paper (ICML 2024 Spotlight) and the Transolver-3 paper.

Context

🌊 Traditional CFD solves Navier-Stokes on fine meshes using HPC clusters β€” a single DrivAerML car aerodynamics run with 140M cells takes hours on hundreds of CPU cores.

🧠 Transolver replaces the iterative PDE solver with a transformer that learns the physics directly from data, predicting pressure, velocity, and other fields in a single forward pass.

πŸ”¬ Transolver-3 scales this to industrial-scale meshes (100M+ cells) through physics-aware attention in a compressed "slice domain" of only 64 slices.

πŸ–₯️ Mesh-sharded DDP distributes meshes too large for a single GPU across multiple GPUs β€” each processes its local partition and all-reduces only the tiny slice accumulators (~514 KB/layer).

⚑ The result: 10-100Γ— faster than classical solvers at engineering-grade accuracy.

DrivAerML pressure comparison

Key Innovations

  1. Faster Slice & Deslice β€” Linear projections moved from O(N) mesh domain to O(M) slice domain via matrix multiplication associativity
  2. Geometry Slice Tiling β€” Input partitioned into tiles with gradient checkpointing, reducing peak memory from O(NM) to O(N_t*M)
  3. Geometry Amortized Training β€” Train on random subsets (100K-400K) of full mesh each iteration
  4. Physical State Caching β€” Two-phase inference: build cache from chunks, decode any point
  5. Mixed Precision β€” Full autocast + GradScaler support, halving memory footprint
  6. Mesh-Sharded Distribution β€” Shard meshes >100 GB across GPUs; all-reduce only the tiny slice accumulators (~514 KB/layer)

Setup on Databricks

Notebook (first cell)

%pip install /Workspace/Repos/<user>/Transolver -q
dbutils.library.restartPython()

DAB deployment

databricks bundle deploy -t a10g       # 4x A10G (96 GB) β€” default
databricks bundle deploy -t a100_40    # 8x A100 40GB
databricks bundle deploy -t a100_80    # 8x A100 80GB

DAB Training Pipeline

The full pipeline runs 5 sequential tasks, each on its own cluster. MLflow is the single source of truth for model artifacts β€” no checkpoint files are passed between tasks.

databricks bundle deploy -t a10g
databricks bundle run transolver3_training_pipeline
Task Cluster What it does
preprocess i3.xlarge (CPU) Register mesh metadata + compute stats in Delta
train g5.12xlarge (4x A10G) Mesh-sharded DDP training via TorchDistributor, live MLflow metrics
evaluate g5.12xlarge (4x A10G) Load model from MLflow run, run cached inference on test set
register i3.xlarge (CPU) Promote already-logged model to UC Model Registry
deploy i3.xlarge (CPU) Create/update Model Serving endpoint with scale-to-zero

The train task uses TorchDistributor(local_mode=True) to launch torchrun on a single multi-GPU node. Each GPU loads a disjoint 1/K shard of the mesh via mmap range reads. Gradients are all-reduced via NCCL. See SPECS/DISTRIBUTED_ARCHITECTURE.md for Mermaid diagrams of the full architecture.

Other DAB jobs

databricks bundle run gpu_memory_benchmark          # Single-GPU memory sweep
databricks bundle run distributed_sharded_test      # 2-GPU validation
databricks bundle run test_mlflow_auth              # Smoke test MLflow auth in child processes
databricks bundle run test_register_deploy          # Serverless register + deploy (fast iteration)

Claude Skills (Newcomer Guide)

Four Claude Code skills in skills/ provide step-by-step guidance for newcomers. All skills target Databricks notebooks and DABs β€” no local setup required.

Skill Purpose
transolver-data Load, inspect, validate .npz meshes in UC Volumes; normalization; memory estimation
transolver-run Config presets (small/medium/large), training in notebooks, 3-phase pipeline, TorchDistributor, DAB workflows
transolver-analyze Loss interpretation, per-channel error stats, physical bounds checking, PSI drift detection, GPU profiling
transolver-deploy MLflow tracking, UC model registration, serving endpoints, inference table monitoring, end-to-end checklist

File Structure

transolver3/                          # Core package
β”œβ”€β”€ physics_attention_v3.py           # Optimized Physics-Attention
β”œβ”€β”€ transolver3_block.py              # Encoder block with tiled MLP
β”œβ”€β”€ model.py                          # Transolver3 model
β”œβ”€β”€ amortized_training.py             # Training (sampler, loss, scheduler, train_step)
β”œβ”€β”€ inference.py                      # CachedInference + DistributedCachedInference
β”œβ”€β”€ distributed.py                    # Multi-GPU mesh sharding utilities
β”œβ”€β”€ normalizer.py                     # InputNormalizer, TargetNormalizer
β”œβ”€β”€ profiling.py                      # Memory/latency benchmarking
β”œβ”€β”€ serving.py                        # MLflow pyfunc wrapper for Model Serving
β”œβ”€β”€ mlflow_utils.py                   # Experiment tracking + model logging
β”œβ”€β”€ data_catalog.py                   # Delta Lake mesh metadata integration
β”œβ”€β”€ databricks_training.py            # TorchDistributor launcher + Spark preprocessing
β”œβ”€β”€ monitoring.py                     # Bounds checking + PSI drift detection
└── common.py                         # MLP, activations, timestep_embedding

resources/                            # DAB job definitions
β”œβ”€β”€ training_workflow.yml             # 5-task pipeline (preprocess β†’ train β†’ evaluate β†’ register β†’ deploy)
β”œβ”€β”€ serving_endpoint.yml              # Model Serving endpoint config
β”œβ”€β”€ gpu_benchmark_job.yml             # Single-GPU memory benchmark
β”œβ”€β”€ distributed_test_job.yml          # 2-GPU mesh-sharded test
β”œβ”€β”€ test_mlflow_auth_job.yml          # Smoke test MLflow auth in TorchDistributor children
β”œβ”€β”€ test_register_job.yml             # Serverless checkpoint inspection test
└── test_register_deploy_job.yml      # Serverless register + deploy test

scripts/                              # Entry points for DAB tasks
β”œβ”€β”€ preprocess.py                     # Register mesh metadata + compute stats
β”œβ”€β”€ register_model.py                 # Promote model from MLflow run to UC registry
β”œβ”€β”€ deploy_endpoint.py                # Deploy to Databricks serving endpoint
β”œβ”€β”€ test_mlflow_auth.py               # MLflow auth propagation smoke test
└── test_register.py                  # Checkpoint inspection + model load test

skills/                               # Claude Code skills for newcomers
β”œβ”€β”€ transolver-data.md                # Mesh data management
β”œβ”€β”€ transolver-run.md                 # Training & simulation
β”œβ”€β”€ transolver-analyze.md             # Results analysis & drift
└── transolver-deploy.md              # Databricks deployment lifecycle

Industrial-Scale-Benchmarks/          # Experiments
β”œβ”€β”€ exp_nasa_crm.py                   # NASA-CRM (~400K cells)
β”œβ”€β”€ exp_ahmed_ml.py                   # AhmedML (~20M cells)
β”œβ”€β”€ exp_drivaer_ml.py                 # DrivAerML (~160M cells, single GPU)
β”œβ”€β”€ exp_drivaer_ml_distributed.py     # DrivAerML distributed (multi-GPU)
β”œβ”€β”€ dataset/                          # Dataset loaders (with mesh sharding)
└── utils/metrics.py                  # Evaluation metrics

experiments/                          # v1 vs v3 comparison
β”œβ”€β”€ compare_v1_v3_drivaer.py          # Synthetic data comparison
β”œβ”€β”€ compare_v1_v3_real_drivaer.py     # Real DrivAerML VTP data comparison
β”œβ”€β”€ COMPARE_v1v3.md                   # Results and analysis
└── results/                          # Pressure heatmap PNGs

benchmarks/                           # GPU benchmarking
β”œβ”€β”€ gpu_memory_benchmark.py           # Sweep mesh sizes, measure all 3 phases
└── test_sharded_distributed.py       # Distributed sharding validation test

SPECS/                                # Design documentation
β”œβ”€β”€ SPEC.md                           # Core v3 architecture specification
β”œβ”€β”€ DISTRIBUTED.md                    # Multi-GPU distribution design
β”œβ”€β”€ DISTRIBUTED_ARCHITECTURE.md       # Mermaid diagrams: pipeline, process model, data flow
β”œβ”€β”€ CRITICAL_ISSUES.md                # Known issues & fixes
β”œβ”€β”€ DIFFERENTIATORS.md                # Why Databricks is ideal
└── VALUEADDED.md                     # Databricks integration roadmap

tests/                                # 100 tests
β”œβ”€β”€ test_transolver3.py               # Core model tests (41)
β”œβ”€β”€ test_distributed.py               # Distributed sharding tests (11)
β”œβ”€β”€ test_serving.py                   # Serving tests (4)
β”œβ”€β”€ test_monitoring.py                # Monitoring tests (5)
β”œβ”€β”€ test_data_catalog.py              # Catalog tests (6)
β”œβ”€β”€ test_mlflow_utils.py              # MLflow tests (4)
└── test_databricks_training.py       # Training integration + auth propagation tests (12)

Memory Scaling

With tile_size=100K and fp16, the paper's claim of 2.9M cells on a single A100 80GB is achievable (~14 GB activations).

Profiling (notebook on GPU cluster)

from transolver3.profiling import benchmark_scaling, format_benchmark_table

results = benchmark_scaling(model, mesh_sizes=[1000, 10000, 100000],
    configs=[
        {'label': 'no_tiling', 'num_tiles': 0},
        {'label': 'tile_100k', 'tile_size': 100_000},
    ])
print(format_benchmark_table(results))

Multi-GPU Distribution

Transolver-3 already handles industrial-scale meshes on a single GPU via amortized subsampling and tiled attention. Multi-GPU distribution shards the mesh across GPUs to parallelize computation and reduce wall-clock time. Each GPU processes 1/K of the mesh independently; the slice accumulators s_raw (B,H,M,C) are additive and all-reduced (~514 KB/layer).

Validated on 4x NVIDIA A10G: sharded cache and decode produce zero numerical difference vs single-GPU. See SPECS/DISTRIBUTED.md for the original design.

GPU Benchmark (DAB)

Three DAB targets map to different GPU instances:

Target Instance GPU VRAM Use case
a10g (default) g5.12xlarge 4x NVIDIA A10G 96 GB Multi-GPU training, benchmarks
a100_40 p4d.24xlarge 8x NVIDIA A100 320 GB Large-scale training
a100_80 p4de.24xlarge 8x NVIDIA A100 640 GB Full-scale DrivAerML
databricks bundle deploy -t a10g
databricks bundle run gpu_memory_benchmark          # Single-GPU memory sweep
databricks bundle run distributed_sharded_test      # 2-GPU validation
databricks bundle run training_workflow             # Full 5-task pipeline

The benchmark sweeps mesh sizes and measures peak GPU memory across all 3 pipeline phases (training, cache build, decode) using synthetic DrivAer ML data.

v1 vs v3 Comparison

Includes experiments comparing Transolver v1 and v3 on both synthetic and real DrivAerML data. See experiments/COMPARE_v1v3.md for full results and pressure heatmaps on the real DrivAer vehicle.

DrivAerML pressure comparison

Citation

@inproceedings{wu2024Transolver,
  title={Transolver: A Fast Transformer Solver for PDEs on General Geometries},
  author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long},
  booktitle={International Conference on Machine Learning},
  year={2024}
}

@article{wu2026Transolver3,
  title={Transolver++: Industrial-Scale Simulation with Transformer Solvers},
  author={Haixu Wu and Huakun Luo and Haowen Wang and Jianmin Wang and Mingsheng Long},
  journal={arXiv preprint arXiv:2602.04940},
  year={2026}
}

About

A reimplementation of Transolver v3 paper ( arxiv 2602.04940) with distributed training and inference on databricks

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages