From 16d5aeac7c2de010d30cddc90c5755ac5b989b2b Mon Sep 17 00:00:00 2001 From: Duco van Amstel Date: Tue, 13 Nov 2018 22:30:56 +0000 Subject: Make config.Config more unit-test friendly (#586) --- bridge/bridge.go | 37 +++++++++------- bridge/config/config.go | 112 +++++++++++++++++++++++++++++++++++++----------- 2 files changed, 109 insertions(+), 40 deletions(-) (limited to 'bridge') diff --git a/bridge/bridge.go b/bridge/bridge.go index 0436eeb6..debe2d62 100644 --- a/bridge/bridge.go +++ b/bridge/bridge.go @@ -22,7 +22,7 @@ type Bridge struct { Channels map[string]config.ChannelInfo Joined map[string]bool Log *log.Entry - Config *config.Config + Config config.Config General *config.Protocol } @@ -69,36 +69,41 @@ func (b *Bridge) joinChannels(channels map[string]config.ChannelInfo, exists map } func (b *Bridge) GetBool(key string) bool { - if b.Config.GetBool(b.Account + "." + key) { - return b.Config.GetBool(b.Account + "." + key) + val, ok := b.Config.GetBool(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetBool("general." + key) } - return b.Config.GetBool("general." + key) + return val } func (b *Bridge) GetInt(key string) int { - if b.Config.GetInt(b.Account+"."+key) != 0 { - return b.Config.GetInt(b.Account + "." + key) + val, ok := b.Config.GetInt(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetInt("general." + key) } - return b.Config.GetInt("general." + key) + return val } func (b *Bridge) GetString(key string) string { - if b.Config.GetString(b.Account+"."+key) != "" { - return b.Config.GetString(b.Account + "." + key) + val, ok := b.Config.GetString(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetString("general." + key) } - return b.Config.GetString("general." + key) + return val } func (b *Bridge) GetStringSlice(key string) []string { - if len(b.Config.GetStringSlice(b.Account+"."+key)) != 0 { - return b.Config.GetStringSlice(b.Account + "." + key) + val, ok := b.Config.GetStringSlice(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetStringSlice("general." + key) } - return b.Config.GetStringSlice("general." + key) + return val } func (b *Bridge) GetStringSlice2D(key string) [][]string { - if len(b.Config.GetStringSlice2D(b.Account+"."+key)) != 0 { - return b.Config.GetStringSlice2D(b.Account + "." + key) + val, ok := b.Config.GetStringSlice2D(b.Account + "." + key) + if !ok { + val, _ = b.Config.GetStringSlice2D("general." + key) } - return b.Config.GetStringSlice2D("general." + key) + return val } diff --git a/bridge/config/config.go b/bridge/config/config.go index 503a8dee..258401db 100644 --- a/bridge/config/config.go +++ b/bridge/config/config.go @@ -2,7 +2,9 @@ package config import ( "bytes" + "fmt" "io/ioutil" + "os" "strings" "sync" "time" @@ -177,13 +179,23 @@ type ConfigValues struct { SameChannelGateway []SameChannelGateway } -type Config struct { +type Config interface { + ConfigValues() *ConfigValues + GetBool(key string) (bool, bool) + GetInt(key string) (int, bool) + GetString(key string) (string, bool) + GetStringSlice(key string) ([]string, bool) + GetStringSlice2D(key string) ([][]string, bool) +} + +type config struct { v *viper.Viper - *ConfigValues sync.RWMutex + + cv *ConfigValues } -func NewConfig(cfgfile string) *Config { +func NewConfig(cfgfile string) Config { log.SetFormatter(&prefixed.TextFormatter{PrefixPadding: 13, DisableColors: true, FullTimestamp: false}) flog := log.WithFields(log.Fields{"prefix": "config"}) viper.SetConfigFile(cfgfile) @@ -191,9 +203,9 @@ func NewConfig(cfgfile string) *Config { if err != nil { log.Fatal(err) } - mycfg := NewConfigFromString(input) - if mycfg.ConfigValues.General.MediaDownloadSize == 0 { - mycfg.ConfigValues.General.MediaDownloadSize = 1000000 + mycfg := newConfigFromString(input) + if mycfg.cv.General.MediaDownloadSize == 0 { + mycfg.cv.General.MediaDownloadSize = 1000000 } viper.WatchConfig() viper.OnConfigChange(func(e fsnotify.Event) { @@ -211,8 +223,11 @@ func getFileContents(filename string) ([]byte, error) { return input, nil } -func NewConfigFromString(input []byte) *Config { - var cfg ConfigValues +func NewConfigFromString(input []byte) Config { + return newConfigFromString(input) +} + +func newConfigFromString(input []byte) *config { viper.SetConfigType("toml") viper.SetEnvPrefix("matterbridge") viper.AddConfigPath(".") @@ -222,45 +237,51 @@ func NewConfigFromString(input []byte) *Config { if err != nil { log.Fatal(err) } - err = viper.Unmarshal(&cfg) + + cfg := &ConfigValues{} + err = viper.Unmarshal(cfg) if err != nil { log.Fatal(err) } - mycfg := new(Config) - mycfg.v = viper.GetViper() - mycfg.ConfigValues = &cfg - return mycfg + return &config{ + v: viper.GetViper(), + cv: cfg, + } +} + +func (c *config) ConfigValues() *ConfigValues { + return c.cv } -func (c *Config) GetBool(key string) bool { +func (c *config) GetBool(key string) (bool, bool) { c.RLock() defer c.RUnlock() // log.Debugf("getting bool %s = %#v", key, c.v.GetBool(key)) - return c.v.GetBool(key) + return c.v.GetBool(key), c.v.IsSet(key) } -func (c *Config) GetInt(key string) int { +func (c *config) GetInt(key string) (int, bool) { c.RLock() defer c.RUnlock() // log.Debugf("getting int %s = %d", key, c.v.GetInt(key)) - return c.v.GetInt(key) + return c.v.GetInt(key), c.v.IsSet(key) } -func (c *Config) GetString(key string) string { +func (c *config) GetString(key string) (string, bool) { c.RLock() defer c.RUnlock() // log.Debugf("getting String %s = %s", key, c.v.GetString(key)) - return c.v.GetString(key) + return c.v.GetString(key), c.v.IsSet(key) } -func (c *Config) GetStringSlice(key string) []string { +func (c *config) GetStringSlice(key string) ([]string, bool) { c.RLock() defer c.RUnlock() // log.Debugf("getting StringSlice %s = %#v", key, c.v.GetStringSlice(key)) - return c.v.GetStringSlice(key) + return c.v.GetStringSlice(key), c.v.IsSet(key) } -func (c *Config) GetStringSlice2D(key string) [][]string { +func (c *config) GetStringSlice2D(key string) ([][]string, bool) { c.RLock() defer c.RUnlock() result := [][]string{} @@ -272,9 +293,9 @@ func (c *Config) GetStringSlice2D(key string) [][]string { } result = append(result, result2) } - return result + return result, true } - return result + return result, false } func GetIconURL(msg *Message, iconURL string) string { @@ -286,3 +307,46 @@ func GetIconURL(msg *Message, iconURL string) string { iconURL = strings.Replace(iconURL, "{PROTOCOL}", protocol, -1) return iconURL } + +type TestConfig struct { + Config + + Overrides map[string]interface{} +} + +func (c *TestConfig) GetBool(key string) (bool, bool) { + val, ok := c.Overrides[key] + fmt.Fprintln(os.Stderr, "DEBUG:", c.Overrides, key, ok, val) + if ok { + return val.(bool), true + } + return c.Config.GetBool(key) +} + +func (c *TestConfig) GetInt(key string) (int, bool) { + if val, ok := c.Overrides[key]; ok { + return val.(int), true + } + return c.Config.GetInt(key) +} + +func (c *TestConfig) GetString(key string) (string, bool) { + if val, ok := c.Overrides[key]; ok { + return val.(string), true + } + return c.Config.GetString(key) +} + +func (c *TestConfig) GetStringSlice(key string) ([]string, bool) { + if val, ok := c.Overrides[key]; ok { + return val.([]string), true + } + return c.Config.GetStringSlice(key) +} + +func (c *TestConfig) GetStringSlice2D(key string) ([][]string, bool) { + if val, ok := c.Overrides[key]; ok { + return val.([][]string), true + } + return c.Config.GetStringSlice2D(key) +} -- cgit v1.2.3