Skip to content

Commit ade743b

Browse files
authored
Merge pull request #2364 from SCIInstitute/sw_monai_mo
Add Support for Multi-Organ Segmentation for AI-Assisted Segmentation Feature in ShapeWorks
2 parents 5078403 + 6b70224 commit ade743b

12 files changed

+301
-14
lines changed

Studio/ShapeWorksMONAI/MonaiLabelJob.cpp

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,13 @@ py::dict MonaiLabelJob::getParamsFromConfig(std::string section,
191191
if (py::isinstance<py::list>(value)) {
192192
py::list valueList = value.cast<py::list>();
193193
if (!valueList.empty()) {
194-
result[key] = py::str(valueList[0]);
194+
// result[key] = py::str(valueList[0]);
195+
result[key] = valueList[0];
196+
195197
}
196198
} else {
197-
result[key] = py::str(value);
199+
// result[key] = py::str(value);
200+
result[key] = value;
198201
}
199202
}
200203
}
@@ -265,6 +268,7 @@ py::tuple MonaiLabelJob::infer(std::string model, std::string image_in,
265268
label_in.empty() ? py::none() : py::cast(label_in),
266269
file.empty() ? py::none() : py::cast(file),
267270
session_id.empty() ? py::none() : py::cast(session_id));
271+
// std::cout << "DEBUG | infer call successfully made " << py::repr(result).cast<std::string>() << std::endl;
268272
// SW_DEBUG("Infer response: " + py::repr(result).cast<std::string>());
269273
}
270274

@@ -375,13 +379,30 @@ void MonaiLabelJob::onRunSegmentationClicked() {
375379
SW_ERROR("Sample not uploaded yet!");
376380
return;
377381
}
382+
SW_LOG("⚙️ Processing inference on the current subject");
378383
py::dict params = getParamsFromConfig("infer", model_name_);
379384

380385
py::tuple result =
381386
infer(model_name_, currentSampleId_, params, "", "", getSessionId());
382387

383-
currentSegmentationPath_ =
384-
result[0].cast<std::string>(); // temp result for segmentation
388+
currentSegmentationPath_ = result[0].cast<std::string>();
389+
py::dict result_params = result[1].cast<py::dict>();
390+
391+
// Extract label names from result_params
392+
py::dict label_dict = result_params["label_names"].cast<py::dict>();
393+
std::map<int, std::string> organLabels;
394+
organNames_.resize(0);
395+
396+
for (auto &item : label_dict) {
397+
std::string organName = item.first.cast<std::string>();
398+
int label = item.second.cast<int>();
399+
if (label > 0) { // Exclude background (0)
400+
organLabels[label] = organName;
401+
organNames_.push_back(organName);
402+
}
403+
}
404+
405+
MonaiLabelUtils::processSegmentation(currentSegmentationPath_, organLabels, tmp_dir_, currentSampleId_, currentSegmentationPaths_);
385406

