diff --git a/conda.recipe/meta.yaml b/conda.recipe/meta.yaml index 109718d..95de0bf 100644 --- a/conda.recipe/meta.yaml +++ b/conda.recipe/meta.yaml @@ -1,5 +1,5 @@ {% set name = "quera-ahs-utils" %} -{% set version = "0.0.6" %} +{% set version = "0.0.7" %} package: name: {{ name|lower }} diff --git a/docs/analysis.html b/docs/analysis.html index 6febba8..cde43f7 100644 --- a/docs/analysis.html +++ b/docs/analysis.html @@ -3,17 +3,15 @@ - + quera_ahs_utils.analysis API documentation - - - - + + + + - -
@@ -91,7 +89,7 @@

Module quera_ahs_utils.analysis

Functions

-def get_avg_density(result: braket.tasks.analog_hamiltonian_simulation_quantum_task_result.AnalogHamiltonianSimulationQuantumTaskResult) ‑> numpy.ndarray +def get_avg_density(result: braket.tasks.analog_hamiltonian_simulation_quantum_task_result.AnalogHamiltonianSimulationQuantumTaskResult) -> numpy.ndarray

Get the average Rydberg densities from the result

@@ -100,9 +98,9 @@

Args

result : AnalogHamiltonianSimulationQuantumTaskResult
The result from which the aggregated state counts are obtained
-
-

Returns: -ndarray: The average densities from the result

+
Returns
+
ndarray: The average densities from the result
+
Expand source code @@ -128,7 +126,7 @@

Args

-def get_counts(result: braket.tasks.analog_hamiltonian_simulation_quantum_task_result.AnalogHamiltonianSimulationQuantumTaskResult) ‑> Dict[str, int] +def get_counts(result: braket.tasks.analog_hamiltonian_simulation_quantum_task_result.AnalogHamiltonianSimulationQuantumTaskResult) -> Dict[str, int]

Aggregate state counts from AHS shot results

@@ -142,11 +140,11 @@

Returns

Dict[str, int]
number of times each state configuration is measured
-
-

Notes: We use the following convention to denote the state of an atom (site): -e: empty site +

Notes : We use the following convention to denote the state of an atom (site):
+
e: empty site r: Rydberg state atom -g: ground state atom

+g: ground state atom
+
Expand source code @@ -205,7 +203,9 @@

Index

+ + \ No newline at end of file diff --git a/docs/drive.html b/docs/drive.html index 80dac01..518d5c3 100644 --- a/docs/drive.html +++ b/docs/drive.html @@ -3,17 +3,15 @@ - + quera_ahs_utils.drive API documentation - - - - + + + + - -
@@ -260,7 +258,7 @@

Module quera_ahs_utils.drive

Functions

-def concatenate_drive_list(drive_list: List[braket.ahs.driving_field.DrivingField]) ‑> braket.ahs.driving_field.DrivingField +def concatenate_drive_list(drive_list: List[braket.ahs.driving_field.DrivingField]) -> braket.ahs.driving_field.DrivingField

Concatenate a list of driving fields to a single driving field

@@ -294,7 +292,7 @@

Returns

-def concatenate_drives(drive_1: braket.ahs.driving_field.DrivingField, drive_2: braket.ahs.driving_field.DrivingField) ‑> braket.ahs.driving_field.DrivingField +def concatenate_drives(drive_1: braket.ahs.driving_field.DrivingField, drive_2: braket.ahs.driving_field.DrivingField) -> braket.ahs.driving_field.DrivingField

Concatenate two driving fields to a single driving field

@@ -332,7 +330,7 @@

Returns

-def concatenate_shift_list(shift_list: List[braket.ahs.shifting_field.ShiftingField]) ‑> braket.ahs.shifting_field.ShiftingField +def concatenate_shift_list(shift_list: List[braket.ahs.shifting_field.ShiftingField]) -> braket.ahs.shifting_field.ShiftingField

Concatenate a list of shifting fields to a single driving field

@@ -366,7 +364,7 @@

Returns

-def concatenate_shifts(shift_1: braket.ahs.shifting_field.ShiftingField, shift_2: braket.ahs.shifting_field.ShiftingField) ‑> braket.ahs.shifting_field.ShiftingField +def concatenate_shifts(shift_1: braket.ahs.shifting_field.ShiftingField, shift_2: braket.ahs.shifting_field.ShiftingField) -> braket.ahs.shifting_field.ShiftingField

