Java多线程文件下载

以下代码有点问题,会发生阻塞,还不知道啥问题:

package com.test.service;

import java.io.File;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.concurrent.CountDownLatch;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;

/**
 * <p>
 * 多线程文件下载,提高大文件的下载速度(暂未使用) 
 * <p>
 * */
@Component
public class MulitThreadDownload {

    private static Logger logger = LoggerFactory.getLogger(MulitThreadDownload.class);

    @Value("${onair.download.threadsize:5}")
    private int threadSize = 5; 
    
    @Value("${onair.download.timeout:5000}")
    private int downloadTimeout;
    
    static boolean flag = true;
    //消息
    private final CountDownLatch msgDownLatch = new CountDownLatch(1);
    //工作线程
    private final CountDownLatch workDownLatch = new CountDownLatch(threadSize);
    
    private DowloadRunnable[] dowloadRunnables = new DowloadRunnable[threadSize];
    
    public static void main(String[] args) {
        new MulitThreadDownload().downloadFile("", "G:\123.mp4");
    }
    
    public boolean downloadFile(String url,String filePath){
        
        logger.debug("下载地址:{},目标文件路径:{}",url,filePath);
        try {
            
            URL urlPath = new URL(url);
            HttpURLConnection conn = (HttpURLConnection)urlPath.openConnection();
            conn.setConnectTimeout(downloadTimeout);
            conn.setRequestMethod("GET");
            
            int status = conn.getResponseCode(); 
            if(status == 200){ //200返回所有,206返回部分 
                //文件长度    
                int length = conn.getContentLength(); 
                logger.info("获取文件大小:{}",length);
                
                //创建下载文件 指定大小 
                RandomAccessFile raf = new RandomAccessFile(new File(filePath), "rwd");
                raf.setLength(length);
                raf.close(); //释放资源
                
                //分块大小
                int blockSize = length / threadSize;
                    
                //创建工作线程
                for (int i = 1; i <= threadSize; i++) {
                    int startIndex = blockSize*(i-1);
                    int endIndex = blockSize * i - 1;
                    if(i == threadSize){
                        endIndex = length;
                    }
                    logger.info("线程:{}下载文件开始点:{}结束点:{}",i,startIndex,endIndex);
                    dowloadRunnables[i-1] = new DowloadRunnable(url,filePath,msgDownLatch, workDownLatch, i,startIndex,endIndex);
                    Thread thread = new Thread(dowloadRunnables[i-1]);
                    thread.start();
                    thread.setUncaughtExceptionHandler(new Thread.UncaughtExceptionHandler() {
                        @Override
                        public void uncaughtException(Thread t, Throwable e) {
                            logger.debug("catch到异常",e);
                            flag = false;
                        }
                    });
                    
                }
                //通知工作线程启动,开始工作
                msgDownLatch.countDown();
                logger.debug("主线程阻塞,等待工作线程完成任务");
                //起一个线程监控下载进度
                //moniterLength(length);
                //阻塞主线程,等待工作线程完成
                workDownLatch.await();
                
                logger.debug("工作线程完成任务,主线程继续");
                return flag;
            }
            
        } catch (Throwable e) {
            logger.error("文件下载失败:"+e.getMessage(),e);
            File file = new File(filePath);
            if(file.exists()){
                file.delete(); //下载失败 删除临时文件 
            }
        }
        return false;
    }
    //输出下载进度
    private void moniterLength(int length) {

        new Thread(new Runnable() {
            
            @Override
            public void run() {

                    while(getDownloadLength() < length){
                        logger.debug("文件大小:{},目前下载大小:{},进度{}",length,getDownloadLength(),getDownloadLength()* 1.0 / (long)length);
                        try {
                            Thread.sleep(10000);
                        } catch (InterruptedException e) {
                            // TODO Auto-generated catch block
                            e.printStackTrace();
                        }
                    }
            }
        }).start();
    }

    //监控下载进度
    public int getDownloadLength(){
        int length = 0;
        
        for (int i = 0; i < dowloadRunnables.length; i++) {
            length += dowloadRunnables[i].downloadLength;
        }
        return length;
    }
    
}

//下载线程
class DowloadRunnable implements Runnable{

    private static Logger logger = LoggerFactory.getLogger(DowloadRunnable.class);
    
