Skip to content

Commit 89ade5f

Browse files
author
Andy Zhang
committed
updated tests for mock dwave sampler
1 parent 95907a8 commit 89ade5f

File tree

1 file changed

+73
-1
lines changed

1 file changed

+73
-1
lines changed

tests/test_mock_sampler.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,79 @@ def test_yield_arguments(self):
216216
broken_edges=delete_edges)
217217
self.assertTrue(len(sampler.nodelist)==4)
218218
self.assertTrue(len(sampler.edgelist)==1)
219-
219+
220+
def test_custom_mock_sampler(self):
221+
"""Test that MockDWaveSampler uses the provided custom mocking_sampler."""
222+
223+
# Define a custom sampler that always returns the same sample
224+
class CustomSampler(dimod.Sampler):
225+
properties = {}
226+
parameters = {}
227+
228+
def sample(self, bqm, **kwargs):
229+
# Return a sample where all variables are set to 1
230+
sample = {v: 1 for v in bqm.variables}
231+
energy = bqm.energy(sample)
232+
return dimod.SampleSet.from_samples_bqm(sample, bqm)
233+
234+
custom_sampler = CustomSampler()
235+
236+
# Create a simple BQM
237+
bqm = dimod.BQM({'a': -1, 'b': -1}, {('a', 'b'): -1}, 0.0, vartype='SPIN')
238+
239+
# Instantiate MockDWaveSampler with nodelist and edgelist including 'a' and 'b'
240+
sampler = MockDWaveSampler(
241+
mocking_sampler=custom_sampler,
242+
nodelist=['a', 'b'],
243+
edgelist=[('a', 'b')]
244+
)
245+
246+
# Sample using the MockDWaveSampler with the custom sampler
247+
ss = sampler.sample(bqm)
248+
249+
# Check that the sample returned is as expected from the custom sampler
250+
expected_sample = {'a': 1, 'b': 1}
251+
self.assertEqual(ss.first.sample, expected_sample)
252+
self.assertEqual(ss.first.energy, bqm.energy(expected_sample))
253+
254+
def test_mocking_sampler_params(self):
255+
"""Test that mocking_sampler_params are correctly passed to the mocking_sampler."""
256+
257+
# Define a custom sampler that checks for a custom parameter
258+
class CustomSampler(dimod.Sampler):
259+
properties = {}
260+
parameters = {'custom_param': []}
261+
262+
def sample(self, bqm, custom_param=None, **kwargs):
263+
# Assert that custom_param is passed correctly
264+
assert custom_param == 'test_value', "custom_param not passed correctly"
265+
# Return a default sample
266+
sample = {v: -1 for v in bqm.variables}
267+
energy = bqm.energy(sample)
268+
return dimod.SampleSet.from_samples_bqm(sample, bqm)
269+
270+
custom_sampler = CustomSampler()
271+
272+
# Create a simple BQM
273+
bqm = dimod.BQM({'a': 1, 'b': 1}, {('a', 'b'): 1}, 0.0, vartype='SPIN')
274+
275+
# Instantiate MockDWaveSampler with nodelist and edgelist including 'a' and 'b'
276+
sampler = MockDWaveSampler(
277+
mocking_sampler=custom_sampler,
278+
mocking_sampler_params={'custom_param': 'test_value'},
279+
nodelist=['a', 'b'],
280+
edgelist=[('a', 'b')]
281+
)
282+
283+
# Sample using the MockDWaveSampler
284+
ss = sampler.sample(bqm)
285+
286+
# Check that the sample returned is as expected from the custom sampler
287+
expected_sample = {'a': -1, 'b': -1}
288+
self.assertEqual(ss.first.sample, expected_sample)
289+
self.assertEqual(ss.first.energy, bqm.energy(expected_sample))
290+
291+
220292
class TestMockLeapHybridDQMSampler(unittest.TestCase):
221293
def test_sampler(self):
222294
sampler = MockLeapHybridDQMSampler()

0 commit comments

Comments
 (0)