diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..e8f30b7 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,33 @@ +name: Tests + +on: + pull_request: + push: + branches: [main] + workflow_dispatch: + +jobs: + test-suite: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install deps + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Run unit, integration, and e2e tests + run: pytest -m "unit or integration or e2e" --junitxml=pytest-all.xml + + - name: Upload test + coverage artifacts + uses: actions/upload-artifact@v4 + with: + name: test-suite-reports + path: | + pytest-all.xml + coverage.xml diff --git a/.gitignore b/.gitignore index 3ed32e6..63809da 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,35 @@ +# Python cache and local environments __pycache__/ -.ipynb_checkpoints +.ipynb_checkpoints/ .env .envrc .venv env/ venv/ ENV/ + +# Editor settings .vscode/ + +# Local build and smoke-test output smoketest/ .build_pyz/ +build/ +dist/ +pip-wheel-metadata/ + +# Test runner and coverage artifacts +.pytest_cache/ +.pytest_tmp*/ +.coverage +.coverage.* +coverage.xml +htmlcov/ +pytest-*.xml + +# Project-generated local data/artifacts +*.egg-info/ +.eggs/ *.fasta *.csv *.xlsx @@ -18,4 +39,5 @@ smoketest/ *.txt *.pyz *.png -*.metadata \ No newline at end of file +*.metadata +*.json diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..3f29a2c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,65 @@ +[build-system] +requires = ["setuptools>=69", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "pepseqpred" +version = "1.0.0rc1" +description = "Residue-level epitope prediction pipeline for peptide/protein workflows." +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "numpy>=2.3,<3", + "pandas>=2.3,<3", + "torch>=2.4,<3", + "fair-esm==2.0.0", + "scikit-learn>=1.5,<2", + "optuna>=3.5,<5" +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-cov>=5.0", + "pytest-mock>=3.14", + "ruff>=0.6" +] + +[tool.pytest.ini_options] +minversion = "8.0" +addopts = "-ra --strict-markers --cov=pepseqpred --cov-report=term-missing --cov-report=xml --cov-fail-under=75" +testpaths = ["tests"] +pythonpath = ["src"] +markers = [ + "unit: fast isolated tests", + "integration: component interaction tests", + "e2e: end-to-end pipeline tests", + "slow: longer-running tests" +] + +[tool.coverage.run] +branch = true +source = ["pepseqpred"] +omit = [ + "tests/*" +] + +[tool.coverage.report] +show_missing = true +skip_empty = true +precision = 2 + +[project.scripts] +pepseqpred-esm = "pepseqpred.apps.esm_cli:main" +pepseqpred-labels = "pepseqpred.apps.labels_cli:main" +pepseqpred-predict = "pepseqpred.apps.prediction_cli:main" +pepseqpred-preprocess = "pepseqpred.apps.preprocess_cli:main" +pepseqpred-train-ffnn = "pepseqpred.apps.train_ffnn_cli:main" +pepseqpred-train-ffnn-optuna = "pepseqpred.apps.train_ffnn_optuna_cli:main" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] +include = ["pepseqpred*"] diff --git a/requirements.txt b/requirements.txt index dc1e4f2..89569ec 100644 Binary files a/requirements.txt and b/requirements.txt differ diff --git a/src/pepseqpred/apps/prediction_cli.py b/src/pepseqpred/apps/prediction_cli.py index 19ba751..68ad75f 100644 --- a/src/pepseqpred/apps/prediction_cli.py +++ b/src/pepseqpred/apps/prediction_cli.py @@ -213,7 +213,8 @@ def main() -> None: layer = esm_model.num_layers # load model from disk - checkpoint = torch.load(args.checkpoint, map_location="cpu") + checkpoint = torch.load( + args.checkpoint, map_location="cpu", weights_only=True) cli_model_cfg = _build_cli_model_config(args) psp_model, model_cfg, model_cfg_src = build_model_from_checkpoint( checkpoint, diff --git a/src/pepseqpred/apps/preprocess_cli.py b/src/pepseqpred/apps/preprocess_cli.py index 3dbc415..ff9bf01 100644 --- a/src/pepseqpred/apps/preprocess_cli.py +++ b/src/pepseqpred/apps/preprocess_cli.py @@ -75,7 +75,7 @@ def main() -> None: help="Prefix for subject column labels in z-score reactivity data.") parser.add_argument("--save", action="store_true", - dest="save_path", + dest="save", default=False, help="Store results in a .tsv output file to be used in model training.") diff --git a/src/pepseqpred/core/labels/builder.py b/src/pepseqpred/core/labels/builder.py index 0fb498c..1aabcb6 100644 --- a/src/pepseqpred/core/labels/builder.py +++ b/src/pepseqpred/core/labels/builder.py @@ -187,7 +187,7 @@ def _find_pt_path(self, protein_id: str) -> Path: def _load_embedding_length(self, protein_id: str) -> int: """Finds .pt path, loads embedding as tensor, and returns the length (number of amino acids).""" pt_path = self._find_pt_path(protein_id) - embedding = torch.load(pt_path, map_location="cpu") + embedding = torch.load(pt_path, map_location="cpu", weights_only=True) if not isinstance(embedding, torch.Tensor) or embedding.dim() != 2: raise ValueError( f"Expected 2D tensor embedding for '{protein_id}', got {type(embedding)}") diff --git a/src/pepseqpred/core/train/metrics.py b/src/pepseqpred/core/train/metrics.py index e6d1c29..03723aa 100644 --- a/src/pepseqpred/core/train/metrics.py +++ b/src/pepseqpred/core/train/metrics.py @@ -6,7 +6,8 @@ from labels, predictions, and probabilities. """ -from typing import Dict, Any +from typing import Dict, Any, Union, Sequence +import numpy as np import torch from sklearn.metrics import (precision_recall_fscore_support, average_precision_score, @@ -16,7 +17,16 @@ auc) -def compute_eval_metrics(y_true: torch.Tensor, y_pred: torch.Tensor, y_prob: torch.Tensor) -> Dict[str, Any]: +ArrayLike1D = Union[torch.Tensor, np.ndarray, Sequence[float], Sequence[int]] + + +def _to_numpy_1d(x: ArrayLike1D) -> np.ndarray: + if isinstance(x, torch.Tensor): + return x.detach().cpu().numpy().reshape(-1) + return np.asarray(x).reshape(-1) + + +def compute_eval_metrics(y_true: ArrayLike1D, y_pred: ArrayLike1D, y_prob: ArrayLike1D) -> Dict[str, Any]: """ Computes evaluation metrics given true lables, predicted labels, and predicted probabilities. @@ -36,31 +46,48 @@ def compute_eval_metrics(y_true: torch.Tensor, y_pred: torch.Tensor, y_prob: tor """ metrics: Dict[str, Any] = {} - # calculate precesion, recall, f1, and mcc + y_true_np = _to_numpy_1d(y_true).astype(np.int64, copy=False) + y_pred_np = _to_numpy_1d(y_pred).astype(np.int64, copy=False) + y_prob_np = _to_numpy_1d(y_prob).astype(np.float64, copy=False) + + # calculate precision, recall, and f1 precision, recall, f1, _ = precision_recall_fscore_support( - y_true, y_pred, average="binary", zero_division=0) + y_true_np, y_pred_np, average="binary", zero_division=0) metrics["precision"] = float(precision) metrics["recall"] = float(recall) metrics["f1"] = float(f1) - metrics["mcc"] = matthews_corrcoef(y_true, y_pred) + + # Avoid sklearn warning when both tensors contain only one shared label. + if np.unique(np.concatenate((y_true_np, y_pred_np))).size < 2: + metrics["mcc"] = 0.0 + else: + metrics["mcc"] = float(matthews_corrcoef(y_true_np, y_pred_np)) + + has_both_classes = np.unique(y_true_np).size >= 2 + if not has_both_classes: + only_class = int(y_true_np[0]) if y_true_np.size > 0 else 0 + metrics["auc"] = float("nan") + metrics["pr_auc"] = 1.0 if only_class == 1 else 0.0 + metrics["auc10"] = float("nan") + return metrics # ROC AUC try: - metrics["auc"] = float(roc_auc_score(y_true, y_prob)) + metrics["auc"] = float(roc_auc_score(y_true_np, y_prob_np)) except Exception: metrics["auc"] = float("nan") # PR AUC try: - metrics["pr_auc"] = float(average_precision_score(y_true, y_prob)) + metrics["pr_auc"] = float(average_precision_score(y_true_np, y_prob_np)) except Exception: metrics["pr_auc"] = float("nan") # AUC10 calculation] try: - fpr, tpr, _ = roc_curve(y_true, y_prob) + fpr, tpr, _ = roc_curve(y_true_np, y_prob_np) mask = fpr <= 0.10 if mask.sum() >= 2: metrics["auc10"] = float(auc(fpr[mask], tpr[mask]) / 0.10) diff --git a/src/pepseqpred/core/train/weights.py b/src/pepseqpred/core/train/weights.py index 30d2924..167511c 100644 --- a/src/pepseqpred/core/train/weights.py +++ b/src/pepseqpred/core/train/weights.py @@ -100,5 +100,15 @@ def pos_weight_from_label_shards(label_shards: List[Path]) -> float: f"{shard} missing class_stats (rebuild labels with --calc-pos-weight)" ) total_pos += int(stats["pos_count"]) - total_neg += int(stats["neg_counts"]) + # Prefer the canonical key written by labels.builder, but support + # legacy/pluralized payloads for backwards compatibility. + if "neg_count" in stats: + total_neg += int(stats["neg_count"]) + elif "neg_counts" in stats: + total_neg += int(stats["neg_counts"]) + else: + raise ValueError( + f"{shard} class_stats missing negative count key " + "(expected 'neg_count' or 'neg_counts')" + ) return float(total_neg / max(1, total_pos)) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..13a7078 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +from pathlib import Path +import pytest +import torch + + +@pytest.fixture +def training_artifacts(tmp_path: Path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir(parents=True, exist_ok=True) + label_shard = tmp_path / "labels_000.pt" + + labels = {} + pos = 0 + neg = 0 + + for protein_id, family in [ + ("P001", "111"), ("P002", "111"), ("P003", "222"), ("P004", "222") + ]: + x = torch.randn(6, 4, dtype=torch.float32) + torch.save(x, emb_dir / f"{protein_id}-{family}.pt") + + y = torch.tensor([1, 0, 0, 1, 0, 0], dtype=torch.float32) + labels[protein_id] = y + pos += int((y == 1).sum().item()) + neg += int((y == 0).sum().item()) + + payload = { + "labels": labels, + "class_stats": { + "pos_count": pos, + "neg_count": neg + } + } + torch.save(payload, label_shard) + + return {"embedding_dir": emb_dir, "label_shard": label_shard} diff --git a/tests/e2e/test_train_to_predict_e2e.py b/tests/e2e/test_train_to_predict_e2e.py new file mode 100644 index 0000000..87806c3 --- /dev/null +++ b/tests/e2e/test_train_to_predict_e2e.py @@ -0,0 +1,137 @@ +import os +from pathlib import Path +import subprocess +import sys +import types +import pytest +import torch +import pepseqpred.apps.prediction_cli as prediction_cli + +pytestmark = [pytest.mark.e2e, pytest.mark.slow] + + +class FakeAlphabet: + def get_batch_converter(self): + def _batch_converter(pairs): + labels = [name for name, _seq in pairs] + seqs = [seq for _name, seq in pairs] + max_len = max((len(seq) for seq in seqs), default=0) + tokens = torch.zeros((len(seqs), max_len + 2), dtype=torch.long) + for i, seq in enumerate(seqs): + seq_len = len(seq) + tokens[i, 1:1 + seq_len] = 1 + tokens[i, 1 + seq_len] = 2 + return labels, seqs, tokens + + return _batch_converter + + +class FakeESMModel: + num_layers = 1 + + def eval(self): + return self + + def to(self, _device): + return self + + def __call__(self, batch_tokens, repr_layers, return_contacts=False): + _ = return_contacts + batch_size, token_len = batch_tokens.shape + # append_seq_len -> final emb dim is 4 (matches training fixture) + rep_dim = 3 + reps = torch.ones((batch_size, token_len, rep_dim), + dtype=torch.float32) + return {"representations": {repr_layers[0]: reps}} + + +def test_train_then_predict_e2e(training_artifacts, tmp_path: Path, monkeypatch): + repo_root = Path(__file__).resolve().parents[2] + src_path = str(repo_root / "src") + + env = os.environ.copy() + current_pythonpath = env.get("PYTHONPATH") + env["PYTHONPATH"] = ( + src_path + if not current_pythonpath + else f"{src_path}{os.pathsep}{current_pythonpath}" + ) + + save_dir = tmp_path / "out" + + train_cmd = [ + sys.executable, + "-m", + "pepseqpred.apps.train_ffnn_cli", + "--embedding-dirs", + str(training_artifacts["embedding_dir"]), + "--label-shards", + str(training_artifacts["label_shard"]), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--val-frac", + "0.5", + "--split-seeds", + "11", + "--train-seeds", + "101", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv"), + ] + + proc = subprocess.run( + train_cmd, + capture_output=True, + text=True, + cwd=repo_root, + env=env, + ) + assert proc.returncode == 0, proc.stderr + + run_dirs = sorted(save_dir.glob("run_*")) + assert run_dirs + checkpoint = run_dirs[0] / "fully_connected.pt" + assert checkpoint.exists() + + fake_pretrained = types.SimpleNamespace( + fake_model=lambda: (FakeESMModel(), FakeAlphabet()) + ) + monkeypatch.setattr(prediction_cli.esm, "pretrained", fake_pretrained) + + fasta = tmp_path / "input.fasta" + fasta.write_text(">protein_e2e\nACDEFG\n", encoding="utf-8") + output_fasta = tmp_path / "predictions.fasta" + + monkeypatch.setattr( + sys, + "argv", + [ + "prediction_cli.py", + str(checkpoint), + str(fasta), + "--output-fasta", + str(output_fasta), + "--model-name", + "fake_model", + "--threshold", + "0.5", + ], + ) + + prediction_cli.main() + + lines = [line.strip() for line in output_fasta.read_text( + encoding="utf-8").splitlines() if line.strip()] + assert lines[0] == ">protein_e2e" + assert len(lines[1]) == 6 + assert set(lines[1]).issubset({"0", "1"}) diff --git a/tests/integration/test_prediction_cli_smoke.py b/tests/integration/test_prediction_cli_smoke.py new file mode 100644 index 0000000..8a1420b --- /dev/null +++ b/tests/integration/test_prediction_cli_smoke.py @@ -0,0 +1,107 @@ +import sys +import types +from pathlib import Path +import pytest +import torch +import pepseqpred.apps.prediction_cli as prediction_cli +from pepseqpred.core.models.ffnn import PepSeqFFNN + +pytestmark = pytest.mark.integration + + +class FakeAlphabet: + def get_batch_converter(self): + def _batch_converter(pairs): + labels = [name for name, _seq in pairs] + seqs = [seq for _name, seq in pairs] + max_len = max((len(seq) for seq in seqs), default=0) + tokens = torch.zeros((len(seqs), max_len + 2), dtype=torch.long) + for i, seq in enumerate(seqs): + seq_len = len(seq) + tokens[i, 1:1 + seq_len] = 1 + tokens[i, 1 + seq_len] = 2 + return labels, seqs, tokens + + return _batch_converter + + +class FakeESMModel: + num_layers = 1 + + def eval(self): + return self + + def to(self, _device): + return self + + def __call__(self, batch_tokens, repr_layers, return_contacts=False): + _ = return_contacts + batch_size, token_len = batch_tokens.shape + rep_dim = 3 # append_seq_len -> final emb dim is 4 + reps = torch.ones((batch_size, token_len, rep_dim), + dtype=torch.float32) + return {"representations": {repr_layers[0]: reps}} + + +def _write_checkpoint(path: Path) -> None: + model = PepSeqFFNN( + emb_dim=4, + hidden_sizes=(3,), + dropouts=(0.0,), + use_layer_norm=False, + use_residual=False, + num_classes=1, + ) + for param in model.parameters(): + torch.nn.init.constant_(param, 0.0) + + torch.save( + { + "model_state_dict": model.state_dict(), + "metrics": {"threshold": 0.5}, + }, + path, + ) + + +def test_prediction_cli_smoke(monkeypatch, tmp_path: Path): + fake_pretrained = types.SimpleNamespace( + fake_model=lambda: (FakeESMModel(), FakeAlphabet()) + ) + monkeypatch.setattr(prediction_cli.esm, "pretrained", fake_pretrained) + + checkpoint = tmp_path / "model.pt" + _write_checkpoint(checkpoint) + + fasta = tmp_path / "input.fasta" + fasta.write_text( + ">protein_1\nACDEFG\n>protein_2\nLMNPQ\n", encoding="utf-8") + + output_fasta = tmp_path / "predictions.fasta" + + monkeypatch.setattr( + sys, + "argv", + [ + "prediction_cli.py", + str(checkpoint), + str(fasta), + "--output-fasta", + str(output_fasta), + "--model-name", + "fake_model", + "--threshold", + "0.5", + ], + ) + + prediction_cli.main() + + lines = [line.strip() for line in output_fasta.read_text( + encoding="utf-8").splitlines() if line.strip()] + assert lines[0] == ">protein_1" + assert len(lines[1]) == 6 + assert set(lines[1]).issubset({"0", "1"}) + assert lines[2] == ">protein_2" + assert len(lines[3]) == 5 + assert set(lines[3]).issubset({"0", "1"}) diff --git a/tests/integration/test_train_clis_inprocess.py b/tests/integration/test_train_clis_inprocess.py new file mode 100644 index 0000000..897c575 --- /dev/null +++ b/tests/integration/test_train_clis_inprocess.py @@ -0,0 +1,106 @@ +import sys +from pathlib import Path +import pytest +import pepseqpred.apps.train_ffnn_cli as train_cli +import pepseqpred.apps.train_ffnn_optuna_cli as optuna_cli + +pytestmark = pytest.mark.integration + + +def test_train_ffnn_cli_main_inprocess(training_artifacts, tmp_path: Path, monkeypatch): + save_dir = tmp_path / "train_out" + + monkeypatch.setattr( + sys, + "argv", + [ + "train_ffnn_cli.py", + "--embedding-dirs", + str(training_artifacts["embedding_dir"]), + "--label-shards", + str(training_artifacts["label_shard"]), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--val-frac", + "0.5", + "--split-seeds", + "11", + "--train-seeds", + "101", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv") + ] + ) + + train_cli.main() + + run_dirs = list(save_dir.glob("run_*")) + assert run_dirs + assert (run_dirs[0] / "fully_connected.pt").exists() + assert (save_dir / "runs.csv").exists() + assert (save_dir / "multi_run_summary.json").exists() + + +@pytest.mark.slow +def test_train_ffnn_optuna_cli_main_inprocess( + training_artifacts, tmp_path: Path, monkeypatch +): + save_dir = tmp_path / "optuna_out" + csv_path = save_dir / "trials.csv" + + monkeypatch.setattr( + sys, + "argv", + [ + "train_ffnn_optuna_cli.py", + "--embedding-dirs", + str(training_artifacts["embedding_dir"]), + "--label-shards", + str(training_artifacts["label_shard"]), + "--n-trials", + "1", + "--epochs", + "1", + "--val-frac", + "0.5", + "--subset", + "4", + "--batch-sizes", + "2", + "--num-workers", + "0", + "--metric", + "auc", + "--arch-mode", + "flat", + "--depth-min", + "1", + "--depth-max", + "1", + "--width-min", + "64", + "--width-max", + "64", + "--save-path", + str(save_dir), + "--csv-path", + str(csv_path), + "--study-name", + "test_smoke" + ] + ) + + optuna_cli.main() + + assert (save_dir / "best_trial.json").exists() + assert csv_path.exists() diff --git a/tests/integration/test_train_ffnn_cli_smoke.py b/tests/integration/test_train_ffnn_cli_smoke.py new file mode 100644 index 0000000..856b304 --- /dev/null +++ b/tests/integration/test_train_ffnn_cli_smoke.py @@ -0,0 +1,60 @@ +import os +from pathlib import Path +import subprocess +import sys +import pytest + +pytestmark = pytest.mark.integration + + +def test_train_ffnn_cli_smoke(training_artifacts, tmp_path): + repo_root = Path(__file__).resolve().parents[2] + src_path = str(repo_root / "src") + env = os.environ.copy() + current_pythonpath = env.get("PYTHONPATH") + env["PYTHONPATH"] = ( + src_path + if not current_pythonpath + else f"{src_path}{os.pathsep}{current_pythonpath}" + ) + + save_dir = tmp_path / "out" + cmd = [ + sys.executable, + "-m", + "pepseqpred.apps.train_ffnn_cli", + "--embedding-dirs", + str(training_artifacts["embedding_dir"]), + "--label-shards", + str(training_artifacts["label_shard"]), + "--epochs", + "1", + "--batch-size", + "2", + "--num-workers", + "0", + "--hidden-sizes", + "8", + "--dropouts", + "0.1", + "--val-frac", + "0.5", + "--split-seeds", + "11", + "--train-seeds", + "101", + "--save-path", + str(save_dir), + "--results-csv", + str(save_dir / "runs.csv"), + ] + proc = subprocess.run( + cmd, capture_output=True, text=True, cwd=repo_root, env=env + ) + assert proc.returncode == 0, proc.stderr + + run_dirs = list(save_dir.glob("run_*")) + assert run_dirs + assert (run_dirs[0] / "fully_connected.pt").exists() + assert (save_dir / "runs.csv").exists() + assert (save_dir / "multi_run_summary.json").exists() diff --git a/tests/unit/apps/test_cli_wrappers.py b/tests/unit/apps/test_cli_wrappers.py new file mode 100644 index 0000000..ebef2d2 --- /dev/null +++ b/tests/unit/apps/test_cli_wrappers.py @@ -0,0 +1,169 @@ +import argparse +import logging +from pathlib import Path +import pandas as pd +import pytest +import pepseqpred.apps.esm_cli as esm_cli +import pepseqpred.apps.labels_cli as labels_cli +import pepseqpred.apps.preprocess_cli as preprocess_cli + +pytestmark = pytest.mark.unit + + +def test_labels_cli_invokes_builder(monkeypatch, tmp_path: Path): + captured = {} + + class DummyBuilder: + def __init__(self, **kwargs): + captured["init"] = kwargs + + def build(self, save_path): + captured["save_path"] = save_path + return {} + + ns = argparse.Namespace( + meta_path=tmp_path / "meta.tsv", + save_path=tmp_path / "labels.pt", + emb_dirs=[tmp_path / "emb"], + restrict_to_embeddings=False, + calc_pos_weight=True, + embedding_key_delim="-" + ) + + monkeypatch.setattr( + labels_cli.argparse.ArgumentParser, "parse_args", lambda self: ns + ) + monkeypatch.setattr( + labels_cli, + "setup_logger", + lambda **kwargs: logging.getLogger("labels_cli_test") + ) + monkeypatch.setattr(labels_cli, "ProteinLabelBuilder", DummyBuilder) + + labels_cli.main() + + assert captured["init"]["meta_path"] == ns.meta_path + assert captured["save_path"] == ns.save_path + + +def test_preprocess_cli_invokes_preprocess(monkeypatch, tmp_path: Path): + captured = {} + + def fake_preprocess(*args, **kwargs): + captured["args"] = args + captured["kwargs"] = kwargs + return pd.DataFrame([{"CodeName": "x"}]) + + ns = argparse.Namespace( + meta_file=tmp_path / "meta.tsv", + z_file=tmp_path / "z.tsv", + fname_col="FullName", + code_col="CodeName", + is_epi_z_min=20.0, + is_epi_min_subs=4, + not_epi_z_max=10.0, + not_epi_max_subs=0, + subject_prefix="VW_", + save=True + ) + + monkeypatch.setattr( + preprocess_cli.argparse.ArgumentParser, "parse_args", lambda self: ns + ) + monkeypatch.setattr( + preprocess_cli, + "setup_logger", + lambda **kwargs: logging.getLogger("preprocess_cli_test") + ) + monkeypatch.setattr(preprocess_cli, "preprocess", fake_preprocess) + + preprocess_cli.main() + + assert captured["kwargs"]["fname_col"] == "FullName" + assert captured["kwargs"]["code_col"] == "CodeName" + assert captured["kwargs"]["save_path"] is not None + + +def test_esm_cli_id_mode_invokes_embedding_pipeline(monkeypatch, tmp_path: Path): + captured = {} + + def fake_esm_embeddings_from_fasta(*args, **kwargs): + captured["kwargs"] = kwargs + return pd.DataFrame([{"id": "P1"}]), [] + + ns = argparse.Namespace( + log_dir=Path("logs"), + log_level="INFO", + log_json=False, + per_seq_dir=Path("artifacts/pts"), + idx_csv_path=Path("artifacts/index.csv"), + out_dir=tmp_path, + fasta_file=tmp_path / "in.fasta", + metadata_file=None, + metadata_name_col="Name", + metadata_family_col="Family", + id_col="ID", + seq_col="Sequence", + embedding_key_mode="id", + key_delimiter="-", + model_name="fake_model", + max_tokens=16, + batch_size=2, + num_shards=1, + shard_id=0 + ) + + monkeypatch.setattr(esm_cli.argparse.ArgumentParser, + "parse_args", lambda self: ns) + monkeypatch.setattr( + esm_cli, "setup_logger", lambda **kwargs: logging.getLogger( + "esm_cli_test") + ) + monkeypatch.setattr( + esm_cli, + "read_fasta", + lambda _p: pd.DataFrame([{"ID": "P1", "Sequence": "ACD"}]) + ) + monkeypatch.setattr(esm_cli, "esm_embeddings_from_fasta", + fake_esm_embeddings_from_fasta) + monkeypatch.setattr(esm_cli.torch.cuda, "is_available", lambda: False) + monkeypatch.setattr(esm_cli.torch.cuda, "device_count", lambda: 0) + + esm_cli.main() + + assert captured["kwargs"]["key_mode"] == "id" + assert captured["kwargs"]["id_col"] == "ID" + + +def test_esm_cli_id_family_requires_metadata(monkeypatch, tmp_path: Path): + ns = argparse.Namespace( + log_dir=Path("logs"), + log_level="INFO", + log_json=False, + per_seq_dir=Path("artifacts/pts"), + idx_csv_path=Path("artifacts/index.csv"), + out_dir=tmp_path, + fasta_file=tmp_path / "in.fasta", + metadata_file=None, + metadata_name_col="Name", + metadata_family_col="Family", + id_col="ID", + seq_col="Sequence", + embedding_key_mode="id-family", + key_delimiter="-", + model_name="fake_model", + max_tokens=16, + batch_size=2, + num_shards=1, + shard_id=0 + ) + + monkeypatch.setattr(esm_cli.argparse.ArgumentParser, + "parse_args", lambda self: ns) + monkeypatch.setattr( + esm_cli, "setup_logger", lambda **kwargs: logging.getLogger( + "esm_cli_test") + ) + + with pytest.raises(ValueError, match="Metadata file is required"): + esm_cli.main() diff --git a/tests/unit/apps/test_prediction_cli_config.py b/tests/unit/apps/test_prediction_cli_config.py new file mode 100644 index 0000000..494c7a0 --- /dev/null +++ b/tests/unit/apps/test_prediction_cli_config.py @@ -0,0 +1,60 @@ +import argparse +import pytest +from pepseqpred.apps.prediction_cli import _build_cli_model_config + +pytestmark = pytest.mark.unit + + +def _args(**overrides): + base = { + "emb_dim": None, + "hidden_sizes": None, + "dropouts": None, + "use_layer_norm": None, + "use_residual": None, + "num_classes": None + } + base.update(overrides) + return argparse.Namespace(**base) + + +def test_build_cli_model_config_returns_none_if_no_explicit_flags(): + assert _build_cli_model_config(_args()) is None + + +def test_build_cli_model_config_requires_all_fields_when_any_explicit(): + with pytest.raises(ValueError, match="provide all required values"): + _build_cli_model_config(_args(emb_dim=4)) + + +def test_build_cli_model_config_rejects_length_mismatch(): + with pytest.raises(ValueError, match="same length"): + _build_cli_model_config( + _args( + emb_dim=4, + hidden_sizes="8,4", + dropouts="0.1", + use_layer_norm=True, + use_residual=False, + num_classes=1 + ) + ) + + +def test_build_cli_model_config_happy_path(): + cfg = _build_cli_model_config( + _args( + emb_dim=4, + hidden_sizes="8,4", + dropouts="0.1,0.2", + use_layer_norm=True, + use_residual=False, + num_classes=1 + ) + ) + assert cfg.emb_dim == 4 + assert cfg.hidden_sizes == (8, 4) + assert cfg.dropouts == (0.1, 0.2) + assert cfg.use_layer_norm is True + assert cfg.use_residual is False + assert cfg.num_classes == 1 diff --git a/tests/unit/core/data/test_proteindataset.py b/tests/unit/core/data/test_proteindataset.py new file mode 100644 index 0000000..bf28162 --- /dev/null +++ b/tests/unit/core/data/test_proteindataset.py @@ -0,0 +1,57 @@ +from pathlib import Path +import pytest +import torch +from pepseqpred.core.data.proteindataset import ProteinDataset, _iter_windows, pad_collate + +pytestmark = pytest.mark.unit + + +def test_iter_windows_disabled(): + assert list(_iter_windows( + length=5, window_size=None, stride=1)) == [(0, 5)] + + +def test_pad_collate_shapes(): + x1 = torch.ones(2, 3) + y1 = torch.tensor([1.0, 0.0]) + m1 = torch.tensor([1, 1]) + + x2 = torch.ones(4, 3) + y2 = torch.tensor([1.0, 0.0, 1.0, 0.0]) + m2 = torch.tensor([1, 1, 1, 0]) + + x, y, m = pad_collate([(x1, y1, m1), (x2, y2, m2)]) + assert x.shape == (2, 4, 3) + assert y.shape == (2, 4) + assert m.shape == (2, 4) + + +def test_dataset_masks_uncertain_and_padding(tmp_path: Path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir() + shard = tmp_path / "labels.pt" + + torch.save(torch.randn(5, 4), emb_dir / "P001-111.pt") + labels = { + "P001": torch.tensor( + [[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1]], + dtype=torch.uint8 + ) + } + torch.save({"labels": labels, "class_stats": { + "pos_count": 2, "neg_counts": 2}}, shard) + + ds = ProteinDataset( + embedding_dirs=[emb_dir], + label_shards=[shard], + window_size=4, + stride=4, + collapse_labels=True, + pad_last_window=True + ) + samples = list(ds) + assert len(samples) == 2 + _, _, m0 = samples[0] + _, _, m1 = samples[1] + assert m0.tolist() == [1, 0, 1, 1] + assert m1.tolist() == [1, 0, 0, 0] diff --git a/tests/unit/core/data/test_proteindataset_edge_cases.py b/tests/unit/core/data/test_proteindataset_edge_cases.py new file mode 100644 index 0000000..8e53a99 --- /dev/null +++ b/tests/unit/core/data/test_proteindataset_edge_cases.py @@ -0,0 +1,126 @@ +from pathlib import Path +import pytest +import torch +from pepseqpred.core.data.proteindataset import ( + ProteinDataset, + _build_embedding_index, + _build_label_index, + _slice_ids_contiguous +) + +pytestmark = pytest.mark.unit + + +def _write_embedding(path: Path, length: int = 5, dim: int = 4) -> None: + torch.save(torch.randn(length, dim, dtype=torch.float32), path) + + +def _write_label_shard(path: Path, labels: dict[str, torch.Tensor]) -> None: + pos = sum(int((y == 1).sum().item()) for y in labels.values()) + neg = sum(int((y == 0).sum().item()) for y in labels.values()) + torch.save( + { + "labels": labels, + "class_stats": {"pos_count": pos, "neg_count": neg} + }, + path + ) + + +def test_slice_ids_contiguous_balanced_chunks(): + ids = ["a", "b", "c", "d", "e", "f", "g"] + assert _slice_ids_contiguous(ids, worker_id=0, num_workers=3) == [ + "a", "b", "c"] + assert _slice_ids_contiguous(ids, worker_id=1, num_workers=3) == ["d", "e"] + assert _slice_ids_contiguous(ids, worker_id=2, num_workers=3) == ["f", "g"] + + +def test_build_label_index_requires_labels_key(tmp_path: Path): + bad_shard = tmp_path / "bad.pt" + torch.save({"not_labels": {}}, bad_shard) + + with pytest.raises(TypeError, match="'labels' key"): + _build_label_index([bad_shard]) + + +def test_build_embedding_index_detects_duplicate_ids(tmp_path: Path): + d1 = tmp_path / "emb1" + d2 = tmp_path / "emb2" + d1.mkdir() + d2.mkdir() + + _write_embedding(d1 / "P001-111.pt") + _write_embedding(d2 / "P001-111.pt") + + with pytest.raises(ValueError, match="Duplicate embedding"): + _build_embedding_index([d1, d2]) + + +def test_dataset_rejects_invalid_label_cache_mode(): + with pytest.raises(ValueError, match="label_cache_mode"): + ProteinDataset( + embedding_dirs=[], + label_shards=[], + embedding_index={}, + label_index={}, + label_cache_mode="invalid" # type: ignore[arg-type] + ) + + +def test_dataset_raises_on_embedding_label_length_mismatch(tmp_path: Path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir() + shard = tmp_path / "labels.pt" + + _write_embedding(emb_dir / "P001-111.pt", length=5, dim=4) + _write_label_shard(shard, {"P001": torch.tensor( + [1, 0, 1, 0], dtype=torch.float32)}) + + ds = ProteinDataset( + embedding_dirs=[emb_dir], + label_shards=[shard], + window_size=None, + stride=1 + ) + + with pytest.raises(ValueError, match="same size at dim 0"): + list(ds) + + +def test_label_cache_mode_all_matches_current_outputs(tmp_path: Path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir() + shard1 = tmp_path / "labels_1.pt" + shard2 = tmp_path / "labels_2.pt" + + _write_embedding(emb_dir / "P001-111.pt", length=5, dim=4) + _write_embedding(emb_dir / "P002-222.pt", length=4, dim=4) + + _write_label_shard(shard1, {"P001": torch.tensor( + [1, 0, 1, 0, 0], dtype=torch.float32)}) + _write_label_shard(shard2, {"P002": torch.tensor( + [0, 1, 0, 1], dtype=torch.float32)}) + + ds_current = ProteinDataset( + embedding_dirs=[emb_dir], + label_shards=[shard1, shard2], + window_size=None, + stride=1, + label_cache_mode="current" + ) + ds_all = ProteinDataset( + embedding_dirs=[emb_dir], + label_shards=[shard1, shard2], + window_size=None, + stride=1, + label_cache_mode="all" + ) + + out_current = list(ds_current) + out_all = list(ds_all) + + assert len(out_current) == len(out_all) == 2 + for (x1, y1, m1), (x2, y2, m2) in zip(out_current, out_all): + assert torch.equal(x1, x2) + assert torch.equal(y1, y2) + assert torch.equal(m1, m2) diff --git a/tests/unit/core/embeddings/test_esm2.py b/tests/unit/core/embeddings/test_esm2.py new file mode 100644 index 0000000..82a8c42 --- /dev/null +++ b/tests/unit/core/embeddings/test_esm2.py @@ -0,0 +1,166 @@ +import logging +import types +from pathlib import Path +import pandas as pd +import pytest +import torch +import pepseqpred.core.embeddings.esm2 as esm2 + +pytestmark = pytest.mark.unit + + +class FakeAlphabet: + def get_batch_converter(self): + def _convert(pairs): + labels = [name for name, _ in pairs] + seqs = [seq for _, seq in pairs] + max_len = max((len(s) for s in seqs), default=0) + tokens = torch.zeros((len(seqs), max_len + 2), dtype=torch.long) + for i, seq in enumerate(seqs): + seq_len = len(seq) + tokens[i, 1:1 + seq_len] = 1 + tokens[i, 1 + seq_len] = 2 + return labels, seqs, tokens + + return _convert + + +class FakeModel(torch.nn.Module): + def __init__(self, embed_dim: int = 3, expose_embed_dim: bool = True): + super().__init__() + self.p = torch.nn.Parameter(torch.zeros(1)) + self._embed_dim = embed_dim + if expose_embed_dim: + self.embed_dim = embed_dim + + def forward(self, batch_tokens, repr_layers, return_contacts=False): + _ = return_contacts + batch_size, token_len = batch_tokens.shape + rep = torch.ones( + (batch_size, token_len, self._embed_dim), + dtype=torch.float32, + device=batch_tokens.device + ) + return {"representations": {repr_layers[0]: rep}} + + +def test_clean_seq_token_batches_and_append_len(): + assert esm2.clean_seq("acdx-*\n") == "ACDX" + + batches = list( + esm2.token_packed_batches( + [("a", "A" * 3), ("b", "B" * 3), ("c", "C" * 9)], + max_tokens=10, + max_tokens_per_seq=1022 + ) + ) + assert len(batches) == 2 + assert [x[0] for x in batches[0]] == ["a", "b"] + + arr = torch.ones((5, 3), dtype=torch.float32).numpy() + out = esm2.append_seq_len(arr, 5) + assert out.shape == (5, 4) + assert (out[:, -1] == 5).all() + + +def test_compute_window_embedding_cpu_paths(): + token = torch.tensor([[0, 1, 1, 1, 1, 1, 2]], dtype=torch.long) + + model_1 = FakeModel(embed_dim=3, expose_embed_dim=True) + out_1 = esm2.compute_window_embedding( + token, model_1, layer=1, device="cpu", window_size=3, stride=2 + ) + assert out_1.shape == (5, 3) + assert torch.allclose(out_1, torch.ones_like(out_1)) + + model_2 = FakeModel(embed_dim=2, expose_embed_dim=False) + out_2 = esm2.compute_window_embedding( + token, model_2, layer=1, device="cpu", window_size=3, stride=2 + ) + assert out_2.shape == (5, 2) + assert torch.allclose(out_2, torch.ones_like(out_2)) + + +def test_esm_embeddings_from_fasta_short_and_long(monkeypatch, tmp_path: Path): + fake_pretrained = types.SimpleNamespace( + fake_model=lambda: (FakeModel(embed_dim=3), FakeAlphabet()) + ) + monkeypatch.setattr(esm2.esm, "pretrained", fake_pretrained) + + df = pd.DataFrame( + [ + {"ID": "P1", "Sequence": "ACD", "viral_family": "111"}, + {"ID": "P2", "Sequence": "ACDEFGH", "viral_family": "222"} + ] + ) + + per_seq = tmp_path / "pts" + idx_csv = tmp_path / "idx.csv" + + index_df, failed = esm2.esm_embeddings_from_fasta( + df, + id_col="ID", + seq_col="Sequence", + family_col="viral_family", + model_name="fake_model", + max_tokens=6, + batch_size=2, + per_seq_dir=per_seq, + index_csv_path=idx_csv, + key_mode="id-family", + key_delimiter="-", + logger=logging.getLogger("esm2_test") + ) + + assert failed == [] + assert len(index_df) == 2 + assert set(index_df["handle"]) == {"short", "long"} + for key in index_df["id"].tolist(): + assert (per_seq / f"{key}.pt").exists() + + +def test_esm_embeddings_from_fasta_key_validation(monkeypatch, tmp_path: Path): + fake_pretrained = types.SimpleNamespace( + fake_model=lambda: (FakeModel(embed_dim=3), FakeAlphabet()) + ) + monkeypatch.setattr(esm2.esm, "pretrained", fake_pretrained) + + df = pd.DataFrame([{"ID": "P1", "Sequence": "ACD"}]) + + with pytest.raises(ValueError, match="Unsupported key_mode"): + esm2.esm_embeddings_from_fasta( + df, + model_name="fake_model", + key_mode="bad", + per_seq_dir=tmp_path / "pts1", + index_csv_path=tmp_path / "idx1.csv", + logger=logging.getLogger("esm2_test") + ) + + with pytest.raises(ValueError, match="Missing required family column"): + esm2.esm_embeddings_from_fasta( + df, + model_name="fake_model", + key_mode="id-family", + per_seq_dir=tmp_path / "pts2", + index_csv_path=tmp_path / "idx2.csv", + logger=logging.getLogger("esm2_test") + ) + + df_conflict = pd.DataFrame( + [ + {"ID": "P1", "Sequence": "AAA"}, + {"ID": "P1", "Sequence": "CCC"} + ] + ) + with pytest.raises( + ValueError, match="Conflicting sequences map to the same embedding key" + ): + esm2.esm_embeddings_from_fasta( + df_conflict, + model_name="fake_model", + key_mode="id", + per_seq_dir=tmp_path / "pts3", + index_csv_path=tmp_path / "idx3.csv", + logger=logging.getLogger("esm2_test") + ) diff --git a/tests/unit/core/io/test_keys.py b/tests/unit/core/io/test_keys.py new file mode 100644 index 0000000..9654036 --- /dev/null +++ b/tests/unit/core/io/test_keys.py @@ -0,0 +1,76 @@ +from pathlib import Path +import pytest +from pepseqpred.core.io.keys import ( + build_emb_stem, + build_id_to_family_from_metadata, + normalize_family_value, + parse_emb_stem, + parse_family_from_oxx, + parse_fullname +) + +pytestmark = pytest.mark.unit + + +def test_parse_fullname_and_family(): + fullname = "ID=P001 AC=A1 OXX=10,20,333" + protein_id, ac, oxx, family = parse_fullname(fullname) + assert protein_id == "P001" + assert ac == "A1" + assert oxx == "10,20,333" + assert family == "333" + assert parse_family_from_oxx(oxx) == "333" + + +def test_normalize_family_value_variants(): + assert normalize_family_value(None) == "" + assert normalize_family_value("nan") == "" + assert normalize_family_value("123.0") == "123" + assert normalize_family_value(" 456 ") == "456" + + +def test_build_and_parse_emb_stem_roundtrip(): + stem = build_emb_stem("P001", "111", delimiter="-") + assert stem == "P001-111" + + protein_id, family, scheme = parse_emb_stem(stem, delimiter="-") + assert protein_id == "P001" + assert family == "111" + assert scheme == "id-family" + + protein_id2, family2, scheme2 = parse_emb_stem("P002", delimiter="-") + assert protein_id2 == "P002" + assert family2 is None + assert scheme2 == "id" + + +def test_build_emb_stem_rejects_non_numeric_family(): + with pytest.raises(ValueError, match="must be numeric"): + build_emb_stem("P001", "abc") + + +def test_build_id_to_family_from_metadata_happy_path(tmp_path: Path): + meta = tmp_path / "meta.tsv" + meta.write_text( + "Name\tFamily\n" + "ID=P001 AC=A1 OXX=10,20,333\t333\n" + "ID=P002 AC=A2 OXX=10,20,444\t444\n", + encoding="utf-8" + ) + + mapping, duplicate_same = build_id_to_family_from_metadata(meta) + assert mapping == {"P001": "333", "P002": "444"} + assert duplicate_same == 0 + + +def test_build_id_to_family_from_metadata_conflict_raises(tmp_path: Path): + meta = tmp_path / "meta.tsv" + meta.write_text( + "Name\tFamily\n" + "ID=P001 AC=A1 OXX=10,20,333\t333\n" + "ID=P001 AC=A1 OXX=10,20,333\t444\n", + encoding="utf-8" + ) + + with pytest.raises(ValueError, match="conflicts"): + build_id_to_family_from_metadata(meta) diff --git a/tests/unit/core/io/test_read_csv.py b/tests/unit/core/io/test_read_csv.py new file mode 100644 index 0000000..72c69ad --- /dev/null +++ b/tests/unit/core/io/test_read_csv.py @@ -0,0 +1,22 @@ +import pytest +from pepseqpred.core.io.read import parse_int_csv, parse_float_csv + +pytestmark = pytest.mark.unit + + +def test_parse_int_csv_success(): + assert parse_int_csv("11,22,33", "--split-seeds") == [11, 22, 33] + + +def test_parse_int_csv_invalid(): + with pytest.raises(ValueError, match="CSV list of integers"): + parse_int_csv("11,abc", "--split-seeds") + + +def test_parse_float_csv_success(): + assert parse_float_csv("0.1,1,2.5", "--dropouts") == [0.1, 1.0, 2.5] + + +def test_parse_float_csv_empty(): + with pytest.raises(ValueError, match="cannot be empty"): + parse_float_csv(" , ", "--dropouts") diff --git a/tests/unit/core/io/test_read_write_preprocess.py b/tests/unit/core/io/test_read_write_preprocess.py new file mode 100644 index 0000000..8e9bfc3 --- /dev/null +++ b/tests/unit/core/io/test_read_write_preprocess.py @@ -0,0 +1,148 @@ +import logging +from pathlib import Path +import pandas as pd +import pytest +from pepseqpred.core.io.read import read_fasta, read_metadata, read_zscores +from pepseqpred.core.io.write import append_csv_row +from pepseqpred.core.preprocess.pv1 import preprocess +from pepseqpred.core.preprocess.zscores import ( + apply_z_threshold, + merge_zscores_metadata +) + +pytestmark = pytest.mark.unit + + +def test_read_fasta_standard_and_full_name(tmp_path: Path): + fasta = tmp_path / "in.fasta" + fasta.write_text( + ">ID=P001 AC=A1 OXX=11,22,33\nACDE\n>ID=P002 AC=A2 OXX=44,55,66\nFGH\n", + encoding="utf-8" + ) + + df_std = read_fasta(fasta, full_name=False) + assert list(df_std.columns) == ["ID", "AC", "OXX", "Sequence"] + assert df_std["ID"].tolist() == ["P001", "P002"] + assert df_std["Sequence"].tolist() == ["ACDE", "FGH"] + + df_full = read_fasta(fasta, full_name=True) + assert list(df_full.columns) == ["FullName", "Sequence"] + assert df_full["FullName"].iloc[0].startswith("ID=P001") + + +def test_read_fasta_bad_header_raises(tmp_path: Path): + fasta = tmp_path / "bad.fasta" + fasta.write_text(">not pv1 style\nACD\n", encoding="utf-8") + with pytest.raises(ValueError, match="Header does not match expected format"): + read_fasta(fasta, full_name=False) + + +def test_read_metadata_extracts_align_indices_and_drops(tmp_path: Path): + meta = tmp_path / "meta.tsv" + pd.DataFrame( + [ + { + "CodeName": "pep1", + "Category": "SetCover", + "SpeciesID": "1", + "Species": "X", + "Protein": "Y", + "FullName": "ID=P001 AC=A1 OXX=11,22,33_2_5", + "Peptide": "ACD", + "Encoding": "enc", + }, + { + "CodeName": "pep2", + "Category": "Other", + "SpeciesID": "1", + "Species": "X", + "Protein": "Y", + "FullName": "ID=P002 AC=A2 OXX=44,55,66_1_3", + "Peptide": "FGH", + "Encoding": "enc" + } + ] + ).to_csv(meta, sep="\t", index=False) + + out = read_metadata( + meta, + drop_cols=["Category", "SpeciesID", "Protein", "Encoding"] + ) + assert len(out) == 1 + assert int(out["AlignStart"].iloc[0]) == 2 + assert int(out["AlignStop"].iloc[0]) == 5 + assert out["FullName"].iloc[0] == "ID=P001 AC=A1 OXX=11,22,33" + + +def test_read_zscores_rename(tmp_path: Path): + z = tmp_path / "z.tsv" + pd.DataFrame( + [ + {"Sequence name": "pep1", "VW_001": 1.0, "VW_002": 2.0}, + {"Sequence name": "pep2", "VW_001": 3.0, "VW_002": 4.0} + ] + ).to_csv(z, sep="\t", index=False) + + out = read_zscores(z) + assert "CodeName" in out.columns + assert "Sequence name" not in out.columns + + +def test_append_csv_row_appends_without_duplicate_header(tmp_path: Path): + csv_path = tmp_path / "runs" / "r.csv" + append_csv_row(csv_path, {"a": 1, "b": "x"}) + append_csv_row(csv_path, {"a": 2, "b": "y"}) + + df = pd.read_csv(csv_path) + assert df.shape == (2, 2) + assert df["a"].tolist() == [1, 2] + + +def test_apply_threshold_merge_and_preprocess(tmp_path: Path): + z_df = pd.DataFrame( + [ + {"CodeName": "pep1", "VW_001": 30, "VW_002": 25}, + {"CodeName": "pep2", "VW_001": 1, "VW_002": 2}, + {"CodeName": "pep3", "VW_001": 15, "VW_002": 5} + ] + ) + labeled = apply_z_threshold( + z_df.copy(), + is_epitope_z_min=20, + is_epitope_min_subjects=1, + not_epitope_z_max=10 + ) + assert labeled[["Def epitope", "Uncertain", "Not epitope"]].sum( + axis=1).eq(1).all() + + meta_df = pd.DataFrame( + {"CodeName": ["pep1", "pep2", "pep3"], "x": [1, 2, 3]}) + merged = merge_zscores_metadata(labeled, meta_df) + assert {"Def epitope", "Uncertain", "Not epitope"}.issubset(merged.columns) + + meta_path = tmp_path / "meta.tsv" + z_path = tmp_path / "z.tsv" + save_path = tmp_path / "out.tsv" + + pd.DataFrame( + [ + { + "CodeName": "pep1", + "Category": "SetCover", + "SpeciesID": "1", + "Species": "X", + "Protein": "Y", + "FullName": "ID=P001 AC=A1 OXX=11,22,33_2_5", + "Peptide": "AAA", + "Encoding": "enc" + } + ] + ).to_csv(meta_path, sep="\t", index=False) + pd.DataFrame( + [{"Sequence name": "pep1", "VW_001": 50, "VW_002": 0}] + ).to_csv(z_path, sep="\t", index=False) + + logger = logging.getLogger("test_preprocess") + out = preprocess(meta_path, z_path, save_path=save_path, logger=logger) + assert len(out) == 1 + assert save_path.exists() diff --git a/tests/unit/core/labels/test_builder.py b/tests/unit/core/labels/test_builder.py new file mode 100644 index 0000000..fe6e40d --- /dev/null +++ b/tests/unit/core/labels/test_builder.py @@ -0,0 +1,156 @@ +import logging +from pathlib import Path +import pandas as pd +import pytest +import torch +from pepseqpred.core.labels.builder import ( + ProteinLabelBuilder, + parse_id_from_fullname, + parse_taxonomy_from_fullname +) + +pytestmark = pytest.mark.unit + + +def _write_embedding(path: Path, length: int = 6, dim: int = 4) -> None: + torch.save(torch.randn(length, dim, dtype=torch.float32), path) + + +def _write_metadata(path: Path, rows: list[dict]) -> None: + pd.DataFrame(rows).to_csv(path, sep="\t", index=False) + + +def test_parse_helpers(): + full = "ID=P001 AC=A1 OXX=11,22,33" + assert parse_id_from_fullname(full) == "P001" + tax = parse_taxonomy_from_fullname(full) + assert tax["protein_id"] == "P001" + assert parse_taxonomy_from_fullname("bad") == {"fullname": "bad"} + + with pytest.raises(ValueError, match="Could not parse ID"): + parse_id_from_fullname("not valid") + + +def test_builder_build_id_mode(tmp_path: Path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir() + _write_embedding(emb_dir / "P001.pt", length=6, dim=4) + _write_embedding(emb_dir / "P002.pt", length=4, dim=4) + + meta = tmp_path / "meta.tsv" + _write_metadata( + meta, + [ + { + "CodeName": "pep1", + "Peptide": "AAA", + "AlignStart": 0, + "AlignStop": 3, + "Def epitope": 1, + "Uncertain": 0, + "Not epitope": 0, + "FullName": "ID=P001 AC=A1 OXX=11,22,33" + }, + { + "CodeName": "pep2", + "Peptide": "BBB", + "AlignStart": 3, + "AlignStop": 5, + "Def epitope": 0, + "Uncertain": 0, + "Not epitope": 1, + "FullName": "ID=P001 AC=A1 OXX=11,22,33" + }, + { + "CodeName": "pep3", + "Peptide": "CCC", + "AlignStart": 0, + "AlignStop": 2, + "Def epitope": 0, + "Uncertain": 1, + "Not epitope": 0, + "FullName": "ID=P002 AC=A2 OXX=44,55,66" + } + ] + ) + + out_pt = tmp_path / "labels.pt" + builder = ProteinLabelBuilder( + meta_path=meta, + emb_dirs=[emb_dir], + logger=logging.getLogger("builder_test"), + calc_pos_weight=True, + embedding_key_delim="" + ) + payload = builder.build(out_pt) + + assert out_pt.exists() + assert set(payload.keys()) >= {"labels", "proteins", "class_stats"} + p1 = payload["labels"]["P001"] + assert p1.shape == (6, 3) + assert p1[:3, 0].sum().item() == 3 + assert p1[3:5, 2].sum().item() == 2 + assert "tax_info" in payload["proteins"]["P001"] + + +def test_builder_restrict_to_embeddings_raises_if_empty(tmp_path: Path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir() + _write_embedding(emb_dir / "P001.pt") + + meta = tmp_path / "meta.tsv" + _write_metadata( + meta, + [ + { + "CodeName": "pepX", + "Peptide": "AAA", + "AlignStart": 0, + "AlignStop": 2, + "Def epitope": 1, + "Uncertain": 0, + "Not epitope": 0, + "FullName": "ID=P999 AC=A9 OXX=1,2,3" + } + ] + ) + + with pytest.raises(ValueError, match="0 rows after --restrict-to-embeddings"): + ProteinLabelBuilder( + meta_path=meta, + emb_dirs=[emb_dir], + logger=logging.getLogger("builder_test"), + restrict_to_embeddings=True + ) + + +def test_builder_id_family_duplicate_ids_raise(tmp_path: Path): + emb_dir = tmp_path / "emb" + emb_dir.mkdir() + _write_embedding(emb_dir / "P001-111.pt") + _write_embedding(emb_dir / "P001-222.pt") + + meta = tmp_path / "meta.tsv" + _write_metadata( + meta, + [ + { + "CodeName": "pep1", + "Peptide": "AAA", + "AlignStart": 0, + "AlignStop": 2, + "Def epitope": 1, + "Uncertain": 0, + "Not epitope": 0, + "FullName": "ID=P001 AC=A1 OXX=11,22,33", + } + ] + ) + + with pytest.raises(ValueError, match="Duplicate ID-family embeddings"): + ProteinLabelBuilder( + meta_path=meta, + emb_dirs=[emb_dir], + logger=logging.getLogger("builder_test"), + embedding_key_delim="-" + ) diff --git a/tests/unit/core/predict/test_inference.py b/tests/unit/core/predict/test_inference.py new file mode 100644 index 0000000..7f459c6 --- /dev/null +++ b/tests/unit/core/predict/test_inference.py @@ -0,0 +1,96 @@ +import pytest +import torch +from pepseqpred.core.models.ffnn import PepSeqFFNN +from pepseqpred.core.predict.inference import ( + FFNNModelConfig, + build_model_from_checkpoint, + infer_decision_threshold, + infer_model_config_from_state, + normalize_state_dict_keys +) + +pytestmark = pytest.mark.unit + + +def _make_checkpoint( + emb_dim: int = 4, + hidden_sizes: tuple[int, ...] = (8,), + dropouts: tuple[float, ...] = (0.0,), + use_layer_norm: bool = False, + use_residual: bool = False +): + model = PepSeqFFNN( + emb_dim=emb_dim, + hidden_sizes=hidden_sizes, + dropouts=dropouts, + num_classes=1, + use_layer_norm=use_layer_norm, + use_residual=use_residual, + ) + return {"model_state_dict": model.state_dict(), "metrics": {"threshold": 0.37}} + + +def test_normalize_state_dict_keys_strips_module_prefix(): + ckpt = _make_checkpoint() + state = ckpt["model_state_dict"] + prefixed = {f"module.{k}": v for k, v in state.items()} + out = normalize_state_dict_keys(prefixed) + assert set(out.keys()) == set(state.keys()) + + +def test_infer_model_config_from_state_happy_path(): + ckpt = _make_checkpoint(emb_dim=4, hidden_sizes=(8,), dropouts=(0.0,)) + cfg = infer_model_config_from_state(ckpt["model_state_dict"]) + assert cfg.emb_dim == 4 + assert cfg.hidden_sizes == (8,) + assert cfg.num_classes == 1 + assert cfg.use_layer_norm is False + assert cfg.use_residual is False + + +def test_infer_model_config_from_state_ambiguous_residual_raises(): + ckpt = _make_checkpoint(emb_dim=4, hidden_sizes=( + 4,), dropouts=(0.0,), use_residual=False) + with pytest.raises(ValueError, match="Cannot infer use_residual"): + infer_model_config_from_state(ckpt["model_state_dict"]) + + +def test_build_model_from_checkpoint_handles_ddp_prefixed_state(): + ckpt = _make_checkpoint() + prefixed_state = {f"module.{k}": v for k, + v in ckpt["model_state_dict"].items()} + model, cfg, cfg_src = build_model_from_checkpoint( + {"model_state_dict": prefixed_state}, + device="cpu", + ) + assert cfg_src == "state_dict" + assert cfg.emb_dim == 4 + + x = torch.randn(2, 5, 4) + with torch.inference_mode(): + y = model(x) + assert y.shape == (2, 5) + + +def test_build_model_from_checkpoint_rejects_invalid_num_classes(): + ckpt = _make_checkpoint() + bad_cfg = FFNNModelConfig( + emb_dim=4, + hidden_sizes=(8,), + dropouts=(0.0,), + num_classes=2, + use_layer_norm=False, + use_residual=False + ) + with pytest.raises(ValueError, match="num_classes=1"): + build_model_from_checkpoint(ckpt, device="cpu", model_config=bad_cfg) + + +def test_infer_decision_threshold_uses_default_on_invalid(): + ckpt = {"metrics": {"threshold": 1.2}} + assert infer_decision_threshold(ckpt, default=0.5) == 0.5 + + +def test_infer_decision_threshold_reads_checkpoint_metric(): + ckpt = {"metrics": {"threshold": 0.42}} + assert infer_decision_threshold(ckpt, default=0.5) == pytest.approx(0.42) diff --git a/tests/unit/core/train/test_split.py b/tests/unit/core/train/test_split.py new file mode 100644 index 0000000..95d0143 --- /dev/null +++ b/tests/unit/core/train/test_split.py @@ -0,0 +1,25 @@ +import pytest +from pepseqpred.core.train.split import split_ids_grouped, partition_ids_weighted + +pytestmark = pytest.mark.unit + + +def test_split_ids_grouped_keeps_groups_intact(): + ids = ["a1", "a2", "b1", "b2", "c1"] + groups = {"a1": "A", "a2": "A", "b1": "B", "b2": "B", "c1": "C"} + train_ids, val_ids = split_ids_grouped( + ids, val_frac=0.4, seed=7, groups=groups) + + train_groups = {groups[i] for i in train_ids} + val_groups = {groups[i] for i in val_ids} + assert train_groups.isdisjoint(val_groups) + + +def test_partition_ids_weighted_non_empty(): + ids = ["p1", "p2", "p3", "p4"] + weights = {"p1": 100.0, "p2": 90.0, "p3": 10.0, "p4": 9.0} + parts = partition_ids_weighted( + ids, world_size=2, weights=weights, ensure_non_empty=True) + assert len(parts) == 2 + assert all(len(p) > 0 for p in parts) + assert sorted([x for part in parts for x in part]) == sorted(ids) diff --git a/tests/unit/core/train/test_support_modules.py b/tests/unit/core/train/test_support_modules.py new file mode 100644 index 0000000..2c8c645 --- /dev/null +++ b/tests/unit/core/train/test_support_modules.py @@ -0,0 +1,106 @@ +import math +import random +import warnings +from pathlib import Path +import numpy as np +import pytest +import torch +import pepseqpred.core.train.ddp as ddp_mod +from pepseqpred.core.train.embedding import infer_emb_dim +from pepseqpred.core.train.metrics import compute_eval_metrics +from pepseqpred.core.train.seed import set_all_seeds + +pytestmark = pytest.mark.unit + + +def test_compute_eval_metrics_happy_and_single_class(): + y_true = torch.tensor([0, 1, 0, 1]) + y_pred = torch.tensor([0, 1, 1, 1]) + y_prob = torch.tensor([0.1, 0.9, 0.7, 0.8], dtype=torch.float32) + + out = compute_eval_metrics(y_true, y_pred, y_prob) + assert {"precision", "recall", "f1", "mcc", "auc", "pr_auc", "auc10"}.issubset( + out + ) + + y_true_2 = torch.tensor([1, 1, 1, 1]) + y_pred_2 = torch.tensor([1, 1, 1, 1]) + y_prob_2 = torch.tensor([0.9, 0.8, 0.7, 0.6], dtype=torch.float32) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("error") + out_2 = compute_eval_metrics(y_true_2, y_pred_2, y_prob_2) + assert not caught + assert math.isnan(out_2["auc"]) + assert out_2["pr_auc"] == pytest.approx(1.0) + assert math.isnan(out_2["auc10"]) + + +def test_set_all_seeds_reproducible(): + set_all_seeds(123) + a = torch.rand(3) + b = np.random.rand(3) + c = random.random() + + set_all_seeds(123) + a_2 = torch.rand(3) + b_2 = np.random.rand(3) + c_2 = random.random() + + assert torch.allclose(a, a_2) + assert np.allclose(b, b_2) + assert c == c_2 + + +def test_infer_emb_dim_success_and_errors(tmp_path: Path): + emb = tmp_path / "x.pt" + torch.save(torch.randn(5, 7), emb) + assert infer_emb_dim({"P1": emb}) == 7 + + with pytest.raises(ValueError, match="No embedding found"): + infer_emb_dim({}) + + bad = tmp_path / "bad.pt" + torch.save(torch.randn(5), bad) + with pytest.raises(ValueError, match="Expected embedding tensor"): + infer_emb_dim({"P2": bad}) + + +def test_ddp_helpers_no_ddp(monkeypatch): + monkeypatch.setattr(ddp_mod.dist, "is_available", lambda: False) + monkeypatch.setattr(ddp_mod.dist, "is_initialized", lambda: False) + + t = torch.tensor([1.0, 2.0]) + out = ddp_mod.ddp_all_reduce_sum(t.clone()) + assert torch.equal(out, t) + + gathered, sizes = ddp_mod.ddp_gather_all_1d( + torch.tensor([1, 2, 3]), torch.device("cpu") + ) + assert sizes == [3] + assert torch.equal(gathered[0], torch.tensor([1, 2, 3])) + + +def test_init_ddp_enabled(monkeypatch): + monkeypatch.setenv("RANK", "2") + monkeypatch.setenv("LOCAL_RANK", "1") + monkeypatch.setenv("PEPSEQPRED_DDP_TIMEOUT_MIN", "5") + + calls = {} + + monkeypatch.setattr( + ddp_mod.dist, + "init_process_group", + lambda backend, timeout: calls.update({"backend": backend}) + ) + monkeypatch.setattr(ddp_mod.dist, "get_rank", lambda: 2) + monkeypatch.setattr(ddp_mod.dist, "get_world_size", lambda: 4) + monkeypatch.setattr( + ddp_mod.torch.cuda, + "set_device", + lambda idx: calls.update({"local_rank": idx}) + ) + + out = ddp_mod.init_ddp() + assert out == {"rank": 2, "world_size": 4, "local_rank": 1} + assert calls["backend"] == "nccl" + assert calls["local_rank"] == 1 diff --git a/tests/unit/core/train/test_threshold.py b/tests/unit/core/train/test_threshold.py new file mode 100644 index 0000000..3885e11 --- /dev/null +++ b/tests/unit/core/train/test_threshold.py @@ -0,0 +1,23 @@ +import numpy as np +import pytest +from pepseqpred.core.train.threshold import find_threshold_max_recall_min_precision + +pytestmark = pytest.mark.unit + + +def test_threshold_no_valid_residues(): + out = find_threshold_max_recall_min_precision( + np.array([], dtype=np.int64), + np.array([], dtype=np.float64), + min_precision=0.25, + ) + assert out["status"] == "no_valid_residues" + + +def test_threshold_respects_min_precision(): + y_true = np.array([1, 1, 0, 0], dtype=np.int64) + y_prob = np.array([0.9, 0.8, 0.4, 0.1], dtype=np.float64) + out = find_threshold_max_recall_min_precision( + y_true, y_prob, min_precision=0.5) + assert out["status"] == "ok" + assert out["precision"] >= 0.5 diff --git a/tests/unit/core/train/test_trainer_batch_step.py b/tests/unit/core/train/test_trainer_batch_step.py new file mode 100644 index 0000000..94adbea --- /dev/null +++ b/tests/unit/core/train/test_trainer_batch_step.py @@ -0,0 +1,76 @@ +import logging +import pytest +import torch +import torch.nn as nn +from pepseqpred.core.models.ffnn import PepSeqFFNN +from pepseqpred.core.train.trainer import Trainer, TrainerConfig + +pytestmark = pytest.mark.unit + + +class BadShapeModel(nn.Module): + def __init__(self): + super().__init__() + self.p = nn.Parameter(torch.tensor(1.0)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Intentionally wrong shape (B, L+1) instead of (B, L) + return torch.zeros((x.size(0), x.size(1) + 1), device=x.device) + (self.p * 0.0) + + +def _make_trainer(model: nn.Module) -> Trainer: + return Trainer( + model=model, + train_loader=[], + logger=logging.getLogger("test_trainer"), + val_loader=None, + config=TrainerConfig(epochs=1, batch_size=2, device="cpu") + ) + + +def test_batch_step_zero_mask_returns_zero_n(): + model = PepSeqFFNN( + emb_dim=4, + hidden_sizes=(8,), + dropouts=(0.0,), + num_classes=1, + use_layer_norm=False, + use_residual=False + ) + trainer = _make_trainer(model) + + x = torch.randn(1, 3, 4) + y = torch.tensor([[1.0, 0.0, 1.0]], dtype=torch.float32) + mask = torch.zeros((1, 3), dtype=torch.long) + + out = trainer._batch_step((x, y, mask), train=True) + assert out["n"] == 0 + assert out["loss"] == pytest.approx(0.0, abs=1e-12) + + +def test_batch_step_rejects_invalid_y_dim(): + model = PepSeqFFNN( + emb_dim=4, + hidden_sizes=(8,), + dropouts=(0.0,), + num_classes=1, + use_layer_norm=False, + use_residual=False + ) + trainer = _make_trainer(model) + + x = torch.randn(1, 3, 4) + y_bad = torch.randn(1, 3, 1) + + with pytest.raises(ValueError, match="Expected y_onehot shape"): + trainer._batch_step((x, y_bad), train=False) + + +def test_batch_step_rejects_logit_shape_mismatch(): + trainer = _make_trainer(BadShapeModel()) + + x = torch.randn(2, 4, 4) + y = torch.zeros((2, 4), dtype=torch.float32) + + with pytest.raises(ValueError, match="Expected logits shape"): + trainer._batch_step((x, y), train=False) diff --git a/tests/unit/core/train/test_trainer_fit.py b/tests/unit/core/train/test_trainer_fit.py new file mode 100644 index 0000000..c01ab62 --- /dev/null +++ b/tests/unit/core/train/test_trainer_fit.py @@ -0,0 +1,91 @@ +import logging +from pathlib import Path +import optuna +import pytest +import torch +from pepseqpred.core.models.ffnn import PepSeqFFNN +from pepseqpred.core.train.trainer import Trainer, TrainerConfig + +pytestmark = pytest.mark.unit + + +def _make_batches( + n_batches: int = 2, + batch_size: int = 2, + length: int = 6, + emb_dim: int = 4, + mask_value: int = 1 +): + batches = [] + base_y = torch.tensor([1, 0, 1, 0, 1, 0], dtype=torch.float32).repeat( + batch_size, 1 + ) + for _ in range(n_batches): + x = torch.randn(batch_size, length, emb_dim) + y = base_y.clone() + m = torch.full((batch_size, length), mask_value, dtype=torch.long) + batches.append((x, y, m)) + return batches + + +def _make_trainer(train_loader, val_loader=None, emb_dim: int = 4, epochs: int = 2): + model = PepSeqFFNN( + emb_dim=emb_dim, + hidden_sizes=(8,), + dropouts=(0.0,), + num_classes=1, + use_layer_norm=False, + use_residual=False + ) + return Trainer( + model=model, + train_loader=train_loader, + val_loader=val_loader, + logger=logging.getLogger("trainer_fit_test"), + config=TrainerConfig( + epochs=epochs, batch_size=2, learning_rate=1e-2, device="cpu" + ) + ) + + +def test_fit_with_validation_saves_checkpoint(tmp_path: Path): + trainer = _make_trainer( + _make_batches(), _make_batches(), emb_dim=4, epochs=2) + summary = trainer.fit(save_dir=tmp_path, score_key="loss") + + assert summary["best_epoch"] >= 0 + assert (tmp_path / "fully_connected.pt").exists() + assert isinstance(summary["best_metrics"], dict) + + +def test_fit_without_validation_saves_no_val_checkpoint(tmp_path: Path): + trainer = _make_trainer(_make_batches(), None, emb_dim=4, epochs=1) + summary = trainer.fit(save_dir=tmp_path, score_key="loss") + + assert summary["best_epoch"] == -1 + assert (tmp_path / "fully_connected_no_val.pt").exists() + + +def test_run_epoch_eval_no_valid_residues(): + train_loader = _make_batches(n_batches=1, mask_value=1) + val_loader = _make_batches(n_batches=1, mask_value=0) + trainer = _make_trainer(train_loader, val_loader, emb_dim=4, epochs=1) + + out = trainer._run_epoch(0, train=False) + assert out["eval_metrics"]["threshold_status"] == "no_valid_residues" + + +class _AlwaysPruneTrial: + def report(self, value, step): + _ = (value, step) + + def should_prune(self): + return True + + +def test_fit_optuna_prune_path(tmp_path: Path): + trainer = _make_trainer( + _make_batches(), _make_batches(), emb_dim=4, epochs=2) + with pytest.raises(optuna.TrialPruned): + trainer.fit_optuna(save_dir=tmp_path, + trial=_AlwaysPruneTrial(), score_key="f1") diff --git a/tests/unit/core/train/test_weights.py b/tests/unit/core/train/test_weights.py new file mode 100644 index 0000000..db77748 --- /dev/null +++ b/tests/unit/core/train/test_weights.py @@ -0,0 +1,26 @@ +from pathlib import Path +import pytest +import torch +from pepseqpred.core.train.weights import pos_weight_from_label_shards + +pytestmark = pytest.mark.unit + + +def _write_label_shard(path: Path, class_stats: dict) -> None: + torch.save({"labels": {}, "class_stats": class_stats}, path) + + +def test_pos_weight_from_label_shards_accepts_neg_count(tmp_path: Path): + shard = tmp_path / "labels.pt" + _write_label_shard( + shard, {"pos_count": 2, "neg_count": 6} + ) + assert pos_weight_from_label_shards([shard]) == pytest.approx(3.0) + + +def test_pos_weight_from_label_shards_accepts_neg_counts(tmp_path: Path): + shard = tmp_path / "labels.pt" + _write_label_shard( + shard, {"pos_count": 2, "neg_counts": 8} + ) + assert pos_weight_from_label_shards([shard]) == pytest.approx(4.0)