    private CountDownLatch msgDownLatch;
    
    private CountDownLatch workDownLatch;
    
    private int threadIndex;
    
    private int startIndex;
    
    private int endIndex;
    
    private String url;
    
    private String filePath;
    
    public int downloadLength; //已下载大小
    
    public DowloadRunnable(String url, String filePath,
             CountDownLatch msgDownLatch,  CountDownLatch workDownLatch, 
             int threadIndex, int startIndex, int endIndex) {
        this.url = url;
        this.filePath = filePath;
        this.msgDownLatch = msgDownLatch;
        this.workDownLatch = workDownLatch;
        this.threadIndex = threadIndex;
        this.startIndex = startIndex;
        this.endIndex = endIndex;
                
    }
    
    @Override
    public void run() {

            try {
                //阻塞此线程,等待主线程给启动消息(msgDownLatch.countDown());
                msgDownLatch.await();
                //具体工作
                logger.info("线程{}任务开始",threadIndex);
                URL urlPath = new URL(url);
                HttpURLConnection conn = (HttpURLConnection)urlPath.openConnection();
                conn.setConnectTimeout(5000);
                conn.setRequestProperty("Range", "bytes=" + startIndex + "-"  
                        + endIndex);
                conn.setRequestMethod("GET");
                
                int status = conn.getResponseCode();
                logger.debug("线程{}请求返回的responseCode:{}",threadIndex,status);
                if(status==206){
                    InputStream in = conn.getInputStream();
                    RandomAccessFile raf = new RandomAccessFile(filePath, "rwd");
                    raf.seek(startIndex);
                    
                    byte[] buffer = new byte[2048];
                    int length = 0;
                    logger.debug("线程{}开始写数据,开始点{}",threadIndex,startIndex);
                    while((length = in.read(buffer)) != -1){
                        //logger.debug("线程{}读取大小:{}",threadIndex,length);
                        raf.write(buffer, 0, length);
                        //downloadLength += length;
                    }
                    raf.close();
                    in.close();
                }else{
                    logger.error("文件下载失败,状态码:"+status);
                    throw new Exception("文件下载失败,状态码:"+status);
                }
                logger.info("线程{}任务完成",threadIndex);
                //工作完成
                workDownLatch.countDown();
            } catch (Throwable e) {
                logger.error(e.getMessage(),e);
                e.printStackTrace();
            }
    }
}
MulitThreadDownload.java

看不出来啥问题,先记下来!

单独提出来下载功能代码,大文件下载还是有问题

package com.test.service;

import java.io.InputStream;
import java.io.RandomAccessFile;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.concurrent.CyclicBarrier;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

//下载线程
class DowloadThread extends Thread{

    private static Logger logger = LoggerFactory.getLogger(DowloadThread.class);
    
    private int threadIndex;
    
    private int startIndex;
    
    private int endIndex;
    
    private String url;
    
    private String filePath;
    
    CyclicBarrier barrier;
    
    public int downloadLength; //已下载大小
    
    public DowloadThread(String url, String filePath,
            int threadIndex, int startIndex, int endIndex) {
        this.url = url;
        this.filePath = filePath;
        this.threadIndex = threadIndex;
        this.startIndex = startIndex;
        this.endIndex = endIndex;
                
    }
    
    public DowloadThread(String url, String filePath,
            int threadIndex, int startIndex, int endIndex,
            final CyclicBarrier barrier) {
        this.url = url;
        this.filePath = filePath;
        this.threadIndex = threadIndex;
        this.startIndex = startIndex;
        this.endIndex = endIndex;
        this.barrier = barrier;
                
    }
    