386407
QDir projDir(QString::fromStdString(tmp_dir_));
387408
QString destPath =
@@ -417,7 +438,7 @@ void MonaiLabelJob::onSubmitLabelClicked() {
417438
std::string label_in = currentSegmentationPath_;
418439

419440
entry[py::str("name")] = image_in;
420-
entry[py::str("idx")] = 1; // TODO: handle multi-organ label submission
441+
entry[py::str("idx")] = 1;
421442
label_info.append(entry);
422443

423444
py::dict params;
@@ -430,13 +451,14 @@ void MonaiLabelJob::onSubmitLabelClicked() {
430451

431452
//---------------------------------------------------------------------------
432453
void MonaiLabelJob::updateShapes() {
433-
if (!currentSampleId_.empty() && !currentSegmentationPath_.empty()) {
454+
if (!currentSampleId_.empty() && !currentSegmentationPaths_.empty()) {
434455
auto shapes = session_->get_shapes();
435-
456+
session_->get_project()->set_domain_names(organNames_);
436457
if (sample_number_ < shapes.size()) {
437458
auto cur_shape = shapes[sample_number_];
438459
auto cur_subject = cur_shape->get_subject();
439-
cur_subject->set_original_filenames({currentSegmentationPath_});
460+
cur_subject->set_number_of_domains(currentSegmentationPaths_.size());
461+
cur_subject->set_original_filenames(currentSegmentationPaths_);
440462
}
441463

442464
} else {

Studio/ShapeWorksMONAI/MonaiLabelJob.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ class MonaiLabelJob : public Job {
107107
int sample_number_;
108108
std::string currentSampleId_;
109109
std::string currentSegmentationPath_;
110+
std::vector<std::string> currentSegmentationPaths_;
111+
std::vector<std::string> organNames_;
110112

111113
QSharedPointer<Session> session_;
112114
ProjectHandle project_;

Studio/ShapeWorksMONAI/MonaiLabelTool.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,11 @@ void MonaiLabelTool::onConnectServer() {
105105
"establishing connection with MONAI Label server");
106106
return;
107107
}
108-
SW_LOG("Connecting to MONAI Label Server...")
108+
SW_LOG("Connecting to MONAI Label Server...")
109109
ui_->connectServerButton->setText("Connecting...");
110110
ui_->connectServerButton->setEnabled(false);
111111
loadParamsFromUi();
112112
if (model_type_ == MONAI_MODE_SEGMENTATION) {
113-
SW_LOG("Connecting to the server...");
114113
runSegmentationTool();
115114
} else {
116115
SW_ERROR(
@@ -194,7 +193,7 @@ void MonaiLabelTool::runSegmentationTool() {
194193

195194
//---------------------------------------------------------------------------
196195
void MonaiLabelTool::handleClientInitialized() {
197-
SW_LOG("Connection successfully established to the server, continue with segmentation!");
196+
SW_LOG("Connection successfully established to the server, continue with segmentation!");
198197
tool_is_running_ = true;
199198
if (session_->get_shapes().size() > 1)
200199
ui_->uploadSampleButton->setEnabled(true);
@@ -221,7 +220,7 @@ void MonaiLabelTool::handleClientInitialized() {
221220

222221
//---------------------------------------------------------------------------
223222
void MonaiLabelTool::handleUploadSampleCompleted() {
224-
SW_LOG("Upload complete! Run {} model on the uploaded sample.", model_type_);
223+
SW_LOG("Upload complete! Run {} model on the uploaded sample.", model_type_);
225224
ui_->uploadSampleButton->setEnabled(false);
226225
ui_->runSegmentationButton->setEnabled(true);
227226
ui_->submitLabelButton->setEnabled(false);
@@ -237,6 +236,9 @@ void MonaiLabelTool::handleSegmentationCompleted() {
237236
ui_->runSegmentationButton->setEnabled(false);
238237
ui_->submitLabelButton->setEnabled(true);
239238
session_->get_project()->save();
239+
SW_LOG(
240+
"✅ Segmentation for the current sample done! Submit the prediction label to server or "
241+
"proceed with next sample!");
240242
Q_EMIT progress(66);
241243
}
242244

@@ -250,7 +252,7 @@ void MonaiLabelTool::handleSubmitLabelCompleted() {
250252
ui_->uploadSampleButton->setEnabled(false);
251253
ui_->runSegmentationButton->setEnabled(false);
252254
ui_->submitLabelButton->setEnabled(false);
253-
SW_LOG("Label submitted to the server. Proceed with next source volume.")
255+
SW_LOG("Label submitted to the server. Proceed with next sample.")
254256
samples_processed_++;
255257
// Q_EMIT
256258
// progress((int)(samples_processed_/session_->get_shapes().size())*100);

Studio/ShapeWorksMONAI/MonaiLabelUtils.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,17 @@
66
#include <QFileInfo>
77
#include <QSharedPointer>
88

9+
#include <itkImageFileReader.h>
10+
#include <itkImageFileWriter.h>
11+
#include <itkBinaryThresholdImageFilter.h>
12+
#include <itkCastImageFilter.h>
13+
#include <itkImageRegionIterator.h>
914
#include <Data/Session.h>
1015

1116
using namespace shapeworks;
1217

18+
typedef float PixelType;
19+
typedef itk::Image< PixelType, 3 > ImageType;
1320
namespace monailabel {
1421

1522
bool MonaiLabelUtils::createDir(const QString& dirPath) {
@@ -45,4 +52,94 @@ std::string MonaiLabelUtils::getFeatureName(QSharedPointer<Session> session) {
4552
return feature_name;
4653
}
4754

55+
//---------------------------------------------------------------------------
56+
ImageType::Pointer MonaiLabelUtils::loadNRRD(const std::string& filePath) {
57+
using ReaderType = itk::ImageFileReader<ImageType>;
58+
ReaderType::Pointer reader = ReaderType::New();
59+
reader->SetFileName(filePath);
60+
reader->Update();
61+
return reader->GetOutput();
62+
}
63+
64+
//---------------------------------------------------------------------------
65+
bool MonaiLabelUtils::isOrganPresent(ImageType::Pointer image) {
66+
itk::ImageRegionIterator<ImageType> it(image, image->GetRequestedRegion());
67+
while (!it.IsAtEnd()) {
68+
if (it.Get() > 0) { // If any non-background pixel is found
69+
return true;
70+
}
71+
++it;
72+
}
73+
return false;
74+
}
75+
//---------------------------------------------------------------------------
76+
ImageType::Pointer MonaiLabelUtils::extractOrganSegmentation(ImageType::Pointer inputImage, int label) {
77+
using ThresholdFilterType = itk::BinaryThresholdImageFilter<ImageType, ImageType>;
78+
ThresholdFilterType::Pointer thresholdFilter = ThresholdFilterType::New();
79+
thresholdFilter->SetInput(inputImage);
80+
thresholdFilter->SetLowerThreshold(label);
81+
thresholdFilter->SetUpperThreshold(label);
82+
// thresholdFilter->SetInsideValue(label);
83+
thresholdFilter->SetInsideValue(1); // dont save as label
84+
thresholdFilter->SetOutsideValue(0); // Background set to 0
85+
thresholdFilter->Update();
86+
87+
ImageType::Pointer organImage = thresholdFilter->GetOutput();
88+
// Check if organ is present in segmentation
89+
if (thresholdFilter->GetOutput()->GetBufferedRegion().GetNumberOfPixels() == 0) {
90+
return nullptr; // Return null if the organ is not present
91+
} // save all
92+
// if (!isOrganPresent(organImage)) {
93+
// return nullptr; // Return nullptr if no valid organ pixels exist
94+
// }
95+
// return thresholdFilter->GetOutput();
96+
return organImage;
97+
}
98+
99+
//---------------------------------------------------------------------------
100+
void MonaiLabelUtils::saveNRRD(ImageType::Pointer image, const std::string& outputPath) {
101+
using WriterType = itk::ImageFileWriter<ImageType>;
102+
WriterType::Pointer writer = WriterType::New();
103+
writer->SetFileName(outputPath);
104+
writer->SetInput(image);
105+
writer->UseCompressionOn();
106+
writer->Update();
107+
}
108+
109+
//---------------------------------------------------------------------------
110+
void MonaiLabelUtils::processSegmentation(
111+
const std::string& segmentationPath,
112+
const std::map<int, std::string>& organLabels, const std::string& outputDir,
113+
const std::string& sampleId,
114+
std::vector<std::string>& organSegmentationPaths) {
115+
116+
organSegmentationPaths.resize(0);
117+
ImageType::Pointer inputImage = loadNRRD(segmentationPath);
118+
if (!inputImage) {
119+
SW_ERROR("Failed to load segmentation file: {}", segmentationPath);
120+
return;
121+
}
122+
123+
QDir projDir(QString::fromStdString(outputDir));
124+
// if (!projDir.exists()) {
125+
// projDir.mkpath(".");
126+
// }
127+
128+
// Extract and save each organ segmentation
129+
for (const auto& [label, organName] : organLabels) {
130+
ImageType::Pointer organImage = extractOrganSegmentation(inputImage, label);
131+
132+
if (!organImage) {
133+
SW_LOG("Warning: {} (Label {}) not found in segmentation.", organName, label);
134+
continue;
135+
}
136+
137+
QString destPath = projDir.filePath(
138+
QString::fromStdString(sampleId + "_" + organName + ".nrrd"));
139+
saveNRRD(organImage, destPath.toStdString());
140+
SW_LOG("✅ Saved segmented organ: {}", destPath.toStdString());
141+
organSegmentationPaths.push_back(destPath.toStdString());
142+
}
143+
}
144+
48145
} // namespace monailabel

Studio/ShapeWorksMONAI/MonaiLabelUtils.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
#include <QValidator>
88
#include <QSharedPointer>
99
#include <Data/Session.h>
10+
#include <itkImage.h>
1011

1112
namespace shapeworks {
1213
class Session;
1314
}
1415
namespace monailabel {
16+
17+
typedef float PixelType;
18+
typedef itk::Image< PixelType, 3 > ImageType;
1519

1620
class UrlValidator : public QValidator {
1721
public:
@@ -45,7 +49,19 @@ class MonaiLabelUtils {
4549
static bool copySegmentation(const QString& sourcePath,
4650
const QString& destinationPath);
4751
static bool deleteTempFile(const QString& filePath);
48-
static std::string getFeatureName(QSharedPointer<shapeworks::Session> session);
52+
static std::string getFeatureName(
53+
QSharedPointer<shapeworks::Session> session);
54+
static ImageType::Pointer loadNRRD(const std::string& filePath);
55+
static ImageType::Pointer extractOrganSegmentation(ImageType::Pointer inputImage,
56+
int label);
57+
static void saveNRRD(ImageType::Pointer image, const std::string& outputPath);
58+
static bool isOrganPresent(ImageType::Pointer image);
59+
60+
static void processSegmentation(const std::string& segmentationPath,
61+
const std::map<int, std::string>& organLabels,
62+
const std::string& outputDir,
63+
const std::string& sampleId,
64+
std::vector<std::string>& organSegmentationPaths);
4965
};
5066

5167
} // namespace monailabel
113 KB
Loading
3.74 MB
Loading
3.58 MB
Loading
3.61 MB
Loading
5.06 MB
Loading

docs/new/ai-assisted-segmentation.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
2+
# AI-Assisted Segmentation in ShapeWorks
3+
4+
## Getting Started with MONAI Label in ShapeWorks
5+
6+
[`Medical Open Network for AI (MONAI) Label`](https://monai.io/) is a deep learning framework designed for efficient annotation and segmentation of medical images.
7+
8+
## What’s New?
9+
ShapeWorks Studio now integrates MONAI Label, enabling seamless access to fully automated and interactive deep learning models for segmenting radiology images across various modalities.
10+
11+
For a detailed demo and step-by-step instructions on using MONAI Label within ShapeWorks Studio, refer to the following guide:
12+
13+
### [Getting Started with AI-Assisted Segmentation](../studio/ai-assisted-segmentation.md)

0 commit comments

Comments
 (0)