Skip to content

Error loading SAE with SAElens library #65

@thdai2000

Description

@thdai2000

Hello LLAMASCOPE team, thank you for your work! I'm having trouble with loading the SAE using the SAElens library, I keep getting the error below:

File "/anaconda3/envs/mechinterp/lib/python3.12/site-packages/sae_lens/sae.py", line 616, in from_pretrained
    cfg_dict, state_dict, log_sparsities = conversion_loader(
                                           ^^^^^^^^^^^^^^^^^^
  File "/anaconda3/envs/mechinterp/lib/python3.12/site-packages/sae_lens/toolkit/pretrained_sae_loaders.py", line 503, in llama_scope_sae_loader
    state_dict_loaded = load_file(sae_path, device=device)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/envs/mechinterp/lib/python3.12/site-packages/safetensors/torch.py", line 313, in load_file
    with safe_open(filename, framework="pt", device=device) as f:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
safetensors_rust.SafetensorError: device cuda is invalid

The loading script I use is

from sae_lens import SAE
sae, cfg_dict, sparsity = SAE.from_pretrained(release="llama_scope_lxm_8x", sae_id="l16m_8x", device=device)

where device is cuda.

Such problem doesn't appear with the GemmaScope SAE. Do you have any clue about this problem? Thanks in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions