PyGVAMP Pipeline Roadmap¶
This document outlines the current state of the PyGVAMP pipeline, steps needed to complete it, and technical debt that should be addressed.
Pipeline Overview¶
PyGVAMP is a refactored implementation of GraphVAMPNets for analyzing molecular dynamics (MD) trajectories using the Variational Approach for Markov Processes (VAMP). The project achieves ~50x speedup compared to the original implementation through PyTorch Geometric architecture.
What VAMP/VAMPNets Does¶
- Input: MD trajectory files (.xtc, .dcd) + topology (.pdb)
- Process: Converts molecular structures to k-NN graphs, learns latent representations via message-passing neural networks
- Output: Discrete state assignments, transition probabilities, and slow collective variable identification
Pipeline Architecture Diagram¶
┌─────────────────────────────────────────────────────────────────────────────────┐
│ PyGVAMP MASTER PIPELINE │
│ (master_pipeline.py) │
└────────────────────────────────────┬────────────────────────────────────────────┘
│
┌────────────────┴────────────────┐
▼ │
┌─────────────────────────────────────┐ │
│ CONFIGURATION LOADING │ │
│ ───────────────────────── │ │
│ • Presets (YAML/JSON) │ │
│ • CLI argument overrides │ │
│ • BaseConfig / SchNetConfig │ │
└─────────────────┬───────────────────┘ │
│ │
▼ │
┌═══════════════════════════════════════════════════════════════════════════════┐
║ PHASE 1: PREPARATION ║
║ (preparation.py) ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ ┌──────────────────┐ ┌──────────────────┐ ┌──────────────────────────┐ ║
║ │ Trajectory Files │ │ Topology File │ │ Configuration │ ║
║ │ (.xtc/.dcd) │ │ (.pdb) │ │ (selection, stride, │ ║
║ └────────┬─────────┘ └────────┬─────────┘ │ lag_time, n_neighbors) │ ║
║ │ │ └────────────┬─────────────┘ ║
║ └───────────┬───────────┘ │ ║
║ ▼ │ ║
║ ┌─────────────────────────────────────────────────────────────────────────┐ ║
║ │ VAMPNetDataset │ ║
║ │ ┌─────────────────────────────────────────────────────────────────┐ │ ║
║ │ │ 1. Load trajectory frames (MDTraj) │ │ ║
║ │ │ 2. Select atoms (e.g., "name CA" for alpha carbons) │ │ ║
║ │ │ 3. Apply stride (every Nth frame) │ │ ║
║ │ │ 4. Calculate pairwise distances │ │ ║
║ │ │ 5. Build k-NN graphs for each frame │ │ ║
║ │ │ 6. Gaussian expansion of edge distances │ │ ║
║ │ │ 7. Create time-lagged pairs (t=0, t=lag) │ │ ║
║ │ └─────────────────────────────────────────────────────────────────┘ │ ║
║ └─────────────────────────────────────────────────────────────────────────┘ ║
║ │ ║
║ ▼ ║
║ ┌───────────────────────┐ ║
║ │ Cached Dataset │ ──► Hash-based caching for reuse ║
║ │ (PyG Data objects) │ ║
║ └───────────────────────┘ ║
╚═══════════════════════════════════════════════════════════════════════════════╝
│
▼
┌═══════════════════════════════════════════════════════════════════════════════┐
║ PHASE 2: TRAINING ║
║ (training.py) ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────────────────────────────────────────────────────────────┐ ║
║ │ GRID SEARCH OVER HYPERPARAMETERS │ ║
║ │ │ ║
║ │ for lag_time in [10, 20, 50, 100] ns: │ ║
║ │ for n_states in [3, 5, 7, 10]: │ ║
║ │ → Create experiment directory │ ║
║ │ → Train model │ ║
║ │ → Save best checkpoint │ ║
║ └─────────────────────────────────────────────────────────────────────────┘ ║
║ │ ║
║ ▼ ║
║ ┌─────────────────────────────────────────────────────────────────────────┐ ║
║ │ VAMPNet MODEL │ ║
║ │ │ ║
║ │ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ ┌────────────┐ │ ║
║ │ │ Embedding │ → │ Encoder │ → │ Classifier │ → │ Softmax │ │ ║
║ │ │ (MLP) │ │ (SchNet/ │ │ (MLP) │ │ Output │ │ ║
║ │ │ optional │ │ Meta) │ │ │ │ │ │ ║
║ │ └─────────────┘ └─────────────┘ └─────────────┘ └────────────┘ │ ║
║ │ │ ║
║ │ Input: PyG Graph (node_feat, edge_index, edge_attr) │ ║
║ │ Output: State probabilities [batch_size, n_states] │ ║
║ └─────────────────────────────────────────────────────────────────────────┘ ║
║ │ ║
║ ▼ ║
║ ┌─────────────────────────────────────────────────────────────────────────┐ ║
║ │ TRAINING LOOP │ ║
║ │ │ ║
║ │ for epoch in range(n_epochs): │ ║
║ │ for (x_t0, x_t1) in train_loader: # time-lagged pairs │ ║
║ │ χ_t0 = model(x_t0) # state probs at t=0 │ ║
║ │ χ_t1 = model(x_t1) # state probs at t=lag │ ║
║ │ loss = -VAMP_score(χ_t0, χ_t1) # maximize VAMP score │ ║
║ │ loss.backward() │ ║
║ │ optimizer.step() │ ║
║ │ │ ║
║ │ validate(test_loader) │ ║
║ │ early_stopping_check() │ ║
║ │ save_checkpoint_if_best() │ ║
║ └─────────────────────────────────────────────────────────────────────────┘ ║
║ │ ║
║ ┌─────────────────────────────────────────────────────────────────────────┐ ║
║ │ VAMP SCORE CALCULATION │ ║
║ │ │ ║
║ │ C₀₀ = E[χ(t)ᵀ χ(t)] # instantaneous covariance │ ║
║ │ C₀ₜ = E[χ(t)ᵀ χ(t+τ)] # cross-covariance │ ║
║ │ Cₜₜ = E[χ(t+τ)ᵀ χ(t+τ)] # lagged covariance │ ║
║ │ │ ║
║ │ VAMP-2 = ||C₀₀^(-1/2) C₀ₜ Cₜₜ^(-1/2)||²_F + 1 │ ║
║ └─────────────────────────────────────────────────────────────────────────┘ ║
║ │ ║
║ ▼ ║
║ ┌───────────────────────┐ ║
║ │ Trained Models │ ║
║ │ (best_model.pt) │ ║
║ └───────────────────────┘ ║
╚═══════════════════════════════════════════════════════════════════════════════╝
│
▼
┌═══════════════════════════════════════════════════════════════════════════════┐
║ PHASE 3: ANALYSIS ║
║ (analysis.py) ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║ ║
║ ┌─────────────────────────────────────────────────────────────────────────┐ ║
║ │ POST-TRAINING ANALYSIS │ ║
║ │ │ ║
║ │ 1. STATE PROBABILITIES │ ║
║ │ • Run inference on all frames │ ║
║ │ • Assign each frame to most probable state │ ║
║ │ │ ║
║ │ 2. TRANSITION ANALYSIS │ ║
║ │ • Calculate transition probability matrices │ ║
║ │ • Plot transition networks (state graph) │ ║
║ │ │ ║
║ │ 3. ATTENTION MAPS │ ║
║ │ • Extract encoder attention weights │ ║
║ │ • Compute per-state residue-residue attention │ ║
║ │ • Identify key interactions per state │ ║
║ │ │ ║
║ │ 4. STATE STRUCTURES │ ║
║ │ • Extract representative PDB structures per state │ ║
║ │ • Generate ensemble PDBs │ ║
║ │ • Color by attention values │ ║
║ │ │ ║
║ │ 5. IMPLIED TIMESCALES (ITS) │ ║
║ │ • Calculate: ITS = -τ / ln|λ| │ ║
║ │ • Plot ITS vs lag time │ ║
║ │ • Identify slow processes │ ║
║ │ │ ║
║ │ 6. CHAPMAN-KOLMOGOROV TESTS │ ║
║ │ • Validate Markov assumption │ ║
║ │ • Compare predicted vs observed transitions │ ║
║ └─────────────────────────────────────────────────────────────────────────┘ ║
║ │ ║
║ ▼ ║
║ ┌─────────────────────────────────────────────────────────────────────────┐ ║
║ │ OUTPUT ARTIFACTS │ ║
║ │ │ ║
║ │ analysis/ │ ║
║ │ ├── transition_matrix.png │ ║
║ │ ├── state_network.png │ ║
║ │ ├── attention_maps/ │ ║
║ │ │ ├── state_0_attention.png │ ║
║ │ │ ├── state_1_attention.png │ ║
║ │ │ └── ... │ ║
║ │ ├── structures/ │ ║
║ │ │ ├── state_0_ensemble.pdb │ ║
║ │ │ └── ... │ ║
║ │ ├── its_plot.png │ ║
║ │ └── ck_test.png │ ║
║ └─────────────────────────────────────────────────────────────────────────┘ ║
╚═══════════════════════════════════════════════════════════════════════════════╝
│
▼
┌────────────────────────────────┐
│ pipeline_summary.json │
│ ──────────────────────────── │
│ • Configuration used │
│ • Models trained │
│ • Analysis completed │
│ • Timestamp │
└────────────────────────────────┘
Current Implementation Status¶
| Component | Status | Notes |
|---|---|---|
| VAMPNetDataset | ✅ Complete | MD→PyG graphs with caching |
| VAMPNet Model | ✅ Complete | Embedding + Encoder + Classifier |
| SchNet Encoder | ✅ Complete | With attention mechanism |
| Meta Encoder | ⚠️ Partial | Config WIP, encoder exists |
| ML3 Encoder | ⚠️ Separate | Working independently, needs pipeline integration |
| VAMP Score | ✅ Complete | VAMP1, VAMP2, VAMPE modes |
| SoftmaxMLP Classifier | ✅ Complete | |
| Master Pipeline | ✅ Complete | 3-phase orchestration |
| Preparation Phase | ✅ Complete | Dataset creation |
| Training Phase | ✅ Complete | Multi-lag, multi-state grid |
| Analysis Phase | ✅ Complete | All visualizations |
| Caching | ✅ Complete | Hash-based dataset caching |
| ITS Analysis | ✅ Complete | Implied timescales |
| CK Tests | ✅ Complete | Chapman-Kolmogorov |
| Visualization | ✅ Complete | 2133 lines in plotting.py |
| Documentation | ⚠️ Minimal | Sphinx setup exists |
| Tests | ⚠️ Incomplete | Tests exist but not in CI |
Steps to Complete the Pipeline¶
Phase 1: Critical Path (Required for Publication)¶
1.1 Validate End-to-End Pipeline Execution¶
- [ ] Run complete pipeline on test dataset
- [ ] Verify all phases execute without errors
- [ ] Confirm all expected outputs are generated
- [ ] Test with multiple trajectory formats (.xtc, .dcd)
1.2 Fix Known Issues¶
- [ ] Address hardcoded
model.to('cuda')intraining.py:268(should use device variable) - [ ] Remove unused
from pymol.querying import distanceimport intraining.py:10 - [ ] Handle case when encoder is
Nonefor ML3 type increate_model() - [ ] Add proper error handling when model loading fails in analysis
1.3 Verify VAMP Score Calculation¶
- [ ] Validate covariance matrix calculations against reference implementation
- [ ] Test eigenvalue truncation/regularization strategies
- [ ] Confirm score maximization (higher = better) throughout codebase
1.4 Test Multi-Lag Pipeline¶
- [ ] Verify lag time compatibility check with trajectory timestep
- [ ] Test stride + lag_time combination edge cases
- [ ] Validate that different lag times produce different models
Phase 2: Robustness & Quality¶
2.1 Error Handling¶
- [ ] Add input validation for trajectory files (existence, format)
- [ ] Add configuration validation (compatible parameters)
- [ ] Improve error messages with actionable guidance
- [ ] Add graceful failure modes with partial results saved
2.2 Gradient Flow & Numerical Stability¶
- [ ] Review and document gradient workarounds (jitter, skip connections)
- [ ] Add gradient monitoring/logging during training
- [ ] Investigate and fix root causes of NaN outputs
- [ ] Add automatic learning rate adjustment on gradient explosion
2.3 Testing¶
- [ ] Create integration tests for full pipeline
- [ ] Add unit tests for VAMP score calculation
- [ ] Add unit tests for dataset creation
- [ ] Clean up existing tests in
testdata/andarea51_testing_grounds/ - [ ] Set up CI/CD with automated testing
Phase 3: Feature Completion¶
3.1 Complete Encoder Options¶
- [ ] Integrate ML3 encoder into pipeline (working separately, needs integration)
- [ ] Complete Meta encoder configuration (MetaConfig)
- [ ] Add encoder selection validation
3.2 Trajectory Handling¶
- [ ] Add
continuousflag for trajectory loading (currently all trajectories concatenated) - [ ] Implement non-continuous trajectory support (treat each file as separate trajectory)
- [ ] Add trajectory boundary detection and handling
3.3 State Count Selection¶
- [ ] Implement automatic n_states selection method
- [ ] Ensure comparable state counts across different lag times
- [ ] Add state count validation/recommendation system
3.4 Enhanced Analysis¶
- [ ] Add cross-validation for model selection
- [ ] Add ensemble model support
3.5 Visualization Improvements¶
- [ ] Add interactive plots (Plotly/Bokeh option)
- [ ] Add comparison plots across lag times
- [ ] Add trajectory projection onto learned states
- [ ] Add video/animation generation for state transitions
Phase 4: Documentation & Usability¶
4.1 Documentation¶
- [ ] Complete Sphinx documentation
- [ ] Add tutorial notebooks (Jupyter)
- [ ] Add example configurations for common proteins
- [ ] Document hyperparameter tuning guidelines
4.2 Report Generation¶
- [ ] Integrate HTML report generation (code exists, needs integration)
- [ ] Combine all analysis outputs into single shareable HTML file
- [ ] Add interactive elements to HTML report (collapsible sections, tooltips)
4.3 CLI Improvements¶
- [ ] Add
--dry-runoption to preview pipeline - [ ] Add
--resumeoption to continue failed runs - [ ] Add progress persistence for long-running jobs
- [ ] Add
--validate-onlymode for configuration checking
Technical Debt¶
High Priority (Blocking Issues)¶
| Issue | Location | Description | Suggested Fix |
|---|---|---|---|
| Hardcoded CUDA | training.py:268 |
model.to('cuda') ignores --cpu flag |
Use device variable from args |
| Unused import | training.py:10 |
from pymol.querying import distance never used |
Remove import |
| None encoder | training.py:194 |
ML3 encoder returns None, causes crash |
Import and instantiate GNNML3 from pygv/encoder/ml3.py |
| Device inconsistency | training.py:268,276 |
Model moved to CUDA before device is determined | Move model.to(device) after device determination |
| Broken imports | pygv/config/__init__.py:4 |
Imports MetaConfig, ML3Config which are commented out in base_config.py |
Either uncomment configs or remove from imports |
| Missing preset files | pygv/config/presets/ |
medium.py and large.py imported but don't exist |
Create these files or update imports |
Medium Priority (Code Quality)¶
| Issue | Location | Description | Suggested Fix |
|---|---|---|---|
| ~~Dual dataset files~~ | ~~vampnet_dataset.py, vampnet_dataset_with_AA.py~~ |
~~Two nearly identical files~~ | ✅ Consolidated: old files moved to legacy/, unified vampnet_dataset.py with use_amino_acid_encoding flag |
| One-hot node features | vampnet_dataset.py |
Creates N×N matrix for N atoms (memory inefficient) | Learned embeddings exist but had issues; document tradeoffs |
| Magic numbers | Various | Epsilon values, thresholds hardcoded | Move to configuration |
| NaN masking | vampnet.py |
NaN outputs replaced with zeros | Fix root cause of NaN generation |
| Commented code | analysis.py:266-272 |
TODO comment with dead code | Remove or implement |
| Batch size sensitivity | vamp_score_v0.py |
VAMP score varies with batch composition | Document limitation or implement batch-invariant version |
Low Priority (Improvements)¶
| Issue | Location | Description | Suggested Fix |
|---|---|---|---|
| WIP configs | base_config.py:135-169 |
MetaConfig and ML3Config commented out | Complete implementations |
| Empty viz module | viz/ |
Directory exists but is empty | Either populate or remove |
| Inconsistent naming | Various | Mix of n_states and n_states_list |
Standardize parameter names |
| Redundant analysis | training.py |
Pre-analysis duplicates analysis phase | Consider consolidating (see details below) |
| Type hints | Various | Incomplete type annotations | Add comprehensive type hints |
Pre-Analysis Duplication Details¶
Location of duplication:
In training.py:373-438 (after training completes):
# Line 375-382: analyze_vampnet_outputs()
# Line 418-426: run_ck_analysis()
# Line 430-437: analyze_implied_timescales()
In analysis.py:212-326 (Analysis phase):
# Line 212-218: analyze_vampnet_outputs()
# Line 223-228: plot_transition_probabilities()
# Plus all attention/structure analysis (lines 230-326)
Decision needed: Should training.py do quick validation (CK/ITS only) while analysis.py does comprehensive output? Or consolidate all analysis to one location?
Planned Features (Not Yet Implemented)¶
| Feature | Description | Priority |
|---|---|---|
| ~~Non-continuous trajectories~~ | ~~Add continuous flag to handle trajectories that aren't continuous.~~ |
✅ Done - continuous parameter in vampnet_dataset.py and BaseConfig |
| Automatic n_states selection | Find correct/comparable number of states for each timescale | High |
| Comparable state counts | Ensure consistent state definitions across different lag times | High |
| ~~Dataset encoding flag~~ | ~~Clean up dual dataset system (one-hot vs amino acid encoding).~~ | ✅ Done - unified vampnet_dataset.py with use_amino_acid_encoding parameter |
| Complete preset system | Add missing preset files (medium.py, large.py) and uncomment MetaConfig/ML3Config in base_config.py. Currently __init__.py imports them but they're commented out (will cause ImportError). |
Medium |
| ML3 pipeline integration | Integrate working ML3 encoder (pygv/encoder/ml3.py - GNNML3 class) into training pipeline. Currently training.py:194 returns None for ML3. |
Medium |
| HTML report generation | Combine all analysis outputs into single HTML file for sharing | Medium |
| Unit test cleanup | Existing tests need cleanup and CI integration | Medium |
Architecture Considerations¶
| Issue | Description | Suggested Approach |
|---|---|---|
| Encoder abstraction | No common base class for encoders | Create BaseEncoder abstract class |
| Config validation | No schema validation for configs | Add Pydantic or dataclass validation |
| Logging | Print statements instead of proper logging | Replace with logging module |
| Progress tracking | Manual print statements | Use unified progress tracking |
Recommended Priority Order¶
- Week 1: Fix high-priority technical debt (especially CUDA hardcoding)
- Week 2: Run and validate end-to-end pipeline on test data
- Week 3: Add critical error handling and input validation
- Week 4: Complete ML3/Meta encoder implementations
- Ongoing: Add tests incrementally as issues are fixed
Design Decisions (Resolved)¶
These questions have been clarified:
| Question | Decision |
|---|---|
| Graph Bidirectionality | Asymmetric k-NN graphs are intentional - edges should NOT automatically be bidirectional |
| Node Features | One-hot encoding is the starting point; learned embeddings had issues previously |
| Multiple Lag Times | Always train the full grid - behavior on different timescales is important |
| Encoder Priority | ML3 has priority (already working independently, needs pipeline integration) |
Open Questions¶
-
Validation Strategy: Should CK test and ITS analysis run during training (current) or only in analysis phase? (See Pre-Analysis Duplication Details above)
-
Memory Constraints: For large trajectories, should the pipeline support streaming/chunked processing?
n_states Selection Methods (Suggestions to Investigate)¶
Since graph2vec didn't work well for determining optimal state count, here are alternative approaches to consider:
| Method | Description | Pros | Cons |
|---|---|---|---|
| VAMP-2 Score Comparison | Train models with different n_states, compare validation VAMP-2 scores. Optimal n_states maximizes score without overfitting. | Direct optimization target; theoretically grounded | Requires training multiple models |
| Eigenvalue Gap Analysis | Analyze VAMP eigenvalues: large gap after k eigenvalues suggests k slow processes (states) | Fast; no retraining needed; MSM-standard approach | May not align with interpretable states |
| ITS Plateau Count | Count number of implied timescales that plateau (converge) across lag times | Physically meaningful; validated approach | Requires multiple lag time runs |
| Silhouette Score | Cluster state assignments, compute silhouette score for different k values | Measures cluster quality | May not capture kinetic relevance |
| Cross-Validation | Hold out trajectories, measure state assignment consistency | Tests generalization | Computationally expensive |
| Bayesian Model Selection | Use BIC/AIC to penalize model complexity | Automatic complexity penalty | Requires likelihood formulation |
Recommended approach: Start with eigenvalue gap analysis (fast, standard in MSM literature) combined with ITS plateau counting (physically interpretable). Use VAMP-2 score comparison for final validation.
File Reference¶
Key files for understanding the pipeline:
| File | Lines | Purpose |
|---|---|---|
pygv/pipe/master_pipeline.py |
431 | Pipeline orchestration |
pygv/pipe/preparation.py |
197 | Data preparation |
pygv/pipe/training.py |
449 | Model training |
pygv/pipe/analysis.py |
352 | Post-training analysis |
pygv/dataset/vampnet_dataset.py |
810 | MD trajectory → PyG graphs (unified with AA encoding & continuous flag) |
pygv/vampnet/vampnet.py |
1184 | VAMPNet model |
pygv/scores/vamp_score_v0.py |
341 | VAMP loss calculation |
pygv/encoder/schnet_wo_embed_v2.py |
300+ | SchNet encoder |
pygv/encoder/ml3.py |
414 | ML3/GNNML3 encoder (working, needs integration) |
pygv/encoder/meta_att.py |
- | Meta encoder |
pygv/config/base_config.py |
169 | Configuration classes |
pygv/config/__init__.py |
99 | Config registry and presets |
pygv/config/presets/small.py |
45 | Small preset configurations |
pygv/utils/plotting.py |
2133 | Visualization utilities |
run_pipeline.py |
12 | Entry point |
Generated: 2025-12-31