diff --git a/README.md b/README.md index 5cc7528..754e744 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,11 @@ typed fields in a struct, and supports required and optional vars with defaults. go-envvar is inspired by the javascript library https://github.com/plaid/envvar. +go-envvar supports fields of most primative types (e.g. int, string, bool, +float64) as well as any type which implements the +[encoding.TextUnmarshaler](https://golang.org/pkg/encoding/#TextUnmarshaler) +interface. + ## Example Usage ```go @@ -28,9 +33,10 @@ type serverEnvVars struct { MaxConns uint `envvar:"MAX_CONNECTIONS" default:"100"` // Similar to GO_PORT, HOST_NAME is required. HostName string `envvar:"HOST_NAME"` - // envvar struct tag is not required if the field name - // matches the envvar name. - HOST_NAME string + // Time values are also supported. Parse uses the UnmarshalText method of + // time.Time in order to set the value of the field. In this case, the + // UnmarshalText method expects the string value to be in RFC 3339 format. + StartTime time.Time `envvar:"START_TIME" default:"2017-10-31T14:18:00Z"` } func main() { diff --git a/envvar/envvar.go b/envvar/envvar.go index 8b8e473..514faaf 100644 --- a/envvar/envvar.go +++ b/envvar/envvar.go @@ -4,6 +4,7 @@ package envvar import ( + "encoding" "fmt" "reflect" "strconv" @@ -33,6 +34,9 @@ import ( // Parse will return an UnsetVariableError if a required environment variable // was not set. It will also return an error if there was a problem converting // environment variable values to the proper type or setting the fields of v. +// +// If a field of v implements the encoding.TextUnmarshaler interface, Parse will +// call the UnmarshalText method on the field in order to set its value. func Parse(v interface{}) error { // Make sure the type of v is what we expect. typ := reflect.TypeOf(v) @@ -147,6 +151,42 @@ func (e ErrorList) Error() string { // setFieldVal first converts v to the type of structField, then uses reflection // to set the field to the converted value. func setFieldVal(structField reflect.Value, name string, v string) error { + + // Check if the struct field type implements the encoding.TextUnmarshaler + // interface. + if structField.Type().Implements(reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()) { + // Call the UnmarshalText method using reflection. + results := structField.MethodByName("UnmarshalText").Call([]reflect.Value{reflect.ValueOf([]byte(v))}) + if !results[0].IsNil() { + err := results[0].Interface().(error) + return InvalidVariableError{name, v, err} + } + return nil + } + + // Check if *a pointer to* the struct field type implements the + // encoding.TextUnmarshaler interface. If it does and the struct value is + // addressable, call the UnmarshalText method using reflection. + if reflect.PtrTo(structField.Type()).Implements(reflect.TypeOf([]encoding.TextUnmarshaler{}).Elem()) { + // CanAddr tells us if reflect is able to get a pointer to the struct field + // value. This should always be true, because the Parse method is strict + // about accepting a pointer to a struct type. However, if it's not true the + // Addr() call will panic, so it is good practice to leave this check in + // place. (In the reflect package, a struct field is considered addressable + // if we originally received a pointer to the struct type). + if structField.CanAddr() { + results := structField.Addr().MethodByName("UnmarshalText").Call([]reflect.Value{reflect.ValueOf([]byte(v))}) + if !results[0].IsNil() { + err := results[0].Interface().(error) + return InvalidVariableError{name, v, err} + } + return nil + } + } + + // If the field type does not implement the encoding.TextUnmarshaler + // interface, we can try decoding some basic primitive types and setting the + // value of the struct field with reflection. switch structField.Kind() { case reflect.String: structField.SetString(v) diff --git a/envvar/envvar_test.go b/envvar/envvar_test.go index 856fd22..042f9ac 100644 --- a/envvar/envvar_test.go +++ b/envvar/envvar_test.go @@ -1,12 +1,15 @@ package envvar import ( + "errors" "fmt" "os" "reflect" "regexp" "runtime" + "strings" "testing" + "time" ) func TestParse(t *testing.T) { @@ -25,6 +28,9 @@ func TestParse(t *testing.T) { "FLOAT32": "0.001234", "FLOAT64": "23.7", "BOOL": "true", + "TIME": "2017-10-31T14:18:00Z", + "CUSTOM": "foo,bar,baz", + "WRAPPER": "a,b,c", } expected := typedVars{ STRING: "foo", @@ -41,8 +47,24 @@ func TestParse(t *testing.T) { FLOAT32: 0.001234, FLOAT64: 23.7, BOOL: true, + TIME: time.Date(2017, 10, 31, 14, 18, 0, 0, time.UTC), + CUSTOM: customUnmarshaler{ + strings: []string{"foo", "bar", "baz"}, + }, + WRAPPER: customUnmarshalerWrapper{ + um: &customUnmarshaler{ + strings: []string{"a", "b", "c"}, + }, + }, + } + // Note that we have to initialize the WRAPPER type so that its field is + // non-nil. No other types need to be initialized. + holder := &typedVars{ + WRAPPER: customUnmarshalerWrapper{ + um: &customUnmarshaler{}, + }, } - testParse(t, vars, &typedVars{}, expected) + testParse(t, vars, holder, expected) } func TestParseCustomNames(t *testing.T) { @@ -77,8 +99,24 @@ func TestParseDefaultVals(t *testing.T) { FLOAT32: 0.001234, FLOAT64: 23.7, BOOL: true, + TIME: time.Date(1992, 9, 29, 0, 0, 0, 0, time.UTC), + CUSTOM: customUnmarshaler{ + strings: []string{"one", "two", "three"}, + }, + WRAPPER: customUnmarshalerWrapper{ + um: &customUnmarshaler{ + strings: []string{"apple", "banana", "cranberry"}, + }, + }, } - testParse(t, nil, &defaultVars{}, expected) + // Note that we have to initialize the WRAPPER type so that its field is + // non-nil. No other types need to be initialized. + holder := &defaultVars{ + WRAPPER: customUnmarshalerWrapper{ + um: &customUnmarshaler{}, + }, + } + testParse(t, nil, holder, expected) } func TestParseCustomNameAndDefaultVal(t *testing.T) { @@ -121,7 +159,7 @@ func TestParseRequiredVars(t *testing.T) { } } -func TestParseWithInvalidArgs(t *testing.T) { +func TestParseErrors(t *testing.T) { testCases := []struct { holder interface{} expectedError string @@ -228,6 +266,77 @@ func expectInvalidVariableError(t *testing.T, err error) { } } +func TestUnmarshalTextError(t *testing.T) { + holder := &alwaysErrorVars{} + err := setFieldVal(reflect.ValueOf(holder).Elem().Field(0), "alwaysError", "") + if err == nil { + t.Errorf("Expected InvalidVariableError, but got nil error") + } else if _, ok := err.(InvalidVariableError); !ok { + t.Errorf("Expected InvalidVariableError, but got %s", err.Error()) + } +} + +func TestUnmarshalTextErrorPtr(t *testing.T) { + holder := &alwaysErrorVarsPtr{} + err := setFieldVal(reflect.ValueOf(holder).Elem().Field(0), "alwaysErrorPtr", "") + if err == nil { + t.Errorf("Expected InvalidVariableError, but got nil error") + } else if _, ok := err.(InvalidVariableError); !ok { + t.Errorf("Expected InvalidVariableError, but got %s", err.Error()) + } +} + +// customUnmarshaler implements the UnmarshalText method. +type customUnmarshaler struct { + strings []string +} + +// UnmarshalText simply splits the text by the separator: ",". +func (cu *customUnmarshaler) UnmarshalText(text []byte) error { + cu.strings = strings.Split(string(text), ",") + return nil +} + +// customUnmarshalerWrapper also implements the UnmarshalText method by calling +// it on its own *customUnmarshaler. +type customUnmarshalerWrapper struct { + um *customUnmarshaler +} + +// UnmarshalText simply calls um.UnmarshalText. Note that here we use a +// non-pointer receiver. It still works because the um field is a pointer. We +// just need to be sure to check if um is nil first. +func (cuw customUnmarshalerWrapper) UnmarshalText(text []byte) error { + if cuw.um == nil { + return nil + } + return cuw.um.UnmarshalText(text) +} + +// alwaysErrorUnmarshaler implements the UnmarshalText method by always +// returning an error. +type alwaysErrorUnmarshaler struct{} + +func (aeu alwaysErrorUnmarshaler) UnmarshalText(text []byte) error { + return errors.New("this function always returns an error") +} + +type alwaysErrorVars struct { + AlwaysError alwaysErrorUnmarshaler +} + +// alwaysErrorUnmarshalerPtr is like alwaysErrorUnmarshaler but implements +// the UnmarshalText method with a pointer receiver. +type alwaysErrorUnmarshalerPtr struct{} + +func (aue *alwaysErrorUnmarshalerPtr) UnmarshalText(text []byte) error { + return errors.New("this function always returns an error") +} + +type alwaysErrorVarsPtr struct { + AlwaysErrorPtr alwaysErrorUnmarshalerPtr +} + type typedVars struct { STRING string INT int @@ -243,6 +352,9 @@ type typedVars struct { FLOAT32 float32 FLOAT64 float64 BOOL bool + TIME time.Time + CUSTOM customUnmarshaler + WRAPPER customUnmarshalerWrapper } type customNamedVars struct { @@ -253,20 +365,23 @@ type customNamedVars struct { } type defaultVars struct { - STRING string `default:"foo"` - INT int `default:"272309480983"` - INT8 int8 `default:"-4"` - INT16 int16 `default:"15893"` - INT32 int32 `default:"-230984"` - INT64 int64 `default:"12"` - UINT uint `default:"42"` - UINT8 uint8 `default:"13"` - UINT16 uint16 `default:"1337"` - UINT32 uint32 `default:"348904"` - UINT64 uint64 `default:"12093803"` - FLOAT32 float32 `default:"0.001234"` - FLOAT64 float64 `default:"23.7"` - BOOL bool `default:"true"` + STRING string `default:"foo"` + INT int `default:"272309480983"` + INT8 int8 `default:"-4"` + INT16 int16 `default:"15893"` + INT32 int32 `default:"-230984"` + INT64 int64 `default:"12"` + UINT uint `default:"42"` + UINT8 uint8 `default:"13"` + UINT16 uint16 `default:"1337"` + UINT32 uint32 `default:"348904"` + UINT64 uint64 `default:"12093803"` + FLOAT32 float32 `default:"0.001234"` + FLOAT64 float64 `default:"23.7"` + BOOL bool `default:"true"` + TIME time.Time `default:"1992-09-29T00:00:00Z"` + CUSTOM customUnmarshaler `default:"one,two,three"` + WRAPPER customUnmarshalerWrapper `default:"apple,banana,cranberry"` } type customNameAndDefaultVars struct {