Skip to content

Commit 532c399

Browse files
authored
make explicit the symbols that imports from the wrapper module into the friendly module (#469)
* add `Library` to generated modules' `__all__`. because that symbol is public but not included. * add `typelib_path` to generated modules' `__all__`. because that symbol is public but not included. * make `ModuleGenerator` class that encapsulates `CodeGenerator` instance. * rename to `generate_wrapper_code` from `generate_code` * add `generate_friendly_code` * add type annotations to `generate_wrapper_code` * add docstring * add `get_symbols` methods to `DeclaredNamespaces` and `ImportedNamespaces` * update imporing symbols * add type annotation to return value for `__init__` * change to using f-string in `generate_friendly_code`
1 parent e1ee6f0 commit 532c399

File tree

2 files changed

+138
-57
lines changed

2 files changed

+138
-57
lines changed

comtypes/client/_generate.py

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType:
121121
pathname = None
122122
tlib = _load_tlib(tlib)
123123
logger.debug("GetModule(%s)", tlib.GetLibAttr())
124-
# create and import the real typelib wrapper module
125-
mod = _create_wrapper_module(tlib, pathname)
126-
# try to get the friendly-name, if not, returns the real typelib wrapper module
127-
modulename = codegenerator.name_friendly_module(tlib)
128-
if modulename is None:
129-
return mod
130-
# create and import the friendly-named module
131-
return _create_friendly_module(tlib, modulename)
124+
return ModuleGenerator().generate(tlib, pathname)
132125

133126

134127
def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
@@ -193,52 +186,64 @@ def _create_module_in_memory(modulename: str, code: str) -> types.ModuleType:
193186
return mod
194187

195188

196-
def _create_friendly_module(
197-
tlib: typeinfo.ITypeLib, modulename: str
198-
) -> types.ModuleType:
199-
"""helper which creates and imports the friendly-named module."""
200-
try:
201-
mod = _my_import(modulename)
202-
except Exception as details:
203-
logger.info("Could not import %s: %s", modulename, details)
204-
else:
205-
return mod
206-
# the module is always regenerated if the import fails
207-
logger.info("# Generating %s", modulename)
208-
# determine the Python module name
209-
modname = codegenerator.name_wrapper_module(tlib).split(".")[-1]
210-
code = "from comtypes.gen import %s\n" % modname
211-
code += "globals().update(%s.__dict__)\n" % modname
212-
code += "__name__ = '%s'" % modulename
213-
if comtypes.client.gen_dir is None:
214-
return _create_module_in_memory(modulename, code)
215-
return _create_module_in_file(modulename, code)
216-
217-
218-
def _create_wrapper_module(
219-
tlib: typeinfo.ITypeLib, pathname: Optional[str]
220-
) -> types.ModuleType:
221-
"""helper which creates and imports the real typelib wrapper module."""
222-
modulename = codegenerator.name_wrapper_module(tlib)
223-
if modulename in sys.modules:
224-
return sys.modules[modulename]
225-
try:
226-
return _my_import(modulename)
227-
except Exception as details:
228-
logger.info("Could not import %s: %s", modulename, details)
229-
# generate the module since it doesn't exist or is out of date
230-
logger.info("# Generating %s", modulename)
231-
p = tlbparser.TypeLibParser(tlib)
232-
if pathname is None:
233-
pathname = tlbparser.get_tlib_filename(tlib)
234-
items = list(p.parse().values())
235-
codegen = codegenerator.CodeGenerator(_get_known_symbols())
236-
code = codegen.generate_code(items, filename=pathname)
237-
for ext_tlib in codegen.externals: # generates dependency COM-lib modules
238-
GetModule(ext_tlib)
239-
if comtypes.client.gen_dir is None:
240-
return _create_module_in_memory(modulename, code)
241-
return _create_module_in_file(modulename, code)
189+
class ModuleGenerator(object):
190+
def __init__(self) -> None:
191+
self.codegen = codegenerator.CodeGenerator(_get_known_symbols())
192+
193+
def generate(
194+
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
195+
) -> types.ModuleType:
196+
# create and import the real typelib wrapper module
197+
mod = self._create_wrapper_module(tlib, pathname)
198+
# try to get the friendly-name, if not, returns the real typelib wrapper module
199+
modulename = codegenerator.name_friendly_module(tlib)
200+
if modulename is None:
201+
return mod
202+
# create and import the friendly-named module
203+
return self._create_friendly_module(tlib, modulename)
204+
205+
def _create_friendly_module(
206+
self, tlib: typeinfo.ITypeLib, modulename: str
207+
) -> types.ModuleType:
208+
"""helper which creates and imports the friendly-named module."""
209+
try:
210+
mod = _my_import(modulename)
211+
except Exception as details:
212+
logger.info("Could not import %s: %s", modulename, details)
213+
else:
214+
return mod
215+
# the module is always regenerated if the import fails
216+
logger.info("# Generating %s", modulename)
217+
# determine the Python module name
218+
modname = codegenerator.name_wrapper_module(tlib)
219+
code = self.codegen.generate_friendly_code(modname)
220+
if comtypes.client.gen_dir is None:
221+
return _create_module_in_memory(modulename, code)
222+
return _create_module_in_file(modulename, code)
223+
224+
def _create_wrapper_module(
225+
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
226+
) -> types.ModuleType:
227+
"""helper which creates and imports the real typelib wrapper module."""
228+
modulename = codegenerator.name_wrapper_module(tlib)
229+
if modulename in sys.modules:
230+
return sys.modules[modulename]
231+
try:
232+
return _my_import(modulename)
233+
except Exception as details:
234+
logger.info("Could not import %s: %s", modulename, details)
235+
# generate the module since it doesn't exist or is out of date
236+
logger.info("# Generating %s", modulename)
237+
p = tlbparser.TypeLibParser(tlib)
238+
if pathname is None:
239+
pathname = tlbparser.get_tlib_filename(tlib)
240+
items = list(p.parse().values())
241+
code = self.codegen.generate_wrapper_code(items, filename=pathname)
242+
for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules
243+
GetModule(ext_tlib)
244+
if comtypes.client.gen_dir is None:
245+
return _create_module_in_memory(modulename, code)
246+
return _create_module_in_file(modulename, code)
242247

243248

244249
def _get_known_symbols() -> Dict[str, str]:

comtypes/tools/codegenerator.py

Lines changed: 79 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,17 @@
77
import os
88
import sys
99
import textwrap
10-
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union as _UnionT
10+
from typing import (
11+
Any,
12+
Dict,
13+
Iterator,
14+
List,
15+
Optional,
16+
Sequence,
17+
Set,
18+
Tuple,
19+
Union as _UnionT,
20+
)
1121
import io
1222

1323
import comtypes
@@ -495,9 +505,20 @@ def _generate_typelib_path(self, filename):
495505
os.path.abspath(os.path.join(comtypes.gen.__path__[0], path))
496506
)
497507
assert os.path.isfile(p)
508+
self.names.add("typelib_path")
509+
510+
def generate_wrapper_code(
511+
self, tdescs: Sequence[Any], filename: Optional[str]
512+
) -> str:
513+
"""Returns the code for the COM type library wrapper module.
498514
499-
def generate_code(self, items, filename):
515+
The returned `Python` code string is containing definitions of interfaces,
516+
coclasses, constants, and structures.
500517
518+
The module will have long name that is derived from the type library guid, lcid
519+
and version numbers.
520+
Such as `comtypes.gen._xxxxxxxx_xxxx_xxxx_xxxx_xxxxxxxxxxxx_l_M_m`.
521+
"""
501522
tlib_mtime = None
502523