    @Override
    public void run() {

            try {
                logger.info("线程{}任务开始",threadIndex);
                URL urlPath = new URL(url);
                HttpURLConnection conn = (HttpURLConnection)urlPath.openConnection();
                conn.setConnectTimeout(5000);
                conn.setRequestProperty("Range", "bytes=" + startIndex + "-"  
                        + endIndex);
                conn.setRequestMethod("GET");
                
                int status = conn.getResponseCode();
                logger.debug("线程{}请求返回的responseCode:{}",threadIndex,status);
                if(status==206 || status == 200){
                    InputStream in = conn.getInputStream();
                    RandomAccessFile raf = new RandomAccessFile(filePath, "rwd");
                    raf.seek(startIndex);
                    
                    byte[] buffer = new byte[2048];
                    int length = 0;
                    logger.debug("线程{}开始写数据,开始点{}",threadIndex,startIndex);
                    while((length = in.read(buffer)) != -1){
                        raf.write(buffer, 0, length);
                        downloadLength += length;
                    }
                    raf.close();
                    in.close();
                }else{
                    logger.error("文件下载失败,状态码:"+status);
                    throw new Exception("文件下载失败,状态码:"+status);
                }
                logger.info("线程{}任务完成",threadIndex);
            } catch (Throwable e) {
                logger.error(e.getMessage(),e);
                e.printStackTrace();
            }
    }
}

 下面的代码相对来说好一些:

package com.test;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.SequenceInputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;

import com.test.DownloadCallable;

/**
 * 未使用
 * */
public class DownloadService {

    private static Logger logger = LoggerFactory.getLogger(DownloadService.class);
    
    @Value("${onair.download.timeout:10000}")
    private int downloadTimeout = 10000;
    
    @Value("${onair.download.threadSize:5}")
    private int threadSize = 5;
    
    @Value("${onair.download.ddxc:false}")
    private boolean ddxc = true;

    ExecutorService executorService = Executors.newFixedThreadPool(threadSize);
    
    DownloadCallable[] callables = new DownloadCallable[threadSize];

    List<Future<String>> futures = new ArrayList<>();
    
    public static void main(String[] args) {
        
        
    }
    
    public boolean doWork(String url,String filePath){

        logger.debug("源地址:{},目标地址:{}",url,filePath);
        
        
        try {
            URL path = new URL(url);
            HttpURLConnection conn = (HttpURLConnection)path.openConnection();
            conn.setConnectTimeout(downloadTimeout);
            conn.setRequestMethod("GET");
            int status = conn.getResponseCode();
            
            if(status==200){
                int length = conn.getContentLength();
                int blockSize = length / threadSize;
                
                for (int i = 1; i <= threadSize; i++) {

                    int startIndex = blockSize*(i-1);
                    
                    int endIndex = startIndex + blockSize -1;
                    
                    if(i==threadSize){
                        endIndex = length;
                    }
                    callables[i-1] = new DownloadCallable(i, startIndex, endIndex, url, 
                            downloadTimeout, filePath, ddxc);
                    futures.add(executorService.submit(callables[i-1]));
                }
                
                for (int i = 0; i < threadSize; i++) {
                    while(!futures.get(i).isDone()){
                        int size = getDownloadSize(blockSize);
                        logger.debug("文件总大小:"+length+"==已下载:"+size+"进度:"+(float)size * 1.0/(float)length);
                        Thread.sleep(30000);
                    }
                }
                if(getDownloadSize(blockSize)==length){
                    
                    System.out.println("下载完成没报错");
                }else{
                    System.out.println("下载完成报错了");
                }
            }
            executorService.shutdown();
            
        } catch (Throwable e) {
            logger.error(e.getMessage(),e);
        }
        
        return false;
    }
    
    public boolean addFile(String filePath) throws Throwable{
        
        List<FileInputStream> list = new ArrayList<FileInputStream>(); 
        List<File> listFile = new ArrayList<File>(); 
        for (int i = 1; i <= threadSize; i++) {
            String tfile = filePath+"_tmp"+i;
            listFile.add(new File(tfile));
            FileInputStream in = new FileInputStream(new File(tfile));
            list.add(in);
        }
        
        //使用 Enumeration(列举) 将文件全部列举出来  
        Enumeration<FileInputStream> eum = Collections.enumeration(list);  
        //SequenceInputStream合并流 合并文件  
        SequenceInputStream sis = new SequenceInputStream(eum);  
        FileOutputStream fos = new FileOutputStream(new File(filePath));  
        
        byte[] by = new byte[1024];  
        
        int len;  
        
        while((len=sis.read(by)) != -1){  
            fos.write(by, 0, len);  
        }  
        fos.flush();  
        fos.close();  
        sis.close();  
        System.out.println("合并完成!");  
        
        //删除文件
        for (File file : listFile) {
           file.delete();
           }
        System.out.println("删除文件完成!"); 
        
        return true;
    }
    
