#!/usr/bin/env python3

import sys
import glob
from pathlib import Path
import numpy
from PIL import Image
from collections import defaultdict

day = sys.argv[1]

total_fn = 0
total_tp = 0
total_tn = 0
total_fp = 0
total_moving_inter = 0
total_moving_union = 0
total_static_inter = 0
total_static_union = 0
for num in [int(n) for n in sys.argv[2:]]:
	# sometimes, there exists no point cloud for an input png
	# we skip those in kitti2scan so we have to re-obtain the mapping here
	modmaskdir = Path(".") / "Extended_MOD_Masks" / day / ("%s_drive_%04d_sync"%(day, num)) / "image_02"
	modmaskmap = dict()
	currscan = 0
	for i, mask in enumerate(sorted(modmaskdir.glob("*.png"))):
		assert mask.stem == ("%010d"%i)
		velodir = Path(".") / day / ("%s_drive_%04d_sync"%(day, num)) / "velodyne_points" / "data"
		if not (velodir / (mask.stem + ".bin")).exists():
			continue
		modmaskmap[currscan] = i
		currscan += 1

	path = Path(".") / day / ("%s_drive_%04d_sync"%(day, num)) / "changedetection"

	fn = 0
	tp = 0
	tn = 0
	fp = 0
	moving_inter = 0
	moving_union = 0
	static_inter = 0
	static_union = 0
	for j,scan in enumerate(sorted(path.glob("scan*.3d"))):
		assert scan.stem == ("scan%03d"%j)
		mask = scan.parent / "pplremover" / "masks" / (scan.stem + ".mask")
		imgmap = scan.parent / (scan.stem + ".mask")
		img = numpy.array(Image.open("Extended_MOD_Masks/%s/%s_drive_%04d_sync/image_02/%010d.png" % (day, day, num, modmaskmap[j])).convert("1"), dtype=numpy.bool)
		#imgres = numpy.zeros_like(img)
		with scan.open() as s, mask.open() as m, imgmap.open() as im:
			for line1, line2, line3 in zip(s, m, im):
				_, _, _, r = line1.strip().split()
				groundtruth = int(r)
				result = int(line2.strip())
				x, y, gt2 = line3.strip().split()
				assert int(gt2) == groundtruth
				assert img[int(y), int(x)] == groundtruth
				#if result == 1:
				#	imgres[int(y), int(x)] = 1
				if groundtruth == 1 and result == 1:
					tp += 1
					moving_inter += 1
					moving_union += 1
				elif groundtruth == 0 and result == 0:
					tn += 1
					static_inter += 1
					static_union += 1
				elif groundtruth == 1 and result == 0:
					fn += 1
					moving_union += 1
					static_union += 1
				elif groundtruth == 0 and result == 1:
					fp += 1
					moving_union += 1
					static_union += 1
				else:
					raise Exception("logic error")

		#moving_inter += (img*imgres).sum()
		#moving_union += (img+imgres).sum()
		#static_inter += (numpy.invert(img)*numpy.invert(imgres)).sum()
		#static_union += (numpy.invert(img)+numpy.invert(imgres)).sum()

	if tp+fp > 0:
		p = tp/(tp+fp)
	else:
		p = 0
	if tp+fn > 0:
		r = tp/(tp+fn)
	else:
		r = 0
	if p+r > 0:
		f1 = (2*p*r)/(p+r)
	else:
		f1 = 0
	print("%d %.04f %.04f %.04f"%(num, f1, moving_inter/moving_union if moving_union != 0 else 1, static_inter/static_union if static_union != 0 else 1))
	total_fn += fn
	total_tp += tp
	total_tn += tn
	total_fp += fp
	total_moving_inter += moving_inter
	total_moving_union += moving_union
	total_static_inter += static_inter
	total_static_union += static_union
if len(sys.argv[2:]) > 1:
	if total_tp+total_fp > 0:
		total_p = total_tp/(total_tp+total_fp)
	else:
		total_p = 0
	if total_tp+total_fn > 0:
		total_r = total_tp/(total_tp+total_fn)
	else:
		total_r = 0
	if total_p+total_r > 0:
		total_f1 = (2*total_p*total_r)/(total_p+total_r)
	else:
		total_f1 = 0
	print("total %.04f %.04f %.04f"%(total_f1, total_moving_inter/total_moving_union, total_static_inter/total_static_union))
