diff --git a/engine/schema/src/com/cloud/upgrade/dao/Upgrade41000to41100.java b/engine/schema/src/com/cloud/upgrade/dao/Upgrade41000to41100.java index fbe9d784432..6afd976e7e2 100644 --- a/engine/schema/src/com/cloud/upgrade/dao/Upgrade41000to41100.java +++ b/engine/schema/src/com/cloud/upgrade/dao/Upgrade41000to41100.java @@ -27,6 +27,7 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.apache.commons.codec.binary.Base64; import org.apache.log4j.Logger; import com.cloud.hypervisor.Hypervisor; @@ -64,9 +65,49 @@ public class Upgrade41000to41100 implements DbUpgrade { @Override public void performDataMigration(Connection conn) { + validateUserDataInBase64(conn); updateSystemVmTemplates(conn); } + private void validateUserDataInBase64(Connection conn) { + try (final PreparedStatement selectStatement = conn.prepareStatement("SELECT `id`, `user_data` FROM `cloud`.`user_vm` WHERE `user_data` IS NOT NULL;"); + final ResultSet selectResultSet = selectStatement.executeQuery()) { + while (selectResultSet.next()) { + final Long userVmId = selectResultSet.getLong(1); + final String userData = selectResultSet.getString(2); + if (Base64.isBase64(userData)) { + final String newUserData = Base64.encodeBase64String(Base64.decodeBase64(userData.getBytes())); + if (!userData.equals(newUserData)) { + try (final PreparedStatement updateStatement = conn.prepareStatement("UPDATE `cloud`.`user_vm` SET `user_data` = ? WHERE `id` = ? ;")) { + updateStatement.setString(1, newUserData); + updateStatement.setLong(2, userVmId); + updateStatement.executeUpdate(); + } catch (SQLException e) { + LOG.error("Failed to update cloud.user_vm user_data for id:" + userVmId + " with exception: " + e.getMessage()); + throw new CloudRuntimeException("Exception while updating cloud.user_vm for id " + userVmId, e); + } + } + } else { + // Update to NULL since it's invalid + LOG.warn("Removing user_data for vm id " + userVmId + " because it's invalid"); + LOG.warn("Removed data was: " + userData); + try (final PreparedStatement updateStatement = conn.prepareStatement("UPDATE `cloud`.`user_vm` SET `user_data` = NULL WHERE `id` = ? ;")) { + updateStatement.setLong(1, userVmId); + updateStatement.executeUpdate(); + } catch (SQLException e) { + LOG.error("Failed to update cloud.user_vm user_data for id:" + userVmId + " to NULL with exception: " + e.getMessage()); + throw new CloudRuntimeException("Exception while updating cloud.user_vm for id " + userVmId + " to NULL", e); + } + } + } + } catch (SQLException e) { + throw new CloudRuntimeException("Exception while validating existing user_vm table's user_data column to be base64 valid with padding", e); + } + if (LOG.isDebugEnabled()) { + LOG.debug("Done validating base64 content of user data"); + } + } + @SuppressWarnings("serial") private void updateSystemVmTemplates(final Connection conn) { LOG.debug("Updating System Vm template IDs"); diff --git a/server/src/com/cloud/vm/UserVmManagerImpl.java b/server/src/com/cloud/vm/UserVmManagerImpl.java index c4edf4fa112..72c47931057 100644 --- a/server/src/com/cloud/vm/UserVmManagerImpl.java +++ b/server/src/com/cloud/vm/UserVmManagerImpl.java @@ -2525,7 +2525,7 @@ public class UserVmManagerImpl extends ManagerBase implements UserVmManager, Vir if (userData != null) { // check and replace newlines userData = userData.replace("\\n", ""); - validateUserData(userData, httpMethod); + userData = validateUserData(userData, httpMethod); // update userData on domain router. updateUserdata = true; } else { @@ -3396,7 +3396,7 @@ public class UserVmManagerImpl extends ManagerBase implements UserVmManager, Vir _accountMgr.checkAccess(owner, AccessType.UseEntry, false, template); // check if the user data is correct - validateUserData(userData, httpmethod); + userData = validateUserData(userData, httpmethod); // Find an SSH public key corresponding to the key pair name, if one is // given @@ -3943,7 +3943,7 @@ public class UserVmManagerImpl extends ManagerBase implements UserVmManager, Vir } } - private void validateUserData(String userData, HTTPMethod httpmethod) { + protected String validateUserData(String userData, HTTPMethod httpmethod) { byte[] decodedUserData = null; if (userData != null) { if (!Base64.isBase64(userData)) { @@ -3971,7 +3971,10 @@ public class UserVmManagerImpl extends ManagerBase implements UserVmManager, Vir if (decodedUserData == null || decodedUserData.length < 1) { throw new InvalidParameterValueException("User data is too short"); } + // Re-encode so that the '=' paddings are added if necessary since 'isBase64' does not require it, but python does on the VR. + return Base64.encodeBase64String(decodedUserData); } + return null; } @Override diff --git a/server/test/com/cloud/vm/UserVmManagerTest.java b/server/test/com/cloud/vm/UserVmManagerTest.java index 1bab84cc36c..89555a2c8c8 100644 --- a/server/test/com/cloud/vm/UserVmManagerTest.java +++ b/server/test/com/cloud/vm/UserVmManagerTest.java @@ -48,6 +48,7 @@ import com.cloud.user.User; import com.cloud.event.dao.UsageEventDao; import com.cloud.uservm.UserVm; import org.junit.Assert; +import org.apache.cloudstack.api.BaseCmd; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -1056,4 +1057,16 @@ public class UserVmManagerTest { _userVmMgr.persistDeviceBusInfo(_vmMock, "lsilogic"); verify(_vmDao, times(1)).saveDetails(any(UserVmVO.class)); } + + @Test + public void testValideBase64WithoutPadding() { + // fo should be encoded in base64 either as Zm8 or Zm8= + String encodedUserdata = "Zm8"; + String encodedUserdataWithPadding = "Zm8="; + + // Verify that we accept both but return the padded version + assertTrue("validate return the value with padding", encodedUserdataWithPadding.equals(_userVmMgr.validateUserData(encodedUserdata, BaseCmd.HTTPMethod.GET))); + assertTrue("validate return the value with padding", encodedUserdataWithPadding.equals(_userVmMgr.validateUserData(encodedUserdataWithPadding, BaseCmd.HTTPMethod.GET))); + } + }