




/* todos for univariates:
   - Factorize
*/


RuleBase("NormalForm",{expression});
Rule("NormalForm",1,1000,True) expression;



Rule("NormalForm",1,0,Type(expression) = "UniVariate")
    ExpandUniVariate(expression[[1]],expression[[2]],expression[[3]]);


Function("ExpandUniVariate",{var,first,coefs})
[
  Local(result,i);
  result:=0;
  For(i:=1,i<=Length(coefs),i++)
    result:=result+NormalForm(coefs[[i]])*var^(first+i-1);
  result;
];


Function("IsUniVar",{expr}) Type(expr) = "UniVariate";

RuleBase("UniVariate",{var,first,coefs});
Rule("UniVariate",3,1000,IsZeroVector(coefs)) 0;
Rule("UniVariate",3,1000,IsComplex(var) Or IsVector(var))
    ExpandUniVariate(var,first,coefs);



RuleBase("Degree",{expr});
Rule("Degree",1,0, IsUniVar(expr))
[
  Local(i,min,max);
  min:=expr[[2]];
  max:=min+Length(expr[[3]]);
  i:=max;
  While(i >= min And IsZero(Coef(expr,i))) i--;
  i;
];
Rule("Degree",1,1, True)
    Degree(MakeUni(expr));


Rule("+",2,500,IsUniVar(aLeft) And
        IsUniVar(aRight) And aLeft[[1]] = aRight[[1]])
[
  Local(from,result);
  Local(curl,curr,left,right);

  curl:=aLeft[[2]];
  curr:=aRight[[2]];
  left:=aLeft[[3]];
  right:=aRight[[3]];
  result:={};
  from:=Min(curl,curr);

  While(curl<curr And left != {})
  [
    DestructiveAppend(result,Head(left));
    left:=Tail(left);
    curl++;
  ];
  While(curl<curr)
  [
    DestructiveAppend(result,0);
    curl++;
  ];
  While(curr<curl And right != {})
  [
    DestructiveAppend(result,Head(right));
    right:=Tail(right);
    curr++;
  ];
  While(curr<curl)
  [
    DestructiveAppend(result,0);
    curr++;
  ];
  While(left != {} And right != {})
  [
    DestructiveAppend(result,Head(left)+Head(right));
    left  := Tail(left);
    right := Tail(right);
  ];


  While(left != {})
  [
    DestructiveAppend(result,Head(left));
    left  := Tail(left);
  ];
  While(right != {})
  [
    DestructiveAppend(result,Head(right));
    right := Tail(right);
  ];

  UniVariate(aLeft[[1]],from,result);
];

Rule("+",2,200,IsNumber(aRight) And IsUniVar(aLeft))
  aRight+aLeft;
Rule("+",2,200,IsNumber(aLeft) And IsUniVar(aRight))
  UniVariate(aRight[[1]],0,{aLeft})+aRight;

Rule("-",1,200,IsUniVar(aLeft))
     Apply("UniVariate",{aLeft[[1]],aLeft[[2]],-(aLeft[[3]])});

Rule("-",2,200,IsUniVar(aLeft) And IsUniVar(aRight))
[
  Local(from,result);
  Local(curl,curr,left,right);

  curl:=aLeft[[2]];
  curr:=aRight[[2]];
  left:=aLeft[[3]];
  right:=aRight[[3]];
  result:={};
  from:=Min(curl,curr);

  While(curl<curr And left != {})
  [
    DestructiveAppend(result,Head(left));
    left:=Tail(left);
    curl++;
  ];
  While(curl<curr)
  [
    DestructiveAppend(result,0);
    curl++;
  ];
  While(curr<curl And right != {})
  [
    DestructiveAppend(result,-Head(right));
    right:=Tail(right);
    curr++;
  ];
  While(curr<curl)
  [
    DestructiveAppend(result,0);
    curr++;
  ];
  While(left != {} And right != {})
  [
    DestructiveAppend(result,Head(left)-Head(right));
    left  := Tail(left);
    right := Tail(right);
  ];


  While(left != {})
  [
    DestructiveAppend(result,Head(left));
    left  := Tail(left);
  ];
  While(right != {})
  [
    DestructiveAppend(result,-Head(right));
    right := Tail(right);
  ];

  UniVariate(aLeft[[1]],from,result);
];



Rule("^",2,200,IsInteger(aRight) And aRight>0 And
     IsUniVar(aLeft))
     aLeft*(aLeft^(aRight-1));



/*TODO this can be made twice as fast!*/
Rule("*",2,200,IsUniVar(aLeft) And
	Not(Contains(VarList(aRight),aLeft[[1]]))
        )
    aRight*aLeft;

Rule("*",2,200,IsUniVar(aRight) And
	Not(Contains(VarList(aLeft),aRight[[1]]))
        )
