diff --git a/pkg/db/v1beta1/common/const.go b/pkg/db/v1beta1/common/const.go index 3bdb09da66e..9754f9a65dd 100644 --- a/pkg/db/v1beta1/common/const.go +++ b/pkg/db/v1beta1/common/const.go @@ -46,4 +46,6 @@ const ( DefaultPostgreSQLDatabase = "katib" DefaultPostgreSQLHost = "katib-postgres" DefaultPostgreSQLPort = "5432" + + SkipDbInitializationEnvName = "SKIP_DB_INITIALIZATION" ) diff --git a/pkg/db/v1beta1/mysql/init.go b/pkg/db/v1beta1/mysql/init.go index 30468062ea7..fe2a2987f0b 100644 --- a/pkg/db/v1beta1/mysql/init.go +++ b/pkg/db/v1beta1/mysql/init.go @@ -20,20 +20,34 @@ import ( "fmt" "k8s.io/klog" + + "github.com/kubeflow/katib/pkg/db/v1beta1/common" + "github.com/kubeflow/katib/pkg/util/v1beta1/env" ) func (d *dbConn) DBInit() { db := d.db - klog.Info("Initializing v1beta1 DB schema") + skipDbInitialization := env.GetEnvOrDefault(common.SkipDbInitializationEnvName, "false") + + if skipDbInitialization == "false" { + klog.Info("Initializing v1beta1 DB schema") - _, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs + _, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs (trial_name VARCHAR(255) NOT NULL, id INT AUTO_INCREMENT PRIMARY KEY, time DATETIME(6), metric_name VARCHAR(255) NOT NULL, value TEXT NOT NULL)`) - if err != nil { - klog.Fatalf("Error creating observation_logs table: %v", err) + if err != nil { + klog.Fatalf("Error creating observation_logs table: %v", err) + } + } else { + klog.Info("Skipping v1beta1 DB schema initialization.") + + _, err := db.Query(`SELECT trial_name, id, time, metric_name, value FROM observation_logs LIMIT 1`) + if err != nil { + klog.Fatalf("Error validating observation_logs table: %v", err) + } } } diff --git a/pkg/db/v1beta1/postgres/init.go b/pkg/db/v1beta1/postgres/init.go index 3ebfad40a46..71c573128c2 100644 --- a/pkg/db/v1beta1/postgres/init.go +++ b/pkg/db/v1beta1/postgres/init.go @@ -20,20 +20,34 @@ import ( "fmt" "k8s.io/klog" + + "github.com/kubeflow/katib/pkg/db/v1beta1/common" + "github.com/kubeflow/katib/pkg/util/v1beta1/env" ) func (d *dbConn) DBInit() { db := d.db - klog.Info("Initializing v1beta1 DB schema") + skipDbInitialization := env.GetEnvOrDefault(common.SkipDbInitializationEnvName, "false") + + if skipDbInitialization == "false" { + klog.Info("Initializing v1beta1 DB schema") - _, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs + _, err := db.Exec(`CREATE TABLE IF NOT EXISTS observation_logs (trial_name VARCHAR(255) NOT NULL, id serial PRIMARY KEY, time TIMESTAMP(6), metric_name VARCHAR(255) NOT NULL, value TEXT NOT NULL)`) - if err != nil { - klog.Fatalf("Error creating observation_logs table: %v", err) + if err != nil { + klog.Fatalf("Error creating observation_logs table: %v", err) + } + } else { + klog.Info("Skipping v1beta1 DB schema initialization.") + + _, err := db.Query(`SELECT trial_name, id, time, metric_name, value FROM observation_logs LIMIT 1`) + if err != nil { + klog.Fatalf("Error validating observation_logs table: %v", err) + } } } diff --git a/pkg/util/v1beta1/env/env.go b/pkg/util/v1beta1/env/env.go index 114e58ed4b0..d9d7516b328 100644 --- a/pkg/util/v1beta1/env/env.go +++ b/pkg/util/v1beta1/env/env.go @@ -16,7 +16,9 @@ limitations under the License. package env -import "os" +import ( + "os" +) func GetEnvOrDefault(key string, fallback string) string { if value, ok := os.LookupEnv(key); ok {