Skip to content

Commit d0c3b13

Browse files
committed
refactor(app): split app into smaller modules
1 parent 5604e3d commit d0c3b13

File tree

12 files changed

+1320
-1288
lines changed

12 files changed

+1320
-1288
lines changed

server/app.py

Lines changed: 13 additions & 1288 deletions
Large diffs are not rendered by default.

server/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import os
2+
3+
import torch
4+
5+
from lm_saes.config import MongoDBConfig
6+
from lm_saes.database import MongoClient
7+
8+
device = "cuda" if torch.cuda.is_available() else "cpu"
9+
client = MongoClient(MongoDBConfig())
10+
sae_series = os.environ.get("SAE_SERIES", "default")
11+
tokenizer_only = os.environ.get("TOKENIZER_ONLY", "false").lower() == "true"

server/logic/__init__.py

Whitespace-only changes.

server/logic/loaders.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from functools import lru_cache
2+
3+
from datasets import Dataset
4+
5+
from lm_saes.abstract_sae import AbstractSparseAutoEncoder
6+
from lm_saes.backend import LanguageModel
7+
from lm_saes.config import BaseSAEConfig
8+
from lm_saes.resource_loaders import load_dataset_shard, load_model
9+
from server.config import client, device, sae_series, tokenizer_only
10+
from server.utils.common import synchronized
11+
12+
13+
@synchronized
14+
@lru_cache(maxsize=8)
15+
def get_model(*, name: str) -> LanguageModel:
16+
"""Load and cache a language model."""
17+
cfg = client.get_model_cfg(name)
18+
if cfg is None:
19+
raise ValueError(f"Model {name} not found")
20+
cfg.tokenizer_only = tokenizer_only
21+
cfg.device = device
22+
return load_model(cfg)
23+
24+
25+
@synchronized
26+
@lru_cache(maxsize=16)
27+
def get_dataset(*, name: str, shard_idx: int = 0, n_shards: int = 1) -> Dataset:
28+
"""Load and cache a dataset shard."""
29+
cfg = client.get_dataset_cfg(name)
30+
assert cfg is not None, f"Dataset {name} not found"
31+
return load_dataset_shard(cfg, shard_idx, n_shards)
32+
33+
34+
@synchronized
35+
@lru_cache(maxsize=8)
36+
def get_sae(*, name: str) -> AbstractSparseAutoEncoder:
37+
"""Load and cache a sparse autoencoder."""
38+
path = client.get_sae_path(name, sae_series)
39+
assert path is not None, f"SAE {name} not found"
40+
cfg = BaseSAEConfig.from_pretrained(path)
41+
cfg.device = device
42+
sae = AbstractSparseAutoEncoder.from_config(cfg)
43+
sae.eval()
44+
return sae