[
  Local(i,from,to,result);

  result:={};
  from:=aRight[[2]];
  to:=aRight[[2]]+Length(aRight[[3]]);
  For(i:=from,i<=to,i++)
  [
    DestructiveAppend(result,aLeft*Coef(aRight,i));
  ];
  UniVariate(aRight[[1]],from,result);
];

MaxUniOrder:=3000;
Function("SetOrder",{order}) MaxUniOrder:=order;


Function("ShiftUniVar",{uni,fact,shift})
[
 Apply("UniVariate",{uni[[1]],uni[[2]]+shift,fact*(uni[[3]])});
];


Rule("*",2,200,IsUniVar(aLeft) And
     IsUniVar(aRight)
      And aLeft[[1]] = aRight[[1]]
    )
[
  Local(i,j,n,shifted,result);
  result:=MakeUni(0,aLeft[[1]]);

  n:=Length(aLeft[[3]]);
  For(i:=1,i<=n,i++)
  [
    result:=result+ShiftUniVar(aRight,aLeft[[3]][[i]],aLeft[[2]]+i-1);
  ];
  result;
];


RuleBase("Coef",{uv,order});

Rule("Coef",2,0,IsInteger(order) And IsUniVar(uv) And
        order<uv[[2]]) 0;

Rule("Coef",2,0,IsInteger(order) And IsUniVar(uv) And
        order>=uv[[2]]+Length(uv[[3]])) 0;
Rule("Coef",2,1,IsInteger(order) And IsUniVar(uv))
        uv[[3]][[(order-uv[[2]])+1]];

Function("Coef",{expression,var,order})
    NormalForm(Coef(MakeUni(expression,Concat({var},VarList(expression))),order));

Function("UniTaylor",{taylorfunction,taylorvariable,taylorat,taylororder})
[
  Local(n,result,dif,polf);
  result:={};
  [
    MacroLocal(taylorvariable);
    MacroSet(taylorvariable,taylorat);
    DestructiveAppend(result,Eval(taylorfunction));
  ];
  dif:=taylorfunction;
  polf:=(taylorvariable-taylorat);
  For(n:=1,n<=taylororder,n++)
  [
    dif:= Deriv(taylorvariable) dif;
    MacroLocal(taylorvariable);
    MacroSet(taylorvariable,taylorat);
    DestructiveAppend(result,(Eval(dif)/n!));
  ];
  UniVariate(taylorvariable,0,result);
];


Function("MakeUni",{expression}) MakeUni(expression,VarList(expression));

/* Convert normal form to univariate expression */
RuleBase("MakeUni",{expression,var});

Rule("MakeUni",2,1,IsList(var))
[
  Local(result,item);
  result:=expression;
  ForEach(item,var)
  [
    result:=MakeUni(result,item);
  ];
  result;
];

Rule("MakeUni",2,10,Type(expression) = "UniVariate")
[
  Local(reslist,item);
  reslist:={};
  ForEach(item,expression[[3]])
  [
    If(IsFreeOf(item,var),
      DestructiveAppend(reslist,item),
      DestructiveAppend(reslist,MakeUni(item,var))
      );
  ];
  Apply("UniVariate",{expression[[1]],expression[[2]],reslist});
];



Rule("MakeUni",2,10,IsFreeOf(expression,var)) UniVariate(var,0,{expression});

Rule("MakeUni",2,10,expression=var)       UniVariate(var,1,{1});

Rule("MakeUni",2,10,Type(expression) = "+")
  MakeUni(expression[[1]],var)+MakeUni(expression[[2]],var);
Rule("MakeUni",2,10,Type(expression) = "*")
  MakeUni(expression[[1]],var)*MakeUni(expression[[2]],var);
Rule("MakeUni",2,10,Type(expression) = "^" And IsInteger(expression[[2]]))
  MakeUni(expression[[1]],var)^expression[[2]];
Rule("MakeUni",2,10,Type(expression) = "-" And NrArgs(expression) = 1)
     -(MakeUni(expression[[1]],var));

Rule("MakeUni",2,10,Type(expression) = "/" And
     Not(Contains(VarList(expression[[2]]),var)))
  MakeUni(expression[[1]],var)*(1/expression[[2]]);

Rule("MakeUni",2,10,Type(expression) = "-" And NrArgs(expression) = 2)
     MakeUni(expression[[1]],var)-MakeUni(expression[[2]],var);

Rule("Div",2,0,IsUniVar(n) And IsUniVar(m) And
                Degree(n) < Degree(m)) 0;

Rule("Mod",2,0,IsUniVar(n) And IsUniVar(m)And
                Degree(n) < Degree(m)) m;

Rule("Div",2,0,IsUniVar(n) And IsUniVar(m) And
     n[[1]] = m[[1]] And
     Degree(n) >= Degree(m))
    UniVariate(n[[1]],0,
               UniDivide(Concat(ZeroVector(n[[2]]),n[[3]]),
                         Concat(ZeroVector(m[[2]]),m[[3]]))[[1]]);
