基于SparkGrapX的自定义加权网络的最短路径规划

0 背景

实际工作中,需要使用最短路径算法,之前一直使用neo4j中的函数,想要和大数据平台结合,就想到了sparkGraphX,之前基本只使用python,不熟悉java和Scala的开发,多方查阅和学习,特此做个记录。

1 关于开发环境

idea-scala + spark的jar包,在scala工程中导入spark的jar包,就可以使用spark相关的函数

2 网络数据准备

为了便于迁移,这里使用CSV文件存储网络的节点和边。 节点数据nodes.csv如下:

node_id,nodes
1,v1
2,v2
3,v3
4,v4
5,v5
6,v6
7,v7

边数据edges.csv如下: 

source,target,length
1,2,2
1,4,1
2,4,3
2,5,10
4,5,2
4,3,2
4,6,8
4,7,4
3,1,4
3,6,5
5,7,6
7,6,1

3 网络构建
读取节点和边的代码如下:

class fileExample{

def fileReader(string: String): Unit ={
// 读取csv文件内容
val ofile = Source.fromFile(string)
val lines = ofile.getLines()
lines.foreach(println)
}

def readerToArray(string: String):Array[String]={
val ofile = Source.fromFile(string)
val lines = ofile.getLines()
lines.toArray
}
}


class getGraph {
def nodes():Seq[(Long, String)]={
val infile = "./dataset/nodes.csv"
val obj = new fileExample()

val context = obj.readerToArray(infile)
println("nodes_context:"+context.length)
context.foreach(println)
var seq = Seq((0L, ""))
for (line <- context.slice(1, context.length)){
var nid = line.split(",")(0)
var nme = line.split(",")(1)
seq = seq :+ (nid.toLong, nme)
}
return seq.slice(1, seq.length)
}

def edges():Seq[(Long, Long, Int)]={
val infile = "./dataset/edges.csv"
val obj = new fileExample()

val context = obj.readerToArray(infile)
println("edge_context:"+context.length)
context.foreach(println)
var seq = Seq((0L, 0L, 0))
for (line <- context.slice(1, context.length)){
var fid = line.split(",")(0).toLong
var tid = line.split(",")(1).toLong
var wht = line.split(",")(2).toInt
seq = seq :+ (fid, tid, wht)
}
return seq.slice(1, seq.length)

}
}

构建SparkGraphX的图的代码如下:

