Skip to content

Commit c7f1e3e

Browse files
authored
gh-127183: Add _ctypes.CopyComPointer tests (GH-127184)
* Make `create_shelllink_persist` top level function. * Add `CopyComPointerTests`. * Add more tests. * Update tests. * Add assertions for `Release`'s return value.
1 parent f4f075b commit c7f1e3e

File tree

1 file changed

+115
-17
lines changed

1 file changed

+115
-17
lines changed

Lib/test/test_ctypes/test_win32_com_foreign_func.py

+115-17
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
raise unittest.SkipTest("Windows-specific test")
1010

1111

12-
from _ctypes import COMError
12+
from _ctypes import COMError, CopyComPointer
1313
from ctypes import HRESULT
1414

1515

@@ -78,6 +78,19 @@ def is_equal_guid(guid1, guid2):
7878
)
7979

8080

81+
def create_shelllink_persist(typ):
82+
ppst = typ()
83+
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
84+
ole32.CoCreateInstance(
85+
byref(CLSID_ShellLink),
86+
None,
87+
CLSCTX_SERVER,
88+
byref(IID_IPersist),
89+
byref(ppst),
90+
)
91+
return ppst
92+
93+
8194
class ForeignFunctionsThatWillCallComMethodsTests(unittest.TestCase):
8295
def setUp(self):
8396
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
@@ -88,19 +101,6 @@ def tearDown(self):
88101
ole32.CoUninitialize()
89102
gc.collect()
90103

91-
@staticmethod
92-
def create_shelllink_persist(typ):
93-
ppst = typ()
94-
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
95-
ole32.CoCreateInstance(
96-
byref(CLSID_ShellLink),
97-
None,
98-
CLSCTX_SERVER,
99-
byref(IID_IPersist),
100-
byref(ppst),
101-
)
102-
return ppst
103-
104104
def test_without_paramflags_and_iid(self):
105105
class IUnknown(c_void_p):
106106
QueryInterface = proto_query_interface()
@@ -110,7 +110,7 @@ class IUnknown(c_void_p):
110110
class IPersist(IUnknown):
111111
GetClassID = proto_get_class_id()
112112

113-
ppst = self.create_shelllink_persist(IPersist)
113+
ppst = create_shelllink_persist(IPersist)
114114

115115
clsid = GUID()
116116
hr_getclsid = ppst.GetClassID(byref(clsid))
@@ -142,7 +142,7 @@ class IUnknown(c_void_p):
142142
class IPersist(IUnknown):
143143
GetClassID = proto_get_class_id(((OUT, "pClassID"),))
144144

145-
ppst = self.create_shelllink_persist(IPersist)
145+
ppst = create_shelllink_persist(IPersist)
146146

147147
clsid = ppst.GetClassID()
148148
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
@@ -167,7 +167,7 @@ class IUnknown(c_void_p):
167167
class IPersist(IUnknown):
168168
GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
169169

170-
ppst = self.create_shelllink_persist(IPersist)
170+
ppst = create_shelllink_persist(IPersist)
171171

172172
clsid = ppst.GetClassID()
173173
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
@@ -184,5 +184,103 @@ class IPersist(IUnknown):
184184
self.assertEqual(0, ppst.Release())
185185

186186

187+
class CopyComPointerTests(unittest.TestCase):
188+
def setUp(self):
189+
ole32.CoInitializeEx(None, COINIT_APARTMENTTHREADED)
190+
191+
class IUnknown(c_void_p):
192+
QueryInterface = proto_query_interface(None, IID_IUnknown)
193+
AddRef = proto_add_ref()
194+
Release = proto_release()
195+
196+
class IPersist(IUnknown):
197+
GetClassID = proto_get_class_id(((OUT, "pClassID"),), IID_IPersist)
198+
199+
self.IUnknown = IUnknown
200+
self.IPersist = IPersist
201+
202+
def tearDown(self):
203+
ole32.CoUninitialize()
204+
gc.collect()
205+
206+
def test_both_are_null(self):
207+
src = self.IPersist()
208+
dst = self.IPersist()
209+
210+
hr = CopyComPointer(src, byref(dst))
211+
212+
self.assertEqual(S_OK, hr)
213+
214+
self.assertIsNone(src.value)
215+
self.assertIsNone(dst.value)
216+
217+
def test_src_is_nonnull_and_dest_is_null(self):
218+
# The reference count of the COM pointer created by `CoCreateInstance`
219+
# is initially 1.
220+
src = create_shelllink_persist(self.IPersist)
221+
dst = self.IPersist()
222+
223+
# `CopyComPointer` calls `AddRef` explicitly in the C implementation.
224+
# The refcount of `src` is incremented from 1 to 2 here.
225+
hr = CopyComPointer(src, byref(dst))
226+
227+
self.assertEqual(S_OK, hr)
228+
self.assertEqual(src.value, dst.value)
229+
230+
# This indicates that the refcount was 2 before the `Release` call.
231+
self.assertEqual(1, src.Release())
232+
233+
clsid = dst.GetClassID()
234+
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
235+
236+
self.assertEqual(0, dst.Release())
237+
238+
def test_src_is_null_and_dest_is_nonnull(self):
239+
src = self.IPersist()
240+
dst_orig = create_shelllink_persist(self.IPersist)
241+
dst = self.IPersist()
242+
CopyComPointer(dst_orig, byref(dst))
243+
self.assertEqual(1, dst_orig.Release())
244+
245+
clsid = dst.GetClassID()
246+
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
247+
248+
# This does NOT affects the refcount of `dst_orig`.
249+
hr = CopyComPointer(src, byref(dst))
250+
251+
self.assertEqual(S_OK, hr)
252+
self.assertIsNone(dst.value)
253+
254+
with self.assertRaises(ValueError):
255+
dst.GetClassID() # NULL COM pointer access
256+
257+
# This indicates that the refcount was 1 before the `Release` call.
258+
self.assertEqual(0, dst_orig.Release())
259+
260+
def test_both_are_nonnull(self):
261+
src = create_shelllink_persist(self.IPersist)
262+
dst_orig = create_shelllink_persist(self.IPersist)
263+
dst = self.IPersist()
264+
CopyComPointer(dst_orig, byref(dst))
265+
self.assertEqual(1, dst_orig.Release())
266+
267+
self.assertEqual(dst.value, dst_orig.value)
268+
self.assertNotEqual(src.value, dst.value)
269+
270+
hr = CopyComPointer(src, byref(dst))
271+
272+
self.assertEqual(S_OK, hr)
273+
self.assertEqual(src.value, dst.value)
274+
self.assertNotEqual(dst.value, dst_orig.value)
275+
276+
self.assertEqual(1, src.Release())
277+
278+
clsid = dst.GetClassID()
279+
self.assertEqual(TRUE, is_equal_guid(CLSID_ShellLink, clsid))
280+
281+
self.assertEqual(0, dst.Release())
282+
self.assertEqual(0, dst_orig.Release())
283+
284+
187285
if __name__ == '__main__':
188286
unittest.main()

0 commit comments

Comments
 (0)