-
Notifications
You must be signed in to change notification settings - Fork 21
Description
Hi, I’m trying to reproduce Figure 2 for the L15R-8x SAE on Llama-3.1-8B.
Setup
LM: meta-llama/Llama-3.1-8B
SAE release: llama_scope_lxr_8x, sae_id = l15r_8x
Script: run_llamascope_inference.py
Cmd (simplified):
python run_llamascope_inference.py --model-name meta-llama/Llama-3.1-8B --sae-release llama_scope_lxr_8x --layers 15 --dataset-cache-path slimpajama_test_cache.jsonl.gz --dataset-max-prompts 5120 --prompts-per-batch 96 --max-seq-len 512
Dataset: SlimPajama cached split, 5120 prompts, max seq len 512 (~1.84M tokens)
Observed metrics (summary JSON)
layer: L15R-8x
explained_variance ≈ 0.702
mean_l0 ≈ 32.3
tokens_evaluated ≈ 1.84M
activation_rate ≈ 9.9e-4, feature_coverage ≈ 0.983
(Δ LM loss is not implemented in my eval script yet.)
Comparison to paper
From Fig.2, my reading is that L15R-8x TopK/JumpReLU is around:
mean L0 ≈ 50
EV ≈ 0.72
So EV is in the right ballpark but a bit lower, while L0 is noticeably lower than the nominal top_k = 50.
Questions
What exact eval configuration (activation scaling, dataset, token count, gating) was used to produce Fig.2 for llama_scope_lxr_8x?
Should I be using a specific ActivationScaler config from the SAE release when calling encode/decode?
Is there a reference eval script to reproduce Fig.2, including Δ LM loss?
Thanks!