class graphExample {
val conf = new SparkConf().setAppName("Example").setMaster("local")
val sc = new SparkContext(conf)

def example(): Unit = {
println("start")

val graph = new getGraph()
var nodes = graph.nodes()
nodes.foreach(println)
println("->nodes
->edges")

var edges = graph.edges()
edges.foreach(println)

// val nn = Seq((1L, ("Alice", 27)),(2L, ("Bob", 27)))
var nn = Seq((0L, ("0", 0L)))
for (node <- nodes) {
nn = nn :+ (node._1, (node._2, node._1))
}

val gnodes: RDD[(Long, (String, Long))] = sc.parallelize(nn.slice(1, nn.length - 1))

// val gg = Seq(Edge(2L, 1L, 7), Edge(1L, 2L, 2))

var gg = Seq(Edge(0L, 0L, 0))
for (e <- edges) {
gg = gg :+ Edge(e._1, e._2, e._3)
}

var gedges: RDD[Edge[Int]] = sc.parallelize(gg.slice(1, gg.length))

val gx: Graph[(String, Long), Int] = Graph(gnodes, gedges)
// 测试图
val tmp = gx.edges.filter { case Edge(f, t, w) => w > 3 }.count
println("tmp:" + tmp)

4 路径查询

基于构建的graphX进行最短路径查询的过程如下:

    // Initialize the graph
    val sourceId : VertexId = 1L
    val initialGraph: Graph[(Double, List[VertexId]), Int] = gx.mapVertices((id, _) =>
      if (id == sourceId) (0.0, List[VertexId](sourceId))
      else (Double.PositiveInfinity, List[VertexId]()))

    val sssp = initialGraph.pregel((Double.PositiveInfinity, List[VertexId]()), Int.MaxValue, EdgeDirection.Out)(

      // Vertex Program
      (id, dist, newDist) => if (dist._1 < newDist._1) dist else newDist,
      // Send Message
      triplet => {
        if (triplet.srcAttr._1 < triplet.dstAttr._1 - triplet.attr) {
          Iterator((triplet.dstId, (triplet.srcAttr._1 + triplet.attr, triplet.srcAttr._2 :+ triplet.dstId)))
        } else {
          Iterator.empty
        }
      },
      //Merge Message
      (a, b) => if (a._1 < b._1) a else b)
    println(sssp.vertices.collect.mkString("
"))
    //println(sssp.vertices.filter{case(id,v) => id ==3})
    val end_ID = 6L
    println(end_ID)
    println(sssp.vertices.collect.filter{case(id,v) => id == end_ID}.mkString("
"))

5 完整代码

整个DEMO的完整代码如下:

import org.apache.spark.graphx._
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext
import org.apache.spark.SparkConf
import org.apache.spark.graphx.lib.ShortestPaths
import scala.io.Source

object graphExample{
  def main(args: Array[String]): Unit = {
    val exam = new graphExample()
    exam.example()


  }
}

class graphExample {
  val conf = new SparkConf().setAppName("Example").setMaster("local")
  val sc = new SparkContext(conf)

  def example(): Unit = {
    println("start")

    val graph = new getGraph()
    var nodes = graph.nodes()
    nodes.foreach(println)
    println("->nodes
->edges")

    var edges = graph.edges()
    edges.foreach(println)

    //    val nn = Seq((1L, ("Alice", 27)),(2L, ("Bob", 27)))
    var nn = Seq((0L, ("0", 0L)))
    for (node <- nodes) {
      nn = nn :+ (node._1, (node._2, node._1))
    }

    val gnodes: RDD[(Long, (String, Long))] = sc.parallelize(nn.slice(1, nn.length - 1))

    //    val gg = Seq(Edge(2L, 1L, 7), Edge(1L, 2L, 2))

    var gg = Seq(Edge(0L, 0L, 0))
    for (e <- edges) {
      gg = gg :+ Edge(e._1, e._2, e._3)
    }

    var gedges: RDD[Edge[Int]] = sc.parallelize(gg.slice(1, gg.length))

    val gx: Graph[(String, Long), Int] = Graph(gnodes, gedges)
    val tmp = gx.edges.filter { case Edge(f, t, w) => w > 3 }.count
    println("tmp:" + tmp)

val sourceId : VertexId = 1L
    val initialGraph: Graph[(Double, List[VertexId]), Int] = gx.mapVertices((id, _) =>
      if (id == sourceId) (0.0, List[VertexId](sourceId))
      else (Double.PositiveInfinity, List[VertexId]()))

    val sssp = initialGraph.pregel((Double.PositiveInfinity, List[VertexId]()), Int.MaxValue, EdgeDirection.Out)(

      // Vertex Program
      (id, dist, newDist) => if (dist._1 < newDist._1) dist else newDist,
      // Send Message
      triplet => {
        if (triplet.srcAttr._1 < triplet.dstAttr._1 - triplet.attr) {
          Iterator((triplet.dstId, (triplet.srcAttr._1 + triplet.attr, triplet.srcAttr._2 :+ triplet.dstId)))
        } else {
          Iterator.empty
        }
      },
      //Merge Message
      (a, b) => if (a._1 < b._1) a else b)
    println(sssp.vertices.collect.mkString("
"))
    //println(sssp.vertices.filter{case(id,v) => id ==3})
    val end_ID = 6L
    println(end_ID)
    println(sssp.vertices.collect.filter{case(id,v) => id == end_ID}.mkString("
"))
//    for (elem <- edges) {println(elem)}



  }
}

class fileExample{

  def fileReader(string: String): Unit ={
    // 读取文件内容
    val ofile = Source.fromFile(string)
    val lines = ofile.getLines()
    lines.foreach(println)
  }

  def readerToArray(string: String):Array[String]={
    val ofile = Source.fromFile(string)
    val lines = ofile.getLines()
    lines.toArray
  }
}


class getGraph {
  def nodes():Seq[(Long, String)]={
    val infile = "./dataset/nodes.csv"
    val obj = new fileExample()

    val context = obj.readerToArray(infile)
    println("nodes_context:"+context.length)
    context.foreach(println)
    var seq = Seq((0L, ""))
    for (line <- context.slice(1, context.length)){
      var nid = line.split(",")(0)
      var nme = line.split(",")(1)
      seq = seq :+ (nid.toLong, nme)
    }
    return seq.slice(1, seq.length)
  }

  def edges():Seq[(Long, Long, Int)]={
    val infile = "./dataset/edges.csv"
    val obj = new fileExample()

    val context = obj.readerToArray(infile)
    println("edge_context:"+context.length)
    context.foreach(println)
    var seq = Seq((0L, 0L, 0))
    for (line <- context.slice(1, context.length)){
      var fid = line.split(",")(0).toLong
      var tid = line.split(",")(1).toLong
      var wht = line.split(",")(2).toInt
      seq = seq :+ (fid, tid, wht)
    }
    return seq.slice(1, seq.length)

  }

}
原文地址:https://www.cnblogs.com/ddzhen/p/15324179.html