P6498 Zamjene 题解

[COCI2016-2017#2] Zamjene

对于可交换集中的元素,可以用类似于冒泡的方法将他们排序,则如果一个交换集中的所有元素与其对应位置的 q 完全重合,则这个交换集可以对 p 的排序做出贡献。反之,若有一个交换集无法满足对应关系,则无法对 p 排序。

想要快速判断交换集中元素是否与 q 中对应位置元素相等,可以用哈希。数 xx 对应的哈希值为 xp1x^{p_1},模数为 p2p_2,然后一个集合内的元素哈希值加进来即为集合的哈希值。

对于每一个交换集,记录 valival_i 表示集合中元素哈希值,preipre_i 表示其在 q 中对应位置的哈希值。

对于操作 3,记录一个 numnum 表示 prei=valipre_i=val_i 的个数,每合并一个集合就把 numnum 加 1,num=nnum=n 则表示可以完成对 p 的排序。

对于操作 4,记录 ansans 表示答案,并把无序数对转为有序因为好处理。我们要求的是 prei+prej=vali+valjpre_i+pre_j=val_i+val_jvaliprei,valjprejval_i\neq pre_i,val_j\neq pre_j 的个数。移项,可得 preivali=valjprejpre_i-val_i=val_j-pre_j,可以把每个 preivalipre_i-val_i 丢进 map 里面查询。

每次操作前把 a,ba,b 所在交换集对 numnumansans 的影响减掉,注意还要乘 2。若 a,ba,b 已在一个集合内,则直接跳过。当 prea+preb=vala+valbpre_a+pre_b=val_a+val_b 时,还要减掉重复算的。注意 map 加入或减掉的是集合的大小而不是 1。

把上面的操作搞完后,我得到了 68 分好成绩。应该是我运气太好了,模数取烂了。我搞了个双哈希就过了。

codecode

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e6+50,ba=10000019,baa=10000079,mod=1000000513;

struct node
{
	int a,b;
	
	friend node operator+(node a,node b)
	{
		return (node){(a.a+b.a)%mod,(a.b+b.b)%mod};
	}
	
	friend node operator-(node a,node b)
	{
		return (node){(a.a-b.a+mod)%mod,(a.b-b.b+mod)%mod};
	}
	
	friend bool operator==(node a,node b)
	{
		return (a.a==b.a)&(a.b==b.b);
	}
	
	friend bool operator!=(node a,node b)
	{
		return (a.a!=b.a)|(a.b!=b.b);
	}
	
	friend bool operator<(node a,node b)
	{
		return a.a==b.a?a.b<b.b:a.a<b.a;
	}
}val[N],pre[N];

int n,Q,p[N],q[N],ro[N],num,len[N],ans;
map<node,int>mp;

int ksm(int a,int b)
{
	int ans=1;
	while(b)
	{
		if(b&1)ans=ans*a%mod;
		a=a*a%mod;
		b>>=1;
	}
	return ans;
}

int find(int x)
{
	return x==ro[x]?x:(ro[x]=find(ro[x]));
}

main()
{
	cin>>n>>Q;
	for(int i=1;i<=n;i++)cin>>p[i],q[i]=p[i],ro[i]=i,val[i]=(node){ksm(p[i],ba),ksm(p[i],baa)},len[i]=1;
	sort(q+1,q+1+n);
	for(int i=1;i<=n;i++)pre[i]=(node){ksm(q[i],ba),ksm(q[i],baa)},num+=(pre[i]==val[i]),mp[pre[i]-val[i]]++;
	for(int i=1;i<=n;i++)if(pre[i]!=val[i])ans+=mp[val[i]-pre[i]];
	int opt,a,b;
	while(Q--)
	{
		cin>>opt;
		if(opt==1)
		{
			cin>>a>>b;
			int ra=find(a),rb=find(b);
			if(ra==rb){swap(p[a],p[b]);continue;}
			num-=(pre[ra]==val[ra])+(pre[rb]==val[rb]);
			if(pre[ra]!=val[ra])ans-=2*len[ra]*mp[val[ra]-pre[ra]];
			if(pre[rb]!=val[rb])ans-=2*len[rb]*mp[val[rb]-pre[rb]];
			if(pre[ra]+pre[rb]==val[ra]+val[rb]&&pre[ra]!=val[ra])ans+=2*len[ra]*len[rb];
			mp[pre[ra]-val[ra]]-=len[ra];mp[pre[rb]-val[rb]]-=len[rb];
			val[ra]=val[ra]-(node){ksm(p[a],ba),ksm(p[a],baa)};
			val[rb]=val[rb]+(node){ksm(p[a],ba),ksm(p[a],baa)};
			val[ra]=val[ra]+(node){ksm(p[b],ba),ksm(p[b],baa)};
			val[rb]=val[rb]-(node){ksm(p[b],ba),ksm(p[b],baa)};
			swap(p[a],p[b]);
			num+=(pre[ra]==val[ra])+(pre[rb]==val[rb]);
			mp[pre[ra]-val[ra]]+=len[ra];mp[pre[rb]-val[rb]]+=len[rb];
			if(pre[ra]!=val[ra])ans+=2*len[ra]*mp[val[ra]-pre[ra]];
			if(pre[rb]!=val[rb])ans+=2*len[rb]*mp[val[rb]-pre[rb]];
			if(pre[ra]+pre[rb]==val[ra]+val[rb]&&pre[ra]!=val[ra])ans-=2*len[ra]*len[rb];
			continue;
		}
		if(opt==2)
		{
			cin>>a>>b;
			a=find(a),b=find(b);
			if(a==b)continue;
			num-=(pre[a]==val[a])+(pre[b]==val[b]);num++;
			if(pre[a]!=val[a])ans-=2*len[a]*mp[val[a]-pre[a]];
			if(pre[b]!=val[b])ans-=2*len[b]*mp[val[b]-pre[b]];
			if(pre[a]+pre[b]==val[a]+val[b]&&pre[a]!=val[a])ans+=2*len[a]*len[b];
			mp[pre[a]-val[a]]-=len[a];mp[pre[b]-val[b]]-=len[b];
			val[a]=val[a]+val[b];
			pre[a]=pre[a]+pre[b];
			ro[a]=ro[b]=a;len[a]+=len[b];
			num+=(pre[a]==val[a]);
			mp[pre[a]-val[a]]+=len[a];
			if(pre[a]!=val[a])ans+=2*len[a]*mp[val[a]-pre[a]];
			continue;
		}
		if(opt==3)
		{
			if(num==n)puts("DA");
			else puts("NE");
			continue;
		}
		cout<<ans/2<<'\n';
	}
}