Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Support updating all attributes of databricks_model_serving #4575

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

### Bug Fixes

* Support updating all attributes for `databricks_model_serving` ([#4575](https://github.com/databricks/terraform-provider-databricks/pull/4575)).

### Documentation

### Exporter
Expand Down
24 changes: 24 additions & 0 deletions serving/model_serving_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ func TestAccModelServing(t *testing.T) {
}
}
}
tags {
key = "key1"
value = "value-should-not-change"
}
tags {
key = "key2"
value = "value-should-change"
}
tags {
key = "key3"
value = "should-be-deleted"
}
}

data "databricks_serving_endpoints" "all" {}
Expand Down Expand Up @@ -79,6 +91,18 @@ func TestAccModelServing(t *testing.T) {
}
}
}
tags {
key = "key1"
value = "value-should-not-change"
}
tags {
key = "key2"
value = "value-should-change-to-something-new"
}
tags {
key = "key4"
value = "should-be-added"
}
}
data "databricks_serving_endpoints" "all" {}
`, name),
Expand Down
102 changes: 97 additions & 5 deletions serving/resource_model_serving.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"log"
"time"

"github.com/databricks/databricks-sdk-go"
"github.com/databricks/databricks-sdk-go/service/serving"
"github.com/databricks/terraform-provider-databricks/common"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
Expand All @@ -13,6 +14,85 @@ import (
const DefaultProvisionTimeout = 45 * time.Minute
const deleteCallTimeout = 10 * time.Second

// updateConfig updates the configuration of the provided serving endpoint to the provided config.
func updateConfig(ctx context.Context, w *databricks.WorkspaceClient, name string, e *serving.EndpointCoreConfigInput, d *schema.ResourceData) error {
e.Name = name
waiter, err := w.ServingEndpoints.UpdateConfig(ctx, *e)
if err != nil {
return err
}
_, err = waiter.GetWithTimeout(d.Timeout(schema.TimeoutUpdate))
if err != nil {
return err
}
return nil
}

// updateTags updates the tags of the provided serving endpoint to the given tags. Any tags not present on the existing
// endpoint will be removed, any tags absent on the endpoint will be added, existing tags will be updated, and unchanged
// tags will remain as-is.
func updateTags(ctx context.Context, w *databricks.WorkspaceClient, name string, newTags []serving.EndpointTag, d *schema.ResourceData) error {
currentEndpoint, err := w.ServingEndpoints.Get(ctx, serving.GetServingEndpointRequest{
Name: name,
})
oldTags := currentEndpoint.Tags
if err != nil {
return err
}
req := serving.PatchServingEndpointTags{
Name: name,
}
for _, newTag := range newTags {
found := false
for _, oldTag := range oldTags {
if oldTag.Key == newTag.Key && oldTag.Value == newTag.Value {
found = true
break
}
}
if !found {
req.AddTags = append(req.AddTags, newTag)
}
}
for _, oldTag := range oldTags {
found := false
for _, newTag := range newTags {
if oldTag.Key == newTag.Key {
found = true
break
}
}
if !found {
req.DeleteTags = append(req.DeleteTags, oldTag.Key)
}
}
if _, err := w.ServingEndpoints.Patch(ctx, req); err != nil {
return err
}
return nil
}

// Update the rate limit configuration for a model serving endpoint.
func updateRateLimits(ctx context.Context, w *databricks.WorkspaceClient, name string, newRateLimits []serving.RateLimit, d *schema.ResourceData) error {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this endpoint is now deprecated, you can remove it
This API is deprecated for your Foundation Model endpoint. Please use AI Gateway to manage rate limits.

_, err := w.ServingEndpoints.Put(ctx, serving.PutRequest{
Name: name,
RateLimits: newRateLimits,
})
return err
}

// Update the AI Gateway configuration for a model serving endpoint.
func updateAiGateway(ctx context.Context, w *databricks.WorkspaceClient, name string, newAiGateway serving.AiGatewayConfig, d *schema.ResourceData) error {
_, err := w.ServingEndpoints.PutAiGateway(ctx, serving.PutAiGatewayRequest{
Name: name,
Guardrails: newAiGateway.Guardrails,
InferenceTableConfig: newAiGateway.InferenceTableConfig,
RateLimits: newAiGateway.RateLimits,
UsageTrackingConfig: newAiGateway.UsageTrackingConfig,
})
return err
}

func ResourceModelServing() common.Resource {
s := common.StructToSchema(
serving.CreateServingEndpoint{},
Expand Down Expand Up @@ -43,6 +123,9 @@ func ResourceModelServing() common.Resource {
common.MustSchemaPath(m, "config", "served_entities", "workload_size").Computed = true
common.MustSchemaPath(m, "config", "served_entities", "workload_type").Computed = true

// route_optimized cannot be updated.
common.MustSchemaPath(m, "route_optimized").ForceNew = true

m["serving_endpoint_id"] = &schema.Schema{
Computed: true,
Type: schema.TypeString,
Expand Down Expand Up @@ -111,13 +194,22 @@ func ResourceModelServing() common.Resource {
var e serving.CreateServingEndpoint
common.DataToStructPointer(d, s, &e)
if d.HasChange("config") {
e.Config.Name = e.Name
waiter, err := w.ServingEndpoints.UpdateConfig(ctx, *e.Config)
if err != nil {
if err := updateConfig(ctx, w, e.Name, e.Config, d); err != nil {
return err
}
}
if d.HasChange("tags") {
if err := updateTags(ctx, w, e.Name, e.Tags, d); err != nil {
return err
}
}
if d.HasChange("rate_limits") {
if err := updateRateLimits(ctx, w, e.Name, e.RateLimits, d); err != nil {
return err
}
_, err = waiter.GetWithTimeout(d.Timeout(schema.TimeoutUpdate))
if err != nil {
}
if d.HasChange("ai_gateway") {
if err := updateAiGateway(ctx, w, e.Name, *e.AiGateway, d); err != nil {
return err
}
}
Expand Down
Loading