9
9
raise unittest .SkipTest ("Windows-specific test" )
10
10
11
11
12
- from _ctypes import COMError
12
+ from _ctypes import COMError , CopyComPointer
13
13
from ctypes import HRESULT
14
14
15
15
@@ -78,6 +78,19 @@ def is_equal_guid(guid1, guid2):
78
78
)
79
79
80
80
81
+ def create_shelllink_persist (typ ):
82
+ ppst = typ ()
83
+ # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
84
+ ole32 .CoCreateInstance (
85
+ byref (CLSID_ShellLink ),
86
+ None ,
87
+ CLSCTX_SERVER ,
88
+ byref (IID_IPersist ),
89
+ byref (ppst ),
90
+ )
91
+ return ppst
92
+
93
+
81
94
class ForeignFunctionsThatWillCallComMethodsTests (unittest .TestCase ):
82
95
def setUp (self ):
83
96
# https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-coinitializeex
@@ -88,19 +101,6 @@ def tearDown(self):
88
101
ole32 .CoUninitialize ()
89
102
gc .collect ()
90
103
91
- @staticmethod
92
- def create_shelllink_persist (typ ):
93
- ppst = typ ()
94
- # https://learn.microsoft.com/en-us/windows/win32/api/combaseapi/nf-combaseapi-cocreateinstance
95
- ole32 .CoCreateInstance (
96
- byref (CLSID_ShellLink ),
97
- None ,
98
- CLSCTX_SERVER ,
99
- byref (IID_IPersist ),
100
- byref (ppst ),
101
- )
102
- return ppst
103
-
104
104
def test_without_paramflags_and_iid (self ):
105
105
class IUnknown (c_void_p ):
106
106
QueryInterface = proto_query_interface ()
@@ -110,7 +110,7 @@ class IUnknown(c_void_p):
110
110
class IPersist (IUnknown ):
111
111
GetClassID = proto_get_class_id ()
112
112
113
- ppst = self . create_shelllink_persist (IPersist )
113
+ ppst = create_shelllink_persist (IPersist )
114
114
115
115
clsid = GUID ()
116
116
hr_getclsid = ppst .GetClassID (byref (clsid ))
@@ -142,7 +142,7 @@ class IUnknown(c_void_p):
142
142
class IPersist (IUnknown ):
143
143
GetClassID = proto_get_class_id (((OUT , "pClassID" ),))
144
144
145
- ppst = self . create_shelllink_persist (IPersist )
145
+ ppst = create_shelllink_persist (IPersist )
146
146
147
147
clsid = ppst .GetClassID ()
148
148
self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
@@ -167,7 +167,7 @@ class IUnknown(c_void_p):
167
167
class IPersist (IUnknown ):
168
168
GetClassID = proto_get_class_id (((OUT , "pClassID" ),), IID_IPersist )
169
169
170
- ppst = self . create_shelllink_persist (IPersist )
170
+ ppst = create_shelllink_persist (IPersist )
171
171
172
172
clsid = ppst .GetClassID ()
173
173
self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
@@ -184,5 +184,103 @@ class IPersist(IUnknown):
184
184
self .assertEqual (0 , ppst .Release ())
185
185
186
186
187
+ class CopyComPointerTests (unittest .TestCase ):
188
+ def setUp (self ):
189
+ ole32 .CoInitializeEx (None , COINIT_APARTMENTTHREADED )
190
+
191
+ class IUnknown (c_void_p ):
192
+ QueryInterface = proto_query_interface (None , IID_IUnknown )
193
+ AddRef = proto_add_ref ()
194
+ Release = proto_release ()
195
+
196
+ class IPersist (IUnknown ):
197
+ GetClassID = proto_get_class_id (((OUT , "pClassID" ),), IID_IPersist )
198
+
199
+ self .IUnknown = IUnknown
200
+ self .IPersist = IPersist
201
+
202
+ def tearDown (self ):
203
+ ole32 .CoUninitialize ()
204
+ gc .collect ()
205
+
206
+ def test_both_are_null (self ):
207
+ src = self .IPersist ()
208
+ dst = self .IPersist ()
209
+
210
+ hr = CopyComPointer (src , byref (dst ))
211
+
212
+ self .assertEqual (S_OK , hr )
213
+
214
+ self .assertIsNone (src .value )
215
+ self .assertIsNone (dst .value )
216
+
217
+ def test_src_is_nonnull_and_dest_is_null (self ):
218
+ # The reference count of the COM pointer created by `CoCreateInstance`
219
+ # is initially 1.
220
+ src = create_shelllink_persist (self .IPersist )
221
+ dst = self .IPersist ()
222
+
223
+ # `CopyComPointer` calls `AddRef` explicitly in the C implementation.
224
+ # The refcount of `src` is incremented from 1 to 2 here.
225
+ hr = CopyComPointer (src , byref (dst ))
226
+
227
+ self .assertEqual (S_OK , hr )
228
+ self .assertEqual (src .value , dst .value )
229
+
230
+ # This indicates that the refcount was 2 before the `Release` call.
231
+ self .assertEqual (1 , src .Release ())
232
+
233
+ clsid = dst .GetClassID ()
234
+ self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
235
+
236
+ self .assertEqual (0 , dst .Release ())
237
+
238
+ def test_src_is_null_and_dest_is_nonnull (self ):
239
+ src = self .IPersist ()
240
+ dst_orig = create_shelllink_persist (self .IPersist )
241
+ dst = self .IPersist ()
242
+ CopyComPointer (dst_orig , byref (dst ))
243
+ self .assertEqual (1 , dst_orig .Release ())
244
+
245
+ clsid = dst .GetClassID ()
246
+ self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
247
+
248
+ # This does NOT affects the refcount of `dst_orig`.
249
+ hr = CopyComPointer (src , byref (dst ))
250
+
251
+ self .assertEqual (S_OK , hr )
252
+ self .assertIsNone (dst .value )
253
+
254
+ with self .assertRaises (ValueError ):
255
+ dst .GetClassID () # NULL COM pointer access
256
+
257
+ # This indicates that the refcount was 1 before the `Release` call.
258
+ self .assertEqual (0 , dst_orig .Release ())
259
+
260
+ def test_both_are_nonnull (self ):
261
+ src = create_shelllink_persist (self .IPersist )
262
+ dst_orig = create_shelllink_persist (self .IPersist )
263
+ dst = self .IPersist ()
264
+ CopyComPointer (dst_orig , byref (dst ))
265
+ self .assertEqual (1 , dst_orig .Release ())
266
+
267
+ self .assertEqual (dst .value , dst_orig .value )
268
+ self .assertNotEqual (src .value , dst .value )
269
+
270
+ hr = CopyComPointer (src , byref (dst ))
271
+
272
+ self .assertEqual (S_OK , hr )
273
+ self .assertEqual (src .value , dst .value )
274
+ self .assertNotEqual (dst .value , dst_orig .value )
275
+
276
+ self .assertEqual (1 , src .Release ())
277
+
278
+ clsid = dst .GetClassID ()
279
+ self .assertEqual (TRUE , is_equal_guid (CLSID_ShellLink , clsid ))
280
+
281
+ self .assertEqual (0 , dst .Release ())
282
+ self .assertEqual (0 , dst_orig .Release ())
283
+
284
+
187
285
if __name__ == '__main__' :
188
286
unittest .main ()
0 commit comments