Concatenate two driving fields to a single driving field

@@ -403,7 +401,7 @@

Returns

-def concatenate_time_series(time_series_1: braket.timings.time_series.TimeSeries, time_series_2: braket.timings.time_series.TimeSeries) ‑> braket.timings.time_series.TimeSeries +def concatenate_time_series(time_series_1: braket.timings.time_series.TimeSeries, time_series_2: braket.timings.time_series.TimeSeries) -> braket.timings.time_series.TimeSeries

Concatenate two time series to a single time series

@@ -448,7 +446,7 @@

Returns

-def constant_time_series(other_time_series: braket.timings.time_series.TimeSeries, constant: float = 0.0) ‑> braket.timings.time_series.TimeSeries +def constant_time_series(other_time_series: braket.timings.time_series.TimeSeries, constant: float = 0.0) -> braket.timings.time_series.TimeSeries

Obtain a constant time series with the same time points as the given time series

@@ -482,7 +480,7 @@

Returns

-def get_drive(times: List[float], amplitude_values: List[float], detuning_values: List[float], phase_values: List[float]) ‑> braket.ahs.driving_field.DrivingField +def get_drive(times: List[float], amplitude_values: List[float], detuning_values: List[float], phase_values: List[float]) -> braket.ahs.driving_field.DrivingField

Get the driving field from a set of time points and values of the fields

@@ -547,7 +545,7 @@

Returns

-def get_shift(times: List[float], values: List[float], pattern: List[float]) ‑> braket.ahs.shifting_field.ShiftingField +def get_shift(times: List[float], values: List[float], pattern: List[float]) -> braket.ahs.shifting_field.ShiftingField

Get the shifting field from a set of time points, values and pattern

@@ -591,7 +589,7 @@

Returns

-def rabi_pulse(rabi_pulse_area: float, omega_max: float, omega_slew_rate_max: float) ‑> Tuple[List[float], List[float]] +def rabi_pulse(rabi_pulse_area: float, omega_max: float, omega_slew_rate_max: float) -> Tuple[List[float], List[float]]

Get a time series for Rabi frequency with specified Rabi phase, maximum amplitude @@ -610,9 +608,9 @@

Returns

Tuple[List[float], List[float]]
A tuple containing the time points and values of the time series for the time dependent Rabi frequency
-
-

Notes: By Rabi phase, it means the integral of the amplitude of a time-dependent -Rabi frequency, \int_0^T\Omega(t)dt, where T is the duration.

+
Notes : By Rabi phase, it means the integral of the amplitude of a time-dependent
+
Rabi frequency, \int_0^T\Omega(t)dt, where T is the duration.
+
Expand source code @@ -685,7 +683,9 @@

Index

+ + \ No newline at end of file diff --git a/docs/index.html b/docs/index.html index 617aeff..cd67362 100644 --- a/docs/index.html +++ b/docs/index.html @@ -3,17 +3,15 @@ - + quera_ahs_utils API documentation - - - - + + + + - -
@@ -74,7 +72,9 @@

Index

+ + \ No newline at end of file diff --git a/docs/ir.html b/docs/ir.html index 97c3499..308a105 100644 --- a/docs/ir.html +++ b/docs/ir.html @@ -3,17 +3,15 @@ - + quera_ahs_utils.ir API documentation - - - - + + + + - -
@@ -55,6 +53,7 @@

Module quera_ahs_utils.ir