    public int getDownloadSize(int blockSize){
        int length = 0;
        for (int i = 1; i <=threadSize; i++) {
            //System.out.println("线程"+i+"==已下载大小:"+callables[i].downloadSize + "进度:"+ (float)callables[i-1].downloadSize *1.0 /(float)blockSize);
            //logger.debug("线程"+i+"==需下载大小:"+blockSize+"=已下载"+callables[i-1].downloadSize + "进度:"+ (float)callables[i-1].downloadSize *1.0 /(float)blockSize);
            length += callables[i-1].downloadSize;
        }
        return length;
    }
}
DownloadService.java
package com.test;

import java.io.File;
import java.io.InputStream;
import java.io.RandomAccessFile;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.concurrent.Callable;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class DownloadCallable implements Callable<String>{

    private static Logger logger = LoggerFactory.getLogger(DownloadCallable.class);
    
    int threadIndex;
    
    int startIndex;
    
    int endIndex;
    
    int timeout;
    
    String url;
    
    String filePath;
    
    boolean ddxc;
    
    public DownloadCallable() {

    }
    
    public DownloadCallable(int threadIndex, int startIndex, int endIndex,
            String url,int timeout, String filePath,boolean ddxc) {
        this.threadIndex = threadIndex;
        this.startIndex = startIndex;
        this.endIndex = endIndex;
        this.url = url;
        this.filePath = filePath;
        this.timeout = timeout;
        this.ddxc = ddxc;
    }
    
    public DownloadCallable(int threadIndex,String url,int timeout, String filePath) {
        this.url = url;
        this.filePath = filePath;
        this.threadIndex = threadIndex;
        this.timeout = timeout;
    }
    
    public DownloadCallable(int threadIndex,String url,int timeout, String filePath,boolean ddxc) {
        this.url = url;
        this.filePath = filePath;
        this.threadIndex = threadIndex;
        this.timeout = timeout;
        this.ddxc = ddxc;
    }

    //记录已下载的大小
    public int downloadSize = 0;
    
    
    @Override
    public String call() throws Exception {

        InputStream in = null;
        RandomAccessFile raf = null;
        try {
            URL path = new URL(url);
            HttpURLConnection conn = (HttpURLConnection)path.openConnection();
            conn.setConnectTimeout(timeout);
            //conn.setReadTimeout(timeout);
            conn.setRequestMethod("GET");
            //conn.setRequestProperty("Keep-Alive", timeout+"");
            
            conn.setRequestProperty("User-Agent", "Mozilla/5.0 (X11; U; Linux i686; en-US; rv:1.9.0.3) Gecko/2008092510 Ubuntu/8.04 (hardy) Firefox/3.0.3");
            
            if(threadIndex > 0){
                //多线程
                conn.setRequestProperty("Range",  "bytes=" + startIndex + "-"  
                        + endIndex);
            }
            int status = conn.getResponseCode();
            
            if(status==200 || status == 206){ //获取到数据
                //断点续传取值逻辑
                File tmpFile = new File(filePath+"_tmp"+threadIndex);
                if(tmpFile.exists()){
                    if(ddxc){
                        downloadSize = (int) tmpFile.length();
                        if(downloadSize >= conn.getContentLength()){
                                return "success"; //下载完成了
                        }
                    }else{
                        tmpFile.delete();
                    }
                }    
                //end

                raf = new RandomAccessFile(filePath+"_tmp"+threadIndex,"rw");
                raf.seek(downloadSize);
                logger.debug("线程:"+threadIndex+"==下载开始节点:"+downloadSize+"=需下载大小::"+conn.getContentLength());
                in = conn.getInputStream();
                byte[] buffer = new byte[1024];
                int length = 0;
                while((length = in.read(buffer))!=-1){
                    raf.write(buffer, 0, length);
                    downloadSize += length;
                }
            }else{
                return null;
            }
            conn.disconnect();
            return "success";
        } catch (Throwable e) {
            logger.error(e.getMessage()+"当前线程:"+threadIndex);
            return null;
        }finally {
            in.close();
            raf.close();
        }
    }
    
}
DownloadCallable.java

还不完美!

原文地址:https://www.cnblogs.com/liangblog/p/7246881.html