Skip to content

Commit bfec3c3

Browse files
authored
Fix gRPC to calculate the correct class name for method request and response types (#7228)
* Fix gRPC to calculate the correct class name for method request and response types * Fix to use the correct gRPC method type when creating a method descriptor
1 parent 3dc2f5e commit bfec3c3

File tree

1 file changed

+24
-6
lines changed
  • nima/grpc/webserver/src/main/java/io/helidon/nima/grpc/webserver

1 file changed

+24
-6
lines changed

nima/grpc/webserver/src/main/java/io/helidon/nima/grpc/webserver/Grpc.java

+24-6
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,6 @@ private static <ResT, ReqT> Grpc<ReqT, ResT> grpc(Descriptors.FileDescriptor pro
109109

110110
Descriptors.ServiceDescriptor svc = proto.findServiceByName(serviceName);
111111
Descriptors.MethodDescriptor mtd = svc.findMethodByName(methodName);
112-
String pkg = proto.getOptions().getJavaPackage();
113-
pkg = "".equals(pkg) ? proto.getPackage() : pkg;
114-
String outerClass = getOuterClass(proto);
115112

116113
String path = svc.getFullName() + "/" + methodName;
117114

@@ -120,20 +117,28 @@ private static <ResT, ReqT> Grpc<ReqT, ResT> grpc(Descriptors.FileDescriptor pro
120117
- to load the class
121118
- to invoke a static method on it
122119
*/
123-
Class<ReqT> requestType = load(pkg + "." + outerClass + mtd.getInputType().getName().replace('.', '$'));
124-
Class<ResT> responsetype = load(pkg + "." + outerClass + mtd.getOutputType().getName().replace('.', '$'));
120+
Class<ReqT> requestType = load(getClassName(mtd.getInputType()));
121+
Class<ResT> responsetype = load(getClassName(mtd.getOutputType()));
125122

126123
MethodDescriptor.Marshaller<ReqT> reqMarshaller = ProtoMarshaller.get(requestType);
127124
MethodDescriptor.Marshaller<ResT> resMarshaller = ProtoMarshaller.get(responsetype);
128125

129126
io.grpc.MethodDescriptor.Builder<ReqT, ResT> grpcDesc = io.grpc.MethodDescriptor.<ReqT, ResT>newBuilder()
130127
.setFullMethodName(io.grpc.MethodDescriptor.generateFullMethodName(serviceName, methodName))
131-
.setType(io.grpc.MethodDescriptor.MethodType.UNARY).setFullMethodName(path).setRequestMarshaller(reqMarshaller)
128+
.setType(getMethodType(mtd)).setFullMethodName(path).setRequestMarshaller(reqMarshaller)
132129
.setResponseMarshaller(resMarshaller).setSampledToLocalTracing(true);
133130

134131
return new Grpc<>(grpcDesc.build(), PathMatchers.exact(path), requestType, responsetype, callHandler);
135132
}
136133

134+
private static String getClassName(Descriptors.Descriptor descriptor) {
135+
Descriptors.FileDescriptor fd = descriptor.getFile();
136+
String outerClass = getOuterClass(fd);
137+
String pkg = fd.getOptions().getJavaPackage();
138+
pkg = "".equals(pkg) ? fd.getPackage() : pkg;
139+
return pkg + "." + outerClass + descriptor.getName().replace('.', '$');
140+
}
141+
137142
@SuppressWarnings("unchecked")
138143
private static <T> Class<T> load(String className) {
139144
try {
@@ -173,4 +178,17 @@ private static String getOuterClassFromFileName(String name) {
173178

174179
return sb.toString();
175180
}
181+
182+
private static io.grpc.MethodDescriptor.MethodType getMethodType(Descriptors.MethodDescriptor mtd) {
183+
if (mtd.isClientStreaming()) {
184+
if (mtd.isServerStreaming()) {
185+
return MethodDescriptor.MethodType.BIDI_STREAMING;
186+
} else {
187+
return MethodDescriptor.MethodType.CLIENT_STREAMING;
188+
}
189+
} else if (mtd.isServerStreaming()) {
190+
return MethodDescriptor.MethodType.SERVER_STREAMING;
191+
}
192+
return MethodDescriptor.MethodType.UNARY;
193+
}
176194
}

0 commit comments

Comments
 (0)