Args: js (dict): data to be serialaized. json_filename (str): filename to output to. + **json_options: options that get passed to the json serializer as `json.dump(...,**json_options)` """ with open(json_filename,"w") as IO: @@ -187,7 +186,7 @@

Module quera_ahs_utils.ir

Functions

-def braket_sdk_to_quera_json(ahs: braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation, shots: int = 1) ‑> dict +def braket_sdk_to_quera_json(ahs: braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation, shots: int = 1) -> dict

Translates Braket AHS IR program to Quera-compatible JSON.

@@ -257,7 +256,7 @@

Returns

-def from_json_file(json_filename: str) ‑> dict +def from_json_file(json_filename: str) -> dict

deserialize a json file.

@@ -291,7 +290,7 @@

Returns

-def quera_json_to_ahs(js: dict) ‑> Tuple[int, braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation] +def quera_json_to_ahs(js: dict) -> Tuple[int, braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation]

Convert a QuEra compatible program to a braket AHS program.

@@ -367,7 +366,7 @@

Returns

-def to_json_file(js: dict, json_filename: str, **json_options) ‑> NoReturn +def to_json_file(js: dict, json_filename: str, **json_options) -> NoReturn

prints out a dictionary to a json file.

@@ -377,6 +376,8 @@

Args

data to be serialaized.
json_filename : str
filename to output to.
+
**json_options
+
options that get passed to the json serializer as json.dump(...,**json_options)
@@ -388,6 +389,7 @@

Args

Args: js (dict): data to be serialaized. json_filename (str): filename to output to. + **json_options: options that get passed to the json serializer as `json.dump(...,**json_options)` """ with open(json_filename,"w") as IO: @@ -422,7 +424,9 @@

Index

+ + \ No newline at end of file diff --git a/docs/parallelize.html b/docs/parallelize.html index 44749cb..f74559f 100644 --- a/docs/parallelize.html +++ b/docs/parallelize.html @@ -3,17 +3,15 @@ - + quera_ahs_utils.parallelize API documentation - - - - + + + + - -
@@ -333,7 +331,7 @@

Module quera_ahs_utils.parallelize

Functions

-def generate_parallel_register(register: braket.ahs.atom_arrangement.AtomArrangement, qpu: braket.aws.aws_device.AwsDevice, interproblem_distance: Union[float, decimal.Decimal]) ‑> Tuple[braket.ahs.atom_arrangement.AtomArrangement, dict] +def generate_parallel_register(register: braket.ahs.atom_arrangement.AtomArrangement, qpu: braket.aws.aws_device.AwsDevice, interproblem_distance: Union[float, decimal.Decimal]) -> Tuple[braket.ahs.atom_arrangement.AtomArrangement, dict]

generate grid of parallel registers from a single register.

@@ -416,7 +414,7 @@

Returns

-def get_shots_braket_sdk_results(results: braket.task_result.analog_hamiltonian_simulation_task_result_v1.AnalogHamiltonianSimulationTaskResult, batch_mapping: Optional[dict] = None, post_select: Optional[bool] = True) ‑>  +def get_shots_braket_sdk_results(results: braket.task_result.analog_hamiltonian_simulation_task_result_v1.AnalogHamiltonianSimulationTaskResult, batch_mapping: Optional[dict] = None, post_select: Optional[bool] = True) -> 

get the shot results from a braket-sdk task results type.

@@ -471,7 +469,7 @@

Returns

-def get_shots_quera_results(results_json: dict, batch_mapping: Optional[dict] = None, post_select: Optional[bool] = True) ‑>  +def get_shots_quera_results(results_json: dict, batch_mapping: Optional[dict] = None, post_select: Optional[bool] = True) -> 

Get the shots out of a QuEra programming

@@ -529,7 +527,7 @@

Returns

-def parallelize_ahs(ahs: braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation, qpu: braket.aws.aws_device.AwsDevice, interproblem_distance: Union[float, decimal.Decimal]) ‑> braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation +def parallelize_ahs(ahs: braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation, qpu: braket.aws.aws_device.AwsDevice, interproblem_distance: Union[float, decimal.Decimal]) -> braket.ahs.analog_hamiltonian_simulation.AnalogHamiltonianSimulation

Generate parallel ahs program.

@@ -576,7 +574,7 @@

Returns

-def parallelize_field(field: braket.ahs.field.Field, batch_mapping: dict) ‑> braket.ahs.field.Field +def parallelize_field(field: braket.ahs.field.Field, batch_mapping: dict) -> braket.ahs.field.Field

Generate parallel field from a batch_mapping

@@ -624,7 +622,7 @@

Returns

-def parallelize_hamiltonian(driving_field: braket.ahs.driving_field.DrivingField, batch_mapping: dict) ‑> braket.ahs.driving_field.DrivingField +def parallelize_hamiltonian(driving_field: braket.ahs.driving_field.DrivingField, batch_mapping: dict) -> braket.ahs.driving_field.DrivingField

Generate the parallel driving fields from a batch_mapping.

