diff options
Diffstat (limited to 'vendor/github.com/yaegashi/msgraph.go/msauth/msauth.go')
-rw-r--r-- | vendor/github.com/yaegashi/msgraph.go/msauth/msauth.go | 60 |
1 files changed, 47 insertions, 13 deletions
diff --git a/vendor/github.com/yaegashi/msgraph.go/msauth/msauth.go b/vendor/github.com/yaegashi/msgraph.go/msauth/msauth.go index c2a86985..014f2522 100644 --- a/vendor/github.com/yaegashi/msgraph.go/msauth/msauth.go +++ b/vendor/github.com/yaegashi/msgraph.go/msauth/msauth.go @@ -36,10 +36,6 @@ func (t *TokenError) Error() string { return fmt.Sprintf("%s: %s", t.ErrorObject, t.ErrorDescription) } -func generateKey(tenantID, clientID string) string { - return fmt.Sprintf("%s:%s", tenantID, clientID) -} - func deviceCodeURL(tenantID string) string { return fmt.Sprintf(endpointURLFormat, tenantID, "devicecode") } @@ -65,6 +61,7 @@ func (e *tokenJSON) expiry() (t time.Time) { // Manager is oauth2 token cache manager type Manager struct { mu sync.Mutex + Dirty bool TokenCache map[string]*oauth2.Token } @@ -87,27 +84,64 @@ func (m *Manager) SaveBytes() ([]byte, error) { return json.Marshal(m.TokenCache) } -// LoadFile loads token cache from file +// LoadFile loads token cache from file with dirty state control func (m *Manager) LoadFile(path string) error { - b, err := ioutil.ReadFile(path) + m.mu.Lock() + defer m.mu.Unlock() + b, err := ReadLocation(path) + if err != nil { + return err + } + err = json.Unmarshal(b, &m.TokenCache) if err != nil { return err } - return m.LoadBytes(b) + m.Dirty = false + return nil } -// SaveFile saves token cache to file +// SaveFile saves token cache to file with dirty state control func (m *Manager) SaveFile(path string) error { - b, err := m.SaveBytes() + m.mu.Lock() + defer m.mu.Unlock() + if !m.Dirty { + return nil + } + b, err := json.Marshal(m.TokenCache) + if err != nil { + return err + } + err = WriteLocation(path, b, 0644) if err != nil { return err } - return ioutil.WriteFile(path, b, 0644) + m.Dirty = false + return nil +} + +// CacheKey generates a token cache key from tenantID/clientID +func CacheKey(tenantID, clientID string) string { + return fmt.Sprintf("%s:%s", tenantID, clientID) +} + +// GetToken gets a token from token cache +func (m *Manager) GetToken(cacheKey string) (*oauth2.Token, bool) { + m.mu.Lock() + defer m.mu.Unlock() + token, ok := m.TokenCache[cacheKey] + return token, ok } -// Cache stores a token into token cache -func (m *Manager) Cache(tenantID, clientID string, token *oauth2.Token) { - m.TokenCache[generateKey(tenantID, clientID)] = token +// PutToken puts a token into token cache +func (m *Manager) PutToken(cacheKey string, token *oauth2.Token) { + m.mu.Lock() + defer m.mu.Unlock() + oldToken, ok := m.TokenCache[cacheKey] + if ok && *oldToken == *token { + return + } + m.TokenCache[cacheKey] = token + m.Dirty = true } // requestToken requests a token from the token endpoint |