diff --git a/utils/src/main/java/com/cloud/utils/UriUtils.java b/utils/src/main/java/com/cloud/utils/UriUtils.java index 8805891cd2a..4d42e3cadd9 100644 --- a/utils/src/main/java/com/cloud/utils/UriUtils.java +++ b/utils/src/main/java/com/cloud/utils/UriUtils.java @@ -35,8 +35,6 @@ import java.util.List; import java.util.ListIterator; import java.util.StringTokenizer; -import javax.net.ssl.HttpsURLConnection; - import org.apache.commons.httpclient.Credentials; import org.apache.commons.httpclient.HttpClient; import org.apache.commons.httpclient.HttpException; @@ -202,39 +200,47 @@ public class UriUtils { } // Get the size of a file from URL response header. - public static Long getRemoteSize(String url) { - Long remoteSize = (long)0; - HttpURLConnection httpConn = null; - HttpsURLConnection httpsConn = null; - try { - URI uri = new URI(url); - if (uri.getScheme().equalsIgnoreCase("http")) { + public static long getRemoteSize(String url) { + long remoteSize = 0L; + final String[] methods = new String[]{"HEAD", "GET"}; + IllegalArgumentException exception = null; + // Attempting first a HEAD request to avoid downloading the whole file. If + // it fails (for example with S3 presigned URL), fallback on a standard GET + // request. + for (String method : methods) { + HttpURLConnection httpConn = null; + try { + URI uri = new URI(url); httpConn = (HttpURLConnection)uri.toURL().openConnection(); - if (httpConn != null) { - httpConn.setConnectTimeout(2000); - httpConn.setReadTimeout(5000); - String contentLength = httpConn.getHeaderField("content-length"); - if (contentLength != null) { - remoteSize = Long.parseLong(contentLength); + httpConn.setRequestMethod(method); + httpConn.setConnectTimeout(2000); + httpConn.setReadTimeout(5000); + String contentLength = httpConn.getHeaderField("Content-Length"); + if (contentLength != null) { + remoteSize = Long.parseLong(contentLength); + } else if (method.equals("GET") && httpConn.getResponseCode() < 300) { + // Calculate the content size based on the input stream content + byte[] buf = new byte[1024]; + int length; + while ((length = httpConn.getInputStream().read(buf, 0, buf.length)) != -1) { + remoteSize += length; } + } + return remoteSize; + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Invalid URL " + url); + } catch (IOException e) { + exception = new IllegalArgumentException("Unable to establish connection with URL " + url); + } finally { + if (httpConn != null) { httpConn.disconnect(); } - } else if (uri.getScheme().equalsIgnoreCase("https")) { - httpsConn = (HttpsURLConnection)uri.toURL().openConnection(); - if (httpsConn != null) { - String contentLength = httpsConn.getHeaderField("content-length"); - if (contentLength != null) { - remoteSize = Long.parseLong(contentLength); - } - httpsConn.disconnect(); - } } - } catch (URISyntaxException e) { - throw new IllegalArgumentException("Invalid URL " + url); - } catch (IOException e) { - throw new IllegalArgumentException("Unable to establish connection with URL " + url); } - return remoteSize; + if (exception != null) { + throw exception; + } + return 0L; } public static Pair validateUrl(String url) throws IllegalArgumentException {