リンク

http://arc015.contest.atcoder.jp/tasks/arc015_4

問題概要

0秒からクッキーを毎秒1枚つくる。毎秒、確率pで、次のN通りの効果から確率q_iで選ばれた効果が発動する。効果は持続時間t_iと倍率x_iを持っていて、次の秒からt_i回、クッキーの生成個数をx_i倍する。T秒たったときのクッキーの総数の期待値を求めよ。

制約

T<=10^5, sum q_i=1

所感

またどこかでみたような問題をつくってしまった・・しかも旬を逃した感じの出題タイミング!一度解いてしまうと簡単に見えてしまう病を発動して、最初はC問題でいいかなーなんて思っていましたがそんなことはありませんでした。

解き方

時刻kで得られるクッキーの個数をf(k)とします。\sum_{i=0}^{T-1} f(k)が答え。
まず金色クッキーの効果が1種類、持続時間が無限の場合を考えると、重複している効果の数の分布は
↓time,効果数→ 0 1 2 3
0 1 0 0 0
1 1-p p 0 0
2 (1-p)^2 2p(1-p) p^2 0
3 (1-p)^3 3p(1-p)^2 3p^2(1-p) p^3
という感じで2項係数の分布になります。したがってk秒目で得られるクッキーの期待値は
f(k)
=1\cdot \binom{k}{0}(1-p)^k+x\cdot \binom{k}{1}(1-p)^{k-1}p+x^2\cdot \binom{k}{2}(1-p)^{k-2}p^2+\cdots +x^k\cdot \binom{k}{k}p^k
=(1-p+px)^kとなります。

次に持続時間を有限にします。たとえば持続時間を3にした場合、timeが3から4になるときにtime=0でクリックされた分は破棄されます。すると、こんなかんじになり、
(1-p)^2 2p(1-p) p^2 0
これにまた新たな金色クッキー1枚分の遷移が入るので、結局変わりません。
(1-p)^3 3p(1-p)^2 3p^2(1-p) p^3
というわけでkが持続時間t以上であれば、つねにk=tのときの結果になります。f(k)=(1-p+px)^t.

この性質は効果の種類が増えてもかわらなくて、持続時間が全部同じ(=t)場合は、
f(k)=(1-p+p(q_1x_1+q_2x_2+\cdots +q_nx_n))^{\min(k,t)}になります。

ここで持続時間をばらばらにするとき、最初にやったように前の状態に戻すのではなく、確率分布はそのままで持続時間がすぎたものの倍率を1にしてやるとよくて、t_1&lt;t_2&lt;\cdots &lt;t_nとしてt_j&lt;=k&lt;t_{j+1}なら、
f(k)=
(1-p+p(q_1x_1+q_2x_2+\cdots +q_nx_n))^{t_1}
*(1-p+p(q_1\cdot 1+q_2x_2+\cdots +q_nx_n))^{t_2}*
\cdots
*(1-p+p(q_1\cdot 1+q_2\cdot 1+\cdots +q_j\cdot 1+q_{j+1}x_{j+1}+\cdots +q_nx_n))^{k-t_j}
となります。めんどくさそうな式ですが、効果をt_i昇順にソートしておいて、k=0から順々に尺取式にかけていけばO(Nlog N+T)で求められます。

Tが大きい場合はもっとオーダーを良くできますが、本問の場合はそこまでしなくても良いです。

コード

writer解(Java)
+ ...
package arc15;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Comparator;
import java.util.InputMismatchException;
 
public class D {
	static InputStream is;
	static PrintWriter out;
	static String INPUT = "";
 
	static class Cookie
	{
		public double q;
		public int x, t;
	}
 
	static void solve()
	{
		int t = ni(), n = ni();
		double P = nd();
 
		Cookie[] cookies = new Cookie[n];
		for(int i = 0;i < n;i++){
			cookies[i] = new Cookie();
			cookies[i].q = nd();
			cookies[i].x = ni();
			cookies[i].t = ni();
		}
		Arrays.sort(cookies, new Comparator<Cookie>() {
			public int compare(Cookie a, Cookie b) {
				return a.t - b.t;
			}
		});
		double ret = 0;
		int p = 0;
		double sum = 1-P;
		for(int i = 0;i < n;i++){
			sum += cookies[i].q * P * cookies[i].x;
		}
 
		double mul = 1;
		for(int i = 0;i < t;i++){
			ret += mul;
			while(p < n && cookies[p].t <= i){
				sum -= cookies[p].q * P * cookies[p].x;
				sum += cookies[p].q * P;
				p++;
			}
			mul *= sum;
		}
		out.printf("%.9f\n", ret);
	}
 
	public static void main(String[] args) throws Exception
	{
		long S = System.currentTimeMillis();
		is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
		out = new PrintWriter(System.out);
 
		solve();
		out.flush();
		long G = System.currentTimeMillis();
		tr(G-S+"ms");
	}
 
	private static boolean eof()
	{
		if(lenbuf == -1)return true;
		int lptr = ptrbuf;
		while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false;
 
		try {
			is.mark(1000);
			while(true){
				int b = is.read();
				if(b == -1){
					is.reset();
					return true;
				}else if(!isSpaceChar(b)){
					is.reset();
					return false;
				}
			}
		} catch (IOException e) {
			return true;
		}
	}
 
