Skip to content

Attempt to statically collect types in typeshed (ahead of time) #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 137 additions & 41 deletions src/docstub/_analysis.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Collect type information."""

import builtins
import importlib
import json
import logging
Expand Down Expand Up @@ -226,27 +225,6 @@ def _is_type(value):
return is_type


def _builtin_types():
"""Return known imports for all builtins (in the current runtime).

Returns
-------
known_imports : dict[str, KnownImport]
"""
known_builtins = set(dir(builtins))

known_imports = {}
for name in known_builtins:
if name.startswith("_"):
continue
value = getattr(builtins, name)
if not _is_type(value):
continue
known_imports[name] = KnownImport(builtin_name=name)

return known_imports


def _runtime_types_in_module(module_name):
module = importlib.import_module(module_name)
types = {}
Expand Down Expand Up @@ -277,18 +255,20 @@ def common_known_types():
Examples
--------
>>> types = common_known_types()
>>> types["str"]
<KnownImport str (builtin)>
>>> types["Iterable"]
<KnownImport 'from collections.abc import Iterable'>
>>> types["builtins.str"]
<KnownImport 'from builtins import str'>
>>> types["typing.Iterable"]
<KnownImport 'from typing import Iterable'>
>>> types["collections.abc.Iterable"]
<KnownImport 'from collections.abc import Iterable'>
"""
known_imports = _builtin_types()
known_imports |= _runtime_types_in_module("typing")
# Overrides containers from typing
known_imports |= _runtime_types_in_module("collections.abc")
return known_imports
from ._stdlib_types import stdlib_types

types = {
f"{module}.{type_name}": KnownImport(import_path=module, import_name=type_name)
for module, type_name in stdlib_types
}
return types


class TypeCollector(cst.CSTVisitor):
Expand Down Expand Up @@ -334,7 +314,7 @@ def collect(cls, file):

Returns
-------
collected : dict[str, KnownImport]
collected_types : dict[str, KnownImport]
"""
file = Path(file)
with file.open("r") as fo:
Expand All @@ -343,7 +323,7 @@ def collect(cls, file):
tree = cst.parse_module(source)
collector = cls(module_name=module_name_from_path(file))
tree.visit(collector)
return collector.known_imports
return collector.collected_types

def __init__(self, *, module_name):
"""Initialize type collector.
Expand All @@ -354,7 +334,7 @@ def __init__(self, *, module_name):
"""
self.module_name = module_name
self._stack = []
self.known_imports = {}
self.collected_types = {}

def visit_ClassDef(self, node: cst.ClassDef) -> bool:
self._stack.append(node.name.value)
Expand Down Expand Up @@ -396,9 +376,104 @@ def _collect_type_annotation(self, stack):
stack : Iterable[str]
A list of names that form the path to the collected type.
"""
qualname = ".".join([self.module_name, *stack])
known_import = KnownImport(import_path=self.module_name, import_name=stack[0])
self.known_imports[qualname] = known_import

qualname = f"{self.module_name}.{'.'.join(stack)}"
scoped_name = f"{self.module_name}:{'.'.join(stack)}"
self.collected_types[qualname] = known_import
self.collected_types[scoped_name] = known_import


class StubTypeCollector(TypeCollector):

def __init__(self, *, module_name):
"""Initialize type collector.

Parameters
----------
module_name : str
"""
super().__init__(module_name=module_name)
self.collected_types = set()
self.dunder_all = set()

@classmethod
def collect(cls, file):
"""Collect importable type annotations in given file.

Parameters
----------
file : Path

