#include <math.h>
#include <stdlib.h>
#include "cristoffel.h"
#include "celeste_types.h"

#define SQ(X) ((X)*(X))
#define G 6.67384E-11
#define c 299792458.0

void compute_cristoffel_symbols(double **pos, cel_system* sys) {
	int dimension = sys->dimension;
	int nb_planets = sys->nb_planets;

	double *r;
	r = calloc(nb_planets,sizeof(double));
	int i; 		//ith planet
	int j,k,l;	//j,k,lth coordinate
	for(i=0; i<nb_planets; i++) {
		if(i!=sys->index_fat_body) {
			for(j=0; j<dimension; j++)
				r[i] += SQ(pos[i][j]-pos[sys->index_fat_body][j]);
			r[i] = sqrt(r[i]);		//distance sun<->ith-planet
		}
	}

	double rc = 2*G*sys->mass_fat_body/SQ(c); //schwartzchild radius
	for(i=0; i<nb_planets; i++) {
		if(i!=sys->index_fat_body) {
			sys->cs[i][dimension][dimension][dimension] = 0;
			for(j=0; j<dimension; j++) {
				sys->cs[i][dimension][dimension][j] = (pos[i][j]-pos[sys->index_fat_body][j])/(2*SQ(r[i])*(r[i]/rc-1));
				sys->cs[i][dimension][j][dimension] = (pos[i][j]-pos[sys->index_fat_body][j])/(2*SQ(r[i])*(r[i]/rc-1));
				sys->cs[i][j][dimension][dimension] = (pos[i][j]-pos[sys->index_fat_body][j])*rc*(1-rc/r[i])/(2*pow(r[i],3));
				for(k=0; k<dimension; k++) {
					sys->cs[i][dimension][j][k] = 0;
					sys->cs[i][j][dimension][k] = 0;
					sys->cs[i][j][k][dimension] = 0;
					for(l=0; l<dimension; l++) {
						sys->cs[i][j][k][l] = (pos[i][k]-pos[sys->index_fat_body][k])*(pos[i][l]-pos[sys->index_fat_body][l])*(3/2*r[i]/rc-1)/(SQ(r[i])*(1-r[i]/rc));
						if(k==l)
							sys->cs[i][j][k][l] += 1;
						sys->cs[i][j][k][l] *= (pos[i][j]-pos[sys->index_fat_body][j])*rc/pow(r[i],3);
					}
				}
			}
		}
	}
	free(r);
}

void compute_K_gr_correction(double **K_vel, double **vel, cel_system* sys) {
	int dimension = sys->dimension;
	int nb_planets = sys->nb_planets;
	double time_step = sys->time_step;	
	int i;		//ith planet
	int j,k,l;	//j, k, lth coordinate 
	for(i=0; i<nb_planets; i++) {
		if(i!=sys->index_fat_body) {
			for(j=0; j<dimension; j++) {
				for(k=0; k<dimension; k++) {
					for(l=0; l<dimension; l++)
						K_vel[i][j] -= time_step * sys->cs[i][j][k][l] * vel[i][k] * vel[i][l];
					K_vel[i][j] += 2.0 * time_step * sys->cs[i][dimension][dimension][k] * vel[i][j] * vel[i][k];
				}
				K_vel[i][j] -= time_step * sys->cs[i][j][dimension][dimension] * SQ(c);
			}
		}
	}
}
