From 2589646e51ee3b02fa9b4b0b368407bbd7bb0e06 Mon Sep 17 00:00:00 2001
From: mhg <gfrerer@tugraz.at>
Date: Thu, 7 Nov 2024 14:57:17 +0100
Subject: [PATCH] hyper dual numbers added

---
 .../hyper_dual_numbers/@hyper_dual/horzcat.m  |   4 +
 .../@hyper_dual/hyper_dual.m                  | 439 ++++++++++++++++++
 .../hyper_dual_numbers/@hyper_dual/subsref.m  |  27 ++
 .../hyper_dual_numbers/@hyper_dual/vertcat.m  |  16 +
 .../hyper_dual_numbers/test_fun_hyper_dual.m  |  32 ++
 matlab/hyper_dual_numbers/test_hyper_dual.m   |  49 ++
 6 files changed, 567 insertions(+)
 create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/horzcat.m
 create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m
 create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/subsref.m
 create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/vertcat.m
 create mode 100644 matlab/hyper_dual_numbers/test_fun_hyper_dual.m
 create mode 100644 matlab/hyper_dual_numbers/test_hyper_dual.m

diff --git a/matlab/hyper_dual_numbers/@hyper_dual/horzcat.m b/matlab/hyper_dual_numbers/@hyper_dual/horzcat.m
new file mode 100644
index 0000000..5edc8ee
--- /dev/null
+++ b/matlab/hyper_dual_numbers/@hyper_dual/horzcat.m
@@ -0,0 +1,4 @@
+function ret = horzcat(varargin)
+obj = cat( 1, varargin{:} );
+ret = hyper_dual([obj.a],[obj.b],[obj.c],[obj.d]);
+end
diff --git a/matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m b/matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m
new file mode 100644
index 0000000..496fed4
--- /dev/null
+++ b/matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m
@@ -0,0 +1,439 @@
+classdef hyper_dual
+    %HYPER_DUAL
+    
+    properties
+        a % the number
+        b
+        c
+        d
+    end
+    
+    methods
+        function obj = hyper_dual(A,B,C,D)
+            if nargin == 1
+                obj.a = zeros(size(A));
+                obj.b = zeros(size(A));
+                obj.c = zeros(size(A));
+                obj.d = zeros(size(A));
+                obj.a = A;
+            elseif nargin == 4
+                if isequal(size(A),size(B)) && isequal(size(B),size(C)) && isequal(size(C),size(D))
+                obj.a = A;
+                obj.b = B;
+                obj.c = C;
+                obj.d = D;
+                else
+                    error('wrong input')
+                end
+            else
+                error('wrong number of input parameter in hyper_dual')
+            end
+        end
+        
+        %% double
+        function ret = double(obj)
+            ret = obj;
+            ret.a = double(obj.a);
+            ret.b = double(obj.b);
+            ret.c = double(obj.c);
+            ret.d = double(obj.d);
+        end
+        %% real
+        function ret = real(obj)
+            ret = obj.a;
+        end
+        %% imag
+        function ret = imag(obj)
+            ret = abs(imag(obj.b))+abs(imag(obj.c))+abs(imag(obj.d));
+        end
+        %% reshape
+        function ret = reshape(obj,varargin)
+            ret = hyper_dual(...
+                reshape(obj.a,varargin{:}),...
+                reshape(obj.b,varargin{:}),...
+                reshape(obj.c,varargin{:}),...
+                reshape(obj.d,varargin{:}));
+        end
+        function ret = transpose(obj)
+            ret = hyper_dual(obj.a.',obj.b.',obj.c.',obj.d.');
+        end
+            
+        %% addition
+        function ret = plus(obj,A)
+            if isa(obj , 'hyper_dual')
+                ret = obj;
+                if isa(A , 'hyper_dual')
+                    ret.a = bsxfun(@plus,ret.a , A.a); % if A is hyper-dual
+                    ret.b = bsxfun(@plus,ret.b , A.b);
+                    ret.c = bsxfun(@plus,ret.c , A.c);
+                    ret.d = bsxfun(@plus,ret.d , A.d);
+                else
+                    ret.a = bsxfun(@plus,ret.a , A); 
+                end
+            else
+                ret = A + obj;
+            end
+        end
+        %% subtraction
+        function ret = minus(obj,A)
+            if isa(obj , 'hyper_dual')
+                ret = obj;
+                if isa(A , 'hyper_dual')
+                    ret.a = bsxfun(@minus,ret.a , A.a); % if A is hyper-dual
+                    ret.b = bsxfun(@minus,ret.b , A.b);
+                    ret.c = bsxfun(@minus,ret.c , A.c);
+                    ret.d = bsxfun(@minus,ret.d , A.d);
+                else
+                    ret.a = bsxfun(@minus,ret.a , A);
+                 end
+             else
+                 ret = - A + obj ;
+             end
+         end
+         %% minus
+         function ret = uminus(obj)
+             ret = obj;
+             f  = -obj.a;
+             df = -1+0*obj.a;
+             ddf= 0*obj.a;
+             ret.a = f; % if A is hyper-dual
+             ret.b = obj.b .* df;
+             ret.c = obj.c .* df;
+             ret.d = obj.d .* df + obj.b .* obj.c .* ddf;
+         end
+         %% Multiplication
+         function ret = times(obj,A)
+             if isa(obj , 'hyper_dual')
+                 ret = obj;
+                 if isa(A , 'hyper_dual')
+                     ret.a = bsxfun(@times,obj.a , A.a); % if A is hyper-dual
+                     ret.b = bsxfun(@times,obj.a , A.b) + bsxfun(@times,obj.b , A.a);
+                     ret.c = bsxfun(@times,obj.a , A.c) + bsxfun(@times,obj.c , A.a);
+                     ret.d = bsxfun(@times,obj.a , A.d) + bsxfun(@times,obj.b , A.c) + bsxfun(@times,obj.c , A.b) + bsxfun(@times,obj.d , A.a);
+                 else
+                     ret = obj;
+                     ret.a = bsxfun(@times,obj.a , A);  % if A is a real number
+                     ret.b = bsxfun(@times,obj.b , A);
+                     ret.c = bsxfun(@times,obj.c , A);
+                     ret.d = bsxfun(@times,obj.d , A);
+                 end
+             else
+                 ret = A .* obj;
+             end
+         end
+         %% Matrix Multiplication
+         function ret = reini(obj)
+             for i = 1:size(obj,1)
+                 for j = 1:size(obj,2)
+                     s1(i,j) = size(obj(i,j).a,1);
+                     s2(i,j) = size(obj(i,j).a,2);
+                 end
+             end
+             i1 = 0;
+             for i = 1:size(obj,1)
+                 i2 = 0;
+                 for j = 1:size(obj,2)
+                     for k = 1:s1(i,j)
+                        for l = 1:s2(i,j) 
+                            a(i1+k,i2+l) = obj(i,j).a(k,l);
+                            b(i1+k,i2+l) = obj(i,j).b(k,l);
+                            c(i1+k,i2+l) = obj(i,j).c(k,l);
+                            d(i1+k,i2+l) = obj(i,j).d(k,l);
+                        end
+                     end
+                     i2 = i2 + s2(i,j);
+                 end
+                 i1 = i1 + s1(i,j);
+             end
+             ret = hyper_dual(a,b,c,d);
+         end
+         %%
+         function ret = mtimes(obj,A)
+             if isa(obj , 'hyper_dual')
+                 if numel(obj) > 1
+                     obj = reini(obj);
+                 end
+             else
+                 obj = hyper_dual(obj);
+             end
+             sO1 = size(obj.a,1);
+             sO2 = size(obj.a,2);   
+             if isa(A , 'hyper_dual')
+                 if numel(A) > 1
+                     A = reini(A);
+                 end
+             else
+                 A = hyper_dual(A);     
+             end
+             sA1 = size(A.a,1);
+             sA2 = size(A.a,2);
+             if sO2 == sA1
+                 aa = zeros(sO1,sA2);
+                 bb = zeros(sO1,sA2);
+                 cc = zeros(sO1,sA2);
+                 dd = zeros(sO1,sA2);
+                 for i = 1:sO1
+                     for j = 1:sA2
+                         for k = 1:sO2
+                             r = my_subs(obj,i,k) .* my_subs(A,k,j);
+                             aa(i,j) = aa(i,j) + r.a;
+                             bb(i,j) = bb(i,j) + r.b;
+                             cc(i,j) = cc(i,j) + r.c;
+                             dd(i,j) = dd(i,j) + r.d;
+                         end
+                     end
+                 end
+                 ret = hyper_dual(aa,bb,cc,dd);
+             else
+                 error('dimension mismatch')
+             end
+         end
+         %% DIVIDE
+         % rdivide ./
+         % obj is nominator
+         % A is denominator
+         function ret = rdivide(obj,A)
+             if isa(obj , 'hyper_dual')
+                 ret = obj;
+                 if isa(A , 'hyper_dual')
+                     df = bsxfun(@rdivide,-1, A.a.^2);
+                     ddf= bsxfun(@rdivide,2, A.a.^3);
+                     ret.a = bsxfun(@rdivide,1, A.a); % if A is hyper-dual
+                     ret.b = A.b .* df;
+                     ret.c = A.c .* df;
+                     ret.d = A.d .* df + A.b .* A.c .* ddf;
+                     ret = obj .* ret;
+                 else
+                     ret = obj .* (1 ./ A);
+                 end
+             else
+                 ret = hyper_dual(obj) ./ A;
+             end
+         end
+                 
+         %% Equality
+         function ret = eq(obj,a)
+             if isa(obj , 'hyper_dual')
+                 if isa(a , 'hyper_dual')
+                     % if a is hyper-dual
+                     ret = (obj.a == a.a) ;
+                 else
+                     ret = ( obj.a == a ); % if a is a real number
+                 end
+             else
+                 ret = ( obj == a.a );
+             end
+         end       
+        %% sin
+        function ret = sin(obj)
+            ret = obj;
+            ret.a = sin( obj.a );
+            ret.b = obj.b .* cos(obj.a);
+            ret.c = obj.c .* cos(obj.a);
+            ret.d = obj.d .* cos(obj.a) - obj.b .* obj.c .* sin( obj.a );
+        end
+        %% cos
+        function ret = cos(obj)
+            ret = obj;
+            ret.a = cos( obj.a );
+            ret.b = -obj.b .* sin(obj.a);
+            ret.c = -obj.c .* sin(obj.a);
+            ret.d = -obj.d .* sin(obj.a) - obj.b .* obj.c .* cos( obj.a );
+        end
+        %% exp
+        function ret = exp(obj)
+            ret = obj;
+            ret.a = exp( obj.a );
+            ret.b = obj.b .* exp(obj.a);
+            ret.c = obj.c .* exp(obj.a);
+            ret.d = obj.d .* exp(obj.a) + obj.b .* obj.c .* exp(obj.a);
+        end
+         % power
+         function ret = power(obj,B)
+            if isa(obj , 'hyper_dual')
+                if isa(B , 'hyper_dual')
+                    % if B is hyper-dual
+                    error('not implemented')
+                else
+                    % if B is a real number
+                    ret = obj;
+                    help = B .* obj.a .^ (B-1);
+                    ret.a = obj.a .^ B;
+                    ret.b = obj.b .* help;
+                    ret.c = obj.c .* help;
+                    ret.d = obj.d .* help + obj.b .* obj.c .* B .* (B-1) .* obj.a .^ (B-2);
+                end
+            else
+                ret = B; % B is hyper-dual
+                ret.a = obj .^ B.a;
+                ret.b = B.b .* obj .^ B.a .* log(obj) ;
+                ret.c = B.c .* obj .^ B.a .* log(obj) ;
+                ret.d = B.d .* obj .^ B.a .* log(obj) + B.b .* B.c .* obj .^ B.a .* log(obj)^2;
+            end
+            
+        end
+        %% sqrt
+        function ret = sqrt(obj)
+            ret = obj;
+            Df = 1 ./ ( 2 * sqrt(obj.a) );
+            ret.a = sqrt(obj.a);
+            ret.b = obj.b .* Df;
+            ret.c = obj.c .* Df;
+            ret.d = obj.d .* Df - obj.b .* obj.c .* Df ./ ( 2 * obj.a );
+        end
+        %% sinh
+        function ret = sinh(obj)
+            ret = obj;
+            Df = cosh(obj.a) ;
+            ret.a = sinh(obj.a);
+            ret.b = obj.b .* Df;
+            ret.c = obj.c .* Df;
+            ret.d = obj.d .* Df + obj.b .* obj.c .* sinh(obj.a);
+        end
+        %% cosh
+        function ret = cosh(obj)
+            ret = obj;
+            Df = sinh(obj.a) ;
+            ret.a = cosh(obj.a);
+            ret.b = obj.b .* Df;
+            ret.c = obj.c .* Df;
+            ret.d = obj.d .* Df + obj.b .* obj.c .* cosh(obj.a);
+        end
+        %% abs
+        function ret = abs(obj)
+            ret = obj;
+            ret.a = abs(obj.a);
+        end
+        %% <
+        function ret = lt(obj,B)
+            if isa(obj , 'hyper_dual')
+                if isa(B , 'hyper_dual')
+                    % obj and B is hyper-dual
+                    ret = obj.a < B.a;
+                else
+                    % if B is a real number
+                    ret = obj.a < B;
+                end
+            else
+                % B is hyper-dual
+                ret = obj < B.a;
+            end
+        end
+        %% <=
+        function ret = le(obj,B)
+            if isa(obj , 'hyper_dual')
+                if isa(B , 'hyper_dual')
+                    % obj and B is hyper-dual
+                    ret = obj.a <= B.a;
+                else
+                    % if B is a real number
+                    ret = obj.a <= B;
+                end
+            else
+                % B is hyper-dual
+                ret = obj <= B.a;
+            end
+        end
+        %% >
+        function ret = gt(obj,B)
+            if isa(obj , 'hyper_dual')
+                if isa(B , 'hyper_dual')
+                    % obj and B is hyper-dual
+                    ret = obj.a > B.a;
+                else
+                    % if B is a real number
+                    ret = obj.a > B;
+                end
+            else
+                % B is hyper-dual
+                ret = obj > B.a;
+            end
+        end
+        %% >=
+        function ret = ge(obj,B)
+            if isa(obj , 'hyper_dual')
+                if isa(B , 'hyper_dual')
+                    % obj and B is hyper-dual
+                    ret = obj.a >= B.a;
+                else
+                    % if B is a real number
+                    ret = obj.a >= B;
+                end
+            else
+                % B is hyper-dual
+                ret = obj >= B.a;
+            end
+        end
+        %% mod
+        function ret = mod(obj,B)
+            if isa(obj , 'hyper_dual')
+                if isa(B , 'hyper_dual')
+                    % obj and B is hyper-dual
+                    error('not implemented')
+                    ret = mod(obj.a , B.a);
+                else
+                    % if B is a real number
+                    ret = obj;
+                    ret.a = mod(obj.a , B);
+                end
+            else
+                % B is hyper-dual
+                error('not implemented')
+                ret = mod(obj , B.a);
+            end
+        end
+        %% sign
+        function ret = sign(obj)
+            ret.a = sign(obj.a);
+        end
+%         function ret = subsasgn(obj,s,a)
+%             if length(s) == 1
+%                 if s.type == '()'
+%                     error('not implemented')
+%                 else
+%                     ret = builtin('subsasgn', obj, s,a);
+%                 end
+%             elseif length(s) == 2  
+%                 if strcmp(s(1).type,'.') && strcmp(s(2).type,'()')
+%                     ret = builtin('subsasgn', obj, s,a);
+%                 else
+%                     error('not implemented')
+%                 end
+%             else
+%                 error('not implemented')
+%             end
+%         end
+        %function C = horzcat(obj,B)
+            
+            
+        %end
+        %%
+        %%
+        function n = numArgumentsFromSubscript(obj, s, ic)
+        n = builtin('numArgumentsFromSubscript', obj, s, ic);
+        end
+        
+    end
+    methods (Static)
+        %% zeros
+        function ret = zeros(numX,numY)
+            ret = hyper_dual( zeros(numX,numY) );
+        end
+    end
+    methods (Hidden)
+        %% zeros
+        function ret = zerosLike(obj,numX,numY)
+            if nargin == 2
+                ret = hyper_dual( zeros(numX) );
+            else
+                ret = hyper_dual( zeros(numX,numY) );
+            end
+        end
+        %% my_subs()
+        function ret = my_subs(obj,i,j)
+            ret = hyper_dual(obj.a(i,j),obj.b(i,j),obj.c(i,j),obj.d(i,j));
+        end
+    end
+end
+
+
diff --git a/matlab/hyper_dual_numbers/@hyper_dual/subsref.m b/matlab/hyper_dual_numbers/@hyper_dual/subsref.m
new file mode 100644
index 0000000..d02123b
--- /dev/null
+++ b/matlab/hyper_dual_numbers/@hyper_dual/subsref.m
@@ -0,0 +1,27 @@
+function ret = subsref(obj,s)
+ret = obj;
+if length(s) == 1
+    if strcmp(s.type,'()')
+        if length(s.subs) == 1
+            ret.a = obj.a(s.subs{1});
+            ret.b = obj.b(s.subs{1});
+            ret.c = obj.c(s.subs{1});
+            ret.d = obj.d(s.subs{1});
+        elseif length(s.subs) == 2
+            ret.a = obj.a(s.subs{1},s.subs{2});
+            ret.b = obj.b(s.subs{1},s.subs{2});
+            ret.c = obj.c(s.subs{1},s.subs{2});
+            ret.d = obj.d(s.subs{1},s.subs{2});
+        else
+            ret.a = builtin('subsref',obj.a,s);
+            ret.b = builtin('subsref',obj.b,s);
+            ret.c = builtin('subsref',obj.c,s);
+            ret.d = builtin('subsref',obj.d,s);
+        end
+    else
+        ret = builtin('subsref', obj, s);
+    end
+else
+    ret = builtin('subsref', obj, s);
+end
+end
\ No newline at end of file
diff --git a/matlab/hyper_dual_numbers/@hyper_dual/vertcat.m b/matlab/hyper_dual_numbers/@hyper_dual/vertcat.m
new file mode 100644
index 0000000..ac5d771
--- /dev/null
+++ b/matlab/hyper_dual_numbers/@hyper_dual/vertcat.m
@@ -0,0 +1,16 @@
+function ret = vertcat(varargin)
+obj = cat( 1, varargin{:} );
+ret = hyper_dual(cat(1,obj.a),cat(1,obj.b),cat(1,obj.c),cat(1,obj.d));
+% if isa(A , 'hyper_dual')
+%     if isa(B , 'hyper_dual')
+%         % obj and B is hyper-dual
+%         ret = hyper_dual([A.a;B.a],[A.b;B.b],[A.c;B.c],[A.d;B.d]);
+%     else
+%         % if B is a real number
+%         ret = hyper_dual([A.a;B],[A.b;0*B],[A.c;0*B],[A.d;0*B]);
+%     end
+% else
+%     % B is hyper-dual
+%     ret = hyper_dual([A;B.a],[0*A;B.b],[0*A;B.c],[0*A;B.d]);
+% end
+end
diff --git a/matlab/hyper_dual_numbers/test_fun_hyper_dual.m b/matlab/hyper_dual_numbers/test_fun_hyper_dual.m
new file mode 100644
index 0000000..1ab2b1d
--- /dev/null
+++ b/matlab/hyper_dual_numbers/test_fun_hyper_dual.m
@@ -0,0 +1,32 @@
+function test_fun_hyper_dual(f)
+%
+df = diff(f,'x');
+ddf= diff(df,'x');
+g = matlabFunction(f);
+%
+h1=0.000000001;
+x = rand(4);
+y = rand(4);
+A = hyper_dual(x,ones(4)*h1, ones(4)*h1, zeros(4) );
+ret = g(A,y);
+
+if double(norm(ret.b / h1 - df(x,y),'fro')) < 10^-13 && ...
+    double(norm(ret.d / h1^2 - ddf(x,y),'fro')) < 10^-10 && ...
+    double(norm(ret.a - f(x,y),'fro')) < 10^-13
+    disp('hyper_dual ok')
+elseif norm(ret.a - f(x,y),'fro') > 10^-13
+    disp('error in hyper_dual - evaluation') 
+    disp(double(norm(ret.a - f(x,y),'fro')))
+    
+elseif norm(ret.b / h1 - df(x,y),'fro') > 10^-13
+    disp('error in hyper_dual - first derivative')
+    disp(double(norm(ret.b / h1 - df(x,y),'fro')))
+    
+else
+    disp('error in hyper_dual - second derivative') 
+    disp(double(norm(ret.d / h1^2 - ddf(x,y),'fro')))
+    
+end
+
+
+end
diff --git a/matlab/hyper_dual_numbers/test_hyper_dual.m b/matlab/hyper_dual_numbers/test_hyper_dual.m
new file mode 100644
index 0000000..bc4186a
--- /dev/null
+++ b/matlab/hyper_dual_numbers/test_hyper_dual.m
@@ -0,0 +1,49 @@
+clear all
+close all
+whos 
+syms x y
+%
+disp('set 1')
+f1 = symfun( x.^2 + y.^2-1 , [x y]);
+test_fun_hyper_dual(f1)
+%
+f11 = symfun( 1 - x.^2 , [x y]);
+test_fun_hyper_dual(f11)
+%
+f12 = symfun( x + y , [x y]);
+test_fun_hyper_dual(f12)
+%
+f13 = symfun( -x .* y - y .* x , [x y]);
+test_fun_hyper_dual(f13)
+%
+f14 = symfun( x .* y + y .* x , [x y]);
+test_fun_hyper_dual(f14)
+%
+disp('set 2')
+f2 = symfun( x ./ y  , [x y]);
+test_fun_hyper_dual(f2)
+%
+f21 = symfun( -x ./ y , [x y]);
+test_fun_hyper_dual(f21)
+%
+f22 = symfun( -y ./ (x+1) , [x y]);
+test_fun_hyper_dual(f22)
+%
+f23 = symfun( x ./ y - y ./ x , [x y]);
+test_fun_hyper_dual(f23)
+%
+disp('set 3')
+f3 = symfun( sin(x) .* cos(x) .* exp(x) , [x y]);
+test_fun_hyper_dual(f3)
+%
+f31 = symfun( sinh(x) .*cosh(y.*x) .* sqrt(x.^y) , [x y]);
+test_fun_hyper_dual(f31)
+%
+disp('set 4')
+f41 = symfun( x.^3 , [x y]);
+test_fun_hyper_dual(f41)
+%
+f42 = symfun( y .* 3.^x , [x y]);
+test_fun_hyper_dual(f42)
+
+
-- 
GitLab