	private static byte[] inbuf = new byte[1024];
	static int lenbuf = 0, ptrbuf = 0;
 
	private static int readByte()
	{
		if(lenbuf == -1)throw new InputMismatchException();
		if(ptrbuf >= lenbuf){
			ptrbuf = 0;
			try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
			if(lenbuf <= 0)return -1;
		}
		return inbuf[ptrbuf++];
	}
 
	private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
	private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
 
	private static double nd() { return Double.parseDouble(ns()); }
	private static char nc() { return (char)skip(); }
 
	private static String ns()
	{
		int b = skip();
		StringBuilder sb = new StringBuilder();
		while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
			sb.appendCodePoint(b);
			b = readByte();
		}
		return sb.toString();
	}
 
	private static char[] ns(int n)
	{
		char[] buf = new char[n];
		int b = skip(), p = 0;
		while(p < n && !(isSpaceChar(b))){
			buf[p++] = (char)b;
			b = readByte();
		}
		return n == p ? buf : Arrays.copyOf(buf, p);
	}
 
	private static char[][] nm(int n, int m)
	{
		char[][] map = new char[n][];
		for(int i = 0;i < n;i++)map[i] = ns(m);
		return map;
	}
 
	private static int[] na(int n)
	{
		int[] a = new int[n];
		for(int i = 0;i < n;i++)a[i] = ni();
		return a;
	}
 
	private static int ni()
	{
		int num = 0, b;
		boolean minus = false;
		while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
		if(b == '-'){
			minus = true;
			b = readByte();
		}
 
		while(true){
			if(b >= '0' && b <= '9'){
				num = num * 10 + (b - '0');
			}else{
				return minus ? -num : num;
			}
			b = readByte();
		}
	}
 
	private static long nl()
	{
		long num = 0;
		int b;
		boolean minus = false;
		while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
		if(b == '-'){
			minus = true;
			b = readByte();
		}
 
		while(true){
			if(b >= '0' && b <= '9'){
				num = num * 10 + (b - '0');
			}else{
				return minus ? -num : num;
			}
			b = readByte();
		}
	}
 
	private static void tr(Object... o) { if(INPUT.length() != 0)System.out.println(Arrays.deepToString(o)); }
}
 

tester解(C)
+ ...
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<string.h>
#include<assert.h>
#define REP(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) REP(i,0,n)
 
#define ll long long
 
ll read_int(ll mn, ll mx, char next){
  ll c, fg = 1, res = 0;
  c=getchar();
  if(c=='-') fg = -1, c = getchar();
  assert('0'<=c && c<='9');
  res = c - '0';
  for(;;){
    c = getchar();
    if(c==next) break;
    assert(res!=0);
    assert('0'<=c && c<='9');
    res = res * 10 + (c - '0');
  }
  res *= fg;
  assert(mn<=res && res<=mx);
  return res;
}
 
double read_double(double mn, double mx, int mx_digits, char next){
  int c, fg = 1, d;
  double up = 0, down = 0, res, add;
 
  c = getchar();
  if(c=='-') fg = -1, c = getchar();
 
  up = c - '0';
  for(;;){
    c = getchar();
    if(c==next || c=='.') break;
    assert(up!=0);
    assert('0'<=c && c<='9');
    res = res * 10 + (c - '0');
  }
 
  if(c=='.'){
    d = 1; add = 0.1;
    for(;;){
      c = getchar();
      if(c==next) break;
      assert(mx_digits >= d);
      assert('0'<=c && c<='9');
      down += add * (c - '0');
      d++; add /= 10;
    }
  }
 
  res = up + down;
  assert(res!=0 || fg==1);
  res *= fg;
  assert(mn <= res && res <= mx); /* be careful about the precision errors */
  return res;
}
 
 
 
 
double q[200000]; int x[200000], t[200000];
double multi[200000];
 
int main(){
  int T, N; double P;
  int i, j, k;
  double res, sum, add;
 
  T = read_int(1, 100000, ' ');
  N = read_int(1, 10000, ' ');
  P = read_double(0, 1, 6, '\n');
 
  rep(i,N){
    q[i] = read_double(0, 1, 6, ' ');
    x[i] = read_int(1, 1000, ' ');
    t[i] = read_int(1, 100000, '\n');
  }
 
  sum = 0;
  rep(i,N) sum += q[i];
  assert(1 - 1e-13 < sum && sum < 1 + 1e-13);
 
  rep(i,N) q[i] *= P;
 
  rep(i,100102) multi[i] = 0;
  rep(i,N) multi[t[i]] += q[i] * (x[i] - 1);
  for(i=100100;i;i--) multi[i] += multi[i+1];
 
  res = 0; add = 1;
  rep(i,T){
    add *= (1 + multi[i]);
    res += add;
  }
 
  assert( res <= 1e100 );
  printf("%.10f\n", res);
 
  {
    char c;
    assert( scanf("%c",&c)!=1 );
  }
 
  return 0;
}
 

タグ:

ARC 解説
最終更新:2013年10月05日 22:38