using Plots, Statistics, Distributions, Printf
function sim(n; ezx = x->cdf(Normal(),x), Δ = x->x^2, covde=1, vare=2)
xd = randn(n)
x = randn(n) + xd
de = randn(n)
z = rand(n) .< ezx.(x)
derr = randn(n)
d = (xd + derr + z + de .> 0)
d1 = (xd + derr .+ 1 + de .> 0)
d0 = (xd + derr .+ 0 + de .> 0)
ϵ = de*covde + randn(n)*sqrt(vare-covde^2)
y = (Δ.(x) + de).*d + ϵ
return(y=y,x=x,z=z,d=d,Δ=(Δ.(x) + de), d0=d0, d1=d1)
end
function bols(y,d,x)
n = length(y)
X = hcat(ones(n), d, x)
return((X'*X) \ X'*y)
end
function b2sls(y,d,x,z)
n = length(y)
Z = hcat(ones(n), z, x)
X = hcat(ones(n), d, x)
iZZ = inv(Z'*Z)
XZ = X'*Z
return((XZ*iZZ*XZ') \ (XZ*iZZ*(Z'*y)))
end
function plotTE(y,d,x,z,Δ,d0,d1; ezx=x->cdf(Normal(),x))
te=scatter(x,Δ, group=[(t0,t1) for (t0,t1) in zip(d0,d1)], alpha=1.0, markersize=1,markerstrokewidth=0)
xlabel!("x")
ylabel!("Treatment Effect")
title!("Treatment Effects")
xy=scatter(x,y,group=d,markersize=1,markerstrokewidth=0)
xlabel!("x")
ylabel!("y")
title!("Observed Data")
xs = sort(x)
pz=plot(xs,ezx.(xs), xlabel="x",ylabel="P(Z=1|X)",title="P(Z|X)",legend=:none)
n = length(z)
X = hcat(ones(n),x)
lzx = X*inv(X'*X)*X'*z
scatter!(x,lzx,label="L[Z|X]",markersize=1,markerstrokewidth=0,alpha=0.5)
bo = bols(y,d,x)[2]
bi = b2sls(y,d,x,z)[2]
LATE = mean(Δ[d1.>d0])
numbers=plot(xlims=(0,1),ylims=(0,1), axis=([], false))
annotate!([(0,0.8,(@sprintf("E[y1-y0|d1>d0] = %.2f",LATE),:left)),
(0,0.6,(@sprintf("βols = %.2f",bo),:left)),
(0,0.4,(@sprintf("βiv = %.2f",bi),:left))])
plot(xy,te,pz,numbers)
end
y,x,z,d,Δ,d0,d1 = sim(5_000, Δ=x->1)
plotTE(y,d,x,z,Δ,d0,d1)