@@ -665,7 +663,7 @@

Returns

-def parallelize_quera_json(input_json: dict, interproblem_distance: float, qpu_width: float, qpu_height: float, n_site_max: int) ‑> Tuple[dict, dict] +def parallelize_quera_json(input_json: dict, interproblem_distance: float, qpu_width: float, qpu_height: float, n_site_max: int) -> Tuple[dict, dict]

Generate a parallel QuEra json program from a single program.

@@ -815,7 +813,9 @@

Index

+ + \ No newline at end of file diff --git a/docs/plotting.html b/docs/plotting.html index c2bf408..474ef9d 100644 --- a/docs/plotting.html +++ b/docs/plotting.html @@ -3,17 +3,15 @@ - + quera_ahs_utils.plotting API documentation - - - - + + + + - -
@@ -55,10 +53,9 @@

Module quera_ahs_utils.plotting

Args: register (AtomArrangement): A given register - blockade_radius (float): The blockade radius for the register. Default is 0 - what_to_draw (str): Either "bond" or "circle" to indicate the blockade region. - Default is "bond" - show_atom_index (bool): Whether showing the indices of the atoms. Default is True + blockade_radius (float): Default is 0. The blockade radius for the register. + what_to_draw (str): Default is "bond". Either "bond" or "circle" to indicate the blockade region. + show_atom_index (bool): Default is True. Whether showing the indices of the atoms. """ filled_sites = [site.coordinate for site in register if site.site_type == SiteType.FILLED] @@ -93,7 +90,7 @@

Module quera_ahs_utils.plotting

"""Plot the driving field Args: drive (DrivingField): The driving field to be plot - axes: matplotlib axis to draw on + axes: Default is None. matplotlib axis to draw on **plot_ops: options passed to matplitlib.pyplot.plot """ @@ -192,15 +189,20 @@

Module quera_ahs_utils.plotting

register (AtomArrangement): The register used in creating the Hamiltonian. - with_labels (Boolean): Choose if each atom's index is displayed over the atom itself in the resulting figure. + with_labels (Boolean): Default is True. Choose if each atom's index is displayed over the atom itself in the resulting figure. Default is True. - custom_axes (matplotlib.axes.Axes): If argument is given, the plot will use the supplied + custom_axes (matplotlib.axes.Axes): Default is None. If argument is given, the plot will use the supplied axis for displaying data and the function will not return anything. Otherwise, a new matplotlib Figure and - Axes will be generated and returned. Default is None. - - cmap (matplotlib.colors.Colormap): Defines the colormap that the plot uses to map the average density values - to the colors of each plotted atom. + Axes will be generated and returned. + + cmap (matplotlib.colors.Colormap): Default is None. Defines the colormap that the plot uses to map the average density values + to the colors of each plotted atom. When Default value is used a the resulting plot uses a Colormap that is given by + `matplotlib.pyplot.cm.bwr` which is gradient from red to blue with white in the middle. + + Returns: + Tuple[Optional[matplotlib.figure.Figure],matplotlib.axes.Axes]]: returns the Figure and the Axes object used to create the plot if `custom_axes` + is not given, otherwise the function returns None """ # get atom coordinates @@ -222,9 +224,11 @@

Module quera_ahs_utils.plotting

# construct plot if custom_axes is None: + return_fig = True fig, ax = plt.subplots() else: ax = custom_axes + return_fig = False nx.draw(g, pos, @@ -263,7 +267,12 @@

Module quera_ahs_utils.plotting

cbar_label = "Rydberg Density" - plt.colorbar(sm, ax=ax, label=cbar_label) + plt.colorbar(sm, ax=ax, label=cbar_label) + + if return_fig: + return fig,ax + else: + return None,ax
@@ -286,15 +295,23 @@

Args

register : AtomArrangement
The register used in creating the Hamiltonian.
with_labels : Boolean
-
Choose if each atom's index is displayed over the atom itself in the resulting figure. +
Default is True. Choose if each atom's index is displayed over the atom itself in the resulting figure. Default is True.
custom_axes : matplotlib.axes.Axes
-
If argument is given, the plot will use the supplied +
Default is None. If argument is given, the plot will use the supplied axis for displaying data and the function will not return anything. Otherwise, a new matplotlib Figure and -Axes will be generated and returned. Default is None.
+Axes will be generated and returned.
cmap : matplotlib.colors.Colormap
-
Defines the colormap that the plot uses to map the average density values -to the colors of each plotted atom.
+
Default is None. Defines the colormap that the plot uses to map the average density values +to the colors of each plotted atom. When Default value is used a the resulting plot uses a Colormap that is given by +matplotlib.pyplot.cm.bwr which is gradient from red to blue with white in the middle. +
+ +