503524
if filename is not None:
@@ -520,7 +541,7 @@ def generate_code(self, items, filename):
520541
self.declarations.add("_lcid", "0", "change this if required")
521542
self._generate_typelib_path(filename)
522543

523-
items = set(items)
544+
items = set(tdescs)
524545
loops = 0
525546
while items:
526547
loops += 1
@@ -557,6 +578,39 @@ def generate_code(self, items, filename):
557578
print("_check_version(%r, %f)" % (version, tlib_mtime), file=output)
558579
return output.getvalue()
559580

581+
def generate_friendly_code(self, modname: str) -> str:
582+
"""Returns the code for the COM type library friendly module.
583+
584+
The returned `Python` code string is containing `from {modname} import
585+
DefinedInWrapper, ...` and `__all__ = ['DefinedInWrapper', ...]`
586+
The `modname` is the wrapper module name like `comtypes.gen._xxxx..._x_x_x`.
587+
588+
The module will have shorter name that is derived from the type library name.
589+
Such as "comtypes.gen.stdole" and "comtypes.gen.Excel".
590+
"""
591+
output = io.StringIO()
592+
txtwrapper = textwrap.TextWrapper(
593+
subsequent_indent=" ", initial_indent=" ", break_long_words=False
594+
)
595+
importing_symbols = set(self.names)
596+
importing_symbols.update(self.imports.get_symbols())
597+
importing_symbols.update(self.declarations.get_symbols())
598+
joined_names = ", ".join(str(n) for n in importing_symbols)
599+
symbols = f"from {modname} import {joined_names}"
600+
if len(symbols) > 80:
601+
wrapped_names = "\n".join(txtwrapper.wrap(joined_names))
602+
symbols = f"from {modname} import (\n{wrapped_names}\n)"
603+
print(symbols, file=output)
604+
print(file=output)
605+
print(file=output)
606+
quoted_names = ", ".join(repr(str(n)) for n in self.names)
607+
dunder_all = f"__all__ = [{quoted_names}]"
608+
if len(dunder_all) > 80:
609+
wrapped_quoted_names = "\n".join(txtwrapper.wrap(quoted_names))
610+
dunder_all = f"__all__ = [\n{wrapped_quoted_names}\n]"
611+
print(dunder_all, file=output)
612+
return output.getvalue()
613+
560614
def need_VARIANT_imports(self, value):
561615
text = repr(value)
562616
if "Decimal(" in text:
@@ -876,6 +930,7 @@ def TypeLib(self, lib: typedesc.TypeLib) -> None:
876930
)
877931
print(file=self.stream)
878932
print(file=self.stream)
933+
self.names.add("Library")
879934

