|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import pkgutil |
| 4 | +import sys |
| 5 | +import tokenize |
| 6 | +from io import StringIO |
| 7 | +from contextlib import contextmanager |
| 8 | +from dataclasses import dataclass |
| 9 | +from itertools import chain |
| 10 | +from tokenize import TokenInfo |
| 11 | + |
| 12 | +TYPE_CHECKING = False |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from typing import Any, Iterable, Iterator, Mapping |
| 16 | + |
| 17 | + |
| 18 | +def make_default_module_completer() -> ModuleCompleter: |
| 19 | + # Inside pyrepl, __package__ is set to '_pyrepl' |
| 20 | + return ModuleCompleter(namespace={'__package__': '_pyrepl'}) |
| 21 | + |
| 22 | + |
| 23 | +class ModuleCompleter: |
| 24 | + """A completer for Python import statements. |
| 25 | +
|
| 26 | + Examples: |
| 27 | + - import <tab> |
| 28 | + - import foo<tab> |
| 29 | + - import foo.<tab> |
| 30 | + - import foo as bar, baz<tab> |
| 31 | +
|
| 32 | + - from <tab> |
| 33 | + - from foo<tab> |
| 34 | + - from foo import <tab> |
| 35 | + - from foo import bar<tab> |
| 36 | + - from foo import (bar as baz, qux<tab> |
| 37 | + """ |
| 38 | + |
| 39 | + def __init__(self, namespace: Mapping[str, Any] | None = None) -> None: |
| 40 | + self.namespace = namespace or {} |
| 41 | + self._global_cache: list[pkgutil.ModuleInfo] = [] |
| 42 | + self._curr_sys_path: list[str] = sys.path[:] |
| 43 | + |
| 44 | + def get_completions(self, line: str) -> list[str]: |
| 45 | + """Return the next possible import completions for 'line'.""" |
| 46 | + result = ImportParser(line).parse() |
| 47 | + if not result: |
| 48 | + return [] |
| 49 | + try: |
| 50 | + return self.complete(*result) |
| 51 | + except Exception: |
| 52 | + # Some unexpected error occurred, make it look like |
| 53 | + # no completions are available |
| 54 | + return [] |
| 55 | + |
| 56 | + def complete(self, from_name: str | None, name: str | None) -> list[str]: |
| 57 | + if from_name is None: |
| 58 | + # import x.y.z<tab> |
| 59 | + assert name is not None |
| 60 | + path, prefix = self.get_path_and_prefix(name) |
| 61 | + modules = self.find_modules(path, prefix) |
| 62 | + return [self.format_completion(path, module) for module in modules] |
| 63 | + |
| 64 | + if name is None: |
| 65 | + # from x.y.z<tab> |
| 66 | + path, prefix = self.get_path_and_prefix(from_name) |
| 67 | + modules = self.find_modules(path, prefix) |
| 68 | + return [self.format_completion(path, module) for module in modules] |
| 69 | + |
| 70 | + # from x.y import z<tab> |
| 71 | + return self.find_modules(from_name, name) |
| 72 | + |
| 73 | + def find_modules(self, path: str, prefix: str) -> list[str]: |
| 74 | + """Find all modules under 'path' that start with 'prefix'.""" |
| 75 | + modules = self._find_modules(path, prefix) |
| 76 | + # Filter out invalid module names |
| 77 | + # (for example those containing dashes that cannot be imported with 'import') |
| 78 | + return [mod for mod in modules if mod.isidentifier()] |
| 79 | + |
| 80 | + def _find_modules(self, path: str, prefix: str) -> list[str]: |
| 81 | + if not path: |
| 82 | + # Top-level import (e.g. `import foo<tab>`` or `from foo<tab>`)` |
| 83 | + return [name for _, name, _ in self.global_cache |
| 84 | + if name.startswith(prefix)] |
| 85 | + |
| 86 | + if path.startswith('.'): |
| 87 | + # Convert relative path to absolute path |
| 88 | + package = self.namespace.get('__package__', '') |
| 89 | + path = self.resolve_relative_name(path, package) # type: ignore[assignment] |
| 90 | + if path is None: |
| 91 | + return [] |
| 92 | + |
| 93 | + modules: Iterable[pkgutil.ModuleInfo] = self.global_cache |
| 94 | + for segment in path.split('.'): |
| 95 | + modules = [mod_info for mod_info in modules |
| 96 | + if mod_info.ispkg and mod_info.name == segment] |
| 97 | + modules = self.iter_submodules(modules) |
| 98 | + return [module.name for module in modules |
| 99 | + if module.name.startswith(prefix)] |
| 100 | + |
| 101 | + def iter_submodules(self, parent_modules: list[pkgutil.ModuleInfo]) -> Iterator[pkgutil.ModuleInfo]: |
| 102 | + """Iterate over all submodules of the given parent modules.""" |
| 103 | + specs = [info.module_finder.find_spec(info.name, None) |
| 104 | + for info in parent_modules if info.ispkg] |
| 105 | + search_locations = set(chain.from_iterable( |
| 106 | + getattr(spec, 'submodule_search_locations', []) |
| 107 | + for spec in specs if spec |
| 108 | + )) |
| 109 | + return pkgutil.iter_modules(search_locations) |
| 110 | + |
| 111 | + def get_path_and_prefix(self, dotted_name: str) -> tuple[str, str]: |
| 112 | + """ |
| 113 | + Split a dotted name into an import path and a |
| 114 | + final prefix that is to be completed. |
| 115 | +
|
| 116 | + Examples: |
| 117 | + 'foo.bar' -> 'foo', 'bar' |
| 118 | + 'foo.' -> 'foo', '' |
| 119 | + '.foo' -> '.', 'foo' |
| 120 | + """ |
| 121 | + if '.' not in dotted_name: |
| 122 | + return '', dotted_name |
| 123 | + if dotted_name.startswith('.'): |
| 124 | + stripped = dotted_name.lstrip('.') |
| 125 | + dots = '.' * (len(dotted_name) - len(stripped)) |
| 126 | + if '.' not in stripped: |
| 127 | + return dots, stripped |
| 128 | + path, prefix = stripped.rsplit('.', 1) |
| 129 | + return dots + path, prefix |
| 130 | + path, prefix = dotted_name.rsplit('.', 1) |
| 131 | + return path, prefix |
| 132 | + |
| 133 | + def format_completion(self, path: str, module: str) -> str: |
| 134 | + if path == '' or path.endswith('.'): |
| 135 | + return f'{path}{module}' |
| 136 | + return f'{path}.{module}' |
| 137 | + |
| 138 | + def resolve_relative_name(self, name: str, package: str) -> str | None: |
| 139 | + """Resolve a relative module name to an absolute name. |
| 140 | +
|
| 141 | + Example: resolve_relative_name('.foo', 'bar') -> 'bar.foo' |
| 142 | + """ |
| 143 | + # taken from importlib._bootstrap |
| 144 | + level = 0 |
| 145 | + for character in name: |
| 146 | + if character != '.': |
| 147 | + break |
| 148 | + level += 1 |
| 149 | + bits = package.rsplit('.', level - 1) |
| 150 | + if len(bits) < level: |
| 151 | + return None |
| 152 | + base = bits[0] |
| 153 | + name = name[level:] |
| 154 | + return f'{base}.{name}' if name else base |
| 155 | + |
| 156 | + @property |
| 157 | + def global_cache(self) -> list[pkgutil.ModuleInfo]: |
| 158 | + """Global module cache""" |
| 159 | + if not self._global_cache or self._curr_sys_path != sys.path: |
| 160 | + self._curr_sys_path = sys.path[:] |
| 161 | + # print('getting packages') |
| 162 | + self._global_cache = list(pkgutil.iter_modules()) |
| 163 | + return self._global_cache |
| 164 | + |
| 165 | + |
| 166 | +class ImportParser: |
| 167 | + """ |
| 168 | + Parses incomplete import statements that are |
| 169 | + suitable for autocomplete suggestions. |
| 170 | +
|
| 171 | + Examples: |
| 172 | + - import foo -> Result(from_name=None, name='foo') |
| 173 | + - import foo. -> Result(from_name=None, name='foo.') |
| 174 | + - from foo -> Result(from_name='foo', name=None) |
| 175 | + - from foo import bar -> Result(from_name='foo', name='bar') |
| 176 | + - from .foo import ( -> Result(from_name='.foo', name='') |
| 177 | +
|
| 178 | + Note that the parser works in reverse order, starting from the |
| 179 | + last token in the input string. This makes the parser more robust |
| 180 | + when parsing multiple statements. |
| 181 | + """ |
| 182 | + _ignored_tokens = { |
| 183 | + tokenize.INDENT, tokenize.DEDENT, tokenize.COMMENT, |
| 184 | + tokenize.NL, tokenize.NEWLINE, tokenize.ENDMARKER |
| 185 | + } |
| 186 | + _keywords = {'import', 'from', 'as'} |
| 187 | + |
| 188 | + def __init__(self, code: str) -> None: |
| 189 | + self.code = code |
| 190 | + tokens = [] |
| 191 | + try: |
| 192 | + for t in tokenize.generate_tokens(StringIO(code).readline): |
| 193 | + if t.type not in self._ignored_tokens: |
| 194 | + tokens.append(t) |
| 195 | + except tokenize.TokenError as e: |
| 196 | + if 'unexpected EOF' not in str(e): |
| 197 | + # unexpected EOF is fine, since we're parsing an |
| 198 | + # incomplete statement, but other errors are not |
| 199 | + # because we may not have all the tokens so it's |
| 200 | + # safer to bail out |
| 201 | + tokens = [] |
| 202 | + except SyntaxError: |
| 203 | + tokens = [] |
| 204 | + self.tokens = TokenQueue(tokens[::-1]) |
| 205 | + |
| 206 | + def parse(self) -> tuple[str | None, str | None] | None: |
| 207 | + if not (res := self._parse()): |
| 208 | + return None |
| 209 | + return res.from_name, res.name |
| 210 | + |
| 211 | + def _parse(self) -> Result | None: |
| 212 | + with self.tokens.save_state(): |
| 213 | + return self.parse_from_import() |
| 214 | + with self.tokens.save_state(): |
| 215 | + return self.parse_import() |
| 216 | + |
| 217 | + def parse_import(self) -> Result: |
| 218 | + if self.code.rstrip().endswith('import') and self.code.endswith(' '): |
| 219 | + return Result(name='') |
| 220 | + if self.tokens.peek_string(','): |
| 221 | + name = '' |
| 222 | + else: |
| 223 | + if self.code.endswith(' '): |
| 224 | + raise ParseError('parse_import') |
| 225 | + name = self.parse_dotted_name() |
| 226 | + if name.startswith('.'): |
| 227 | + raise ParseError('parse_import') |
| 228 | + while self.tokens.peek_string(','): |
| 229 | + self.tokens.pop() |
| 230 | + self.parse_dotted_as_name() |
| 231 | + if self.tokens.peek_string('import'): |
| 232 | + return Result(name=name) |
| 233 | + raise ParseError('parse_import') |
| 234 | + |
| 235 | + def parse_from_import(self) -> Result: |
| 236 | + stripped = self.code.rstrip() |
| 237 | + if stripped.endswith('import') and self.code.endswith(' '): |
| 238 | + return Result(from_name=self.parse_empty_from_import(), name='') |
| 239 | + if stripped.endswith('from') and self.code.endswith(' '): |
| 240 | + return Result(from_name='') |
| 241 | + if self.tokens.peek_string('(') or self.tokens.peek_string(','): |
| 242 | + return Result(from_name=self.parse_empty_from_import(), name='') |
| 243 | + if self.code.endswith(' '): |
| 244 | + raise ParseError('parse_from_import') |
| 245 | + name = self.parse_dotted_name() |
| 246 | + if '.' in name: |
| 247 | + self.tokens.pop_string('from') |
| 248 | + return Result(from_name=name) |
| 249 | + if self.tokens.peek_string('from'): |
| 250 | + return Result(from_name=name) |
| 251 | + from_name = self.parse_empty_from_import() |
| 252 | + return Result(from_name=from_name, name=name) |
| 253 | + |
| 254 | + def parse_empty_from_import(self) -> str: |
| 255 | + if self.tokens.peek_string(','): |
| 256 | + self.tokens.pop() |
| 257 | + self.parse_as_names() |
| 258 | + if self.tokens.peek_string('('): |
| 259 | + self.tokens.pop() |
| 260 | + self.tokens.pop_string('import') |
| 261 | + return self.parse_from() |
| 262 | + |
| 263 | + def parse_from(self) -> str: |
| 264 | + from_name = self.parse_dotted_name() |
| 265 | + self.tokens.pop_string('from') |
| 266 | + return from_name |
| 267 | + |
| 268 | + def parse_dotted_as_name(self) -> str: |
| 269 | + self.tokens.pop_name() |
| 270 | + if self.tokens.peek_string('as'): |
| 271 | + self.tokens.pop() |
| 272 | + with self.tokens.save_state(): |
| 273 | + return self.parse_dotted_name() |
| 274 | + |
| 275 | + def parse_dotted_name(self) -> str: |
| 276 | + name = [] |
| 277 | + if self.tokens.peek_string('.'): |
| 278 | + name.append('.') |
| 279 | + self.tokens.pop() |
| 280 | + if (self.tokens.peek_name() |
| 281 | + and (tok := self.tokens.peek()) |
| 282 | + and tok.string not in self._keywords): |
| 283 | + name.append(self.tokens.pop_name()) |
| 284 | + if not name: |
| 285 | + raise ParseError('parse_dotted_name') |
| 286 | + while self.tokens.peek_string('.'): |
| 287 | + name.append('.') |
| 288 | + self.tokens.pop() |
| 289 | + if (self.tokens.peek_name() |
| 290 | + and (tok := self.tokens.peek()) |
| 291 | + and tok.string not in self._keywords): |
| 292 | + name.append(self.tokens.pop_name()) |
| 293 | + else: |
| 294 | + break |
| 295 | + |
| 296 | + while self.tokens.peek_string('.'): |
| 297 | + name.append('.') |
| 298 | + self.tokens.pop() |
| 299 | + return ''.join(name[::-1]) |
| 300 | + |
| 301 | + def parse_as_names(self) -> None: |
| 302 | + self.parse_as_name() |
| 303 | + while self.tokens.peek_string(','): |
| 304 | + self.tokens.pop() |
| 305 | + self.parse_as_name() |
| 306 | + |
| 307 | + def parse_as_name(self) -> None: |
| 308 | + self.tokens.pop_name() |
| 309 | + if self.tokens.peek_string('as'): |
| 310 | + self.tokens.pop() |
| 311 | + self.tokens.pop_name() |
| 312 | + |
| 313 | + |
| 314 | +class ParseError(Exception): |
| 315 | + pass |
| 316 | + |
| 317 | + |
| 318 | +@dataclass(frozen=True) |
| 319 | +class Result: |
| 320 | + from_name: str | None = None |
| 321 | + name: str | None = None |
| 322 | + |
| 323 | + |
| 324 | +class TokenQueue: |
| 325 | + """Provides helper functions for working with a sequence of tokens.""" |
| 326 | + |
| 327 | + def __init__(self, tokens: list[TokenInfo]) -> None: |
| 328 | + self.tokens: list[TokenInfo] = tokens |
| 329 | + self.index: int = 0 |
| 330 | + self.stack: list[int] = [] |
| 331 | + |
| 332 | + @contextmanager |
| 333 | + def save_state(self) -> Any: |
| 334 | + try: |
| 335 | + self.stack.append(self.index) |
| 336 | + yield |
| 337 | + except ParseError: |
| 338 | + self.index = self.stack.pop() |
| 339 | + else: |
| 340 | + self.stack.pop() |
| 341 | + |
| 342 | + def __bool__(self) -> bool: |
| 343 | + return self.index < len(self.tokens) |
| 344 | + |
| 345 | + def peek(self) -> TokenInfo | None: |
| 346 | + if not self: |
| 347 | + return None |
| 348 | + return self.tokens[self.index] |
| 349 | + |
| 350 | + def peek_name(self) -> bool: |
| 351 | + if not (tok := self.peek()): |
| 352 | + return False |
| 353 | + return tok.type == tokenize.NAME |
| 354 | + |
| 355 | + def pop_name(self) -> str: |
| 356 | + tok = self.pop() |
| 357 | + if tok.type != tokenize.NAME: |
| 358 | + raise ParseError('pop_name') |
| 359 | + return tok.string |
| 360 | + |
| 361 | + def peek_string(self, string: str) -> bool: |
| 362 | + if not (tok := self.peek()): |
| 363 | + return False |
| 364 | + return tok.string == string |
| 365 | + |
| 366 | + def pop_string(self, string: str) -> str: |
| 367 | + tok = self.pop() |
| 368 | + if tok.string != string: |
| 369 | + raise ParseError('pop_string') |
| 370 | + return tok.string |
| 371 | + |
| 372 | + def pop(self) -> TokenInfo: |
| 373 | + if not self: |
| 374 | + raise ParseError('pop') |
| 375 | + tok = self.tokens[self.index] |
| 376 | + self.index += 1 |
| 377 | + return tok |
0 commit comments