Returns

+
+
Tuple[Optional[matplotlib.figure.Figure],matplotlib.axes.Axes]]
+
returns the Figure and the Axes object used to create the plot if custom_axes +is not given, otherwise the function returns None
@@ -309,15 +326,20 @@

Args

register (AtomArrangement): The register used in creating the Hamiltonian. - with_labels (Boolean): Choose if each atom's index is displayed over the atom itself in the resulting figure. + with_labels (Boolean): Default is True. Choose if each atom's index is displayed over the atom itself in the resulting figure. Default is True. - custom_axes (matplotlib.axes.Axes): If argument is given, the plot will use the supplied + custom_axes (matplotlib.axes.Axes): Default is None. If argument is given, the plot will use the supplied axis for displaying data and the function will not return anything. Otherwise, a new matplotlib Figure and - Axes will be generated and returned. Default is None. - - cmap (matplotlib.colors.Colormap): Defines the colormap that the plot uses to map the average density values - to the colors of each plotted atom. + Axes will be generated and returned. + + cmap (matplotlib.colors.Colormap): Default is None. Defines the colormap that the plot uses to map the average density values + to the colors of each plotted atom. When Default value is used a the resulting plot uses a Colormap that is given by + `matplotlib.pyplot.cm.bwr` which is gradient from red to blue with white in the middle. + + Returns: + Tuple[Optional[matplotlib.figure.Figure],matplotlib.axes.Axes]]: returns the Figure and the Axes object used to create the plot if `custom_axes` + is not given, otherwise the function returns None """ # get atom coordinates @@ -339,9 +361,11 @@

Args

# construct plot if custom_axes is None: + return_fig = True fig, ax = plt.subplots() else: ax = custom_axes + return_fig = False nx.draw(g, pos, @@ -380,7 +404,12 @@

Args

cbar_label = "Rydberg Density" - plt.colorbar(sm, ax=ax, label=cbar_label) + plt.colorbar(sm, ax=ax, label=cbar_label) + + if return_fig: + return fig,ax + else: + return None,ax
@@ -472,7 +501,7 @@

Args

drive : DrivingField
The driving field to be plot
axes
-
matplotlib axis to draw on
+
Default is None. matplotlib axis to draw on
**plot_ops
options passed to matplitlib.pyplot.plot
@@ -484,7 +513,7 @@

Args

"""Plot the driving field Args: drive (DrivingField): The driving field to be plot - axes: matplotlib axis to draw on + axes: Default is None. matplotlib axis to draw on **plot_ops: options passed to matplitlib.pyplot.plot """ @@ -550,12 +579,11 @@

Args

register : AtomArrangement
A given register
blockade_radius : float
-
The blockade radius for the register. Default is 0
+
Default is 0. The blockade radius for the register.
what_to_draw : str
-
Either "bond" or "circle" to indicate the blockade region. -Default is "bond"
+
Default is "bond". Either "bond" or "circle" to indicate the blockade region.
show_atom_index : bool
-
Whether showing the indices of the atoms. Default is True
+
Default is True. Whether showing the indices of the atoms.
@@ -571,10 +599,9 @@

Args

Args: register (AtomArrangement): A given register - blockade_radius (float): The blockade radius for the register. Default is 0 - what_to_draw (str): Either "bond" or "circle" to indicate the blockade region. - Default is "bond" - show_atom_index (bool): Whether showing the indices of the atoms. Default is True + blockade_radius (float): Default is 0. The blockade radius for the register. + what_to_draw (str): Default is "bond". Either "bond" or "circle" to indicate the blockade region. + show_atom_index (bool): Default is True. Whether showing the indices of the atoms. """ filled_sites = [site.coordinate for site in register if site.site_type == SiteType.FILLED] @@ -635,7 +662,9 @@

Index