Returns
-------
collected_types : dict[str, KnownImport]
"""
file = Path(file)
with file.open("r") as fo:
source = fo.read()

tree = cst.parse_module(source)
collector = cls(module_name=module_name_from_path(file))
tree.visit(collector)
return collector.collected_types, collector.dunder_all

def visit_ImportFrom(self, node):
# https://typing.python.org/en/latest/spec/distributing.html#import-conventions

if cstm.matches(node, cstm.ImportFrom(names=cstm.ImportStar())):
module_names = cstm.findall(node.module, cstm.Name())
module = "_".join(name.value for name in module_names)
stack = [*self._stack, f"<Reference: {module}.*>"]
self._collect_type_annotation(stack)

names = cstm.findall(node, cstm.AsName())
for name in names:
if cstm.matches(name, cstm.AsName(name=cstm.Name())):
value = name.name.value
assert value
if value == "__all__":
continue

stack = [*self._stack, value]
self._collect_type_annotation(stack)

def visit_AugAssign(self, node):
is_add_assign_to_dunder_all = cstm.matches(
node,
cstm.AugAssign(
target=cstm.Name(value="__all__"), operator=cstm.AddAssign()
),
)
is_assign_list = cstm.matches(node.value, cstm.List())
if is_add_assign_to_dunder_all and is_assign_list:
strings = cstm.findall(node.value, cstm.SimpleString())
for string in strings:
self._collect_dunder_all(string.value)

def visit_Assign(self, node):
is_assign_to_dunder_all = cstm.matches(
node,
cstm.Assign(targets=[cstm.AssignTarget(target=cstm.Name(value="__all__"))]),
)
is_assign_list = cstm.matches(node.value, cstm.List())
if is_assign_to_dunder_all and is_assign_list:
strings = cstm.findall(node.value, cstm.SimpleString())
for string in strings:
self._collect_dunder_all(string.value)

def _collect_type_annotation(self, stack):
"""Collect an importable type annotation.

Parameters
----------
stack : Iterable[str]
A list of names that form the path to the collected type.
"""
self.collected_types.add((self.module_name, ".".join(stack)))

def _collect_dunder_all(self, value):
self.dunder_all.add((self.module_name, value.strip("'\"")))


class TypeMatcher:
Expand Down Expand Up @@ -427,6 +502,7 @@ def __init__(
types=None,
type_prefixes=None,
type_nicknames=None,
implicit_modules=("collections.abc", "typing", "_typeshed"),
):
"""
Parameters
Expand All @@ -438,6 +514,7 @@ def __init__(
self.types = types or common_known_types()
self.type_prefixes = type_prefixes or {}
self.type_nicknames = type_nicknames or {}
self.implicit_modules = implicit_modules
self.successful_queries = 0
self.unknown_qualnames = []

Expand Down Expand Up @@ -492,20 +569,39 @@ def match(self, search_name):
# Replace alias
search_name = self.type_nicknames.get(search_name, search_name)

if type_origin is None and self.current_module:
# Try scope of current module
module_name = module_name_from_path(self.current_module)
try_qualname = f"{module_name}.{search_name}"
if type_origin is None:
# Try builtin
try_qualname = f"builtins.{search_name}"
type_origin = self.types.get(try_qualname)
if type_origin:
type_name = search_name

if type_origin is None and search_name in self.types:
# Direct match
type_name = search_name
type_origin = self.types[search_name]

if type_origin is None and self.current_module:
# Try scope of current module
for sep in [".", ":"]:
try_qualname = f"{self.current_module}{sep}{search_name}"
type_origin = self.types.get(try_qualname)
if type_origin:
type_name = search_name
break

if type_origin is None and self.implicit_modules:
# Try implicit modules
for module in self.implicit_modules:
try_qualname = f"{module}.{search_name}"
type_origin = self.types.get(try_qualname)
if type_origin:
type_name = search_name
break

if type_origin is None:
# Try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
# Try matching with module prefix,
# try a subset of the qualname (first 'a.b.c', then 'a.b' and 'a')
for partial_qualname in reversed(accumulate_qualname(search_name)):
type_origin = self.type_prefixes.get(partial_qualname)
if type_origin:
Expand Down
4 changes: 2 additions & 2 deletions src/docstub/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ def _collect_types(root_path):
-------
types : dict[str, ~.KnownImport]
"""
types = common_known_types()

types = {}
collect_cached_types = FileCache(
func=TypeCollector.collect,
serializer=TypeCollector.ImportSerializer(),
cache_dir=Path.cwd() / ".docstub_cache",
name=f"{__version__}/collected_types",
)

if root_path.is_dir():
for source_path in walk_python_package(root_path):
logger.info("collecting types in %s", source_path)
Expand Down
Loading
Loading