Skip to content

Commit 6e9c529

Browse files
committed
feat: implement unified plot saving mechanism across experiments; enhance visualization utilities for multiple formats
1 parent 7db0506 commit 6e9c529

File tree

10 files changed

+238
-85
lines changed

10 files changed

+238
-85
lines changed

src/oqtopus_experiments/core/base_experiment.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,51 @@ def save_experiment_summary(self) -> str:
133133
"""Save experiment summary"""
134134
return self.data_manager.summary()
135135

136+
def save_plot_figure(
137+
self,
138+
figure,
139+
experiment_name: str,
140+
save_formats: list[str] | None = None,
141+
) -> str | None:
142+
"""
143+
Save plot figure using data manager with unified .results/ organization
144+
145+
Args:
146+
figure: Plotly figure object to save
147+
experiment_name: Name of the experiment for filename
148+
save_formats: List of formats to save (default: ["png"])
149+
150+
Returns:
151+
Path to saved plot file (primary format) or None if saving failed
152+
153+
Note:
154+
Saves plots to .results/experiment_TIMESTAMP/plots/ directory
155+
via data manager for consistent file organization.
156+
"""
157+
if figure is None:
158+
return None
159+
160+
try:
161+
from ..utils.visualization import save_plotly_figure
162+
163+
# Use data manager to get plots directory
164+
plots_dir = self.data_manager.get_plots_directory()
165+
166+
# Default to PNG format
167+
if save_formats is None:
168+
save_formats = ["png"]
169+
170+
# Save figure using visualization utility
171+
saved_path = save_plotly_figure(
172+
figure, name=experiment_name, images_dir=plots_dir, formats=save_formats
173+
)
174+
175+
return saved_path
176+
177+
except Exception as e:
178+
print(f"Warning: Failed to save plot figure: {e}")
179+
return None
180+
136181
# Template method: overall experiment flow
137182
def run(
138183
self,

src/oqtopus_experiments/core/data_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ def get_data_directory(self) -> str:
153153
"""
154154
return f"{self.session_dir}/data"
155155

156+
def get_plots_directory(self) -> str:
157+
"""
158+
Get the plots directory path
159+
160+
Returns:
161+
Plots directory path
162+
"""
163+
return f"{self.session_dir}/plots"
164+
156165
def _convert_for_json(self, obj):
157166
"""Helper for JSON conversion"""
158167
if isinstance(obj, dict):

src/oqtopus_experiments/experiments/deutsch_jozsa.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,22 @@ def analyze(
102102

103103
# Optional actions
104104
if plot:
105-
self._create_plot(analysis_result, save_image)
105+
self._create_plot(analysis_result, save_image=False)
106106
if save_data:
107107
self._save_results(analysis_result)
108108

109+
# Handle plot saving via unified architecture
110+
if (
111+
save_image
112+
and hasattr(analysis_result, "_plot_figure")
113+
and analysis_result._plot_figure is not None
114+
):
115+
self.save_plot_figure(
116+
analysis_result._plot_figure,
117+
f"deutsch_jozsa_{self.n_qubits}qubits_{self.oracle_type}",
118+
save_formats=["png"],
119+
)
120+
109121
return df
110122

111123
def circuits(self, **kwargs: Any) -> list["QuantumCircuit"]:
@@ -310,7 +322,6 @@ def _create_plot(
310322
apply_experiment_layout,
311323
get_experiment_colors,
312324
get_plotly_config,
313-
save_plotly_figure,
314325
setup_plotly_environment,
315326
show_plotly_figure,
316327
)
@@ -464,18 +475,8 @@ def _create_plot(
464475
annotation_position="right",
465476
)
466477

467-
# Save and show
468-
if save_image:
469-
images_dir = (
470-
getattr(self.data_manager, "session_dir", "./images") + "/plots"
471-
)
472-
save_plotly_figure(
473-
fig,
474-
name=f"deutsch_jozsa_{self.n_qubits}qubits_{self.oracle_type}",
475-
images_dir=images_dir,
476-
width=1000,
477-
height=500,
478-
)
478+
# Store figure for experiment class to handle saving
479+
analysis_result._plot_figure = fig
479480

480481
config = get_plotly_config(
481482
f"deutsch_jozsa_{self.n_qubits}qubits", width=1000, height=500
@@ -486,14 +487,16 @@ def _create_plot(
486487
print(f"Failed to create plot: {e}")
487488

488489
def _save_results(self, analysis_result: DeutschJozsaAnalysisResult):
489-
"""Save analysis results to CSV file"""
490+
"""Save analysis results using data manager"""
490491
try:
491-
# Save to CSV using the DataFrame
492-
filename = (
493-
f"deutsch_jozsa_{self.n_qubits}qubits_{self.oracle_type}_results.csv"
492+
# Save to unified data structure via data manager
493+
data_dict = analysis_result.dataframe.to_dict(orient="records")
494+
metadata = analysis_result.metadata.copy()
495+
496+
self.save_experiment_data(
497+
data_dict, metadata=metadata, experiment_type="deutsch_jozsa"
494498
)
495-
analysis_result.dataframe.to_csv(filename, index=False)
496-
print(f"Results saved to {filename}")
499+
print(f"Results saved to {self.data_manager.session_dir}")
497500
except Exception as e:
498501
print(f"Failed to save results: {e}")
499502

src/oqtopus_experiments/experiments/randomized_benchmarking.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,13 +412,25 @@ def analyze(
412412
metadata={"experiment_type": "randomized_benchmarking"},
413413
)
414414

415-
# Perform analysis
415+
# Perform analysis (models return plot figure in metadata)
416416
df = rb_result.analyze(
417417
plot=plot,
418418
save_data=False,
419-
save_image=save_image, # Disable model's direct save
419+
save_image=False, # Don't let model save directly
420420
)
421421

422+
# Handle plot saving via experiment class
423+
if (
424+
save_image
425+
and hasattr(rb_result, "_plot_figure")
426+
and rb_result._plot_figure is not None
427+
):
428+
self.save_plot_figure(
429+
rb_result._plot_figure,
430+
"randomized_benchmarking_decay",
431+
save_formats=["png"],
432+
)
433+
422434
# Use standard experiment data saving if requested
423435
if save_data:
424436
self._save_rb_analysis(df, rb_result.data.fitting_result)

src/oqtopus_experiments/models/chsh_models.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,10 +163,15 @@ def analyze(
163163
# Update data with analysis results
164164
self.data.analysis_result = analysis_result
165165

166+
# Store plot figure for experiment classes
167+
self._plot_figure = None
168+
166169
# Generate plots if requested
167170
if plot:
168171
try:
169-
self._create_chsh_plot(analysis_result, save_image)
172+
self._plot_figure = self._create_chsh_plot(
173+
analysis_result, save_image
174+
)
170175
except Exception as e:
171176
print(f"Warning: Plot generation failed: {e}")
172177

@@ -203,6 +208,7 @@ def analyze(
203208
metadata={
204209
"bell_violation": analysis_result.bell_violation,
205210
"statistical_significance": f"{analysis_result.significance:.2f}σ",
211+
"plot_figure": getattr(self, "_plot_figure", None),
206212
},
207213
)
208214
return result.to_legacy_dataframe()
@@ -405,14 +411,21 @@ def _validate_analysis_quality(
405411
def _create_chsh_plot(
406412
self, analysis_result: CHSHAnalysisResult, save_image: bool = True
407413
):
408-
"""Create CHSH analysis plot following plot_settings.md guidelines"""
414+
"""
415+
Create CHSH analysis plot following plot_settings.md guidelines
416+
417+
Note: This method returns the figure object for experiment classes to handle saving.
418+
Model classes should not perform I/O operations directly.
419+
420+
Returns:
421+
plotly.graph_objects.Figure: The created figure object
422+
"""
409423
try:
410424
import plotly.graph_objects as go
411425

412426
from ..utils.visualization import (
413427
get_experiment_colors,
414428
get_plotly_config,
415-
save_plotly_figure,
416429
setup_plotly_environment,
417430
show_plotly_figure,
418431
)
@@ -514,15 +527,15 @@ def _create_chsh_plot(
514527
borderwidth=1,
515528
)
516529

517-
# Save and show plot
518-
if save_image:
519-
save_plotly_figure(fig, name="chsh_analysis", images_dir="./plots")
520-
530+
# Show plot only - saving handled by experiment classes
521531
config = get_plotly_config("chsh_analysis", 1000, 500)
522532
show_plotly_figure(fig, config)
523533

534+
return fig
535+
524536
except Exception as e:
525537
print(f"Warning: CHSH plot generation failed: {e}")
538+
return None
526539

527540

528541
class CHSHCircuitParams(BaseModel):

src/oqtopus_experiments/models/hadamard_test_models.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,11 @@ def analyze(
168168
# Generate plots if requested
169169
if plot:
170170
try:
171-
self._create_hadamard_test_plot(fitting_result, save_image)
171+
self._plot_figure = self._create_hadamard_test_plot(
172+
fitting_result, save_image
173+
)
172174
except Exception as e:
175+
self._plot_figure = None
173176
print(f"Warning: Plot generation failed: {e}")
174177

175178
# Create successful result
@@ -193,7 +196,11 @@ def analyze(
193196
result = AnalysisResult(
194197
success=True,
195198
data=analysis_data,
196-
metadata={"analysis_quality": "good", "error_info": None},
199+
metadata={
200+
"analysis_quality": "good",
201+
"error_info": None,
202+
"plot_figure": getattr(self, "_plot_figure", None),
203+
},
197204
)
198205
return result.to_legacy_dataframe()
199206

@@ -468,15 +475,22 @@ def _calculate_fitted_probabilities(
468475
def _create_hadamard_test_plot(
469476
self, fitting_result: HadamardTestFittingResult, save_image: bool = True
470477
):
471-
"""Create Hadamard Test plot following plot_settings.md guidelines"""
478+
"""
479+
Create Hadamard Test plot following plot_settings.md guidelines
480+
481+
Note: This method returns the figure object for experiment classes to handle saving.
482+
Model classes should not perform I/O operations directly.
483+
484+
Returns:
485+
plotly.graph_objects.Figure: The created figure object
486+
"""
472487
try:
473488
import numpy as np
474489
import plotly.graph_objects as go
475490

476491
from ..utils.visualization import (
477492
get_experiment_colors,
478493
get_plotly_config,
479-
save_plotly_figure,
480494
setup_plotly_environment,
481495
show_plotly_figure,
482496
)
@@ -632,14 +646,12 @@ def _create_hadamard_test_plot(
632646
borderwidth=1,
633647
)
634648

635-
# Save and show plot
636-
if save_image:
637-
save_plotly_figure(
638-
fig, name="hadamard_test_analysis", images_dir="./plots"
639-
)
640-
649+
# Show plot only - saving handled by experiment classes
641650
config = get_plotly_config("hadamard_test_analysis", 1000, 500)
642651
show_plotly_figure(fig, config)
643652

653+
return fig
654+
644655
except Exception as e:
645656
print(f"Warning: Hadamard Test plot generation failed: {e}")
657+
return None

src/oqtopus_experiments/models/ramsey_models.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,11 @@ def analyze(
148148
# Generate plots if requested
149149
if plot:
150150
try:
151-
self._create_ramsey_plot(fitting_result, save_image)
151+
self._plot_figure = self._create_ramsey_plot(
152+
fitting_result, save_image
153+
)
152154
except Exception as e:
155+
self._plot_figure = None
153156
print(f"Warning: Plot generation failed: {e}")
154157

155158
# Create successful result
@@ -171,7 +174,11 @@ def analyze(
171174
result = AnalysisResult(
172175
success=True,
173176
data=analysis_data,
174-
metadata={"fitting_quality": "good", "error_info": None},
177+
metadata={
178+
"fitting_quality": "good",
179+
"error_info": None,
180+
"plot_figure": getattr(self, "_plot_figure", None),
181+
},
175182
)
176183
return result.to_legacy_dataframe()
177184

@@ -423,7 +430,15 @@ def ramsey_func(t, amplitude, t2_star, offset, phase, frequency):
423430
def _create_ramsey_plot(
424431
self, fitting_result: RamseyFittingResult, save_image: bool = True
425432
):
426-
"""Create Ramsey fringe plot following plot_settings.md guidelines"""
433+
"""
434+
Create Ramsey fringe plot following plot_settings.md guidelines
435+
436+
Note: This method returns the figure object for experiment classes to handle saving.
437+
Model classes should not perform I/O operations directly.
438+
439+
Returns:
440+
plotly.graph_objects.Figure: The created figure object
441+
"""
427442
try:
428443
import numpy as np
429444
import plotly.graph_objects as go
@@ -432,7 +447,6 @@ def _create_ramsey_plot(
432447
apply_experiment_layout,
433448
get_experiment_colors,
434449
get_plotly_config,
435-
save_plotly_figure,
436450
setup_plotly_environment,
437451
show_plotly_figure,
438452
)
@@ -518,15 +532,15 @@ def _create_ramsey_plot(
518532
borderwidth=1,
519533
)
520534

521-
# Save and show plot
522-
if save_image:
523-
save_plotly_figure(fig, name="ramsey_fringes", images_dir="./plots")
524-
535+
# Show plot only - saving handled by experiment classes
525536
config = get_plotly_config("ramsey_fringes", 1000, 500)
526537
show_plotly_figure(fig, config)
527538

539+
return fig
540+
528541
except Exception as e:
529542
print(f"Warning: Ramsey plot generation failed: {e}")
543+
return None
530544

531545

532546
class RamseyCircuitParams(BaseModel):

0 commit comments

Comments
 (0)