From 146dd3e63901b973e0f8f51c8858ada4a8922968 Mon Sep 17 00:00:00 2001 From: Guillaume Tardif Date: Wed, 13 May 2020 23:33:16 +0200 Subject: [PATCH] Fix tokenStore not creating ~/.azure folder if not exist --- azure/aci.go | 8 ++--- azure/backend.go | 14 +++++---- azure/login/login.go | 51 ++++++++++++++++++-------------- azure/login/login_test.go | 15 +++++----- azure/login/tokenStore.go | 31 +++++++++++++++++-- azure/login/tokenStore_test.go | 54 ++++++++++++++++++++++++++++++++++ 6 files changed, 132 insertions(+), 41 deletions(-) create mode 100644 azure/login/tokenStore_test.go diff --git a/azure/aci.go b/azure/aci.go index 252a80e2..79907d3b 100644 --- a/azure/aci.go +++ b/azure/aci.go @@ -235,7 +235,7 @@ func getACIContainerLogs(ctx context.Context, aciContext store.AciContext, conta } func getContainerGroupsClient(subscriptionID string) (containerinstance.ContainerGroupsClient, error) { - auth, err := login.NewAzureLoginService().NewAuthorizerFromLogin() + auth, err := login.NewAuthorizerFromLogin() if err != nil { return containerinstance.ContainerGroupsClient{}, err } @@ -248,7 +248,7 @@ func getContainerGroupsClient(subscriptionID string) (containerinstance.Containe } func getContainerClient(subscriptionID string) (containerinstance.ContainerClient, error) { - auth, err := login.NewAzureLoginService().NewAuthorizerFromLogin() + auth, err := login.NewAuthorizerFromLogin() if err != nil { return containerinstance.ContainerClient{}, err } @@ -259,7 +259,7 @@ func getContainerClient(subscriptionID string) (containerinstance.ContainerClien func getSubscriptionsClient() subscription.SubscriptionsClient { subc := subscription.NewSubscriptionsClient() - authorizer, _ := login.NewAzureLoginService().NewAuthorizerFromLogin() + authorizer, _ := login.NewAuthorizerFromLogin() subc.Authorizer = authorizer return subc } @@ -267,7 +267,7 @@ func getSubscriptionsClient() subscription.SubscriptionsClient { // GetGroupsClient ... func GetGroupsClient(subscriptionID string) resources.GroupsClient { groupsClient := resources.NewGroupsClient(subscriptionID) - authorizer, _ := login.NewAzureLoginService().NewAuthorizerFromLogin() + authorizer, _ := login.NewAuthorizerFromLogin() groupsClient.Authorizer = authorizer return groupsClient } diff --git a/azure/backend.go b/azure/backend.go index 8cda586e..1819b99b 100644 --- a/azure/backend.go +++ b/azure/backend.go @@ -52,14 +52,18 @@ func New(ctx context.Context) (backend.Service, error) { } aciContext, _ := metadata.Metadata.Data.(store.AciContext) - auth, _ := login.NewAzureLoginService().NewAuthorizerFromLogin() + auth, _ := login.NewAuthorizerFromLogin() containerGroupsClient := containerinstance.NewContainerGroupsClient(aciContext.SubscriptionID) containerGroupsClient.Authorizer = auth - return getAciAPIService(containerGroupsClient, aciContext), nil + return getAciAPIService(containerGroupsClient, aciContext) } -func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.AciContext) *aciAPIService { +func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store.AciContext) (*aciAPIService, error) { + service, err := login.NewAzureLoginService() + if err != nil { + return nil, err + } return &aciAPIService{ aciContainerService: aciContainerService{ containerGroupsClient: cgc, @@ -69,9 +73,9 @@ func getAciAPIService(cgc containerinstance.ContainerGroupsClient, aciCtx store. ctx: aciCtx, }, aciCloudService: aciCloudService{ - loginService: login.NewAzureLoginService(), + loginService: service, }, - } + }, nil } type aciAPIService struct { diff --git a/azure/login/login.go b/azure/login/login.go index f14ddf93..a011a413 100644 --- a/azure/login/login.go +++ b/azure/login/login.go @@ -68,25 +68,27 @@ type AzureLoginService struct { apiHelper apiHelper } -const tokenFilename = "dockerAccessToken.json" +const tokenStoreFilename = "dockerAccessToken.json" func getTokenStorePath() string { cliPath, _ := cli.AccessTokensPath() - return filepath.Join(filepath.Dir(cliPath), tokenFilename) + return filepath.Join(filepath.Dir(cliPath), tokenStoreFilename) } // NewAzureLoginService creates a NewAzureLoginService -func NewAzureLoginService() AzureLoginService { +func NewAzureLoginService() (AzureLoginService, error) { return newAzureLoginServiceFromPath(getTokenStorePath(), azureAPIHelper{}) } -func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) AzureLoginService { - return AzureLoginService{ - tokenStore: tokenStore{ - filePath: tokenStorePath, - }, - apiHelper: helper, +func newAzureLoginServiceFromPath(tokenStorePath string, helper apiHelper) (AzureLoginService, error) { + store, err := newTokenStore(tokenStorePath) + if err != nil { + return AzureLoginService{}, err } + return AzureLoginService{ + tokenStore: store, + apiHelper: helper, + }, nil } type apiHelper interface { @@ -229,20 +231,21 @@ func queryHandler(queryCh chan url.Values) func(w http.ResponseWriter, r *http.R return queryHandler } -func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (token azureToken, err error) { +func (helper azureAPIHelper) queryToken(data url.Values, tenantID string) (azureToken, error) { res, err := http.Post(fmt.Sprintf(tokenEndpoint, tenantID), "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) if err != nil { - return token, err + return azureToken{}, err } if res.StatusCode != 200 { - return token, errors.Errorf("error while renewing access token, status : %s", res.Status) + return azureToken{}, errors.Errorf("error while renewing access token, status : %s", res.Status) } bits, err := ioutil.ReadAll(res.Body) if err != nil { - return token, err + return azureToken{}, err } + token := azureToken{} if err := json.Unmarshal(bits, &token); err != nil { - return token, err + return azureToken{}, err } return token, nil } @@ -259,7 +262,11 @@ func toOAuthToken(token azureToken) oauth2.Token { } // NewAuthorizerFromLogin creates an authorizer based on login access token -func (login AzureLoginService) NewAuthorizerFromLogin() (autorest.Authorizer, error) { +func NewAuthorizerFromLogin() (autorest.Authorizer, error) { + login, err := NewAzureLoginService() + if err != nil { + return nil, err + } oauthToken, err := login.GetValidToken() if err != nil { return nil, err @@ -278,28 +285,28 @@ func (login AzureLoginService) NewAuthorizerFromLogin() (autorest.Authorizer, er } // GetValidToken returns an access token. Refresh token if needed -func (login AzureLoginService) GetValidToken() (token oauth2.Token, err error) { +func (login AzureLoginService) GetValidToken() (oauth2.Token, error) { loginInfo, err := login.tokenStore.readToken() if err != nil { - return token, err + return oauth2.Token{}, err } - token = loginInfo.Token + token := loginInfo.Token if token.Valid() { return token, nil } tenantID := loginInfo.TenantID token, err = login.refreshToken(token.RefreshToken, tenantID) if err != nil { - return token, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.") + return oauth2.Token{}, errors.Wrap(err, "access token request failed. Maybe you need to login to azure again.") } err = login.tokenStore.writeLoginInfo(TokenInfo{TenantID: tenantID, Token: token}) if err != nil { - return token, err + return oauth2.Token{}, err } return token, nil } -func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauthToken oauth2.Token, err error) { +func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID string) (oauth2.Token, error) { data := url.Values{ "grant_type": []string{"refresh_token"}, "client_id": []string{clientID}, @@ -308,7 +315,7 @@ func (login AzureLoginService) refreshToken(currentRefreshToken string, tenantID } token, err := login.apiHelper.queryToken(data, tenantID) if err != nil { - return oauthToken, err + return oauth2.Token{}, err } return toOAuthToken(token), nil diff --git a/azure/login/login_test.go b/azure/login/login_test.go index 0295d505..81ea2234 100644 --- a/azure/login/login_test.go +++ b/azure/login/login_test.go @@ -8,8 +8,6 @@ import ( "testing" "time" - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/mock" "github.com/stretchr/testify/suite" @@ -27,17 +25,18 @@ type LoginSuiteTest struct { func (suite *LoginSuiteTest) BeforeTest(suiteName, testName string) { dir, err := ioutil.TempDir("", "test_store") - require.Nil(suite.T(), err) + Expect(err).To(BeNil()) suite.dir = dir suite.mockHelper = MockAzureHelper{} //nolint copylocks - suite.azureLogin = newAzureLoginServiceFromPath(filepath.Join(dir, tokenFilename), suite.mockHelper) + suite.azureLogin, err = newAzureLoginServiceFromPath(filepath.Join(dir, tokenStoreFilename), suite.mockHelper) + Expect(err).To(BeNil()) } func (suite *LoginSuiteTest) AfterTest(suiteName, testName string) { err := os.RemoveAll(suite.dir) - require.Nil(suite.T(), err) + Expect(err).To(BeNil()) } func (suite *LoginSuiteTest) TestRefreshInValidToken() { @@ -55,8 +54,10 @@ func (suite *LoginSuiteTest) TestRefreshInValidToken() { }, nil) //nolint copylocks - suite.azureLogin = newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenFilename), suite.mockHelper) - err := suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{ + azureLogin, err := newAzureLoginServiceFromPath(filepath.Join(suite.dir, tokenStoreFilename), suite.mockHelper) + Expect(err).To(BeNil()) + suite.azureLogin = azureLogin + err = suite.azureLogin.tokenStore.writeLoginInfo(TokenInfo{ TenantID: "123456", Token: oauth2.Token{ AccessToken: "accessToken", diff --git a/azure/login/tokenStore.go b/azure/login/tokenStore.go index 8e0b7dac..d6a7c59b 100644 --- a/azure/login/tokenStore.go +++ b/azure/login/tokenStore.go @@ -2,7 +2,10 @@ package login import ( "encoding/json" + "errors" "io/ioutil" + "os" + "path/filepath" "golang.org/x/oauth2" ) @@ -17,6 +20,27 @@ type TokenInfo struct { TenantID string `json:"tenantId"` } +func newTokenStore(path string) (tokenStore, error) { + parentFolder := filepath.Dir(path) + dir, err := os.Stat(parentFolder) + if os.IsNotExist(err) { + err = os.MkdirAll(parentFolder, 0700) + if err != nil { + return tokenStore{}, err + } + dir, err = os.Stat(parentFolder) + } + if err != nil { + return tokenStore{}, err + } + if !dir.Mode().IsDir() { + return tokenStore{}, errors.New("cannot use path " + path + " ; " + parentFolder + " already exists and is not a directory") + } + return tokenStore{ + filePath: path, + }, nil +} + func (store tokenStore) writeLoginInfo(info TokenInfo) error { bytes, err := json.MarshalIndent(info, "", " ") if err != nil { @@ -25,13 +49,14 @@ func (store tokenStore) writeLoginInfo(info TokenInfo) error { return ioutil.WriteFile(store.filePath, bytes, 0644) } -func (store tokenStore) readToken() (loginInfo TokenInfo, err error) { +func (store tokenStore) readToken() (TokenInfo, error) { bytes, err := ioutil.ReadFile(store.filePath) if err != nil { - return loginInfo, err + return TokenInfo{}, err } + loginInfo := TokenInfo{} if err := json.Unmarshal(bytes, &loginInfo); err != nil { - return loginInfo, err + return TokenInfo{}, err } return loginInfo, nil } diff --git a/azure/login/tokenStore_test.go b/azure/login/tokenStore_test.go new file mode 100644 index 00000000..cd1818fd --- /dev/null +++ b/azure/login/tokenStore_test.go @@ -0,0 +1,54 @@ +package login + +import ( + "errors" + "io/ioutil" + "os" + "path/filepath" + "testing" + + . "github.com/onsi/gomega" + "github.com/stretchr/testify/suite" +) + +type tokenStoreTestSuite struct { + suite.Suite +} + +func (suite *tokenStoreTestSuite) TestCreateStoreFromExistingFolder() { + existingDir, err := ioutil.TempDir("", "test_store") + Expect(err).To(BeNil()) + + storePath := filepath.Join(existingDir, tokenStoreFilename) + store, err := newTokenStore(storePath) + Expect(err).To(BeNil()) + Expect((store.filePath)).To(Equal(storePath)) +} + +func (suite *tokenStoreTestSuite) TestCreateStoreFromNonExistingFolder() { + existingDir, err := ioutil.TempDir("", "test_store") + Expect(err).To(BeNil()) + + storePath := filepath.Join(existingDir, "new", tokenStoreFilename) + store, err := newTokenStore(storePath) + Expect(err).To(BeNil()) + Expect((store.filePath)).To(Equal(storePath)) + + newDir, err := os.Stat(filepath.Join(existingDir, "new")) + Expect(err).To(BeNil()) + Expect(newDir.Mode().IsDir()).To(BeTrue()) +} + +func (suite *tokenStoreTestSuite) TestErrorIfParentFolderIsAFile() { + existingDir, err := ioutil.TempFile("", "test_store") + Expect(err).To(BeNil()) + + storePath := filepath.Join(existingDir.Name(), tokenStoreFilename) + _, err = newTokenStore(storePath) + Expect(err).To(MatchError(errors.New("cannot use path " + storePath + " ; " + existingDir.Name() + " already exists and is not a directory"))) +} + +func TestTokenStoreSuite(t *testing.T) { + RegisterTestingT(t) + suite.Run(t, new(tokenStoreTestSuite)) +}