4
4
#include < memory>
5
5
#include < string>
6
6
#include < utility>
7
- #include < vector>
8
7
9
8
#include " google/rpc/status.pb.h"
10
9
#include " absl/log/absl_log.h"
11
10
#include " absl/status/status.h"
12
11
#include " absl/status/statusor.h"
13
12
#include " absl/strings/match.h"
14
13
#include " absl/strings/str_cat.h"
15
- #include " absl/strings/str_split .h"
14
+ #include " absl/strings/str_format .h"
16
15
#include " absl/strings/string_view.h"
16
+ #include " cpp/jwt.h"
17
17
#include " curl/curl.h"
18
18
#include " curl/easy.h"
19
19
@@ -42,6 +42,24 @@ using Request =
42
42
using Response =
43
43
::google::cloud::agentcommunication::v1::StreamAgentMessagesResponse;
44
44
45
+ constexpr absl::string_view kAcsTokenEndpointGce =
46
+ " instance/service-accounts/default/"
47
+ " identity?audience=agentcommunication.googleapis.com&format=full" ;
48
+
49
+ // TODO: b/384093718 - Update the endpoint once the token endpoint is finalized.
50
+ constexpr absl::string_view kAcsTokenEndpointGke =
51
+ " instance/gke/agent-communication-service/ncclmetrics-token/"
52
+ " identity?audience=agentcommunication.googleapis.com&format=full" ;
53
+
54
+ // Internal helper struct to hold the ACS token and the parsed values from the
55
+ // token.
56
+ struct AcsToken {
57
+ std::string token;
58
+ std::string instance_id;
59
+ std::string project_number;
60
+ std::string zone;
61
+ };
62
+
45
63
std::unique_ptr<Request> MakeAck (std::string message_id) {
46
64
google::rpc::Status status;
47
65
status.set_code (0 );
@@ -143,36 +161,50 @@ absl::StatusOr<std::string> GetMetadata(absl::string_view key) {
143
161
" Metadata-Flavor: Google" );
144
162
}
145
163
146
- absl::StatusOr<AgentConnectionId> GenerateAgentConnectionId (
147
- std::string channel_id, bool regional) {
148
- absl::StatusOr<std::string> token = GetMetadata (
149
- " instance/service-accounts/default/"
150
- " identity?audience=agentcommunication.googleapis.com&format=full" );
164
+ absl::StatusOr<AcsToken> ParseAcsToken (absl::string_view endpoint) {
165
+ absl::StatusOr<std::string> token = GetMetadata (endpoint);
151
166
if (!token.ok ()) {
152
167
return token.status ();
153
168
}
154
169
ABSL_VLOG (2 ) << " Successfully got token from metadata service: " << *token;
155
-
156
- absl::StatusOr<std::string> numeric_project_id_zone =
157
- GetMetadata (" instance/zone" );
158
- if (!numeric_project_id_zone.ok ()) {
159
- return numeric_project_id_zone.status ();
170
+ absl::StatusOr<std::string> instance_id = GetValueFromTokenPayloadWithKeys (
171
+ *token, {" google" , " compute_engine" , " instance_id" });
172
+ if (!instance_id.ok ()) {
173
+ return instance_id.status ();
160
174
}
161
- std::vector<std::string> numeric_project_id_zone_vector =
162
- absl::StrSplit (*numeric_project_id_zone, ' /' );
163
- if (numeric_project_id_zone_vector.size () != 4 ) {
164
- ABSL_LOG (ERROR)
165
- << " Wrong format of numeric_project_id_zone from metadata service: "
166
- << numeric_project_id_zone;
167
- return absl::InternalError (absl::StrCat (
168
- " Wrong format of numeric_project_id_zone from metadata service: " ,
169
- *numeric_project_id_zone));
175
+ ABSL_VLOG (2 ) << " Successfully got instance_id from metadata service: "
176
+ << *instance_id;
177
+ absl::StatusOr<std::string> project_number = GetValueFromTokenPayloadWithKeys (
178
+ *token, {" google" , " compute_engine" , " project_number" });
179
+ if (!project_number.ok ()) {
180
+ return project_number.status ();
181
+ }
182
+ ABSL_VLOG (2 ) << " Successfully got project_number from metadata service: "
183
+ << *project_number;
184
+ absl::StatusOr<std::string> zone = GetValueFromTokenPayloadWithKeys (
185
+ *token, {" google" , " compute_engine" , " zone" });
186
+ if (!zone.ok ()) {
187
+ return zone.status ();
170
188
}
171
- ABSL_VLOG (2 )
172
- << " Successfully got numeric_project_id_zone from metadata service: "
173
- << *numeric_project_id_zone;
174
- const std::string& zone = numeric_project_id_zone_vector[3 ];
189
+ ABSL_VLOG (2 ) << " Successfully got zone from metadata service: " << *zone;
190
+ return AcsToken{.token = *std::move (token),
191
+ .instance_id = *std::move (instance_id),
192
+ .project_number = *std::move (project_number),
193
+ .zone = *std::move (zone)};
194
+ }
175
195
196
+ absl::StatusOr<AgentConnectionId> GenerateAgentConnectionId (
197
+ std::string channel_id, bool regional) {
198
+ absl::StatusOr<AcsToken> AcsToken = ParseAcsToken (kAcsTokenEndpointGce );
199
+ if (!AcsToken.ok ()) {
200
+ // If the token is not available from the GCE endpoint, try the GKE
201
+ // endpoint.
202
+ AcsToken = ParseAcsToken (kAcsTokenEndpointGke );
203
+ if (!AcsToken.ok ()) {
204
+ return AcsToken.status ();
205
+ }
206
+ }
207
+ const std::string& zone = AcsToken->zone ;
176
208
// Deduce the location from the zone.
177
209
// If regional is true, the location is the zone without the last two
178
210
// characters. Otherwise, the location is the zone itself.
@@ -189,16 +221,10 @@ absl::StatusOr<AgentConnectionId> GenerateAgentConnectionId(
189
221
" -agentcommunication.sandbox.googleapis.com:443" )
190
222
: absl::StrCat (location, " -agentcommunication.googleapis.com:443" );
191
223
192
- absl::StatusOr<std::string> instance_id = GetMetadata (" instance/id" );
193
- if (!instance_id.ok ()) {
194
- return instance_id.status ();
195
- }
196
- ABSL_VLOG (2 ) << " Successfully got instance_id from metadata service: "
197
- << *instance_id;
198
-
199
224
std::string resource_id =
200
- absl::StrCat (*numeric_project_id_zone, " /instances/" , *instance_id);
201
- return AgentConnectionId{.token = std::move (*token),
225
+ absl::StrFormat (" projects/%s/zones/%s/instances/%s" ,
226
+ AcsToken->project_number , zone, AcsToken->instance_id );
227
+ return AgentConnectionId{.token = std::move (AcsToken->token ),
202
228
.resource_id = std::move (resource_id),
203
229
.channel_id = std::move (channel_id),
204
230
.endpoint = std::move (endpoint),
0 commit comments