server/logic/samples.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
from typing import Any, Generator
2+
3+
import numpy as np
4+
5+
from lm_saes.database import FeatureAnalysisSampling
6+
from server.logic.loaders import get_dataset, get_model
7+
8+
9+
def extract_samples(
10+
sampling: FeatureAnalysisSampling,
11+
start: int | None = None,
12+
end: int | None = None,
13+
visible_range: int | None = None,
14+
) -> list[dict[str, Any]]:
15+
def process_sample(
16+
*,
17+
sparse_feature_acts: tuple[np.ndarray, np.ndarray, np.ndarray | None, np.ndarray | None],
18+
context_idx: int,
19+
dataset_name: str,
20+
model_name: str,
21+
shard_idx: int | None = None,
22+
n_shards: int | None = None,
23+
):
24+
model = get_model(name=model_name)
25+
data = get_dataset(name=dataset_name, shard_idx=shard_idx, n_shards=n_shards)[context_idx]
26+
27+
origins = model.trace({k: [v] for k, v in data.items()})[0]
28+
29+
(
30+
feature_acts_indices,
31+
feature_acts_values,
32+
z_pattern_indices,
33+
z_pattern_values,
34+
) = sparse_feature_acts
35+
36+
assert origins is not None and feature_acts_indices is not None and feature_acts_values is not None, (
37+
"Origins and feature acts must not be None"
38+
)
39+
40+
token_offset = 0
41+
if visible_range is not None: # Drop tokens before and after the highest activating token
42+
if len(feature_acts_indices) == 0:
43+
max_feature_act_index = 0
44+
else:
45+
max_feature_act_index = int(feature_acts_indices[np.argmax(feature_acts_values).item()].item())
46+
47+
feature_acts_mask = np.logical_and(
48+
feature_acts_indices > max_feature_act_index - visible_range,
49+
feature_acts_indices < max_feature_act_index + visible_range,
50+
)
51+
feature_acts_indices = feature_acts_indices[feature_acts_mask]
52+
feature_acts_values = feature_acts_values[feature_acts_mask]
53+
54+
if z_pattern_indices is not None and z_pattern_values is not None:
55+
z_pattern_mask = np.logical_and(
56+
z_pattern_indices > max_feature_act_index - visible_range,
57+
z_pattern_indices < max_feature_act_index + visible_range,
58+
).all(axis=0)
59+
z_pattern_indices = z_pattern_indices[:, z_pattern_mask]
60+
z_pattern_values = z_pattern_values[z_pattern_mask]
61+
62+
token_offset = max(0, max_feature_act_index - visible_range)
63+
64+
origins = origins[token_offset : max_feature_act_index + visible_range]
65+
66+
text_offset = None
67+
if "text" in data:
68+
text_ranges = [origin["range"] for origin in origins if origin is not None and origin["key"] == "text"]
69+
if text_ranges:
70+
max_text_origin = max(text_ranges, key=lambda x: x[1])
71+
data["text"] = data["text"][: max_text_origin[1]]
72+
if visible_range is not None:
73+
text_offset = min(text_ranges, key=lambda x: x[0])[0]
74+
data["text"] = data["text"][text_offset:]
75+
76+
return {
77+
**data,
78+
"token_offset": token_offset,
79+
"text_offset": text_offset,
80+
"origins": origins,
81+
"feature_acts_indices": feature_acts_indices,
82+
"feature_acts_values": feature_acts_values,
83+
"z_pattern_indices": z_pattern_indices,
84+
"z_pattern_values": z_pattern_values,
85+
}
86+
87+
def index_select(
88+
indices: np.ndarray,
89+
values: np.ndarray,
90+
i: int,
91+
) -> tuple[np.ndarray, np.ndarray]:
92+
"""Select i-th sample from sparse tensor indices and values."""
93+
mask = indices[0] == i
94+
return indices[1:, mask], values[mask]
95+
96+
def process_sparse_feature_acts(
97+
feature_acts_indices: np.ndarray,
98+
feature_acts_values: np.ndarray,
99+
z_pattern_indices: np.ndarray | None,
100+
z_pattern_values: np.ndarray | None,
101+
start: int,
102+
end: int,
103+
) -> Generator[tuple[np.ndarray, np.ndarray, np.ndarray | None, np.ndarray | None], Any, None]:
104+
for i in range(start, end):
105+
feature_acts_indices_i, feature_acts_values_i = index_select(feature_acts_indices, feature_acts_values, i)
106+
if z_pattern_indices is not None and z_pattern_values is not None:
107+
z_pattern_indices_i, z_pattern_values_i = index_select(z_pattern_indices, z_pattern_values, i)
108+
else:
109+
z_pattern_indices_i, z_pattern_values_i = None, None
110+
yield feature_acts_indices_i[0], feature_acts_values_i, z_pattern_indices_i, z_pattern_values_i
111+
112+
start = start if start is not None else 0
113+
end = end if end is not None else len(sampling.context_idx)
114+
115+
return [
116+
process_sample(
117+
sparse_feature_acts=sparse_feature_acts,
118+
context_idx=context_idx,
119+
dataset_name=dataset_name,
120+
model_name=model_name,
121+
shard_idx=shard_idx,
122+
n_shards=n_shards,
123+
)
124+
for sparse_feature_acts, context_idx, dataset_name, model_name, shard_idx, n_shards in zip(
125+
process_sparse_feature_acts(
126+
sampling.feature_acts_indices,
127+
sampling.feature_acts_values,
128+
sampling.z_pattern_indices,
129+
sampling.z_pattern_values,
130+
start,
131+
end,
132+
),
133+
sampling.context_idx[start:end],
134+
sampling.dataset_name[start:end],
135+
sampling.model_name[start:end],
136+
sampling.shard_idx[start:end] if sampling.shard_idx is not None else [0] * (end - start),
137+
sampling.n_shards[start:end] if sampling.n_shards is not None else [1] * (end - start),
138+
)
139+
]

server/routers/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)