Skip to content

[DRAFT] FEAT: Dataset Loading Changes#1451

Draft
ValbuenaVC wants to merge 7 commits intoAzure:mainfrom
ValbuenaVC:datasetloader
Draft

[DRAFT] FEAT: Dataset Loading Changes#1451
ValbuenaVC wants to merge 7 commits intoAzure:mainfrom
ValbuenaVC:datasetloader

Conversation

@ValbuenaVC
Copy link
Contributor

@ValbuenaVC ValbuenaVC commented Mar 10, 2026

Description

Features:

  • Addition of filters argument to get_all_dataset_names, which rejects datasets that don't meet filter criteria. filters has type SeedDatasetFilters.
  • SeedDatasetProviders have a new optional SeedDatasetMetadata private attribute (_metadata) that contains static metadata (dynamic metadata, like derived attributes at runtime, has been scoped out of this PR)
  • SeedDatasetMetadata dataclass contains: tags: set[str]; size: SeedDatasetSize; modalities: list[SeedDatasetModality]; source: SeedDatasetSourceType; rank: SeedDatasetLoadingRank; harm_categories: list[str]
  • Each dataset child can optionally implement the private attribute _metadata to make adding and removing datasets easier. However, datasets that do not specify _metadata are excluded from any searches with filters.

Issues: reviewers, please provide your opinion on these.

  • Most important issue: if we use a private class attribute to store metadata, we may need to migrate the local datasets (which are .prompt files) to Python classes. Alternatively, we can add a metadata attribute to the .prompt files, but I'm not sure which is the better approach long-term. I'm in favor of changing the schema so each local dataset at the top level is just dataset_name, seeds, and metadata, with metadata containing everything else of use (authors goes in tags), etc.
  • SeedDatasetFilter and SeedDatasetMetadata are almost the same. I wonder if it would be overengineering to extract a common schema object and turn those two into wrapper classes that just interact with a unified schema.
  • The necessary imports to set up metadata are very verbose (import SeedDatasetMetadata, SeedDatasetSize, ...). This could be fixed by adding constructors that build out the actual metadata enum types from primitive types, but this adds a layer of indirection.
  • Who gets responsibility of filter parsing and matching? I prefer keeping it in SeedDatasetProvider since the logic of filtering is different from the filtering fields itself, but this may be more confusing.
  • Some fields like SeedDatasetLoadingRank and SeedDatasetSourceType seem like they could be excluded.
  • Keeping metadata as an optional private class attribute seems like the best compromise between flexibility to easily add datasets and ensuring there is a unified filtering system that can be used without instantiating the dataset class, but there are issues with this approach:
    • It doesn't expose a unified interface. Expecting _metadata is not type safe.
    • It keeps us from easily adding dynamic metadata that would be derived at runtime, unless we bootstrap an _add_dynamic_attributes_to_metadata method that peeks at _metadata after the actual instance is created, which could work if we're comfortable letting metadata be incomplete at some point in the dataset's lifecycle.
    • It's not easy for users to interact with filters or metadata. The intuitive and Pythonic angle would be to pass dictionaries or typed dictionaries, but we need strong type safety.

Follow-Up PRs:

  • Dynamic Metadata including things like timestamps, exact size, and caching of changes made to remote datasets on local disk. Not possible with the class attribute and frozen dataclass approach we have currently.
  • SQL Passthrough from SeedDatasetProvider to CentralMemory to allow for complex operations across datasets. For example, consider a user that wants to get all text prompts with string "harm" from two datasets. Something like a SQL JOIN would be ideal in this situation.
  • Rich Encoding/Decoding for metadata filtering and storage. Not quite a DSL, but make it easier to convert filters and CentralMemory queries.

Tests and Documentation

  • Addition of test_seed_dataset_metadata.py under unit tests.
  • Addition of test_fetch_dataset_with_filtering under test_seed_dataset_provider_integration.py in tests.integration.datasets.

invalid_categories = {
cat for cat in harm_categories if cat not in self.HARM_CATEGORIES}
if invalid_categories:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we likely still want to load these; should we use a default harm category here?


@classmethod
def get_all_dataset_names(cls) -> list[str]:
def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]:
Copy link
Contributor

@rlundeen2 rlundeen2 Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might start by going backwards from what will work well

An open-ended dictionary might be tricky to use - as a caller it's not clear how to I'd want to use it. I'd Potentially have a DatasetProviderFilter class. I'd likely make the decision that we want to filter these before fetching the datasets. But both could be valid options.

Here is what it might look like (ty copilot)

Problem

Today, SeedDatasetProvider.fetch_datasets_async() fetches all registered datasets (or a hard-coded list of names). There's no way to say "give me only text datasets" or "give me only small safety-related datasets." Every call potentially downloads 30+ datasets from HuggingFace/GitHub.

We want to add metadata to each provider and a filter object so users can query datasets like this:

# "Give me small text-only datasets tagged as default"
f = DatasetProviderFilter(
    tags={"default"},
    modalities=[DatasetModality.TEXT],
    sizes=[DatasetSize.SMALL],
)
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

# "Give me everything multimodal"
f = DatasetProviderFilter(modalities=[DatasetModality.TEXT, DatasetModality.IMAGE])
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

# "Give me all datasets (no filtering)"
f = DatasetProviderFilter(tags={"all"})
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

Filtering happens before fetching — providers that don't match are skipped entirely, so nothing is downloaded unnecessarily.

There are exactly two things to build: Provider metadata and DatasetProviderFilter

1. Provider metadata — each provider declares what it is

Every SeedDatasetProvider subclass needs to declare four pieces of metadata as class-level attributes (not instance properties — we need to read them without instantiating):

