diff --git a/pyrefly/lib/export/exports.rs b/pyrefly/lib/export/exports.rs index bdf80e8fc2..8c96738a66 100644 --- a/pyrefly/lib/export/exports.rs +++ b/pyrefly/lib/export/exports.rs @@ -20,6 +20,7 @@ use pyrefly_types::callable::Deprecation; use ruff_python_ast::Stmt; use ruff_python_ast::name::Name; use ruff_text_size::TextRange; +use ruff_text_size::TextSize; use starlark_map::small_map::SmallMap; use starlark_map::small_set::SmallSet; @@ -179,6 +180,24 @@ impl Exports { self.0.docstring_range } + /// If `position` is inside a user-specified `__all__` string entry, return its range and name. + pub fn dunder_all_name_at(&self, position: TextSize) -> Option<(TextRange, Name)> { + if self.0.definitions.dunder_all.kind != DunderAllKind::Specified { + return None; + } + self.0 + .definitions + .dunder_all + .entries + .iter() + .find_map(|entry| match entry { + DunderAllEntry::Name(range, name) if range.contains_inclusive(position) => { + Some((*range, name.clone())) + } + _ => None, + }) + } + pub fn is_submodule_imported_implicitly(&self, name: &Name) -> bool { self.0 .definitions diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index 75ec0a4a4c..53f79f4dda 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -1474,6 +1474,41 @@ impl<'a> Transaction<'a> { }) } + fn find_definition_for_dunder_all_entry( + &self, + handle: &Handle, + position: TextSize, + preference: FindPreference, + ) -> Option { + let module_info = self.get_module_info(handle)?; + let exports = self.get_exports_data(handle); + let (_entry_range, name) = exports.dunder_all_name_at(position)?; + + if let Some((definition_handle, export)) = + self.resolve_named_import(handle, module_info.name(), name.clone(), preference) + { + let definition_module = self.get_module_info(&definition_handle)?; + return Some(FindDefinitionItemWithDocstring { + metadata: DefinitionMetadata::VariableOrAttribute(export.symbol_kind), + definition_range: export.location, + module: definition_module, + docstring_range: export.docstring_range, + display_name: Some(name.to_string()), + }); + } + + if module_info.path().is_init() { + let submodule = module_info.name().append(&name); + if let Some(definition) = + self.find_definition_for_imported_module(handle, submodule, preference) + { + return Some(definition); + } + } + + None + } + fn find_definition_for_keyword_argument( &self, handle: &Handle, @@ -1560,6 +1595,15 @@ impl<'a> Transaction<'a> { }; let covering_nodes = Ast::locate_node(&mod_module, position); + if covering_nodes + .iter() + .any(|node| matches!(node, AnyNodeRef::ExprStringLiteral(_))) + && let Some(definition) = + self.find_definition_for_dunder_all_entry(handle, position, preference) + { + return vec![definition]; + } + match Self::identifier_from_covering_nodes(&covering_nodes) { Some(IdentifierWithContext { identifier: id, diff --git a/pyrefly/lib/state/state.rs b/pyrefly/lib/state/state.rs index 05195d6fd4..1b1bc91777 100644 --- a/pyrefly/lib/state/state.rs +++ b/pyrefly/lib/state/state.rs @@ -1927,6 +1927,11 @@ impl<'a> Transaction<'a> { .exports(&self.lookup(module_data)) } + pub(crate) fn get_exports_data(&self, handle: &Handle) -> Exports { + let module_data = self.get_module(handle); + self.lookup_export(&module_data) + } + pub fn get_module_docstring_range(&self, handle: &Handle) -> Option { let module_data = self.get_module(handle); self.lookup_export(&module_data).docstring_range() diff --git a/pyrefly/lib/test/lsp/definition.rs b/pyrefly/lib/test/lsp/definition.rs index b499d1f378..9af268d10e 100644 --- a/pyrefly/lib/test/lsp/definition.rs +++ b/pyrefly/lib/test/lsp/definition.rs @@ -1434,6 +1434,50 @@ Definition Result: ); } +#[test] +fn dunder_all_entry_definition_test() { + let pkg = r#" +from pkg.bar import Bar + +class Baz: + pass + +__all__ = ( + "Bar", +# ^ + "Baz", +# ^ +) +"#; + let bar = r#" +class Bar: + pass +"#; + let report = + get_batched_lsp_operations_report(&[("pkg", pkg), ("pkg.bar", bar)], get_test_report); + assert_eq!( + r#" +# pkg.py +8 | "Bar", + ^ +Definition Result: +2 | class Bar: + ^^^ + +10 | "Baz", + ^ +Definition Result: +4 | class Baz: + ^^^ + + +# pkg.bar.py +"# + .trim(), + report.trim(), + ); +} + #[test] fn renamed_reexport() { let lib2 = r#"