【Python】数据库查询 & 批量下载文件

这几天看书有点看毛了,但是网络编程那块儿也不是很熟。先按老师需求写个工具玩玩吧_(:3」∠)_

  • 需求:根据游戏批量下载代码
  • python版本:3.5.2
  • 数据库:Mongodb
  • 文件传输:SFTP批量下载文件

最近感觉自己写代码有点丑,也不精简,看来要多阅读一些优质代码_(:3」∠)_

老师要求批量下载cpp代码,cpp的名字是一个数据库id(ObjectId)。文件的组织方式是4层16叉的Trie树(从右往左4个字母,取值为0,1,2,...,D,E,F)。

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Author: Katrin
# Create Time: 2020/04/22

from bson import ObjectId
from pymongo import MongoClient
import os, sys, time, getpass
import paramiko

def get_group():
    # group id
    group = [
        'group id 1',
        'group id 2'
    ]
    return [ObjectId(i) for i in group]

def get_queries(group_list):
    queries = []
    for group in group_list:
        query = {
            'group': group,
            'ranked': True
        }
        queries.append(query)
    return queries

def get_ssh():
    user = 'user'
    ip = 'ip'
    passwd = 'passwd'

    ssh = paramiko.SSHClient()
    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    ssh.connect(ip, 22, user, passwd)
    if_success = ssh.get_transport().is_active()
    assert if_success == True, 'SSH failed'
    return ssh

def get_sftp(ssh):
    sftp = paramiko.SFTPClient.from_transport(ssh.get_transport())
    # 可替代写法 sftp = ssh.open_sftp()
    return sftp

# SSHClient的exec_command是单会话,在下次运行时会重置
def simple_pyshell(ssh, cmd):
    stdin, stdout, stderr = ssh.exec_command(cmd + ';ls')
    stdout_lines = stdout.readlines()
    stderr_lines = stderr.readlines()
    #print(stdout_lines)
    #print(stderr_lines)
    return stdout_lines, stderr_lines

def mkdir_valid(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)   
    return path

# 如不查询,已知file_list
# 也可以直接批量下载
def download_file(ssh, sftp, download_path, local_path, obj_file):
    remote_path = download_path + ''.join([obj_file[-i] + '/' for i in range(1, 5)])
    cmd = 'cd ' + remote_path
    stdout_lines, _ = simple_pyshell(ssh, cmd)
    stdout_lines = [_file for _file in stdout_lines if '.cpp' in _file and obj_file in _file]
    if stdout_lines == []:
        print('obj_file {fileid} not exist for group {groupid}'.format(
            fileid=obj_file, groupid=str(query['group'])))
        return False
    stdout_lines = stdout_lines[0].replace('
', '')
    sftp.get(remote_path + stdout_lines, local_path + stdout_lines)
    return True

if __name__ == "__main__":
    group_list = get_group()
    queries = get_queries(group_list)
    
    client = MongoClient('mongodb://user:passwd@ip:port/')
    db = client.prod

    ssh = get_ssh()
    sftp = get_sftp(ssh)
    
    download_path = 'path'
    save_path = mkdir_valid('codes/')

    for query in queries:
        # {'versions': 1} 表示返回时只包含'_id'和'versions'域
        res = db.files.find(query, {'versions': 1})
        files = [str(fid['versions'][-1]) for fid in res]
        local_path = mkdir_valid(save_path + str(query['group']) + '/')
        for obj_file in files:
            download_file(ssh, sftp, download_path, local_path, obj_file)
    sftp.close()
    client.close()

(注:我替换了一些词,如果发现代码有错,麻烦下面评论区通知,谢谢!)

原文地址:https://www.cnblogs.com/zhouys96/p/12753595.html