class _HarmBenchDataset(_RemoteDatasetLoader):
    """HarmBench: 504 harmful behaviors across safety categories."""

    harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"]
    modalities: list[DatasetModality] = [DatasetModality.TEXT]
    size: DatasetSize = DatasetSize.LARGE          # 504 seeds
    tags: set[str] = {"default", "safety"}         # "default" means included in curated set

    @property
    def dataset_name(self) -> str:
        return "harmbench"
    # ...
Attribute Type Purpose
harm_categories list[str] Free-form strings like "violence", "cybercrime". No enum — each dataset uses its own vocabulary.
modalities list[DatasetModality] TEXT, IMAGE, AUDIO, VIDEO. Indicates what data types the seeds contain.
size DatasetSize SMALL (<50 seeds), MEDIUM (50–500), LARGE (>500). Self-declared by the provider author.
tags set[str] Flexible labels. "default" means “include me in the curated default set.” Anything else is free-form.
source_type str "local" or "remote". Set once on the base classes: _LocalDatasetLoader returns "local" and _RemoteDatasetLoader returns "remote".

2. DatasetProviderFilter — the user-facing filter

@dataclass
class DatasetProviderFilter:
    """
    Filters dataset providers based on their declared metadata.

    All fields are optional. None means "don't filter on this axis."
    Across axes: AND (all specified conditions must match).
    Within each axis: OR (provider needs at least one overlap).

    Special tag behavior:
    - tags={"all"} → skip tag filtering entirely, return everything
    - tags={"default"} → only providers that have "default" in their tags
    - tags=None → no tag filtering (same as "all")
    """

    harm_categories: Optional[list[str]] = None
    source_type: Optional[Literal["local", "remote"]] = None
    modalities: Optional[list[DatasetModality]] = None
    sizes: Optional[list[DatasetSize]] = None
    tags: Optional[set[str]] = None

    def matches(self, *, provider: SeedDatasetProvider) -> bool:
        """Return True if the provider passes all filter conditions."""

        # Tags: "all" means skip tag check
        if self.tags is not None and "all" not in self.tags:
            if not self.tags & provider.tags:  # set intersection — need at least one overlap
                return False

        # Harm categories: provider must have at least one matching category
        if self.harm_categories is not None:
            if not set(self.harm_categories) & set(provider.harm_categories):
                return False

        # Source type
        if self.source_type is not None:
            if provider.source_type != self.source_type:
                return False

        # Modalities: provider must support at least one requested modality
        if self.modalities is not None:
            if not set(self.modalities) & set(provider.modalities):
                return False

        # Size
        if self.sizes is not None:
            if provider.size not in self.sizes:
                return False

        return True

Matching logic in plain English:

Each specified filter field must be satisfied (AND)
Within a field, the provider only needs to overlap on one value (OR)
None on any field = don't care about that field
tags={"all"} = special: return everything regardless of tags
tags=None = also returns everything (no tag filter applied)

Fetch Datasets Update

In seed_dataset_provider.py, the existing method gains two new parameter, filter and max_seeds. Before creating tasks, datasets are narrowed:

# Apply filter to narrow down which providers to even consider
providers = cls._registry
if filter:
    providers = {
        name: pclass for name, pclass in providers.items()
        if filter.matches(provider=pclass())
    }

max_seeds needs to be implemented in the dataset classes, but would allow us to limit the number of seeds retrieved. This way we can still have integration tests for all datasets.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with these points. A few things to work on in my opinion:

Caller Interface
A dictionary literal is definitely too ambiguous like you said. Users have no way of intuiting what is or isn't a valid filter. But I think calling the enums directly is verbose, and where those live and why they work isn't obvious to the user. Users who just want to grab all small datasets have to invoke a new class and several custom types with that approach, and they shouldn't have to dive into the type system of DatasetMetadata to do it.

I think we have a few options to narrow it down. We could use a typed dictionary, something like this:

f: DatasetFilters = {
    "sizes": "small",
    "modalities": ["text", "image"]
}
datasets = await SeedDatasetProvider.fetch_datasets_async(filter=f)

This lets us constrain the types allowed in the filters without making it cumbersome for the user. I'll keep iterating on this, but I think it's a good start.

Class Attributes and Instantiating
My first instinct was to just add the metadata as fields to each SeedDatasetProvider child, but there are some issues with this. The first is that our default implementation instantiates the class in SeedDatasetProvider.__init_subclass__ anyway, which makes that seem like the more natural injection point. The second is that derived attributes like exact size cannot be used as metadata if they're kept as class attributes. And the third is that we run the risk of having out-of-date metadata. We could just do a one-time scan of each dataset and store its size, but for remote datasets especially I feel that drift is an issue.

I don't have a good solution to this, so I'm leaning towards scoping derived attributes out of the PR, but worth thinking about tradeoffs.

Dataset-Specific Metadata
One issue I ran into early was whether or not we need each dataset to explicitly define its own metadata. The answer can definitely be a yes, but I wanted to do it in a way that would make it easier for users to add or remove datasets without spending too much time on it.

The first attempt I did was a factory method that produced a metadata object. That seemed more cumbersome. The second was a private class attribute, which has hidden state, but is more convenient. Neither approach is great.

Where I'm leaning now is private class attribute for metadata that has custom tags as an attribute. But I don't like that approach very much. I'll keep iterating, but I think we should try to keep metadata together as a single object.

"""

@staticmethod
def populate_metadata() -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open to suggestions on how to avoid this without changing the source files, which I really don't like. If we kept metadata outside of them, this would be easier, but that feels like overengineering.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants