リンク
http://arc015.contest.atcoder.jp/tasks/arc015_4
問題概要
0秒からクッキーを毎秒1枚つくる。毎秒、確率pで、次のN通りの効果から確率q_iで選ばれた効果が発動する。効果は持続時間t_iと倍率x_iを持っていて、次の秒からt_i回、クッキーの生成個数をx_i倍する。T秒たったときのクッキーの総数の期待値を求めよ。
制約
T<=10^5, sum q_i=1
所感
またどこかでみたような問題をつくってしまった・・しかも旬を逃した感じの出題タイミング!一度解いてしまうと簡単に見えてしまう病を発動して、最初はC問題でいいかなーなんて思っていましたがそんなことはありませんでした。
解き方
時刻kで得られるクッキーの個数をf(k)とします。
が答え。
まず金色クッキーの効果が1種類、持続時間が無限の場合を考えると、重複している効果の数の分布は
↓time,効果数→ |
0 |
1 |
2 |
3 |
0 |
|
0 |
0 |
0 |
1 |
|
|
0 |
0 |
2 |
|
|
|
0 |
3 |
|
|
|
|
という感じで2項係数の分布になります。したがってk秒目で得られるクッキーの期待値は
となります。
次に持続時間を有限にします。たとえば持続時間を3にした場合、timeが3から4になるときにtime=0でクリックされた分は破棄されます。すると、こんなかんじになり、
これにまた新たな金色クッキー1枚分の遷移が入るので、結局変わりません。
というわけでkが持続時間t以上であれば、つねにk=tのときの結果になります。
この性質は効果の種類が増えてもかわらなくて、持続時間が全部同じ(=t)場合は、
になります。
ここで持続時間をばらばらにするとき、最初にやったように前の状態に戻すのではなく、確率分布はそのままで持続時間がすぎたものの倍率を1にしてやるとよくて、
として
なら、
となります。めんどくさそうな式ですが、効果をt_i昇順にソートしておいて、k=0から順々に尺取式にかけていけばO(Nlog N+T)で求められます。
Tが大きい場合はもっとオーダーを良くできますが、本問の場合はそこまでしなくても良いです。
コード
writer解(Java)
+
|
... |
package arc15;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.util.Arrays;
import java.util.Comparator;
import java.util.InputMismatchException;
public class D {
static InputStream is;
static PrintWriter out;
static String INPUT = "";
static class Cookie
{
public double q;
public int x, t;
}
static void solve()
{
int t = ni(), n = ni();
double P = nd();
Cookie[] cookies = new Cookie[n];
for(int i = 0;i < n;i++){
cookies[i] = new Cookie();
cookies[i].q = nd();
cookies[i].x = ni();
cookies[i].t = ni();
}
Arrays.sort(cookies, new Comparator<Cookie>() {
public int compare(Cookie a, Cookie b) {
return a.t - b.t;
}
});
double ret = 0;
int p = 0;
double sum = 1-P;
for(int i = 0;i < n;i++){
sum += cookies[i].q * P * cookies[i].x;
}
double mul = 1;
for(int i = 0;i < t;i++){
ret += mul;
while(p < n && cookies[p].t <= i){
sum -= cookies[p].q * P * cookies[p].x;
sum += cookies[p].q * P;
p++;
}
mul *= sum;
}
out.printf("%.9f\n", ret);
}
public static void main(String[] args) throws Exception
{
long S = System.currentTimeMillis();
is = INPUT.isEmpty() ? System.in : new ByteArrayInputStream(INPUT.getBytes());
out = new PrintWriter(System.out);
solve();
out.flush();
long G = System.currentTimeMillis();
tr(G-S+"ms");
}
private static boolean eof()
{
if(lenbuf == -1)return true;
int lptr = ptrbuf;
while(lptr < lenbuf)if(!isSpaceChar(inbuf[lptr++]))return false;
try {
is.mark(1000);
while(true){
int b = is.read();
if(b == -1){
is.reset();
return true;
}else if(!isSpaceChar(b)){
is.reset();
return false;
}
}
} catch (IOException e) {
return true;
}
}
private static byte[] inbuf = new byte[1024];
static int lenbuf = 0, ptrbuf = 0;
private static int readByte()
{
if(lenbuf == -1)throw new InputMismatchException();
if(ptrbuf >= lenbuf){
ptrbuf = 0;
try { lenbuf = is.read(inbuf); } catch (IOException e) { throw new InputMismatchException(); }
if(lenbuf <= 0)return -1;
}
return inbuf[ptrbuf++];
}
private static boolean isSpaceChar(int c) { return !(c >= 33 && c <= 126); }
private static int skip() { int b; while((b = readByte()) != -1 && isSpaceChar(b)); return b; }
private static double nd() { return Double.parseDouble(ns()); }
private static char nc() { return (char)skip(); }
private static String ns()
{
int b = skip();
StringBuilder sb = new StringBuilder();
while(!(isSpaceChar(b))){ // when nextLine, (isSpaceChar(b) && b != ' ')
sb.appendCodePoint(b);
b = readByte();
}
return sb.toString();
}
private static char[] ns(int n)
{
char[] buf = new char[n];
int b = skip(), p = 0;
while(p < n && !(isSpaceChar(b))){
buf[p++] = (char)b;
b = readByte();
}
return n == p ? buf : Arrays.copyOf(buf, p);
}
private static char[][] nm(int n, int m)
{
char[][] map = new char[n][];
for(int i = 0;i < n;i++)map[i] = ns(m);
return map;
}
private static int[] na(int n)
{
int[] a = new int[n];
for(int i = 0;i < n;i++)a[i] = ni();
return a;
}
private static int ni()
{
int num = 0, b;
boolean minus = false;
while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}
while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}
private static long nl()
{
long num = 0;
int b;
boolean minus = false;
while((b = readByte()) != -1 && !((b >= '0' && b <= '9') || b == '-'));
if(b == '-'){
minus = true;
b = readByte();
}
while(true){
if(b >= '0' && b <= '9'){
num = num * 10 + (b - '0');
}else{
return minus ? -num : num;
}
b = readByte();
}
}
private static void tr(Object... o) { if(INPUT.length() != 0)System.out.println(Arrays.deepToString(o)); }
}
|
tester解(C)
+
|
... |
#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<string.h>
#include<assert.h>
#define REP(i,a,b) for(i=a;i<b;i++)
#define rep(i,n) REP(i,0,n)
#define ll long long
ll read_int(ll mn, ll mx, char next){
ll c, fg = 1, res = 0;
c=getchar();
if(c=='-') fg = -1, c = getchar();
assert('0'<=c && c<='9');
res = c - '0';
for(;;){
c = getchar();
if(c==next) break;
assert(res!=0);
assert('0'<=c && c<='9');
res = res * 10 + (c - '0');
}
res *= fg;
assert(mn<=res && res<=mx);
return res;
}
double read_double(double mn, double mx, int mx_digits, char next){
int c, fg = 1, d;
double up = 0, down = 0, res, add;
c = getchar();
if(c=='-') fg = -1, c = getchar();
up = c - '0';
for(;;){
c = getchar();
if(c==next || c=='.') break;
assert(up!=0);
assert('0'<=c && c<='9');
res = res * 10 + (c - '0');
}
if(c=='.'){
d = 1; add = 0.1;
for(;;){
c = getchar();
if(c==next) break;
assert(mx_digits >= d);
assert('0'<=c && c<='9');
down += add * (c - '0');
d++; add /= 10;
}
}
res = up + down;
assert(res!=0 || fg==1);
res *= fg;
assert(mn <= res && res <= mx); /* be careful about the precision errors */
return res;
}
double q[200000]; int x[200000], t[200000];
double multi[200000];
int main(){
int T, N; double P;
int i, j, k;
double res, sum, add;
T = read_int(1, 100000, ' ');
N = read_int(1, 10000, ' ');
P = read_double(0, 1, 6, '\n');
rep(i,N){
q[i] = read_double(0, 1, 6, ' ');
x[i] = read_int(1, 1000, ' ');
t[i] = read_int(1, 100000, '\n');
}
sum = 0;
rep(i,N) sum += q[i];
assert(1 - 1e-13 < sum && sum < 1 + 1e-13);
rep(i,N) q[i] *= P;
rep(i,100102) multi[i] = 0;
rep(i,N) multi[t[i]] += q[i] * (x[i] - 1);
for(i=100100;i;i--) multi[i] += multi[i+1];
res = 0; add = 1;
rep(i,T){
add *= (1 + multi[i]);
res += add;
}
assert( res <= 1e100 );
printf("%.10f\n", res);
{
char c;
assert( scanf("%c",&c)!=1 );
}
return 0;
}
|
最終更新:2013年10月05日 22:38