From 2589646e51ee3b02fa9b4b0b368407bbd7bb0e06 Mon Sep 17 00:00:00 2001 From: mhg <gfrerer@tugraz.at> Date: Thu, 7 Nov 2024 14:57:17 +0100 Subject: [PATCH] hyper dual numbers added --- .../hyper_dual_numbers/@hyper_dual/horzcat.m | 4 + .../@hyper_dual/hyper_dual.m | 439 ++++++++++++++++++ .../hyper_dual_numbers/@hyper_dual/subsref.m | 27 ++ .../hyper_dual_numbers/@hyper_dual/vertcat.m | 16 + .../hyper_dual_numbers/test_fun_hyper_dual.m | 32 ++ matlab/hyper_dual_numbers/test_hyper_dual.m | 49 ++ 6 files changed, 567 insertions(+) create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/horzcat.m create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/subsref.m create mode 100644 matlab/hyper_dual_numbers/@hyper_dual/vertcat.m create mode 100644 matlab/hyper_dual_numbers/test_fun_hyper_dual.m create mode 100644 matlab/hyper_dual_numbers/test_hyper_dual.m diff --git a/matlab/hyper_dual_numbers/@hyper_dual/horzcat.m b/matlab/hyper_dual_numbers/@hyper_dual/horzcat.m new file mode 100644 index 0000000..5edc8ee --- /dev/null +++ b/matlab/hyper_dual_numbers/@hyper_dual/horzcat.m @@ -0,0 +1,4 @@ +function ret = horzcat(varargin) +obj = cat( 1, varargin{:} ); +ret = hyper_dual([obj.a],[obj.b],[obj.c],[obj.d]); +end diff --git a/matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m b/matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m new file mode 100644 index 0000000..496fed4 --- /dev/null +++ b/matlab/hyper_dual_numbers/@hyper_dual/hyper_dual.m @@ -0,0 +1,439 @@ +classdef hyper_dual + %HYPER_DUAL + + properties + a % the number + b + c + d + end + + methods + function obj = hyper_dual(A,B,C,D) + if nargin == 1 + obj.a = zeros(size(A)); + obj.b = zeros(size(A)); + obj.c = zeros(size(A)); + obj.d = zeros(size(A)); + obj.a = A; + elseif nargin == 4 + if isequal(size(A),size(B)) && isequal(size(B),size(C)) && isequal(size(C),size(D)) + obj.a = A; + obj.b = B; + obj.c = C; + obj.d = D; + else + error('wrong input') + end + else + error('wrong number of input parameter in hyper_dual') + end + end + + %% double + function ret = double(obj) + ret = obj; + ret.a = double(obj.a); + ret.b = double(obj.b); + ret.c = double(obj.c); + ret.d = double(obj.d); + end + %% real + function ret = real(obj) + ret = obj.a; + end + %% imag + function ret = imag(obj) + ret = abs(imag(obj.b))+abs(imag(obj.c))+abs(imag(obj.d)); + end + %% reshape + function ret = reshape(obj,varargin) + ret = hyper_dual(... + reshape(obj.a,varargin{:}),... + reshape(obj.b,varargin{:}),... + reshape(obj.c,varargin{:}),... + reshape(obj.d,varargin{:})); + end + function ret = transpose(obj) + ret = hyper_dual(obj.a.',obj.b.',obj.c.',obj.d.'); + end + + %% addition + function ret = plus(obj,A) + if isa(obj , 'hyper_dual') + ret = obj; + if isa(A , 'hyper_dual') + ret.a = bsxfun(@plus,ret.a , A.a); % if A is hyper-dual + ret.b = bsxfun(@plus,ret.b , A.b); + ret.c = bsxfun(@plus,ret.c , A.c); + ret.d = bsxfun(@plus,ret.d , A.d); + else + ret.a = bsxfun(@plus,ret.a , A); + end + else + ret = A + obj; + end + end + %% subtraction + function ret = minus(obj,A) + if isa(obj , 'hyper_dual') + ret = obj; + if isa(A , 'hyper_dual') + ret.a = bsxfun(@minus,ret.a , A.a); % if A is hyper-dual + ret.b = bsxfun(@minus,ret.b , A.b); + ret.c = bsxfun(@minus,ret.c , A.c); + ret.d = bsxfun(@minus,ret.d , A.d); + else + ret.a = bsxfun(@minus,ret.a , A); + end + else + ret = - A + obj ; + end + end + %% minus + function ret = uminus(obj) + ret = obj; + f = -obj.a; + df = -1+0*obj.a; + ddf= 0*obj.a; + ret.a = f; % if A is hyper-dual + ret.b = obj.b .* df; + ret.c = obj.c .* df; + ret.d = obj.d .* df + obj.b .* obj.c .* ddf; + end + %% Multiplication + function ret = times(obj,A) + if isa(obj , 'hyper_dual') + ret = obj; + if isa(A , 'hyper_dual') + ret.a = bsxfun(@times,obj.a , A.a); % if A is hyper-dual + ret.b = bsxfun(@times,obj.a , A.b) + bsxfun(@times,obj.b , A.a); + ret.c = bsxfun(@times,obj.a , A.c) + bsxfun(@times,obj.c , A.a); + ret.d = bsxfun(@times,obj.a , A.d) + bsxfun(@times,obj.b , A.c) + bsxfun(@times,obj.c , A.b) + bsxfun(@times,obj.d , A.a); + else + ret = obj; + ret.a = bsxfun(@times,obj.a , A); % if A is a real number + ret.b = bsxfun(@times,obj.b , A); + ret.c = bsxfun(@times,obj.c , A); + ret.d = bsxfun(@times,obj.d , A); + end + else + ret = A .* obj; + end + end + %% Matrix Multiplication + function ret = reini(obj) + for i = 1:size(obj,1) + for j = 1:size(obj,2) + s1(i,j) = size(obj(i,j).a,1); + s2(i,j) = size(obj(i,j).a,2); + end + end + i1 = 0; + for i = 1:size(obj,1) + i2 = 0; + for j = 1:size(obj,2) + for k = 1:s1(i,j) + for l = 1:s2(i,j) + a(i1+k,i2+l) = obj(i,j).a(k,l); + b(i1+k,i2+l) = obj(i,j).b(k,l); + c(i1+k,i2+l) = obj(i,j).c(k,l); + d(i1+k,i2+l) = obj(i,j).d(k,l); + end + end + i2 = i2 + s2(i,j); + end + i1 = i1 + s1(i,j); + end + ret = hyper_dual(a,b,c,d); + end + %% + function ret = mtimes(obj,A) + if isa(obj , 'hyper_dual') + if numel(obj) > 1 + obj = reini(obj); + end + else + obj = hyper_dual(obj); + end + sO1 = size(obj.a,1); + sO2 = size(obj.a,2); + if isa(A , 'hyper_dual') + if numel(A) > 1 + A = reini(A); + end + else + A = hyper_dual(A); + end + sA1 = size(A.a,1); + sA2 = size(A.a,2); + if sO2 == sA1 + aa = zeros(sO1,sA2); + bb = zeros(sO1,sA2); + cc = zeros(sO1,sA2); + dd = zeros(sO1,sA2); + for i = 1:sO1 + for j = 1:sA2 + for k = 1:sO2 + r = my_subs(obj,i,k) .* my_subs(A,k,j); + aa(i,j) = aa(i,j) + r.a; + bb(i,j) = bb(i,j) + r.b; + cc(i,j) = cc(i,j) + r.c; + dd(i,j) = dd(i,j) + r.d; + end + end + end + ret = hyper_dual(aa,bb,cc,dd); + else + error('dimension mismatch') + end + end + %% DIVIDE + % rdivide ./ + % obj is nominator + % A is denominator + function ret = rdivide(obj,A) + if isa(obj , 'hyper_dual') + ret = obj; + if isa(A , 'hyper_dual') + df = bsxfun(@rdivide,-1, A.a.^2); + ddf= bsxfun(@rdivide,2, A.a.^3); + ret.a = bsxfun(@rdivide,1, A.a); % if A is hyper-dual + ret.b = A.b .* df; + ret.c = A.c .* df; + ret.d = A.d .* df + A.b .* A.c .* ddf; + ret = obj .* ret; + else + ret = obj .* (1 ./ A); + end + else + ret = hyper_dual(obj) ./ A; + end + end + + %% Equality + function ret = eq(obj,a) + if isa(obj , 'hyper_dual') + if isa(a , 'hyper_dual') + % if a is hyper-dual + ret = (obj.a == a.a) ; + else + ret = ( obj.a == a ); % if a is a real number + end + else + ret = ( obj == a.a ); + end + end + %% sin + function ret = sin(obj) + ret = obj; + ret.a = sin( obj.a ); + ret.b = obj.b .* cos(obj.a); + ret.c = obj.c .* cos(obj.a); + ret.d = obj.d .* cos(obj.a) - obj.b .* obj.c .* sin( obj.a ); + end + %% cos + function ret = cos(obj) + ret = obj; + ret.a = cos( obj.a ); + ret.b = -obj.b .* sin(obj.a); + ret.c = -obj.c .* sin(obj.a); + ret.d = -obj.d .* sin(obj.a) - obj.b .* obj.c .* cos( obj.a ); + end + %% exp + function ret = exp(obj) + ret = obj; + ret.a = exp( obj.a ); + ret.b = obj.b .* exp(obj.a); + ret.c = obj.c .* exp(obj.a); + ret.d = obj.d .* exp(obj.a) + obj.b .* obj.c .* exp(obj.a); + end + % power + function ret = power(obj,B) + if isa(obj , 'hyper_dual') + if isa(B , 'hyper_dual') + % if B is hyper-dual + error('not implemented') + else + % if B is a real number + ret = obj; + help = B .* obj.a .^ (B-1); + ret.a = obj.a .^ B; + ret.b = obj.b .* help; + ret.c = obj.c .* help; + ret.d = obj.d .* help + obj.b .* obj.c .* B .* (B-1) .* obj.a .^ (B-2); + end + else + ret = B; % B is hyper-dual + ret.a = obj .^ B.a; + ret.b = B.b .* obj .^ B.a .* log(obj) ; + ret.c = B.c .* obj .^ B.a .* log(obj) ; + ret.d = B.d .* obj .^ B.a .* log(obj) + B.b .* B.c .* obj .^ B.a .* log(obj)^2; + end + + end + %% sqrt + function ret = sqrt(obj) + ret = obj; + Df = 1 ./ ( 2 * sqrt(obj.a) ); + ret.a = sqrt(obj.a); + ret.b = obj.b .* Df; + ret.c = obj.c .* Df; + ret.d = obj.d .* Df - obj.b .* obj.c .* Df ./ ( 2 * obj.a ); + end + %% sinh + function ret = sinh(obj) + ret = obj; + Df = cosh(obj.a) ; + ret.a = sinh(obj.a); + ret.b = obj.b .* Df; + ret.c = obj.c .* Df; + ret.d = obj.d .* Df + obj.b .* obj.c .* sinh(obj.a); + end + %% cosh + function ret = cosh(obj) + ret = obj; + Df = sinh(obj.a) ; + ret.a = cosh(obj.a); + ret.b = obj.b .* Df; + ret.c = obj.c .* Df; + ret.d = obj.d .* Df + obj.b .* obj.c .* cosh(obj.a); + end + %% abs + function ret = abs(obj) + ret = obj; + ret.a = abs(obj.a); + end + %% < + function ret = lt(obj,B) + if isa(obj , 'hyper_dual') + if isa(B , 'hyper_dual') + % obj and B is hyper-dual + ret = obj.a < B.a; + else + % if B is a real number + ret = obj.a < B; + end + else + % B is hyper-dual + ret = obj < B.a; + end + end + %% <= + function ret = le(obj,B) + if isa(obj , 'hyper_dual') + if isa(B , 'hyper_dual') + % obj and B is hyper-dual + ret = obj.a <= B.a; + else + % if B is a real number + ret = obj.a <= B; + end + else + % B is hyper-dual + ret = obj <= B.a; + end + end + %% > + function ret = gt(obj,B) + if isa(obj , 'hyper_dual') + if isa(B , 'hyper_dual') + % obj and B is hyper-dual + ret = obj.a > B.a; + else + % if B is a real number + ret = obj.a > B; + end + else + % B is hyper-dual + ret = obj > B.a; + end + end + %% >= + function ret = ge(obj,B) + if isa(obj , 'hyper_dual') + if isa(B , 'hyper_dual') + % obj and B is hyper-dual + ret = obj.a >= B.a; + else + % if B is a real number + ret = obj.a >= B; + end + else + % B is hyper-dual + ret = obj >= B.a; + end + end + %% mod + function ret = mod(obj,B) + if isa(obj , 'hyper_dual') + if isa(B , 'hyper_dual') + % obj and B is hyper-dual + error('not implemented') + ret = mod(obj.a , B.a); + else + % if B is a real number + ret = obj; + ret.a = mod(obj.a , B); + end + else + % B is hyper-dual + error('not implemented') + ret = mod(obj , B.a); + end + end + %% sign + function ret = sign(obj) + ret.a = sign(obj.a); + end +% function ret = subsasgn(obj,s,a) +% if length(s) == 1 +% if s.type == '()' +% error('not implemented') +% else +% ret = builtin('subsasgn', obj, s,a); +% end +% elseif length(s) == 2 +% if strcmp(s(1).type,'.') && strcmp(s(2).type,'()') +% ret = builtin('subsasgn', obj, s,a); +% else +% error('not implemented') +% end +% else +% error('not implemented') +% end +% end + %function C = horzcat(obj,B) + + + %end + %% + %% + function n = numArgumentsFromSubscript(obj, s, ic) + n = builtin('numArgumentsFromSubscript', obj, s, ic); + end + + end + methods (Static) + %% zeros + function ret = zeros(numX,numY) + ret = hyper_dual( zeros(numX,numY) ); + end + end + methods (Hidden) + %% zeros + function ret = zerosLike(obj,numX,numY) + if nargin == 2 + ret = hyper_dual( zeros(numX) ); + else + ret = hyper_dual( zeros(numX,numY) ); + end + end + %% my_subs() + function ret = my_subs(obj,i,j) + ret = hyper_dual(obj.a(i,j),obj.b(i,j),obj.c(i,j),obj.d(i,j)); + end + end +end + + diff --git a/matlab/hyper_dual_numbers/@hyper_dual/subsref.m b/matlab/hyper_dual_numbers/@hyper_dual/subsref.m new file mode 100644 index 0000000..d02123b --- /dev/null +++ b/matlab/hyper_dual_numbers/@hyper_dual/subsref.m @@ -0,0 +1,27 @@ +function ret = subsref(obj,s) +ret = obj; +if length(s) == 1 + if strcmp(s.type,'()') + if length(s.subs) == 1 + ret.a = obj.a(s.subs{1}); + ret.b = obj.b(s.subs{1}); + ret.c = obj.c(s.subs{1}); + ret.d = obj.d(s.subs{1}); + elseif length(s.subs) == 2 + ret.a = obj.a(s.subs{1},s.subs{2}); + ret.b = obj.b(s.subs{1},s.subs{2}); + ret.c = obj.c(s.subs{1},s.subs{2}); + ret.d = obj.d(s.subs{1},s.subs{2}); + else + ret.a = builtin('subsref',obj.a,s); + ret.b = builtin('subsref',obj.b,s); + ret.c = builtin('subsref',obj.c,s); + ret.d = builtin('subsref',obj.d,s); + end + else + ret = builtin('subsref', obj, s); + end +else + ret = builtin('subsref', obj, s); +end +end \ No newline at end of file diff --git a/matlab/hyper_dual_numbers/@hyper_dual/vertcat.m b/matlab/hyper_dual_numbers/@hyper_dual/vertcat.m new file mode 100644 index 0000000..ac5d771 --- /dev/null +++ b/matlab/hyper_dual_numbers/@hyper_dual/vertcat.m @@ -0,0 +1,16 @@ +function ret = vertcat(varargin) +obj = cat( 1, varargin{:} ); +ret = hyper_dual(cat(1,obj.a),cat(1,obj.b),cat(1,obj.c),cat(1,obj.d)); +% if isa(A , 'hyper_dual') +% if isa(B , 'hyper_dual') +% % obj and B is hyper-dual +% ret = hyper_dual([A.a;B.a],[A.b;B.b],[A.c;B.c],[A.d;B.d]); +% else +% % if B is a real number +% ret = hyper_dual([A.a;B],[A.b;0*B],[A.c;0*B],[A.d;0*B]); +% end +% else +% % B is hyper-dual +% ret = hyper_dual([A;B.a],[0*A;B.b],[0*A;B.c],[0*A;B.d]); +% end +end diff --git a/matlab/hyper_dual_numbers/test_fun_hyper_dual.m b/matlab/hyper_dual_numbers/test_fun_hyper_dual.m new file mode 100644 index 0000000..1ab2b1d --- /dev/null +++ b/matlab/hyper_dual_numbers/test_fun_hyper_dual.m @@ -0,0 +1,32 @@ +function test_fun_hyper_dual(f) +% +df = diff(f,'x'); +ddf= diff(df,'x'); +g = matlabFunction(f); +% +h1=0.000000001; +x = rand(4); +y = rand(4); +A = hyper_dual(x,ones(4)*h1, ones(4)*h1, zeros(4) ); +ret = g(A,y); + +if double(norm(ret.b / h1 - df(x,y),'fro')) < 10^-13 && ... + double(norm(ret.d / h1^2 - ddf(x,y),'fro')) < 10^-10 && ... + double(norm(ret.a - f(x,y),'fro')) < 10^-13 + disp('hyper_dual ok') +elseif norm(ret.a - f(x,y),'fro') > 10^-13 + disp('error in hyper_dual - evaluation') + disp(double(norm(ret.a - f(x,y),'fro'))) + +elseif norm(ret.b / h1 - df(x,y),'fro') > 10^-13 + disp('error in hyper_dual - first derivative') + disp(double(norm(ret.b / h1 - df(x,y),'fro'))) + +else + disp('error in hyper_dual - second derivative') + disp(double(norm(ret.d / h1^2 - ddf(x,y),'fro'))) + +end + + +end diff --git a/matlab/hyper_dual_numbers/test_hyper_dual.m b/matlab/hyper_dual_numbers/test_hyper_dual.m new file mode 100644 index 0000000..bc4186a --- /dev/null +++ b/matlab/hyper_dual_numbers/test_hyper_dual.m @@ -0,0 +1,49 @@ +clear all +close all +whos +syms x y +% +disp('set 1') +f1 = symfun( x.^2 + y.^2-1 , [x y]); +test_fun_hyper_dual(f1) +% +f11 = symfun( 1 - x.^2 , [x y]); +test_fun_hyper_dual(f11) +% +f12 = symfun( x + y , [x y]); +test_fun_hyper_dual(f12) +% +f13 = symfun( -x .* y - y .* x , [x y]); +test_fun_hyper_dual(f13) +% +f14 = symfun( x .* y + y .* x , [x y]); +test_fun_hyper_dual(f14) +% +disp('set 2') +f2 = symfun( x ./ y , [x y]); +test_fun_hyper_dual(f2) +% +f21 = symfun( -x ./ y , [x y]); +test_fun_hyper_dual(f21) +% +f22 = symfun( -y ./ (x+1) , [x y]); +test_fun_hyper_dual(f22) +% +f23 = symfun( x ./ y - y ./ x , [x y]); +test_fun_hyper_dual(f23) +% +disp('set 3') +f3 = symfun( sin(x) .* cos(x) .* exp(x) , [x y]); +test_fun_hyper_dual(f3) +% +f31 = symfun( sinh(x) .*cosh(y.*x) .* sqrt(x.^y) , [x y]); +test_fun_hyper_dual(f31) +% +disp('set 4') +f41 = symfun( x.^3 , [x y]); +test_fun_hyper_dual(f41) +% +f42 = symfun( y .* 3.^x , [x y]); +test_fun_hyper_dual(f42) + + -- GitLab