+ + \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e9b8e13..0c7c72a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "quera-ahs-utils" -version = "0.0.6" +version = "0.0.7" authors = [ { name="Phillip Weinberg", email="pweinberg@quera.com" }, { name="John Long", email="jlong@quera.com" } diff --git a/src/quera_ahs_utils/_version.py b/src/quera_ahs_utils/_version.py index fa9c4ec..2792152 100644 --- a/src/quera_ahs_utils/_version.py +++ b/src/quera_ahs_utils/_version.py @@ -1 +1 @@ -__version__ = '0.0.6' +__version__ = '0.0.7' diff --git a/src/quera_ahs_utils/ir.py b/src/quera_ahs_utils/ir.py index 72185c2..3aac6f7 100644 --- a/src/quera_ahs_utils/ir.py +++ b/src/quera_ahs_utils/ir.py @@ -27,6 +27,7 @@ def to_json_file(js:dict,json_filename:str,**json_options) -> NoReturn: Args: js (dict): data to be serialaized. json_filename (str): filename to output to. + **json_options: options that get passed to the json serializer as `json.dump(...,**json_options)` """ with open(json_filename,"w") as IO: diff --git a/src/quera_ahs_utils/plotting.py b/src/quera_ahs_utils/plotting.py index e81008f..51c0c6e 100644 --- a/src/quera_ahs_utils/plotting.py +++ b/src/quera_ahs_utils/plotting.py @@ -27,10 +27,9 @@ def show_register( Args: register (AtomArrangement): A given register - blockade_radius (float): The blockade radius for the register. Default is 0 - what_to_draw (str): Either "bond" or "circle" to indicate the blockade region. - Default is "bond" - show_atom_index (bool): Whether showing the indices of the atoms. Default is True + blockade_radius (float): Default is 0. The blockade radius for the register. + what_to_draw (str): Default is "bond". Either "bond" or "circle" to indicate the blockade region. + show_atom_index (bool): Default is True. Choose if each atom's index is displayed over the atom itself in the resulting figure. """ filled_sites = [site.coordinate for site in register if site.site_type == SiteType.FILLED] @@ -65,7 +64,7 @@ def show_global_drive(drive, axes=None, **plot_ops): """Plot the driving field Args: drive (DrivingField): The driving field to be plot - axes: matplotlib axis to draw on + axes: Default is None. matplotlib axis to draw on **plot_ops: options passed to matplitlib.pyplot.plot """ @@ -164,15 +163,20 @@ def plot_avg_density(densities, register, with_labels = True, custom_axes = None register (AtomArrangement): The register used in creating the Hamiltonian. - with_labels (Boolean): Choose if each atom's index is displayed over the atom itself in the resulting figure. + with_labels (Boolean): Default is True. Choose if each atom's index is displayed over the atom itself in the resulting figure. Default is True. - custom_axes (matplotlib.axes.Axes): If argument is given, the plot will use the supplied + custom_axes (matplotlib.axes.Axes): Default is None. If argument is given, the plot will use the supplied axis for displaying data and the function will not return anything. Otherwise, a new matplotlib Figure and - Axes will be generated and returned. Default is None. - - cmap (matplotlib.colors.Colormap): Defines the colormap that the plot uses to map the average density values - to the colors of each plotted atom. + Axes will be generated and returned. + + cmap (matplotlib.colors.Colormap): Default is None. Defines the colormap that the plot uses to map the average density values + to the colors of each plotted atom. When Default value is used a the resulting plot uses a Colormap that is given by + `matplotlib.pyplot.cm.bwr` which is gradient from red to blue with white in the middle. + + Returns: + Tuple[Optional[matplotlib.figure.Figure],matplotlib.axes.Axes]]: returns the Figure and the Axes object used to create the plot if `custom_axes` + is not given, otherwise the function returns None """ # get atom coordinates @@ -194,9 +198,11 @@ def plot_avg_density(densities, register, with_labels = True, custom_axes = None # construct plot if custom_axes is None: + return_fig = True fig, ax = plt.subplots() else: ax = custom_axes + return_fig = False nx.draw(g, pos, @@ -236,4 +242,9 @@ def plot_avg_density(densities, register, with_labels = True, custom_axes = None cbar_label = "Rydberg Density" plt.colorbar(sm, ax=ax, label=cbar_label) + + if return_fig: + return fig,ax + else: + return None,ax