880935
def External(self, ext: typedesc.External) -> None:
881936
modname = name_wrapper_module(ext.tlib)
@@ -1329,6 +1384,10 @@ def add(self, name1, name2=None, symbols=None):
13291384
IUnknown
13301385
)
13311386
import ctypes.wintypes
1387+
>>> assert imports.get_symbols() == {
1388+
... 'Decimal', 'GUID', 'COMMETHOD', 'DISPMETHOD', 'IUnknown',
1389+
... 'dispid', 'CoClass', 'BSTR', 'DISPPROPERTY'
1390+
... }
13321391
>>> print(imports.getvalue(for_stub=True))
13331392
from ctypes import *
13341393
import datetime
@@ -1381,6 +1440,14 @@ def __contains__(self, item):
13811440
return self.data[import_] == from_
13821441
return False
13831442

1443+
def get_symbols(self) -> Set[str]:
1444+
names = set()
1445+
for key, val in self.data.items():
1446+
if val is None or key == "*":
1447+
continue
1448+
names.add(key)
1449+
return names
1450+
13841451
def _make_line(self, from_, imports, for_stub):
13851452
if for_stub:
13861453
import_ = ", ".join("%s as %s" % (n, n) for n in imports)
@@ -1432,9 +1499,18 @@ def add(self, alias, definition, comment=None):
14321499
>>> print(declarations.getvalue())
14331500
STRING = c_char_p
14341501
_lcid = 0 # change this if required
1502+
>>> assert declarations.get_symbols() == {
1503+
... 'STRING', '_lcid'
1504+
... }
14351505
"""
14361506
self.data[(alias, definition)] = comment
14371507

1508+
def get_symbols(self) -> Set[str]:
1509+
names = set()
1510+
for alias, _ in self.data.keys():
1511+
names.add(alias)
1512+
return names
1513+
14381514
def getvalue(self):
14391515
lines = []
14401516
for (alias, definition), comment in self.data.items():

0 commit comments

Comments
 (0)