diff --git a/plugins/user-authenticators/oauth2/src/main/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticator.java b/plugins/user-authenticators/oauth2/src/main/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticator.java index 8484a5ef798..3a6d2bd6fae 100644 --- a/plugins/user-authenticators/oauth2/src/main/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticator.java +++ b/plugins/user-authenticators/oauth2/src/main/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticator.java @@ -32,16 +32,18 @@ import org.apache.log4j.Logger; import javax.inject.Inject; import java.util.Map; +import static org.apache.cloudstack.oauth2.OAuth2AuthManager.OAuth2IsPluginEnabled; + public class OAuth2UserAuthenticator extends AdapterBase implements UserAuthenticator { public static final Logger s_logger = Logger.getLogger(OAuth2UserAuthenticator.class); @Inject - private UserAccountDao _userAccountDao; + private UserAccountDao userAccountDao; @Inject - private UserDao _userDao; + private UserDao userDao; @Inject - private OAuth2AuthManager _userOAuth2mgr; + private OAuth2AuthManager userOAuth2mgr; @Override public Pair authenticate(String username, String password, Long domainId, Map requestParameters) { @@ -49,12 +51,20 @@ public class OAuth2UserAuthenticator extends AdapterBase implements UserAuthenti s_logger.debug("Trying OAuth2 auth for user: " + username); } - final UserAccount userAccount = _userAccountDao.getUserAccount(username, domainId); + if (!isOAuthPluginEnabled()) { + s_logger.debug("OAuth2 plugin is disabled"); + return new Pair(false, null); + } else if (requestParameters == null) { + s_logger.debug("Request parameters are null"); + return new Pair(false, null); + } + + final UserAccount userAccount = userAccountDao.getUserAccount(username, domainId); if (userAccount == null) { s_logger.debug("Unable to find user with " + username + " in domain " + domainId + ", or user source is not OAUTH2"); return new Pair(false, null); } else { - User user = _userDao.getUser(userAccount.getId()); + User user = userDao.getUser(userAccount.getId()); final String[] provider = (String[])requestParameters.get(ApiConstants.PROVIDER); final String[] emailArray = (String[])requestParameters.get(ApiConstants.EMAIL); final String[] secretCodeArray = (String[])requestParameters.get(ApiConstants.SECRET_CODE); @@ -62,7 +72,7 @@ public class OAuth2UserAuthenticator extends AdapterBase implements UserAuthenti String email = ((emailArray == null) ? null : emailArray[0]); String secretCode = ((secretCodeArray == null) ? null : secretCodeArray[0]); - UserOAuth2Authenticator authenticator = _userOAuth2mgr.getUserOAuth2AuthenticationProvider(oauthProvider); + UserOAuth2Authenticator authenticator = userOAuth2mgr.getUserOAuth2AuthenticationProvider(oauthProvider); if (user != null && authenticator.verifyUser(email, secretCode)) { return new Pair(true, null); } @@ -75,4 +85,8 @@ public class OAuth2UserAuthenticator extends AdapterBase implements UserAuthenti public String encode(String password) { return null; } + + protected boolean isOAuthPluginEnabled() { + return OAuth2IsPluginEnabled.value(); + } } diff --git a/plugins/user-authenticators/oauth2/src/test/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticatorTest.java b/plugins/user-authenticators/oauth2/src/test/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticatorTest.java index 06aa04d729c..9a6e8812679 100644 --- a/plugins/user-authenticators/oauth2/src/test/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticatorTest.java +++ b/plugins/user-authenticators/oauth2/src/test/java/org/apache/cloudstack/oauth2/OAuth2UserAuthenticatorTest.java @@ -24,23 +24,32 @@ import com.cloud.user.dao.UserAccountDao; import com.cloud.user.dao.UserDao; import com.cloud.utils.Pair; import org.apache.cloudstack.auth.UserOAuth2Authenticator; +import org.junit.After; import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.mockito.Spy; +import org.mockito.junit.MockitoJUnitRunner; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class OAuth2UserAuthenticatorTest { @Mock @@ -53,13 +62,22 @@ public class OAuth2UserAuthenticatorTest { private OAuth2AuthManager userOAuth2mgr; @InjectMocks + @Spy private OAuth2UserAuthenticator authenticator; + private AutoCloseable closeable; @Before public void setUp() { - MockitoAnnotations.initMocks(this); + closeable = MockitoAnnotations.openMocks(this); + doReturn(true).when(authenticator).isOAuthPluginEnabled(); } + @After + public void tearDown() throws Exception { + closeable.close(); + } + + @Test public void testAuthenticateWithValidCredentials() { String username = "testuser"; @@ -84,13 +102,13 @@ public class OAuth2UserAuthenticatorTest { Pair result = authenticator.authenticate(username, null, domainId, requestParameters); + assertTrue(result.first()); + assertNull(result.second()); + verify(userAccountDao).getUserAccount(username, domainId); verify(userDao).getUser(userAccount.getId()); verify(userOAuth2mgr).getUserOAuth2AuthenticationProvider(provider[0]); verify(userOAuth2Authenticator).verifyUser(email[0], secretCode[0]); - - assertEquals(true, result.first().booleanValue()); - assertEquals(null, result.second()); } @Test @@ -106,7 +124,7 @@ public class OAuth2UserAuthenticatorTest { UserOAuth2Authenticator userOAuth2Authenticator = mock(UserOAuth2Authenticator.class); when(userAccountDao.getUserAccount(username, domainId)).thenReturn(userAccount); - when(userDao.getUser(userAccount.getId())).thenReturn( user); + when(userDao.getUser(userAccount.getId())).thenReturn(user); when(userOAuth2mgr.getUserOAuth2AuthenticationProvider(provider[0])).thenReturn(userOAuth2Authenticator); when(userOAuth2Authenticator.verifyUser(email[0], secretCode[0])).thenReturn(false); @@ -117,13 +135,13 @@ public class OAuth2UserAuthenticatorTest { Pair result = authenticator.authenticate(username, null, domainId, requestParameters); + assertFalse(result.first()); + assertEquals(OAuth2UserAuthenticator.ActionOnFailedAuthentication.INCREMENT_INCORRECT_LOGIN_ATTEMPT_COUNT, result.second()); + verify(userAccountDao).getUserAccount(username, domainId); verify(userDao).getUser(userAccount.getId()); verify(userOAuth2mgr).getUserOAuth2AuthenticationProvider(provider[0]); verify(userOAuth2Authenticator).verifyUser(email[0], secretCode[0]); - - assertEquals(false, result.first().booleanValue()); - assertEquals(OAuth2UserAuthenticator.ActionOnFailedAuthentication.INCREMENT_INCORRECT_LOGIN_ATTEMPT_COUNT, result.second()); } @Test @@ -143,11 +161,11 @@ public class OAuth2UserAuthenticatorTest { Pair result = authenticator.authenticate(username, null, domainId, requestParameters); + assertFalse(result.first()); + assertNull(result.second()); + verify(userAccountDao).getUserAccount(username, domainId); verify(userDao, never()).getUser(anyLong()); verify(userOAuth2mgr, never()).getUserOAuth2AuthenticationProvider(anyString()); - - assertEquals(false, result.first().booleanValue()); - assertEquals(null, result.second()); } } diff --git a/server/src/main/java/com/cloud/user/AccountManagerImpl.java b/server/src/main/java/com/cloud/user/AccountManagerImpl.java index 86a359a3348..3eed429ed21 100644 --- a/server/src/main/java/com/cloud/user/AccountManagerImpl.java +++ b/server/src/main/java/com/cloud/user/AccountManagerImpl.java @@ -332,6 +332,7 @@ public class AccountManagerImpl extends ManagerBase implements AccountManager, M private List _securityCheckers; private int _cleanupInterval; + private static final String OAUTH2_PROVIDER_NAME = "oauth2"; private List apiNameList; protected static Map userTwoFactorAuthenticationProvidersMap = new HashMap<>(); @@ -2650,7 +2651,8 @@ public class AccountManagerImpl extends ManagerBase implements AccountManager, M continue; } } - if (secretCode != null && !authenticator.getName().equals("oauth2")) { + if ((secretCode != null && !authenticator.getName().equals(OAUTH2_PROVIDER_NAME)) + || (secretCode == null && authenticator.getName().equals(OAUTH2_PROVIDER_NAME))) { continue; } Pair result = authenticator.authenticate(username, password, domainId, requestParameters); diff --git a/server/src/test/java/com/cloud/user/AccountManagerImplTest.java b/server/src/test/java/com/cloud/user/AccountManagerImplTest.java index 6d9211dd526..9014af523fd 100644 --- a/server/src/test/java/com/cloud/user/AccountManagerImplTest.java +++ b/server/src/test/java/com/cloud/user/AccountManagerImplTest.java @@ -203,6 +203,7 @@ public class AccountManagerImplTest extends AccountManagetImplTestBase { Mockito.when(userAuthenticator.authenticate("test", "fail", 1L, new HashMap<>())).thenReturn(failureAuthenticationPair); Mockito.lenient().when(userAuthenticator.authenticate("test", null, 1L, new HashMap<>())).thenReturn(successAuthenticationPair); Mockito.lenient().when(userAuthenticator.authenticate("test", "", 1L, new HashMap<>())).thenReturn(successAuthenticationPair); + Mockito.when(userAuthenticator.getName()).thenReturn("test"); //Test for incorrect password. authentication should fail UserAccount userAccount = accountManagerImpl.authenticateUser("test", "fail", 1L, InetAddress.getByName("127.0.0.1"), new HashMap<>());