Rule("Mod",2,0,IsUniVar(n) And IsUniVar(m)And
     n[[1]] = m[[1]] And
     Degree(n) >= Degree(m))
    UniVariate(n[[1]],0,
               UniDivide(Concat(ZeroVector(n[[2]]),n[[3]]),
                         Concat(ZeroVector(m[[2]]),m[[3]]))[[2]]);



/* division algo: (for zero-base univariates:) */
Function("UniDivide",{u,v})
[
  Local(m,n,q,r,k,j);
  m := Length(u)-1;
  n := Length(v)-1;
  While (m>0 And IsZero(u[[m+1]])) m--;
  While (n>0 And IsZero(v[[n+1]])) n--;
  q := ZeroVector(m-n+1);
  r := FlatCopy(u);  /*  (m should be >= n) */
  For(k:=m-n,k>=0,k--)
  [
    q[[k+1]] := r[[n+k+1]]/v[[n+1]];
    For (j:=n+k-1,j>=k,j--)
    [
      r[[j+1]] := r[[j+1]] - q[[k+1]]*v[[j-k+1]];
    ];
  ];
  While (Length(r)>n) DestructiveDelete(r,Length(r));
  {q,r};
];


Function("UniGcd",{u,v})
[
 Local(l,div,mod,m);

  l:=UniDivide(u,v);
  mod:=l[[2]];

  m := Length(mod);
  While (m>1 And IsZero(mod[[m]])) m--;

  If(m <= 1,
     v,
     UniGcd(v,mod));
];



Rule("Gcd",2,10,IsUniVar(n) And IsUniVar(m)
     And
     n[[1]] = m[[1]]
     And
     Degree(n) < Degree(m)
    )
     Gcd(m,n);


Rule("Gcd",2,11,IsUniVar(n) And IsUniVar(m)And
     n[[1]] = m[[1]] And
     Degree(n) >= Degree(m))
    [

     UniVariate(n[[1]],0,
                UniGcd(Concat(ZeroVector(n[[2]]),n[[3]]),
                       Concat(ZeroVector(m[[2]]),m[[3]])));
     ];


RuleBase("PSolve",{uni});

Rule("PSolve",1,1,IsUniVar(uni) And Degree(uni) = 1)
    -Coef(uni,0)/Coef(uni,1);

Rule("PSolve",1,1,IsUniVar(uni) And Degree(uni) = 2)
    [
     Local(a,b,c,d);
     c:=Coef(uni,0);
     b:=Coef(uni,1);
     a:=Coef(uni,2);
     d:=b*b-4*a*c;
     {(-b+Sqrt(d))/(2*a),(-b-Sqrt(d))/(2*a)};
    ];

Function("PSolve",{uni,var})
    [
     PSolve(MakeUni(uni,var));
     ];



/* CanBeUni returns whether the function can be converted to a
 * univariate, with respect to a variable.
 */
Function("CanBeUni",{expression}) CanBeUni(expression,VarList(expression));

/* Convert normal form to univariate expression */
RuleBase("CanBeUni",{expression,var});

Rule("CanBeUni",2,1,IsList(var))
[
  Local(result,item);
  result:=True;
  ForEach(item,var)
  [
    result:=result And CanBeUni(result,item);
  ];
  result;
];

/* Accepting an expression as being convertable to univariate */
Rule("CanBeUni",2,2,IsAtom(expression)) True;
Rule("CanBeUni",2,2,Type(expression)="+" And NrArgs(expression)=2)
    CanBeUni(expression[[1]],var) And CanBeUni(expression[[2]],var);
Rule("CanBeUni",2,2,Type(expression)="-" And NrArgs(expression)=2)
    CanBeUni(expression[[1]],var) And CanBeUni(expression[[2]],var);
Rule("CanBeUni",2,2,Type(expression)="*")
    CanBeUni(expression[[1]],var) And CanBeUni(expression[[2]],var);
Rule("CanBeUni",2,2,Type(expression)="/")
    CanBeUni(expression[[1]],var) And
    Not(Contains(VarList(expression[[2]]),var));
Rule("CanBeUni",2,2,Type(expression)="^")
    (CanBeUni(expression[[1]],var) And IsInteger(expression[[2]])) Or
    (Not(Contains(VarList(expression[[1]]),var)) And
     Not(Contains(VarList(expression[[2]]),var)));
Rule("CanBeUni",2,2,Type(expression)="+" And NrArgs(expression)=1)
    CanBeUni(expression[[1]],var);
Rule("CanBeUni",2,2,Type(expression)="-" And NrArgs(expression)=1)
    CanBeUni(expression[[1]],var);

/* Otherwise, only possible if the arguments don't depend on the variable. */
Rule("CanBeUni",2,3,True) Not(Contains(VarList(expression),var));

Function("InverseTaylor",{var,degree,func})
[
  NormalForm(AntiDeriv(UniTaylor(1/(Apply("D",{var,func})),var,0,degree),var));
];




