Skip to content

Commit 17b142a

Browse files
AndyZzzZzzZzzAndy Zhangjackraymondkevinchernrandomir
authored
Enable mock sampler selection in MockDWaveSampler (#537)
* WIP - added the option to switch sampler and relevant testing codes * removed the flux biases code * changed the initialization of mocking sampler from init to body * simplified substitude_kwargs * Fix num_qubits bug; raise mocked_parameters and dimod_sampler to class variable status * Add missing property. Make substitute_kwargs a class variable. * Minor corrections * Revert change to support virtual graph composite * num_qubits for pegasus fixed * moved class attributes to instance attributes * updated tests for mock dwave sampler * Revert "updated tests for mock dwave sampler" This reverts commit 89ade5f. * created a local dictionary substitue_kwargs to ensure each instance of mock sampler has its own copies * updated gitignore * Update BQM definition to simplify variable weights Co-authored-by: Jack Raymond <[email protected]> * Replace single-read sample with num_reads=2 in MockDWaveSampler test Co-authored-by: Jack Raymond <[email protected]> * Update test to validate second sample state in MockDWaveSampler Co-authored-by: Jack Raymond <[email protected]> * Remove redundant energy check in MockDWaveSampler test Co-authored-by: Jack Raymond <[email protected]> * Removed redundant comments * Polished formatting and removed changes in .gitignore * Renamed CustomSampler to ConstantSampler in test cases * Updated documentation of MockDWaveSampler * Bugfix: substitute sampler not working as expected * Removed files * Changed shortcircuit to None identity check * Modified None identity check slightly * Add ss.info.update * Added documentation for new parameters * Move comments to documentation * Update exact_solver_cutoff along with its documentation * Correct errors related to misnaming subtitute_* as mock_* * Delete duplicate file accidentally pushed * Add note on return of ascent sampler * Update dwave/system/testing.py Co-authored-by: Radomir Stevanovic <[email protected]> * Update dwave/system/testing.py Change plain text to Sphinx style for documentation Co-authored-by: Radomir Stevanovic <[email protected]> * Update dwave/system/testing.py Co-authored-by: Radomir Stevanovic <[email protected]> * Fixed indentation * Updated comment * Added subtests in TestMockSampler * Modified comments * Minor adjustment for assert statement * Added test for kwargs overwrite in TestMockDWaveSampler --------- Co-authored-by: Andy Zhang <[email protected]> Co-authored-by: Jack Raymond <[email protected]> Co-authored-by: Kevin Chern <[email protected]> Co-authored-by: Radomir Stevanovic <[email protected]>
1 parent 0136eca commit 17b142a

File tree

4 files changed

+170
-33
lines changed

4 files changed

+170
-33
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,4 +106,4 @@ ENV/
106106
*.sublime-project
107107
*.sublime-workspace
108108

109-
generated/
109+
generated/

dwave/system/testing.py

Lines changed: 79 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,15 @@
2828
class MockDWaveSampler(dimod.Sampler, dimod.Structured):
2929
"""Mock sampler modeled after DWaveSampler that can be used for tests.
3030
31-
Properties and topology parameters are populated qualitatively matching
31+
Properties and topology parameters are populated to qualitatively match
3232
online systems, and a placeholder sampler routine based on steepest descent
33-
is instantiated.
33+
is instantiated by default.
34+
35+
The :attr:`.EXACT_SOLVER_CUTOFF_DEFAULT` defines the problem size threshold for using the exact solver.
36+
For problems with fewer variables than this threshold, the exact ground state is computed
37+
using a brute-force solver. This provides a reproducible solution for small problem sizes.
38+
39+
For larger problems, the `SteepestDescentSampler` is used as a placeholder solver.
3440
3541
Args:
3642
nodelist (iterable of ints, optional):
@@ -73,16 +79,36 @@ class MockDWaveSampler(dimod.Sampler, dimod.Structured):
7379
parameters. By default ``initial_state`` can also be mocked, if
7480
dwave-greedy is installed. All other parameters are ignored and a
7581
warning will be raised by default.
82+
83+
substitute_sampler (:class:`~dimod.Sampler`, optional, default=SteepestDescentSampler()):
84+
The sampler to be used as a substitute when executing the mock sampler.
85+
By default, :class:`~dwave.samplers.SteepestDescentSampler` is employed, which performs a
86+
deterministic steepest descent optimization on the BQM. Supported options are
87+
any dimod-compatible sampler to customize the sampling behavior of
88+
`MockDWaveSampler()`.
89+
90+
substitute_kwargs (dict, optional, default={}):
91+
A dictionary of keyword arguments to pass to the `substitute_sampler`'s
92+
`sample` method. This allows users to configure the substitute sampler
93+
with specific parameters like `num_reads`, `initial_state`, or other
94+
sampler-specific options. If not provided, an empty dictionary is used
95+
by default.
7696
7797
exact_solver_cutoff (int, optional, default=:attr:`EXACT_SOLVER_CUTOFF_DEFAULT`):
7898
For problems smaller or equal in size to ``exact_solver_cutoff``, the
7999
first sample in any sampleset returned by the sampling routines
80100
is replaced by a reproducible ground state (determined exactly with
81-
a brute-force :class:`~dimod.ExactSolver`). Only small cut offs
101+
a brute-force :class:`~dimod.ExactSolver`). Only small cutoffs
82102
should be used since solution time increases exponentially with
83103
problem size.
104+
105+
- When ``substitute_sampler`` is not provided, the default value is
106+
``EXACT_SOLVER_CUTOFF_DEFAULT`` (e.g., 16).
107+
- When ``substitute_sampler`` is provided, the default value is
108+
``0``, disabling exact ground state calculation.
109+
84110
Set ``exact_solver_cutoff`` to zero to disable exact ground state
85-
calculation.
111+
calculation explicitly.
86112
87113
Examples
88114
The first example creates a MockSampler without reference to a
@@ -108,23 +134,43 @@ class MockDWaveSampler(dimod.Sampler, dimod.Structured):
108134
-1
109135
110136
"""
111-
# Feature suggestion - add seed as an optional input, to allow reproducibility.
112-
113137
nodelist = None
114138
edgelist = None
115139
properties = None
116140
parameters = None
117141

118-
# by default, use ExactSolver for problems up to size (inclusive):
119-
EXACT_SOLVER_CUTOFF_DEFAULT = 16
120-
121142
def __init__(self,
122143
nodelist=None, edgelist=None, properties=None,
123144
broken_nodes=None, broken_edges=None,
124145
topology_type=None, topology_shape=None,
125146
parameter_warnings=True,
126-
exact_solver_cutoff=EXACT_SOLVER_CUTOFF_DEFAULT,
147+
substitute_sampler=None,
148+
substitute_kwargs=None,
149+
exact_solver_cutoff=None,
127150
**config):
151+
152+
self.mocked_parameters={'answer_mode',
153+
'max_answers',
154+
'num_reads',
155+
'label',
156+
'initial_state'}
157+
158+
EXACT_SOLVER_CUTOFF_DEFAULT = 16
159+
160+
if substitute_sampler is None:
161+
substitute_sampler = SteepestDescentSampler()
162+
if exact_solver_cutoff is None:
163+
exact_solver_cutoff = EXACT_SOLVER_CUTOFF_DEFAULT
164+
else:
165+
if exact_solver_cutoff is None:
166+
exact_solver_cutoff = 0
167+
168+
self.substitute_sampler = substitute_sampler
169+
170+
if substitute_kwargs is None:
171+
substitute_kwargs = {}
172+
self.substitute_kwargs = substitute_kwargs
173+
128174
self.parameter_warnings = parameter_warnings
129175
self.exact_solver_cutoff = exact_solver_cutoff
130176

@@ -158,7 +204,6 @@ def __init__(self,
158204
'chip_id': 'MockDWaveSampler',
159205
'topology': {'type': topology_type, 'shape': topology_shape}
160206
}
161-
162207
#Create graph object, introduce defects per input arguments
163208
if nodelist is not None:
164209
self.nodelist = nodelist.copy()
@@ -171,6 +216,11 @@ def __init__(self,
171216
self.properties['topology']['shape'],
172217
self.nodelist, self.edgelist)
173218

219+
if topology_type == 'pegasus':
220+
m = self.properties['topology']['shape'][0]
221+
num_qubits = m*(m-1)*24 # fabric_only=True technicality
222+
else:
223+
num_qubits = len(solver_graph)
174224
if broken_nodes is None and broken_edges is None:
175225
self.nodelist = sorted(solver_graph.nodes)
176226
self.edgelist = sorted(tuple(sorted(edge))
@@ -189,7 +239,7 @@ def __init__(self,
189239
and (v, u) not in broken_edges)
190240
#Finalize yield-dependent properties:
191241
self.properties.update({
192-
'num_qubits': len(solver_graph),
242+
'num_qubits': num_qubits,
193243
'qubits': self.nodelist.copy(),
194244
'couplers': self.edgelist.copy(),
195245
'anneal_offset_ranges': [[-0.5, 0.5] if i in self.nodelist
@@ -295,6 +345,7 @@ def __init__(self,
295345
'tags': [],
296346
'category': 'qpu',
297347
'quota_conversion_rate': 1,
348+
'fast_anneal_time_range': [0.005, 83000.0],
298349
})
299350

300351
if properties is not None:
@@ -313,15 +364,9 @@ def from_qpu_sampler(cls, sampler):
313364
def sample(self, bqm, **kwargs):
314365

315366
# Check kwargs compatibility with parameters and substitute sampler:
316-
mocked_parameters={'answer_mode',
317-
'max_answers',
318-
'num_reads',
319-
'label',
320-
'initial_state',
321-
}
322367
for kw in kwargs:
323368
if kw in self.parameters:
324-
if self.parameter_warnings and kw not in mocked_parameters:
369+
if self.parameter_warnings and kw not in self.mocked_parameters:
325370
warnings.warn(f'{kw!r} parameter is valid for DWaveSampler(), '
326371
'but not mocked in MockDWaveSampler().')
327372
else:
@@ -344,19 +389,21 @@ def sample(self, bqm, **kwargs):
344389
label = kwargs.get('label')
345390
if label is not None:
346391
info.update(problem_label=label)
347-
348-
#Special handling of flux_biases, for compatibility with virtual graphs
349-
392+
393+
# Special handling of flux_biases, for compatibility with virtual graphs
350394
flux_biases = kwargs.get('flux_biases')
351395
if flux_biases is not None:
352396
self.flux_biases_flag = True
353397

354-
substitute_kwargs = {'num_reads' : kwargs.get('num_reads')}
355-
if substitute_kwargs['num_reads'] is None:
356-
substitute_kwargs['num_reads'] = 1
398+
# Create a local dictionary combining self.substitute_kwargs and relevant kwargs
399+
substitute_kwargs = self.substitute_kwargs.copy()
357400

358-
initial_state = kwargs.get('initial_state')
359-
if initial_state is not None:
401+
# Handle 'num_reads', defaulting to 1 if not provided
402+
num_reads = kwargs.get('num_reads', substitute_kwargs.get('num_reads', 1))
403+
substitute_kwargs['num_reads'] = num_reads
404+
405+
if 'initial_state' in kwargs:
406+
initial_state = kwargs['initial_state']
360407
# Initial state format is a list of (qubit,values)
361408
# value=3 denotes an unused variable (should be absent
362409
# from bqm).
@@ -366,15 +413,17 @@ def sample(self, bqm, **kwargs):
366413
if pair[1]!=3],dtype=float),
367414
[pair[0] for pair in initial_state if pair[1]!=3])
368415

369-
ss = SteepestDescentSampler().sample(bqm, **substitute_kwargs)
370-
ss.info.update(info)
416+
sampler_kwargs = kwargs.copy()
417+
sampler_kwargs.update(substitute_kwargs)
371418

419+
ss = self.substitute_sampler.sample(bqm, **sampler_kwargs)
420+
ss.info.update(info)
372421
# determine ground state exactly for small problems
373422
if 0 < len(bqm) <= self.exact_solver_cutoff and len(ss) >= 1:
374423
ground = dimod.ExactSolver().sample(bqm).truncate(1)
375424
ss.record[0].sample = ground.record[0].sample
376425
ss.record[0].energy = ground.record[0].energy
377-
426+
378427
answer_mode = kwargs.get('answer_mode')
379428
if answer_mode is None or answer_mode == 'histogram':
380429
# Default for DWaveSampler() is 'histogram'

tests/test_mock_sampler.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from dwave.system.testing import MockDWaveSampler, MockLeapHybridDQMSampler
2626
from dwave.cloud.exceptions import ConfigFileError, SolverNotFoundError
2727
from dimod import DiscreteQuadraticModel, ExtendedVartype, SampleSet
28+
from dwave.samplers import SteepestDescentSolver
2829

2930

3031
class TestMockDWaveSampler(unittest.TestCase):
@@ -216,7 +217,94 @@ def test_yield_arguments(self):
216217
broken_edges=delete_edges)
217218
self.assertTrue(len(sampler.nodelist)==4)
218219
self.assertTrue(len(sampler.edgelist)==1)
219-
220+
221+
def test_custom_substitute_sampler(self):
222+
"""Test that MockDWaveSampler uses the provided custom substitute_sampler."""
223+
224+
# Define a sampler that always returns the a constant (excited) state
225+
class SteepestAscentSolver(SteepestDescentSolver):
226+
def sample(self, bqm, **kwargs):
227+
# Return local (or global) maxima instead of local minima
228+
# NOTE: energy returned is not faithful to the original bqm (energy calculated as `-bqm`)
229+
return super().sample(-bqm, **kwargs)
230+
231+
inverted_sampler = SteepestAscentSolver()
232+
233+
# Create a simple BQM
234+
bqm = dimod.BQM({'a': 1, 'b': 1}, {}, 0.0, vartype="SPIN")
235+
236+
# Instantiate MockDWaveSampler with nodelist and edgelist including 'a' and 'b'
237+
sampler = MockDWaveSampler(
238+
substitute_sampler=inverted_sampler,
239+
nodelist=['a', 'b'],
240+
edgelist=[('a', 'b')]
241+
)
242+
243+
# First Subtest: First sample does not use ExactSampler();
244+
# Second sample does not use SteepestDescentSampler()
245+
with self.subTest("Sampler without ExactSampler"):
246+
ss = sampler.sample(bqm, num_reads=2)
247+
self.assertEqual(sampler.exact_solver_cutoff, 0)
248+
self.assertEqual(ss.record.sample.shape, (1,2), 'Unique sample expected')
249+
self.assertTrue(np.all(ss.record.sample==1), 'Excited states expected')
250+
251+
sampler = MockDWaveSampler(
252+
substitute_sampler=inverted_sampler,
253+
nodelist=['a', 'b'],
254+
edgelist=[('a', 'b')],
255+
exact_solver_cutoff=2
256+
)
257+
# Second Subtest: First sample uses ExactSampler();
258+
# Second sampler uses inverted sampler. Explicit exact_solver_cutoff overrides substitute_sampler.
259+
with self.subTest("Sampler with ExactSampler and substitute_sampler"):
260+
ss = sampler.sample(bqm, num_reads=2, answer_mode='raw')
261+
self.assertEqual(sampler.exact_solver_cutoff, 2)
262+
self.assertEqual(ss.record.sample.shape, (2,2), 'Non-unique samples expected')
263+
self.assertTrue(np.all(ss.record.sample[0,:] == -1), 'Excited states expected')
264+
self.assertTrue(np.all(ss.record.sample[1,:] == 1), 'Excited states expected')
265+
266+
def test_mocking_sampler_params(self):
267+
"""Test that substitute_kwargs are correctly passed to the substitute_sampler."""
268+
269+
# Define a constant sampler that checks for a custom parameter
270+
class ConstantSampler(dimod.Sampler):
271+
properties = {}
272+
parameters = {'custom_param': [], 'num_reads': []}
273+
274+
def sample(self, bqm, **kwargs):
275+
custom_param = kwargs.get('custom_param')
276+
num_reads = kwargs.get('num_reads')
277+
# Raise exception if parameters passed incorrectly
278+
if custom_param != 'test_value':
279+
raise ValueError("custom_param not passed correctly")
280+
if num_reads != 10:
281+
raise ValueError(f"num_reads not passed correctly, expected 10, got {num_reads}")
282+
# Return a default sample
283+
sample = {v: -1 for v in bqm.variables}
284+
return dimod.SampleSet.from_samples_bqm(sample, bqm)
285+
286+
constant_sampler = ConstantSampler()
287+
288+
# Create a simple BQM
289+
bqm = dimod.BQM({'a': 1, 'b': 1}, {('a', 'b'): 1}, 0.0, vartype="SPIN")
290+
291+
# Instantiate MockDWaveSampler with nodelist and edgelist including 'a' and 'b'
292+
sampler = MockDWaveSampler(
293+
substitute_sampler=constant_sampler,
294+
substitute_kwargs={'custom_param': 'test_value'},
295+
nodelist=['a', 'b'],
296+
edgelist=[('a', 'b')]
297+
)
298+
299+
# Sample using the MockDWaveSampler
300+
ss = sampler.sample(bqm, num_reads=10)
301+
302+
# Check that the sample returned is as expected from the custom sampler
303+
expected_sample = {'a': -1, 'b': -1}
304+
self.assertEqual(ss.first.sample, expected_sample)
305+
self.assertEqual(ss.first.energy, bqm.energy(expected_sample))
306+
307+
220308
class TestMockLeapHybridDQMSampler(unittest.TestCase):
221309
def test_sampler(self):
222310
sampler = MockLeapHybridDQMSampler()

tests/test_virtual_graph_composite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_smoke(self):
3333

3434
# depending on how recenlty flux bias data was gathered, this may be true
3535
child_sampler.flux_biases_flag = False
36-
36+
child_sampler.mocked_parameters.add('flux_biases') # Don't raise warning
3737
if sampler.flux_biases:
3838
sampler.sample_ising({'a': -1}, {})
3939
self.assertTrue(child_sampler.flux_biases_flag) # true when some have been provided to sample_ising

0 commit comments

Comments
 (0)