【GDOI2020模拟03.08】圣痕(二分+几何性质+线段树):

题目大意:

有n条直线,求它们两两之间的交点到(p,q)前m近的距离和。

(n le 50000, m le 10^7)

题解:

二分答案r,肯定的。

接着就是求有多少个交点到(p,q)的距离<=r

每条直线可以转换成一条线段,相当于求线段交点个数。

我所知道的最优做法扫描线+平衡树+堆,复杂度:(O((m+n)logn)),其中m为交点个数。

这个这么做显然是TlE的。

注意到这些线段全部都是以(p,q)为圆心,r为半径的圆的弦。

两条弦相交相当于对应的弧真相交(相交且不包含),这里的弧不管是优弧还是劣弧都可以。

解一元二次方程求交点,atan2求出几角区间,离散化一下,就是线段树能做的事了。

注意到这题要求的是距离和。

先二分出求出最大的r使得count(r)<=m。

再考虑求出r以内的所有交点,和(p,q)求距离加入答案,最后加上(m-count(r))*r即可。

求出所有交点:

考虑再线段树的每个点上多记一个vector,表示该区间内的直线有哪些,查询时搞一下就好了。

时间复杂度:(O(n ~ log ~n ~ log ~r + (m + n ~ log ~n)))

Code:

#include<bits/stdc++.h>
#define fo(i, x, y) for(int i = x, _b = y; i <= _b; i ++)
#define ff(i, x, y) for(int i = x, _b = y; i <  _b; i ++)
#define fd(i, x, y) for(int i = x, _b = y; i >= _b; i --)
#define ll long long
#define pp printf
#define hh pp("
")
using namespace std;

#define db double

const db eps = 1e-8;

const db pi = acos(-1);

const int N = 50005;

int n, m; db p, q;
db a[N], b[N];

db sqr(db x) { return x * x;}

db r;

db dis(db x, db y, db u, db v) {
	return sqrt(sqr(x - u) + sqr(y - v));
}

struct P {
	db x; int i;
	P(db _x = 0, int _i = 0) { x = _x, i = _i;}
};
P d[N * 2]; int d0;
int cmpd(P a, P b) {
	if(a.x == b.x) return a.i > b.i;
	return a.x < b.x;
}

int u[N], v[N];

void build() {
	d0 = 0;
	fo(i, 1, n) {
		db a0 = a[i] * a[i] + 1;
		db b0 = -2 * p + 2 * a[i] * (b[i] - q);
		db c0 = sqr(p) + sqr(b[i] - q) - sqr(r);
		db t = sqr(b0) - 4 * a0 * c0;
		if(t < 0) continue;
		t = sqrt(t);
		db x0 = (-b0 + t) / 2 / a0, x1 = (-b0 - t) / 2 / a0;
		db u = atan2(a[i] * x0 + b[i] - q, x0 - p);
		db v = atan2(a[i] * x1 + b[i] - q, x1 - p);
		if(u > v) swap(u, v);
		d[++ d0] = P(u, i);
		d[++ d0] = P(v, -i);
	}
	sort(d + 1, d + d0 + 1, cmpd);
	fo(i, 1, d0) if(d[i].i > 0)
		u[d[i].i] = i; else v[-d[i].i] = i;
}

int kq;

db sum;

void solve(int x, int y) {
	db x0 = (b[y] - b[x]) / (a[x] - a[y]);
	sum += dis(x0, a[x] * x0 + b[x], p, q);
}

#define i0 i + i
#define i1 i + i + 1
int t[N * 8], pl, pr, px, py;
vector<int> g[N * 8];
void add(int i, int x, int y) {
	if(kq) g[i].push_back(py);
	t[i] ++;
	if(x == y) return;
	int m = x + y >> 1;
	if(pl <= m) add(i0, x, m); else add(i1, m + 1, y);
	t[i] = t[i0] + t[i1];
}
void ft(int i, int x, int y) {
	if(y < pl || x > pr) return;
	if(x >= pl && y <= pr) { 
		px += t[i];
		if(kq) {
			ff(j, 0, g[i].size())
				solve(py, g[i][j]);
		}
		return;
	}
	int m = x + y >> 1;
	ft(i0, x, m); ft(i1, m + 1, y);
}

ll count(db _r) {
	r = _r;
	build();
	ll ans = 0;	
	fo(i, 1, d0 * 4) t[i] = 0;
	fo(i, 1, d0) if(d[i].i > 0) {
		pl = u[d[i].i], pr = v[d[i].i]; px = 0; py = d[i].i;
		ft(1, 1, d0);
		ans += px;
		
		pl = pr = v[d[i].i];
		add(1, 1, d0);
	}
	return ans;
}

int main() {
	freopen("stigmata.in", "r", stdin);
	freopen("stigmata.out", "w", stdout);
	scanf("%d %lf %lf %d", &n, &p, &q, &m);
	p /= 1000, q /= 1000;
	fo(i, 1, n) {
		scanf("%lf %lf", &a[i], &b[i]);
		a[i] /= 1000, b[i] /= 1000;
	}
	db ans = 0;
	for(db v = 1e6; v > 1e-11; v /= 2)
		if(count(ans + v) <= m) ans += v;
	kq = 1; int cnt = count(ans);
	pp("%.9lf
", sum + ans * (m - cnt));
}
原文地址:https://www.cnblogs.com/coldchair/p/12450700.html