Skip to content

Commit c3a7118

Browse files
authored
gh-69605: Add module autocomplete to PyREPL (#129329)
1 parent 22c9886 commit c3a7118

File tree

5 files changed

+588
-1
lines changed

5 files changed

+588
-1
lines changed

Lib/_pyrepl/_module_completer.py

+377
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,377 @@
1+
from __future__ import annotations
2+
3+
import pkgutil
4+
import sys
5+
import tokenize
6+
from io import StringIO
7+
from contextlib import contextmanager
8+
from dataclasses import dataclass
9+
from itertools import chain
10+
from tokenize import TokenInfo
11+
12+
TYPE_CHECKING = False
13+
14+
if TYPE_CHECKING:
15+
from typing import Any, Iterable, Iterator, Mapping
16+
17+
18+
def make_default_module_completer() -> ModuleCompleter:
19+
# Inside pyrepl, __package__ is set to '_pyrepl'
20+
return ModuleCompleter(namespace={'__package__': '_pyrepl'})
21+
22+
23+
class ModuleCompleter:
24+
"""A completer for Python import statements.
25+
26+
Examples:
27+
- import <tab>
28+
- import foo<tab>
29+
- import foo.<tab>
30+
- import foo as bar, baz<tab>
31+
32+
- from <tab>
33+
- from foo<tab>
34+
- from foo import <tab>
35+
- from foo import bar<tab>
36+
- from foo import (bar as baz, qux<tab>
37+
"""
38+
39+
def __init__(self, namespace: Mapping[str, Any] | None = None) -> None:
40+
self.namespace = namespace or {}
41+
self._global_cache: list[pkgutil.ModuleInfo] = []
42+
self._curr_sys_path: list[str] = sys.path[:]
43+
44+
def get_completions(self, line: str) -> list[str]:
45+
"""Return the next possible import completions for 'line'."""
46+
result = ImportParser(line).parse()
47+
if not result:
48+
return []
49+
try:
50+
return self.complete(*result)
51+
except Exception:
52+
# Some unexpected error occurred, make it look like
53+
# no completions are available
54+
return []
55+
56+
def complete(self, from_name: str | None, name: str | None) -> list[str]:
57+
if from_name is None:
58+
# import x.y.z<tab>
59+
assert name is not None
60+
path, prefix = self.get_path_and_prefix(name)
61+
modules = self.find_modules(path, prefix)
62+
return [self.format_completion(path, module) for module in modules]
63+
64+
if name is None:
65+
# from x.y.z<tab>
66+
path, prefix = self.get_path_and_prefix(from_name)
67+
modules = self.find_modules(path, prefix)
68+
return [self.format_completion(path, module) for module in modules]
69+
70+
# from x.y import z<tab>
71+
return self.find_modules(from_name, name)
72+
73+
def find_modules(self, path: str, prefix: str) -> list[str]:
74+
"""Find all modules under 'path' that start with 'prefix'."""
75+
modules = self._find_modules(path, prefix)
76+
# Filter out invalid module names
77+
# (for example those containing dashes that cannot be imported with 'import')
78+
return [mod for mod in modules if mod.isidentifier()]
79+
80+
def _find_modules(self, path: str, prefix: str) -> list[str]:
81+
if not path:
82+
# Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)`
83+
return [name for _, name, _ in self.global_cache
84+
if name.startswith(prefix)]
85+
86+
if path.startswith('.'):
87+
# Convert relative path to absolute path
88+
package = self.namespace.get('__package__', '')
89+
path = self.resolve_relative_name(path, package) # type: ignore[assignment]
90+
if path is None:
91+
return []
92+
93+
modules: Iterable[pkgutil.ModuleInfo] = self.global_cache
94+
for segment in path.split('.'):
95+
modules = [mod_info for mod_info in modules
96+
if mod_info.ispkg and mod_info.name == segment]
97+
modules = self.iter_submodules(modules)
98+
return [module.name for module in modules
99+
if module.name.startswith(prefix)]
100+
101+
def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]:
102+
"""Iterate over all submodules of the given parent modules."""
103+
specs = [info.module_finder.find_spec(info.name, None)
104+
for info in parent_modules if info.ispkg]
105+
search_locations = set(chain.from_iterable(
106+
getattr(spec, 'submodule_search_locations', [])
107+
for spec in specs if spec
108+
))
109+
return pkgutil.iter_modules(search_locations)
110+
111+
def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]:
112+
"""
113+
Split a dotted name into an import path and a
114+
final prefix that is to be completed.
115+
116+
Examples:
117+
'foo.bar' -> 'foo', 'bar'
118+
'foo.' -> 'foo', ''
119+
'.foo' -> '.', 'foo'
120+
"""
121+
if '.' not in dotted_name:
122+
return '', dotted_name
123+
if dotted_name.startswith('.'):
124+
stripped = dotted_name.lstrip('.')
125+
dots = '.' * (len(dotted_name) - len(stripped))
126+
if '.' not in stripped:
127+
return dots, stripped
128+
path, prefix = stripped.rsplit('.', 1)
129+
return dots + path, prefix
130+
path, prefix = dotted_name.rsplit('.', 1)
131+
return path, prefix
132+
133+
def format_completion(self, path: str, module: str) -> str:
134+
if path == '' or path.endswith('.'):
135+
return f'{path}{module}'
136+
return f'{path}.{module}'
137+
138+
def resolve_relative_name(self, name: str, package: str) -> str | None:
139+
"""Resolve a relative module name to an absolute name.
140+
141+
Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo'
142+
"""
143+
# taken from importlib._bootstrap
144+
level = 0
145+
for character in name:
146+
if character != '.':
147+
break
148+
level += 1
149+
bits = package.rsplit('.', level - 1)
150+
if len(bits) < level:
151+
return None
152+
base = bits[0]
153+
name = name[level:]
154+
return f'{base}.{name}' if name else base
155+
156+
@property
157+
def global_cache(self) -> list[pkgutil.ModuleInfo]:
158+
"""Global module cache"""
159+
if not self._global_cache or self._curr_sys_path != sys.path:
160+
self._curr_sys_path = sys.path[:]
161+
# print('getting packages')
162+
self._global_cache = list(pkgutil.iter_modules())
163+
return self._global_cache
164+
165+
166+
class ImportParser:
167+
"""
168+
Parses incomplete import statements that are
169+
suitable for autocomplete suggestions.
170+
171+
Examples:
172+
- import foo -> Result(from_name=None, name='foo')
173+
- import foo. -> Result(from_name=None, name='foo.')
174+
- from foo -> Result(from_name='foo', name=None)
175+
- from foo import bar -> Result(from_name='foo', name='bar')
176+
- from .foo import ( -> Result(from_name='.foo', name='')
177+
178+
Note that the parser works in reverse order, starting from the
179+
last token in the input string. This makes the parser more robust
180+
when parsing multiple statements.
181+
"""
182+
_ignored_tokens = {
183+
tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT,
184+
tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER
185+
}
186+
_keywords = {'import', 'from', 'as'}
187+
188+
def __init__(self, code: str) -> None:
189+
self.code = code
190+
tokens = []
191+
try:
192+
for t in tokenize.generate_tokens(StringIO(code).readline):
193+
if t.type not in self._ignored_tokens:
194+
tokens.append(t)
195+
except tokenize.TokenError as e:
196+
if 'unexpected EOF' not in str(e):
197+
# unexpected EOF is fine, since we're parsing an
198+
# incomplete statement, but other errors are not
199+
# because we may not have all the tokens so it's
200+
# safer to bail out
201+
tokens = []
202+
except SyntaxError:
203+
tokens = []
204+
self.tokens = TokenQueue(tokens[::-1])
205+
206+
def parse(self) -> tuple[str | None, str | None] | None:
207+
if not (res := self._parse()):
208+
return None
209+
return res.from_name, res.name
210+
211+
def _parse(self) -> Result | None:
212+
with self.tokens.save_state():
213+
return self.parse_from_import()
214+
with self.tokens.save_state():
215+
return self.parse_import()
216+
217+
def parse_import(self) -> Result:
218+
if self.code.rstrip().endswith('import') and self.code.endswith(' '):
219+
return Result(name='')
220+
if self.tokens.peek_string(','):
221+
name = ''
222+
else:
223+
if self.code.endswith(' '):
224+
raise ParseError('parse_import')
225+
name = self.parse_dotted_name()
226+
if name.startswith('.'):
227+
raise ParseError('parse_import')
228+
while self.tokens.peek_string(','):
229+
self.tokens.pop()
230+
self.parse_dotted_as_name()
231+
if self.tokens.peek_string('import'):
232+
return Result(name=name)
233+
raise ParseError('parse_import')
234+
235+
def parse_from_import(self) -> Result:
236+
stripped = self.code.rstrip()
237+
if stripped.endswith('import') and self.code.endswith(' '):
238+
return Result(from_name=self.parse_empty_from_import(), name='')
239+
if stripped.endswith('from') and self.code.endswith(' '):
240+
return Result(from_name='')
241+
if self.tokens.peek_string('(') or self.tokens.peek_string(','):
242+
return Result(from_name=self.parse_empty_from_import(), name='')
243+
if self.code.endswith(' '):
244+
raise ParseError('parse_from_import')
245+
name = self.parse_dotted_name()
246+
if '.' in name:
247+
self.tokens.pop_string('from')
248+
return Result(from_name=name)
249+
if self.tokens.peek_string('from'):
250+
return Result(from_name=name)
251+
from_name = self.parse_empty_from_import()
252+
return Result(from_name=from_name, name=name)
253+
254+
def parse_empty_from_import(self) -> str:
255+
if self.tokens.peek_string(','):
256+
self.tokens.pop()
257+
self.parse_as_names()
258+
if self.tokens.peek_string('('):
259+
self.tokens.pop()
260+
self.tokens.pop_string('import')
261+
return self.parse_from()
262+
263+
def parse_from(self) -> str:
264+
from_name = self.parse_dotted_name()
265+
self.tokens.pop_string('from')
266+
return from_name
267+
268+
def parse_dotted_as_name(self) -> str:
269+
self.tokens.pop_name()
270+
if self.tokens.peek_string('as'):
271+
self.tokens.pop()
272+
with self.tokens.save_state():
273+
return self.parse_dotted_name()
274+
275+
def parse_dotted_name(self) -> str:
276+
name = []
277+
if self.tokens.peek_string('.'):
278+
name.append('.')
279+
self.tokens.pop()
280+
if (self.tokens.peek_name()
281+
and (tok := self.tokens.peek())
282+
and tok.string not in self._keywords):
283+
name.append(self.tokens.pop_name())
284+
if not name:
285+
raise ParseError('parse_dotted_name')
286+
while self.tokens.peek_string('.'):
287+
name.append('.')
288+
self.tokens.pop()
289+
if (self.tokens.peek_name()
290+
and (tok := self.tokens.peek())
291+
and tok.string not in self._keywords):
292+
name.append(self.tokens.pop_name())
293+
else:
294+
break
295+
296+
while self.tokens.peek_string('.'):
297+
name.append('.')
298+
self.tokens.pop()
299+
return ''.join(name[::-1])
300+
301+
def parse_as_names(self) -> None:
302+
self.parse_as_name()
303+
while self.tokens.peek_string(','):
304+
self.tokens.pop()
305+
self.parse_as_name()
306+
307+
def parse_as_name(self) -> None:
308+
self.tokens.pop_name()
309+
if self.tokens.peek_string('as'):
310+
self.tokens.pop()
311+
self.tokens.pop_name()
312+
313+
314+
class ParseError(Exception):
315+
pass
316+
317+
318+
@dataclass(frozen=True)
319+
class Result:
320+
from_name: str | None = None
321+
name: str | None = None
322+
323+
324+
class TokenQueue:
325+
"""Provides helper functions for working with a sequence of tokens."""
326+
327+
def __init__(self, tokens: list[TokenInfo]) -> None:
328+
self.tokens: list[TokenInfo] = tokens
329+
self.index: int = 0
330+
self.stack: list[int] = []
331+
332+
@contextmanager
333+
def save_state(self) -> Any:
334+
try:
335+
self.stack.append(self.index)
336+
yield
337+
except ParseError:
338+
self.index = self.stack.pop()
339+
else:
340+
self.stack.pop()
341+
342+
def __bool__(self) -> bool:
343+
return self.index < len(self.tokens)
344+
345+
def peek(self) -> TokenInfo | None:
346+
if not self:
347+
return None
348+
return self.tokens[self.index]
349+
350+
def peek_name(self) -> bool:
351+
if not (tok := self.peek()):
352+
return False
353+
return tok.type == tokenize.NAME
354+
355+
def pop_name(self) -> str:
356+
tok = self.pop()
357+
if tok.type != tokenize.NAME:
358+
raise ParseError('pop_name')
359+
return tok.string
360+
361+
def peek_string(self, string: str) -> bool:
362+
if not (tok := self.peek()):
363+
return False
364+
return tok.string == string
365+
366+
def pop_string(self, string: str) -> str:
367+
tok = self.pop()
368+
if tok.string != string:
369+
raise ParseError('pop_string')
370+
return tok.string
371+
372+
def pop(self) -> TokenInfo:
373+
if not self:
374+
raise ParseError('pop')
375+
tok = self.tokens[self.index]
376+
self.index += 1
377+
return tok

Lib/_pyrepl/completing_reader.py

+4
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,7 @@ def get_stem(self) -> str:
293293

294294
def get_completions(self, stem: str) -> list[str]:
295295
return []
296+
297+
def get_line(self) -> str:
298+
"""Return the current line until the cursor position."""
299+
return ''.join(self.buffer[:self.pos])

0 commit comments

Comments
 (0)