@@ -216,7 +216,79 @@ def test_yield_arguments(self):
216
216
broken_edges = delete_edges )
217
217
self .assertTrue (len (sampler .nodelist )== 4 )
218
218
self .assertTrue (len (sampler .edgelist )== 1 )
219
-
219
+
220
+ def test_custom_mock_sampler (self ):
221
+ """Test that MockDWaveSampler uses the provided custom mocking_sampler."""
222
+
223
+ # Define a custom sampler that always returns the same sample
224
+ class CustomSampler (dimod .Sampler ):
225
+ properties = {}
226
+ parameters = {}
227
+
228
+ def sample (self , bqm , ** kwargs ):
229
+ # Return a sample where all variables are set to 1
230
+ sample = {v : 1 for v in bqm .variables }
231
+ energy = bqm .energy (sample )
232
+ return dimod .SampleSet .from_samples_bqm (sample , bqm )
233
+
234
+ custom_sampler = CustomSampler ()
235
+
236
+ # Create a simple BQM
237
+ bqm = dimod .BQM ({'a' : - 1 , 'b' : - 1 }, {('a' , 'b' ): - 1 }, 0.0 , vartype = 'SPIN' )
238
+
239
+ # Instantiate MockDWaveSampler with nodelist and edgelist including 'a' and 'b'
240
+ sampler = MockDWaveSampler (
241
+ mocking_sampler = custom_sampler ,
242
+ nodelist = ['a' , 'b' ],
243
+ edgelist = [('a' , 'b' )]
244
+ )
245
+
246
+ # Sample using the MockDWaveSampler with the custom sampler
247
+ ss = sampler .sample (bqm )
248
+
249
+ # Check that the sample returned is as expected from the custom sampler
250
+ expected_sample = {'a' : 1 , 'b' : 1 }
251
+ self .assertEqual (ss .first .sample , expected_sample )
252
+ self .assertEqual (ss .first .energy , bqm .energy (expected_sample ))
253
+
254
+ def test_mocking_sampler_params (self ):
255
+ """Test that mocking_sampler_params are correctly passed to the mocking_sampler."""
256
+
257
+ # Define a custom sampler that checks for a custom parameter
258
+ class CustomSampler (dimod .Sampler ):
259
+ properties = {}
260
+ parameters = {'custom_param' : []}
261
+
262
+ def sample (self , bqm , custom_param = None , ** kwargs ):
263
+ # Assert that custom_param is passed correctly
264
+ assert custom_param == 'test_value' , "custom_param not passed correctly"
265
+ # Return a default sample
266
+ sample = {v : - 1 for v in bqm .variables }
267
+ energy = bqm .energy (sample )
268
+ return dimod .SampleSet .from_samples_bqm (sample , bqm )
269
+
270
+ custom_sampler = CustomSampler ()
271
+
272
+ # Create a simple BQM
273
+ bqm = dimod .BQM ({'a' : 1 , 'b' : 1 }, {('a' , 'b' ): 1 }, 0.0 , vartype = 'SPIN' )
274
+
275
+ # Instantiate MockDWaveSampler with nodelist and edgelist including 'a' and 'b'
276
+ sampler = MockDWaveSampler (
277
+ mocking_sampler = custom_sampler ,
278
+ mocking_sampler_params = {'custom_param' : 'test_value' },
279
+ nodelist = ['a' , 'b' ],
280
+ edgelist = [('a' , 'b' )]
281
+ )
282
+
283
+ # Sample using the MockDWaveSampler
284
+ ss = sampler .sample (bqm )
285
+
286
+ # Check that the sample returned is as expected from the custom sampler
287
+ expected_sample = {'a' : - 1 , 'b' : - 1 }
288
+ self .assertEqual (ss .first .sample , expected_sample )
289
+ self .assertEqual (ss .first .energy , bqm .energy (expected_sample ))
290
+
291
+
220
292
class TestMockLeapHybridDQMSampler (unittest .TestCase ):
221
293
def test_sampler (self ):
222
294
sampler = MockLeapHybridDQMSampler ()
0 commit comments