Logo ROOT  
Reference Guide
 
Loading...
Searching...
No Matches
CladDerivator.h
Go to the documentation of this file.
1/// \file CladDerivator.h
2///
3/// \brief The file is a bridge between ROOT and clad automatic differentiation
4/// plugin.
5///
6/// \author Vassil Vassilev <vvasilev@cern.ch>
7///
8/// \date July, 2018
9
10/*************************************************************************
11 * Copyright (C) 1995-2018, Rene Brun and Fons Rademakers. *
12 * All rights reserved. *
13 * *
14 * For the licensing terms see $ROOTSYS/LICENSE. *
15 * For the list of contributors see $ROOTSYS/README/CREDITS. *
16 *************************************************************************/
17
18#ifndef CLAD_DERIVATOR
19#define CLAD_DERIVATOR
20
21#ifndef __CLING__
22#error "This file must not be included by compiled programs."
23#endif //__CLING__
24
25#include <plugins/include/clad/Differentiator/Differentiator.h>
26#include "TMath.h"
27
28// For the digamma function, that is the derivative of lgamma. We get it via
29// mathmore from the GSL, so the pullbacks that use digamma are only available
30// with mathmore=ON.
31#ifdef R__HAS_MATHMORE
33#endif
34
35#include <stdexcept>
36
37namespace clad {
38namespace custom_derivatives {
39namespace TMath {
40template <typename T>
42{
43 return {::TMath::Abs(x), ((x < 0) ? -1 : 1) * d_x};
44}
45
46template <typename T>
48{
49 return {::TMath::ACos(x), (-1. / ::TMath::Sqrt(1 - x * x)) * d_x};
50}
51
52template <typename T>
54{
55 return {::TMath::ACosH(x), (1. / ::TMath::Sqrt(x * x - 1)) * d_x};
56}
57
58template <typename T>
60{
61 return {::TMath::ASin(x), (1. / ::TMath::Sqrt(1 - x * x)) * d_x};
62}
63
64template <typename T>
66{
67 return {::TMath::ASinH(x), (1. / ::TMath::Sqrt(x * x + 1)) * d_x};
68}
69
70template <typename T>
72{
73 return {::TMath::ATan(x), (1. / (x * x + 1)) * d_x};
74}
75
76template <typename T>
78{
79 return {::TMath::ATanH(x), (1. / (1 - x * x)) * d_x};
80}
81
82template <typename T>
87
88template <typename T>
93
94template <typename T>
99
100template <typename T>
102{
103 return {::TMath::Erfc(x), -Erf_pushforward(x, d_x).pushforward};
104}
105
106#ifdef R__HAS_MATHMORE
107
108template <typename T>
110{
112}
113
114#endif
115
116template <typename T>
121
122template <typename T>
127
128template <typename T, typename U>
129void Hypot_pullback(T x, T y, U p, clad::array_ref<T> d_x, clad::array_ref<T> d_y)
130{
131 T h = ::TMath::Hypot(x, y);
132 *d_x += x / h * p;
133 *d_y += y / h * p;
134}
135
136template <typename T>
138{
139 return {::TMath::Log(x), (1. / x) * d_x};
140}
141
142template <typename T>
144{
145 return {::TMath::Log10(x), (1.0 / (x * ::TMath::Ln10())) * d_x};
146}
147
148template <typename T>
150{
151 return {::TMath::Log2(x), (1.0 / (x * ::TMath::Log(2.0))) * d_x};
152}
153
154template <typename T>
156{
157 T derivative = 0;
158 if (x >= y)
159 derivative = d_x;
160 else
161 derivative = d_y;
162 return {::TMath::Max(x, y), derivative};
163}
164
165template <typename T, typename U>
166void Max_pullback(T a, T b, U p, clad::array_ref<T> d_a, clad::array_ref<T> d_b)
167{
168 if (a >= b)
169 *d_a += p;
170 else
171 *d_b += p;
172}
173
174template <typename T>
176{
177 T derivative = 0;
178 if (x <= y)
179 derivative = d_x;
180 else
181 derivative = d_y;
182 return {::TMath::Min(x, y), derivative};
183}
184
185template <typename T, typename U>
186void Min_pullback(T a, T b, U p, clad::array_ref<T> d_a, clad::array_ref<T> d_b)
187{
188 if (a <= b)
189 *d_a += p;
190 else
191 *d_b += p;
192}
193
194template <typename T>
196{
197 T pushforward = y * ::TMath::Power(x, y - 1) * d_x;
198 if (d_y) {
200 }
201 return {::TMath::Power(x, y), pushforward};
202}
203
204template <typename T, typename U>
205void Power_pullback(T x, T y, U p, clad::array_ref<T> d_x, clad::array_ref<T> d_y)
206{
207 auto t = pow_pushforward(x, y, 1, 0);
208 *d_x += t.pushforward * p;
209 t = pow_pushforward(x, y, 0, 1);
210 *d_y += t.pushforward * p;
211}
212
213template <typename T>
218
219template <typename T>
224
225template <typename T>
227{
228 return {::TMath::Sq(x), 2 * x * d_x};
229}
230
231template <typename T>
236
237template <typename T>
242
243template <typename T>
248
249#ifdef WIN32
250// Additional custom derivatives that can be removed
251// after Issue #12108 in ROOT is resolved
252// constexpr is removed
254{
255 return {3.1415926535897931, 0.};
256}
257// constexpr is removed
259{
260 return {2.3025850929940459, 0.};
261}
262#endif
263} // namespace TMath
264
265namespace ROOT {
266namespace Math {
267
268inline void landau_pdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
269{
270 if (xi <= 0) {
271 return;
272 }
273 // clang-format off
274 static double p1[5] = {0.4259894875,-0.1249762550, 0.03984243700, -0.006298287635, 0.001511162253};
275 static double q1[5] = {1.0 ,-0.3388260629, 0.09594393323, -0.01608042283, 0.003778942063};
276
277 static double p2[5] = {0.1788541609, 0.1173957403, 0.01488850518, -0.001394989411, 0.0001283617211};
278 static double q2[5] = {1.0 , 0.7428795082, 0.3153932961, 0.06694219548, 0.008790609714};
279
280 static double p3[5] = {0.1788544503, 0.09359161662,0.006325387654, 0.00006611667319,-0.000002031049101};
281 static double q3[5] = {1.0 , 0.6097809921, 0.2560616665, 0.04746722384, 0.006957301675};
282
283 static double p4[5] = {0.9874054407, 118.6723273, 849.2794360, -743.7792444, 427.0262186};
284 static double q4[5] = {1.0 , 106.8615961, 337.6496214, 2016.712389, 1597.063511};
285
286 static double p5[5] = {1.003675074, 167.5702434, 4789.711289, 21217.86767, -22324.94910};
287 static double q5[5] = {1.0 , 156.9424537, 3745.310488, 9834.698876, 66924.28357};
288
289 static double p6[5] = {1.000827619, 664.9143136, 62972.92665, 475554.6998, -5743609.109};
290 static double q6[5] = {1.0 , 651.4101098, 56974.73333, 165917.4725, -2815759.939};
291
292 static double a1[3] = {0.04166666667,-0.01996527778, 0.02709538966};
293
294 static double a2[2] = {-1.845568670,-4.284640743};
295 // clang-format on
296 const double _const0 = 0.3989422803;
297 double v = (x - x0) / xi;
298 double _d_v = 0;
299 double _d_denlan = 0;
300 if (v < -5.5) {
301 double u = ::std::exp(v + 1.);
302 double _d_u = 0;
303 if (u >= 1.e-10) {
304 const double ue = ::std::exp(-1 / u);
305 const double us = ::std::sqrt(u);
306 double _t3;
307 double _d_ue = 0;
308 double _d_us = 0;
309 double denlan = _const0 * (ue / us) * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u);
310 _d_denlan += d_out / xi;
311 *d_xi += d_out * -(denlan / (xi * xi));
312 denlan = _t3;
313 double _r_d3 = _d_denlan;
314 _d_denlan -= _r_d3;
315 _d_ue += _const0 * _r_d3 * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u) / us;
316 double _r5 = _const0 * _r_d3 * (1 + (a1[0] + (a1[1] + a1[2] * u) * u) * u) * -(ue / (us * us));
317 _d_us += _r5;
318 _d_u += a1[2] * _const0 * (ue / us) * _r_d3 * u * u;
319 _d_u += (a1[1] + a1[2] * u) * _const0 * (ue / us) * _r_d3 * u;
320 _d_u += (a1[0] + (a1[1] + a1[2] * u) * u) * _const0 * (ue / us) * _r_d3;
321 double _r_d2 = _d_us;
322 _d_us -= _r_d2;
323 double _r4 = 0;
324 _r4 += _r_d2 * clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
325 _d_u += _r4;
326 double _r_d1 = _d_ue;
327 _d_ue -= _r_d1;
328 double _r2 = 0;
329 _r2 += _r_d1 * ::std::exp(-1 / u);
330 double _r3 = _r2 * -(-1 / (u * u));
331 _d_u += _r3;
332 }
333 double _r_d0 = _d_u;
334 _d_u -= _r_d0;
335 double _r1 = 0;
336 _r1 += _r_d0 * ::std::exp(v + 1.);
337 _d_v += _r1;
338 } else if (v < -1) {
339 double _t4;
340 double u = ::std::exp(-v - 1);
341 double _d_u = 0;
342 double _t5;
343 double _t8 = ::std::exp(-u);
344 double _t7 = ::std::sqrt(u);
345 double _t6 = (q1[0] + (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * v);
346 double denlan = _t8 * _t7 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t6;
347 _d_denlan += d_out / xi;
348 *d_xi += d_out * -(denlan / (xi * xi));
349 denlan = _t5;
350 double _r_d5 = _d_denlan;
351 _d_denlan -= _r_d5;
352 double _r7 = 0;
353 _r7 += _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * _t7 * ::std::exp(-u);
354 _d_u += -_r7;
355 double _r8 = 0;
356 _r8 += _t8 * _r_d5 / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) *
357 clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
358 _d_u += _r8;
359 _d_v += p1[4] * _t8 * _t7 * _r_d5 / _t6 * v * v * v;
360 _d_v += (p1[3] + p1[4] * v) * _t8 * _t7 * _r_d5 / _t6 * v * v;
361 _d_v += (p1[2] + (p1[3] + p1[4] * v) * v) * _t8 * _t7 * _r_d5 / _t6 * v;
362 _d_v += (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * _t8 * _t7 * _r_d5 / _t6;
363 double _r9 = _r_d5 * -(_t8 * _t7 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / (_t6 * _t6));
364 _d_v += q1[4] * _r9 * v * v * v;
365 _d_v += (q1[3] + q1[4] * v) * _r9 * v * v;
366 _d_v += (q1[2] + (q1[3] + q1[4] * v) * v) * _r9 * v;
367 _d_v += (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * _r9;
368 u = _t4;
369 double _r_d4 = _d_u;
370 _d_u -= _r_d4;
371 double _r6 = 0;
372 _r6 += _r_d4 * ::std::exp(-v - 1);
373 _d_v += -_r6;
374 } else if (v < 1) {
375 double _t9;
376 double _t10 = (q2[0] + (q2[1] + (q2[2] + (q2[3] + q2[4] * v) * v) * v) * v);
377 double denlan = (p2[0] + (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * v) / _t10;
378 _d_denlan += d_out / xi;
379 *d_xi += d_out * -(denlan / (xi * xi));
380 denlan = _t9;
381 double _r_d6 = _d_denlan;
382 _d_denlan -= _r_d6;
383 _d_v += p2[4] * _r_d6 / _t10 * v * v * v;
384 _d_v += (p2[3] + p2[4] * v) * _r_d6 / _t10 * v * v;
385 _d_v += (p2[2] + (p2[3] + p2[4] * v) * v) * _r_d6 / _t10 * v;
386 _d_v += (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * _r_d6 / _t10;
387 double _r10 = _r_d6 * -((p2[0] + (p2[1] + (p2[2] + (p2[3] + p2[4] * v) * v) * v) * v) / (_t10 * _t10));
388 _d_v += q2[4] * _r10 * v * v * v;
389 _d_v += (q2[3] + q2[4] * v) * _r10 * v * v;
390 _d_v += (q2[2] + (q2[3] + q2[4] * v) * v) * _r10 * v;
391 _d_v += (q2[1] + (q2[2] + (q2[3] + q2[4] * v) * v) * v) * _r10;
392 } else if (v < 5) {
393 double _t11;
394 double _t12 = (q3[0] + (q3[1] + (q3[2] + (q3[3] + q3[4] * v) * v) * v) * v);
395 double denlan = (p3[0] + (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * v) / _t12;
396 _d_denlan += d_out / xi;
397 *d_xi += d_out * -(denlan / (xi * xi));
398 denlan = _t11;
399 double _r_d7 = _d_denlan;
400 _d_denlan -= _r_d7;
401 _d_v += p3[4] * _r_d7 / _t12 * v * v * v;
402 _d_v += (p3[3] + p3[4] * v) * _r_d7 / _t12 * v * v;
403 _d_v += (p3[2] + (p3[3] + p3[4] * v) * v) * _r_d7 / _t12 * v;
404 _d_v += (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * _r_d7 / _t12;
405 double _r11 = _r_d7 * -((p3[0] + (p3[1] + (p3[2] + (p3[3] + p3[4] * v) * v) * v) * v) / (_t12 * _t12));
406 _d_v += q3[4] * _r11 * v * v * v;
407 _d_v += (q3[3] + q3[4] * v) * _r11 * v * v;
408 _d_v += (q3[2] + (q3[3] + q3[4] * v) * v) * _r11 * v;
409 _d_v += (q3[1] + (q3[2] + (q3[3] + q3[4] * v) * v) * v) * _r11;
410 } else if (v < 12) {
411 double u = 1 / v;
412 double _d_u = 0;
413 double _t14;
414 double _t15 = (q4[0] + (q4[1] + (q4[2] + (q4[3] + q4[4] * u) * u) * u) * u);
415 double denlan = u * u * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) / _t15;
416 _d_denlan += d_out / xi;
417 *d_xi += d_out * -(denlan / (xi * xi));
418 denlan = _t14;
419 double _r_d9 = _d_denlan;
420 _d_denlan -= _r_d9;
421 _d_u += _r_d9 / _t15 * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) * u;
422 _d_u += u * _r_d9 / _t15 * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u);
423 _d_u += p4[4] * u * u * _r_d9 / _t15 * u * u * u;
424 _d_u += (p4[3] + p4[4] * u) * u * u * _r_d9 / _t15 * u * u;
425 _d_u += (p4[2] + (p4[3] + p4[4] * u) * u) * u * u * _r_d9 / _t15 * u;
426 _d_u += (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u * u * _r_d9 / _t15;
427 double _r13 = _r_d9 * -(u * u * (p4[0] + (p4[1] + (p4[2] + (p4[3] + p4[4] * u) * u) * u) * u) / (_t15 * _t15));
428 _d_u += q4[4] * _r13 * u * u * u;
429 _d_u += (q4[3] + q4[4] * u) * _r13 * u * u;
430 _d_u += (q4[2] + (q4[3] + q4[4] * u) * u) * _r13 * u;
431 _d_u += (q4[1] + (q4[2] + (q4[3] + q4[4] * u) * u) * u) * _r13;
432 double _r_d8 = _d_u;
433 _d_u -= _r_d8;
434 double _r12 = _r_d8 * -(1 / (v * v));
435 _d_v += _r12;
436 } else if (v < 50) {
437 double u = 1 / v;
438 double _d_u = 0;
439 double _t17;
440 double _t18 = (q5[0] + (q5[1] + (q5[2] + (q5[3] + q5[4] * u) * u) * u) * u);
441 double denlan = u * u * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) / _t18;
442 _d_denlan += d_out / xi;
443 *d_xi += d_out * -(denlan / (xi * xi));
444 denlan = _t17;
445 double _r_d11 = _d_denlan;
446 _d_denlan -= _r_d11;
447 _d_u += _r_d11 / _t18 * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) * u;
448 _d_u += u * _r_d11 / _t18 * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u);
449 _d_u += p5[4] * u * u * _r_d11 / _t18 * u * u * u;
450 _d_u += (p5[3] + p5[4] * u) * u * u * _r_d11 / _t18 * u * u;
451 _d_u += (p5[2] + (p5[3] + p5[4] * u) * u) * u * u * _r_d11 / _t18 * u;
452 _d_u += (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u * u * _r_d11 / _t18;
453 double _r15 = _r_d11 * -(u * u * (p5[0] + (p5[1] + (p5[2] + (p5[3] + p5[4] * u) * u) * u) * u) / (_t18 * _t18));
454 _d_u += q5[4] * _r15 * u * u * u;
455 _d_u += (q5[3] + q5[4] * u) * _r15 * u * u;
456 _d_u += (q5[2] + (q5[3] + q5[4] * u) * u) * _r15 * u;
457 _d_u += (q5[1] + (q5[2] + (q5[3] + q5[4] * u) * u) * u) * _r15;
458 double _r_d10 = _d_u;
459 _d_u -= _r_d10;
460 double _r14 = _r_d10 * -(1 / (v * v));
461 _d_v += _r14;
462 } else if (v < 300) {
463 double _t19;
464 double u = 1 / v;
465 double _d_u = 0;
466 double _t20;
467 double _t21 = (q6[0] + (q6[1] + (q6[2] + (q6[3] + q6[4] * u) * u) * u) * u);
468 double denlan = u * u * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) / _t21;
469 _d_denlan += d_out / xi;
470 *d_xi += d_out * -(denlan / (xi * xi));
471 denlan = _t20;
472 double _r_d13 = _d_denlan;
473 _d_denlan -= _r_d13;
474 _d_u += _r_d13 / _t21 * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) * u;
475 _d_u += u * _r_d13 / _t21 * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u);
476 _d_u += p6[4] * u * u * _r_d13 / _t21 * u * u * u;
477 _d_u += (p6[3] + p6[4] * u) * u * u * _r_d13 / _t21 * u * u;
478 _d_u += (p6[2] + (p6[3] + p6[4] * u) * u) * u * u * _r_d13 / _t21 * u;
479 _d_u += (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u * u * _r_d13 / _t21;
480 double _r17 = _r_d13 * -(u * u * (p6[0] + (p6[1] + (p6[2] + (p6[3] + p6[4] * u) * u) * u) * u) / (_t21 * _t21));
481 _d_u += q6[4] * _r17 * u * u * u;
482 _d_u += (q6[3] + q6[4] * u) * _r17 * u * u;
483 _d_u += (q6[2] + (q6[3] + q6[4] * u) * u) * _r17 * u;
484 _d_u += (q6[1] + (q6[2] + (q6[3] + q6[4] * u) * u) * u) * _r17;
485 u = _t19;
486 double _r_d12 = _d_u;
487 _d_u -= _r_d12;
488 double _r16 = _r_d12 * -(1 / (v * v));
489 _d_v += _r16;
490 } else {
491 double _t22;
492 double _t25 = ::std::log(v);
493 double _t24 = (v + 1);
494 double _t23 = (v - v * _t25 / _t24);
495 double u = 1 / _t23;
496 double _d_u = 0;
497 double _t26;
498 double denlan = u * u * (1 + (a2[0] + a2[1] * u) * u);
499 _d_denlan += d_out / xi;
500 *d_xi += d_out * -(denlan / (xi * xi));
501 denlan = _t26;
502 double _r_d15 = _d_denlan;
503 _d_denlan -= _r_d15;
504 _d_u += _r_d15 * (1 + (a2[0] + a2[1] * u) * u) * u;
505 _d_u += u * _r_d15 * (1 + (a2[0] + a2[1] * u) * u);
506 _d_u += a2[1] * u * u * _r_d15 * u;
507 _d_u += (a2[0] + a2[1] * u) * u * u * _r_d15;
508 u = _t22;
509 double _r_d14 = _d_u;
510 _d_u -= _r_d14;
511 double _r18 = _r_d14 * -(1 / (_t23 * _t23));
512 _d_v += _r18;
513 _d_v += -_r18 / _t24 * _t25;
514 double _r19 = 0;
515 _r19 += v * -_r18 / _t24 / v;
516 _d_v += _r19;
517 double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
518 _d_v += _r20;
519 }
520 *d_x += _d_v / xi;
521 *d_x0 += -_d_v / xi;
522 double _r0 = _d_v * -((x - x0) / (xi * xi));
523 *d_xi += _r0;
524}
525
526inline void landau_cdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
527{
528 // clang-format off
529 static double p1[5] = {0.2514091491e+0,-0.6250580444e-1, 0.1458381230e-1,-0.2108817737e-2, 0.7411247290e-3};
530 static double q1[5] = {1.0 ,-0.5571175625e-2, 0.6225310236e-1,-0.3137378427e-2, 0.1931496439e-2};
531
532 static double p2[4] = {0.2868328584e+0, 0.3564363231e+0, 0.1523518695e+0, 0.2251304883e-1};
533 static double q2[4] = {1.0 , 0.6191136137e+0, 0.1720721448e+0, 0.2278594771e-1};
534
535 static double p3[4] = {0.2868329066e+0, 0.3003828436e+0, 0.9950951941e-1, 0.8733827185e-2};
536 static double q3[4] = {1.0 , 0.4237190502e+0, 0.1095631512e+0, 0.8693851567e-2};
537
538 static double p4[4] = {0.1000351630e+1, 0.4503592498e+1, 0.1085883880e+2, 0.7536052269e+1};
539 static double q4[4] = {1.0 , 0.5539969678e+1, 0.1933581111e+2, 0.2721321508e+2};
540
541 static double p5[4] = {0.1000006517e+1, 0.4909414111e+2, 0.8505544753e+2, 0.1532153455e+3};
542 static double q5[4] = {1.0 , 0.5009928881e+2, 0.1399819104e+3, 0.4200002909e+3};
543
544 static double p6[4] = {0.1000000983e+1, 0.1329868456e+3, 0.9162149244e+3,-0.9605054274e+3};
545 static double q6[4] = {1.0 , 0.1339887843e+3, 0.1055990413e+4, 0.5532224619e+3};
546
547 static double a1[4] = {0 ,-0.4583333333e+0, 0.6675347222e+0,-0.1641741416e+1};
548 static double a2[4] = {0 , 1.0 ,-0.4227843351e+0,-0.2043403138e+1};
549 // clang-format on
550
551 const double v = (x - x0) / xi;
552 double _d_v = 0;
553 if (v < -5.5) {
554 double _d_u = 0;
555 const double _const0 = 0.3989422803;
556 double u = ::std::exp(v + 1);
557 double _t3 = ::std::exp(-1. / u);
558 double _t2 = ::std::sqrt(u);
559 double _r2 = 0;
560 _r2 += _const0 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) * _t2 * ::std::exp(-1. / u);
561 double _r3 = _r2 * -(-1. / (u * u));
562 _d_u += _r3;
563 double _r4 = 0;
564 _r4 += _const0 * _t3 * d_out * (1 + (a1[1] + (a1[2] + a1[3] * u) * u) * u) *
565 clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
566 _d_u += _r4;
567 _d_u += a1[3] * _const0 * _t3 * _t2 * d_out * u * u;
568 _d_u += (a1[2] + a1[3] * u) * _const0 * _t3 * _t2 * d_out * u;
569 _d_u += (a1[1] + (a1[2] + a1[3] * u) * u) * _const0 * _t3 * _t2 * d_out;
570 _d_v += _d_u * ::std::exp(v + 1);
571 } else if (v < -1) {
572 double _d_u = 0;
573 double u = ::std::exp(-v - 1);
574 double _t8 = ::std::exp(-u);
575 double _t7 = ::std::sqrt(u);
576 double _t6 = (q1[0] + (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * v);
577 double _r6 = 0;
578 _r6 += d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / _t7 * ::std::exp(-u);
579 _d_u += -_r6;
580 double _r7 = d_out / _t6 * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) * -(_t8 / (_t7 * _t7));
581 double _r8 = 0;
582 _r8 += _r7 * clad::custom_derivatives::sqrt_pushforward(u, 1.).pushforward;
583 _d_u += _r8;
584 _d_v += p1[4] * (_t8 / _t7) * d_out / _t6 * v * v * v;
585 _d_v += (p1[3] + p1[4] * v) * (_t8 / _t7) * d_out / _t6 * v * v;
586 _d_v += (p1[2] + (p1[3] + p1[4] * v) * v) * (_t8 / _t7) * d_out / _t6 * v;
587 _d_v += (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * (_t8 / _t7) * d_out / _t6;
588 double _r9 = d_out * -((_t8 / _t7) * (p1[0] + (p1[1] + (p1[2] + (p1[3] + p1[4] * v) * v) * v) * v) / (_t6 * _t6));
589 _d_v += q1[4] * _r9 * v * v * v;
590 _d_v += (q1[3] + q1[4] * v) * _r9 * v * v;
591 _d_v += (q1[2] + (q1[3] + q1[4] * v) * v) * _r9 * v;
592 _d_v += (q1[1] + (q1[2] + (q1[3] + q1[4] * v) * v) * v) * _r9;
593 _d_v += -_d_u * ::std::exp(-v - 1);
594 } else if (v < 1) {
595 double _t10 = (q2[0] + (q2[1] + (q2[2] + q2[3] * v) * v) * v);
596 _d_v += p2[3] * d_out / _t10 * v * v;
597 _d_v += (p2[2] + p2[3] * v) * d_out / _t10 * v;
598 _d_v += (p2[1] + (p2[2] + p2[3] * v) * v) * d_out / _t10;
599 double _r10 = d_out * -((p2[0] + (p2[1] + (p2[2] + p2[3] * v) * v) * v) / (_t10 * _t10));
600 _d_v += q2[3] * _r10 * v * v;
601 _d_v += (q2[2] + q2[3] * v) * _r10 * v;
602 _d_v += (q2[1] + (q2[2] + q2[3] * v) * v) * _r10;
603 } else if (v < 4) {
604 double _t12 = (q3[0] + (q3[1] + (q3[2] + q3[3] * v) * v) * v);
605 _d_v += p3[3] * d_out / _t12 * v * v;
606 _d_v += (p3[2] + p3[3] * v) * d_out / _t12 * v;
607 _d_v += (p3[1] + (p3[2] + p3[3] * v) * v) * d_out / _t12;
608 double _r11 = d_out * -((p3[0] + (p3[1] + (p3[2] + p3[3] * v) * v) * v) / (_t12 * _t12));
609 _d_v += q3[3] * _r11 * v * v;
610 _d_v += (q3[2] + q3[3] * v) * _r11 * v;
611 _d_v += (q3[1] + (q3[2] + q3[3] * v) * v) * _r11;
612 } else if (v < 12) {
613 double _d_u = 0;
614 double u = 1. / v;
615 double _t15 = (q4[0] + (q4[1] + (q4[2] + q4[3] * u) * u) * u);
616 _d_u += p4[3] * d_out / _t15 * u * u;
617 _d_u += (p4[2] + p4[3] * u) * d_out / _t15 * u;
618 _d_u += (p4[1] + (p4[2] + p4[3] * u) * u) * d_out / _t15;
619 double _r13 = d_out * -((p4[0] + (p4[1] + (p4[2] + p4[3] * u) * u) * u) / (_t15 * _t15));
620 _d_u += q4[3] * _r13 * u * u;
621 _d_u += (q4[2] + q4[3] * u) * _r13 * u;
622 _d_u += (q4[1] + (q4[2] + q4[3] * u) * u) * _r13;
623 double _r12 = _d_u * -(1. / (v * v));
624 _d_v += _r12;
625 } else if (v < 50) {
626 double _d_u = 0;
627 double u = 1. / v;
628 double _t18 = (q5[0] + (q5[1] + (q5[2] + q5[3] * u) * u) * u);
629 _d_u += p5[3] * d_out / _t18 * u * u;
630 _d_u += (p5[2] + p5[3] * u) * d_out / _t18 * u;
631 _d_u += (p5[1] + (p5[2] + p5[3] * u) * u) * d_out / _t18;
632 double _r15 = d_out * -((p5[0] + (p5[1] + (p5[2] + p5[3] * u) * u) * u) / (_t18 * _t18));
633 _d_u += q5[3] * _r15 * u * u;
634 _d_u += (q5[2] + q5[3] * u) * _r15 * u;
635 _d_u += (q5[1] + (q5[2] + q5[3] * u) * u) * _r15;
636 double _r14 = _d_u * -(1. / (v * v));
637 _d_v += _r14;
638 } else if (v < 300) {
639 double _d_u = 0;
640 double u = 1. / v;
641 double _t21 = (q6[0] + (q6[1] + (q6[2] + q6[3] * u) * u) * u);
642 _d_u += p6[3] * d_out / _t21 * u * u;
643 _d_u += (p6[2] + p6[3] * u) * d_out / _t21 * u;
644 _d_u += (p6[1] + (p6[2] + p6[3] * u) * u) * d_out / _t21;
645 double _r17 = d_out * -((p6[0] + (p6[1] + (p6[2] + p6[3] * u) * u) * u) / (_t21 * _t21));
646 _d_u += q6[3] * _r17 * u * u;
647 _d_u += (q6[2] + q6[3] * u) * _r17 * u;
648 _d_u += (q6[1] + (q6[2] + q6[3] * u) * u) * _r17;
649 double _r16 = _d_u * -(1. / (v * v));
650 _d_v += _r16;
651 } else {
652 double _d_u = 0;
653 double _t25 = ::std::log(v);
654 double _t24 = (v + 1);
655 double _t23 = (v - v * _t25 / _t24);
656 double u = 1. / _t23;
657 double _t26;
658 _d_u += a2[3] * -d_out * u * u;
659 _d_u += (a2[2] + a2[3] * u) * -d_out * u;
660 _d_u += (a2[1] + (a2[2] + a2[3] * u) * u) * -d_out;
661 double _r18 = _d_u * -(1. / (_t23 * _t23));
662 _d_v += _r18;
663 _d_v += -_r18 / _t24 * _t25;
664 double _r19 = 0;
665 _r19 += v * -_r18 / _t24 / v;
666 _d_v += _r19;
667 double _r20 = -_r18 * -(v * _t25 / (_t24 * _t24));
668 _d_v += _r20;
669 }
670
671 *d_x += _d_v / xi;
672 *d_x0 += -_d_v / xi;
673 *d_xi += _d_v * -((x - x0) / (xi * xi));
674}
675
676#ifdef R__HAS_MATHMORE
677
678inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x);
679
680inline void inc_gamma_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
681{
682 // Synced with SpecFuncCephes.h
683 constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
684 constexpr double kMAXLOG = 709.782712893383973096206318587;
685 constexpr double kMINLOG = -708.396418532264078748994506896;
686 constexpr double kMAXSTIR = 108.116855767857671821730036754;
687 constexpr double kMAXLGM = 2.556348e305;
688 constexpr double kBig = 4.503599627370496e15;
689 constexpr double kBiginv = 2.22044604925031308085e-16;
690
691 double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_r = 0;
692 double _t1;
693 double _t2;
694 double _t3;
695 double _t4;
696 double _t5;
697 clad::tape<double> _t7 = {};
698 clad::tape<double> _t8 = {};
699 clad::tape<double> _t9 = {};
700 double ans, ax, c, r;
701 if (a <= 0)
702 return;
703 if (x <= 0)
704 return;
705 if ((x > 1.) && (x > a)) {
706 double _r0 = 0;
707 double _r1 = 0;
709 *_d_a += _r0;
710 *_d_x += _r1;
711 return;
712 }
713 _t1 = ::std::log(x);
714 ax = a * _t1 - x - ::std::lgamma(a);
715 if (ax < -kMAXLOG) {
716 *_d_x += (a * _d_ax / x) - _d_ax;
717 *_d_a +=
718 _d_ax *
719 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
720 _d_ax = 0.;
721 return;
722 }
723 _t2 = ax;
724 ax = ::std::exp(ax);
725 _t3 = r;
726 r = a;
727 _t4 = c;
728 c = 1.;
729 _t5 = ans;
730 ans = 1.;
731 unsigned long _t6 = 0;
732 do {
733 _t6++;
734 clad::push(_t7, r);
735 r += 1.;
736 clad::push(_t8, c);
737 c *= x / r;
738 clad::push(_t9, ans);
739 ans += c;
740 } while (c / ans > kMACHEP);
741 {
742 _d_ans += _d_y / a * ax;
743 _d_ax += ans * _d_y / a;
744 double _r6 = _d_y * -(ans * ax / (a * a));
745 *_d_a += _r6;
746 }
747 do {
748 {
749 {
750 ans = clad::pop(_t9);
751 double _r_d7 = _d_ans;
752 _d_c += _r_d7;
753 }
754 {
755 c = clad::pop(_t8);
756 double _r_d6 = _d_c;
757 _d_c -= _r_d6;
758 _d_c += _r_d6 * x / r;
759 *_d_x += c * _r_d6 / r;
760 double _r5 = c * _r_d6 * -(x / (r * r));
761 _d_r += _r5;
762 }
763 {
764 r = clad::pop(_t7);
765 double _r_d5 = _d_r;
766 }
767 }
768 _t6--;
769 } while (_t6);
770 {
771 ans = _t5;
772 double _r_d4 = _d_ans;
773 _d_ans -= _r_d4;
774 }
775 {
776 c = _t4;
777 double _r_d3 = _d_c;
778 _d_c -= _r_d3;
779 }
780 {
781 r = _t3;
782 double _r_d2 = _d_r;
783 _d_r -= _r_d2;
784 *_d_a += _r_d2;
785 }
786 {
787 ax = _t2;
788 double _r_d1 = _d_ax;
789 _d_ax -= _r_d1;
790 double _r4 = 0;
791 _r4 += _r_d1 * ::std::exp(ax);
792 _d_ax += _r4;
793 }
794 {
795 *_d_x += (a * _d_ax / x) - _d_ax;
796 *_d_a +=
797 _d_ax *
798 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
799 _d_ax = 0.;
800 }
801}
802
803inline void inc_gamma_c_pullback(double a, double x, double _d_y, double *_d_a, double *_d_x)
804{
805 // Synced with SpecFuncCephes.h
806 constexpr double kMACHEP = 1.11022302462515654042363166809e-16;
807 constexpr double kMAXLOG = 709.782712893383973096206318587;
808 constexpr double kMINLOG = -708.396418532264078748994506896;
809 constexpr double kMAXSTIR = 108.116855767857671821730036754;
810 constexpr double kMAXLGM = 2.556348e305;
811 constexpr double kBig = 4.503599627370496e15;
812 constexpr double kBiginv = 2.22044604925031308085e-16;
813
814 double _d_ans = 0, _d_ax = 0, _d_c = 0, _d_yc = 0, _d_r = 0, _d_t = 0, _d_y0 = 0, _d_z = 0;
815 double _d_pk = 0, _d_pkm1 = 0, _d_pkm2 = 0, _d_qk = 0, _d_qkm1 = 0, _d_qkm2 = 0;
816 double _t1;
817 double _t2;
818 double _t3;
819 double _t4;
820 double _t5;
821 double _t6;
822 double _t7;
823 double _t8;
824 double _t9;
825 double _t10;
826 unsigned long _t11;
827 clad::tape<double> _t12 = {};
828 clad::tape<double> _t13 = {};
829 clad::tape<double> _t14 = {};
830 clad::tape<double> _t15 = {};
831 clad::tape<double> _t16 = {};
832 clad::tape<double> _t17 = {};
833 clad::tape<double> _t19 = {};
834 clad::tape<double> _t20 = {};
835 clad::tape<double> _t21 = {};
836 clad::tape<double> _t22 = {};
837 clad::tape<double> _t23 = {};
838 clad::tape<double> _t24 = {};
839 clad::tape<double> _t25 = {};
840 clad::tape<double> _t26 = {};
841 clad::tape<double> _t27 = {};
842 clad::tape<bool> _t29 = {};
843 clad::tape<double> _t30 = {};
844 clad::tape<double> _t31 = {};
845 clad::tape<double> _t32 = {};
846 clad::tape<double> _t33 = {};
847 double ans, ax, c, yc, r, t, y, z;
848 double pk, pkm1, pkm2, qk, qkm1, qkm2;
849 if (a <= 0)
850 return;
851 if (x <= 0)
852 return;
853 if ((x < 1.) || (x < a)) {
854 double _r0 = 0;
855 double _r1 = 0;
857 *_d_a += _r0;
858 *_d_x += _r1;
859 return;
860 }
861 _t1 = ::std::log(x);
862 ax = a * _t1 - x - ::std::lgamma(a);
863 if (ax < -kMAXLOG) {
864 *_d_x += a * _d_ax / x - _d_ax;
865 *_d_a +=
866 _d_ax *
867 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
868 _d_ax = 0.;
869 return;
870 }
871 _t2 = ax;
872 ax = ::std::exp(ax);
873 _t3 = y;
874 y = 1. - a;
875 _t4 = z;
876 z = x + y + 1.;
877 _t5 = c;
878 c = 0.;
879 _t6 = pkm2;
880 pkm2 = 1.;
881 _t7 = qkm2;
882 qkm2 = x;
883 _t8 = pkm1;
884 pkm1 = x + 1.;
885 _t9 = qkm1;
886 qkm1 = z * x;
887 _t10 = ans;
888 ans = pkm1 / qkm1;
889 _t11 = 0;
890 do {
891 _t11++;
892 clad::push(_t12, c);
893 c += 1.;
894 clad::push(_t13, y);
895 y += 1.;
896 clad::push(_t14, z);
897 z += 2.;
898 clad::push(_t15, yc);
899 yc = y * c;
900 clad::push(_t16, pk);
901 pk = pkm1 * z - pkm2 * yc;
902 clad::push(_t17, qk);
903 qk = qkm1 * z - qkm2 * yc;
904 double _t18 = qk;
905 {
906 if (_t18) {
907 clad::push(_t20, r);
908 r = pk / qk;
909 clad::push(_t21, t);
910 t = ::std::abs((ans - r) / r);
911 clad::push(_t22, ans);
912 ans = r;
913 } else {
914 clad::push(_t23, t);
915 t = 1.;
916 }
917 clad::push(_t19, _t18);
918 }
919 clad::push(_t24, pkm2);
920 pkm2 = pkm1;
921 clad::push(_t25, pkm1);
922 pkm1 = pk;
923 clad::push(_t26, qkm2);
924 qkm2 = qkm1;
925 clad::push(_t27, qkm1);
926 qkm1 = qk;
927 bool _t28 = ::std::abs(pk) > kBig;
928 {
929 if (_t28) {
930 clad::push(_t30, pkm2);
931 pkm2 *= kBiginv;
932 clad::push(_t31, pkm1);
933 pkm1 *= kBiginv;
934 clad::push(_t32, qkm2);
935 qkm2 *= kBiginv;
936 clad::push(_t33, qkm1);
937 qkm1 *= kBiginv;
938 }
939 clad::push(_t29, _t28);
940 }
941 } while (t > kMACHEP);
942 {
943 _d_ans += _d_y * ax;
944 _d_ax += ans * _d_y;
945 }
946 do {
947 {
948 if (clad::pop(_t29)) {
949 {
950 qkm1 = clad::pop(_t33);
951 double _r_d27 = _d_qkm1;
952 _d_qkm1 -= _r_d27;
953 _d_qkm1 += _r_d27 * kBiginv;
954 }
955 {
956 qkm2 = clad::pop(_t32);
957 double _r_d26 = _d_qkm2;
958 _d_qkm2 -= _r_d26;
959 _d_qkm2 += _r_d26 * kBiginv;
960 }
961 {
962 pkm1 = clad::pop(_t31);
963 double _r_d25 = _d_pkm1;
964 _d_pkm1 -= _r_d25;
965 _d_pkm1 += _r_d25 * kBiginv;
966 }
967 {
968 pkm2 = clad::pop(_t30);
969 double _r_d24 = _d_pkm2;
970 _d_pkm2 -= _r_d24;
971 _d_pkm2 += _r_d24 * kBiginv;
972 }
973 }
974 {
975 qkm1 = clad::pop(_t27);
976 double _r_d23 = _d_qkm1;
977 _d_qkm1 -= _r_d23;
978 _d_qk += _r_d23;
979 }
980 {
981 qkm2 = clad::pop(_t26);
982 double _r_d22 = _d_qkm2;
983 _d_qkm2 -= _r_d22;
984 _d_qkm1 += _r_d22;
985 }
986 {
987 pkm1 = clad::pop(_t25);
988 double _r_d21 = _d_pkm1;
989 _d_pkm1 -= _r_d21;
990 _d_pk += _r_d21;
991 }
992 {
993 pkm2 = clad::pop(_t24);
994 double _r_d20 = _d_pkm2;
995 _d_pkm2 -= _r_d20;
996 _d_pkm1 += _r_d20;
997 }
998 if (clad::pop(_t19)) {
999 {
1000 ans = clad::pop(_t22);
1001 double _r_d18 = _d_ans;
1002 _d_ans -= _r_d18;
1003 _d_r += _r_d18;
1004 }
1005 {
1006 t = clad::pop(_t21);
1007 double _r_d17 = _d_t;
1008 _d_t -= _r_d17;
1009 double _r7 = 0;
1010 _r7 += _r_d17 * clad::custom_derivatives::std::abs_pushforward((ans - r) / r, 1.).pushforward;
1011 _d_ans += _r7 / r;
1012 _d_r += -_r7 / r;
1013 double _r8 = _r7 * -((ans - r) / (r * r));
1014 _d_r += _r8;
1015 }
1016 {
1017 r = clad::pop(_t20);
1018 double _r_d16 = _d_r;
1019 _d_r -= _r_d16;
1020 _d_pk += _r_d16 / qk;
1021 double _r6 = _r_d16 * -(pk / (qk * qk));
1022 _d_qk += _r6;
1023 }
1024 } else {
1025 t = clad::pop(_t23);
1026 double _r_d19 = _d_t;
1027 _d_t -= _r_d19;
1028 }
1029 {
1030 qk = clad::pop(_t17);
1031 double _r_d15 = _d_qk;
1032 _d_qk -= _r_d15;
1033 _d_qkm1 += _r_d15 * z;
1034 _d_z += qkm1 * _r_d15;
1035 _d_qkm2 += -_r_d15 * yc;
1036 _d_yc += qkm2 * -_r_d15;
1037 }
1038 {
1039 pk = clad::pop(_t16);
1040 double _r_d14 = _d_pk;
1041 _d_pk -= _r_d14;
1042 _d_pkm1 += _r_d14 * z;
1043 _d_z += pkm1 * _r_d14;
1044 _d_pkm2 += -_r_d14 * yc;
1045 _d_yc += pkm2 * -_r_d14;
1046 }
1047 {
1048 yc = clad::pop(_t15);
1049 double _r_d13 = _d_yc;
1050 _d_yc -= _r_d13;
1051 _d_y0 += _r_d13 * c;
1052 _d_c += y * _r_d13;
1053 }
1054 {
1055 z = clad::pop(_t14);
1056 double _r_d12 = _d_z;
1057 }
1058 {
1059 y = clad::pop(_t13);
1060 double _r_d11 = _d_y0;
1061 }
1062 {
1063 c = clad::pop(_t12);
1064 double _r_d10 = _d_c;
1065 }
1066 }
1067 _t11--;
1068 } while (_t11);
1069 {
1070 ans = _t10;
1071 double _r_d9 = _d_ans;
1072 _d_ans -= _r_d9;
1073 _d_pkm1 += _r_d9 / qkm1;
1074 double _r5 = _r_d9 * -(pkm1 / (qkm1 * qkm1));
1075 _d_qkm1 += _r5;
1076 }
1077 {
1078 qkm1 = _t9;
1079 double _r_d8 = _d_qkm1;
1080 _d_qkm1 -= _r_d8;
1081 _d_z += _r_d8 * x;
1082 *_d_x += z * _r_d8;
1083 }
1084 {
1085 pkm1 = _t8;
1086 double _r_d7 = _d_pkm1;
1087 _d_pkm1 -= _r_d7;
1088 *_d_x += _r_d7;
1089 }
1090 {
1091 qkm2 = _t7;
1092 double _r_d6 = _d_qkm2;
1093 _d_qkm2 -= _r_d6;
1094 *_d_x += _r_d6;
1095 }
1096 {
1097 pkm2 = _t6;
1098 double _r_d5 = _d_pkm2;
1099 _d_pkm2 -= _r_d5;
1100 }
1101 {
1102 c = _t5;
1103 double _r_d4 = _d_c;
1104 _d_c -= _r_d4;
1105 }
1106 {
1107 z = _t4;
1108 double _r_d3 = _d_z;
1109 _d_z -= _r_d3;
1110 *_d_x += _r_d3;
1111 _d_y0 += _r_d3;
1112 }
1113 {
1114 y = _t3;
1115 double _r_d2 = _d_y0;
1116 _d_y0 -= _r_d2;
1117 *_d_a += -_r_d2;
1118 }
1119 {
1120 ax = _t2;
1121 double _r_d1 = _d_ax;
1122 _d_ax -= _r_d1;
1123 double _r4 = _r_d1 * ::std::exp(ax);
1124 _d_ax += _r4;
1125 }
1126 {
1127 *_d_x += a * _d_ax / x - _d_ax;
1128 *_d_a +=
1129 _d_ax *
1130 (_t1 - ::ROOT::Math::digamma(a)); // numerical_diff::forward_central_difference(::std::lgamma, a, 0, 0, a);
1131 _d_ax = 0.;
1132 }
1133}
1134
1135#endif // R__HAS_MATHMORE
1136
1137} // namespace Math
1138} // namespace ROOT
1139
1140} // namespace custom_derivatives
1141} // namespace clad
1142
1143// Forward declare BLAS functions.
1144extern "C" void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k,
1145 const float *alpha, const float *A, const int *lda, const float *B, const int *ldb,
1146 const float *beta, float *C, const int *ldc);
1147
1148namespace clad::custom_derivatives {
1149
1151
1152inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, int n, int k, float alpha,
1153 const float *A, const float *B, float beta, const float *C, float *_d_output, bool *,
1154 bool *, int *, int *, int *, float *_d_alpha, float *_d_A, float *_d_B, float *_d_beta,
1155 float *_d_C)
1156{
1157 // TODO:
1158 // - handle transa and transb cases correctly
1159 if (transa || transb) {
1160 return;
1161 }
1162
1163 char ct = 't';
1164 char cn = 'n';
1165
1166 // beta needs to be one because we want to add to _d_A and _d_B instead of
1167 // overwriting it.
1168 float one = 1.;
1169
1170 // _d_A, _d_B
1171 // note: beta needs to be one because we want to add to _d_A and _d_B instead of overwriting it.
1172 ::sgemm_(&cn, &ct, &m, &k, &n, &alpha, _d_output, &m, B, &k, &one, _d_A, &m);
1173 ::sgemm_(&ct, &cn, &k, &n, &m, &alpha, A, &m, _d_output, &m, &one, _d_B, &k);
1174
1175 // _d_alpha, _d_beta, _d_C
1176 int sizeC = n * m;
1177 for (int i = 0; i < sizeC; ++i) {
1178 *_d_alpha += _d_output[i] * (output[i] - beta * C[i]);
1179 *_d_beta += _d_output[i] * C[i];
1180 _d_C[i] += _d_output[i] * beta;
1181 }
1182}
1183
1184} // namespace TMVA::Experimental::SOFIE
1185
1186} // namespace clad::custom_derivatives
1187
1188#endif // CLAD_DERIVATOR
void sgemm_(const char *transa, const char *transb, const int *m, const int *n, const int *k, const float *alpha, const float *A, const int *lda, const float *B, const int *ldb, const float *beta, float *C, const int *ldc)
#define b(i)
Definition RSha256.hxx:100
#define c(i)
Definition RSha256.hxx:101
#define a(i)
Definition RSha256.hxx:99
#define h(i)
Definition RSha256.hxx:106
#define kMACHEP
#define kMAXLOG
#define kMAXLGM
#define kMAXSTIR
#define kMINLOG
ROOT::Detail::TRangeCast< T, true > TRangeDynCast
TRangeDynCast is an adapter class that allows the typed iteration through a TCollection.
winID h TVirtualViewer3D TVirtualGLPainter p
Option_t Option_t TPoint TPoint const char GetTextMagnitude GetFillStyle GetLineColor GetLineWidth GetMarkerStyle GetTextAlign GetTextColor GetTextSize void char Point_t Rectangle_t WindowAttributes_t Float_t r
double digamma(double x)
Double_t y[n]
Definition legend1.C:17
Double_t x[n]
Definition legend1.C:17
const Int_t n
Definition legend1.C:16
Namespace for new Math classes and functions.
tbb::task_arena is an alias of tbb::interface7::task_arena, which doesn't allow to forward declare tb...
TMath.
Definition TMathBase.h:35
Double_t CosH(Double_t)
Returns the hyperbolic cosine of x.
Definition TMath.h:616
Double_t ACos(Double_t)
Returns the principal value of the arc cosine of x, expressed in radians.
Definition TMath.h:636
Short_t Max(Short_t a, Short_t b)
Returns the largest of a and b.
Definition TMathBase.h:250
Double_t ASin(Double_t)
Returns the principal value of the arc sine of x, expressed in radians.
Definition TMath.h:628
Double_t Log2(Double_t x)
Returns the binary (base-2) logarithm of x.
Definition TMath.cxx:107
Double_t Exp(Double_t x)
Returns the base-e exponential function of x, which is e raised to the power x.
Definition TMath.h:713
Double_t Erf(Double_t x)
Computation of the error function erf(x).
Definition TMath.cxx:190
Double_t ATan(Double_t)
Returns the principal value of the arc tangent of x, expressed in radians.
Definition TMath.h:644
Double_t ASinH(Double_t)
Returns the area hyperbolic sine of x.
Definition TMath.cxx:67
Double_t TanH(Double_t)
Returns the hyperbolic tangent of x.
Definition TMath.h:622
Double_t ACosH(Double_t)
Returns the nonnegative area hyperbolic cosine of x.
Definition TMath.cxx:81
Double_t Log(Double_t x)
Returns the natural logarithm of x.
Definition TMath.h:760
Double_t Erfc(Double_t x)
Computes the complementary error function erfc(x).
Definition TMath.cxx:199
Double_t Sq(Double_t x)
Returns x*x.
Definition TMath.h:660
Double_t Sqrt(Double_t x)
Returns the square root of x.
Definition TMath.h:666
LongDouble_t Power(LongDouble_t x, LongDouble_t y)
Returns x raised to the power y.
Definition TMath.h:725
Short_t Min(Short_t a, Short_t b)
Returns the smallest of a and b.
Definition TMathBase.h:198
constexpr Double_t Ln10()
Natural log of 10 (to convert log to ln)
Definition TMath.h:100
Double_t Hypot(Double_t x, Double_t y)
Returns sqrt(x*x + y*y)
Definition TMath.cxx:59
Double_t Cos(Double_t)
Returns the cosine of an angle of x radians.
Definition TMath.h:598
constexpr Double_t Pi()
Definition TMath.h:37
Double_t LnGamma(Double_t z)
Computation of ln[gamma(z)] for all z.
Definition TMath.cxx:509
Double_t Sin(Double_t)
Returns the sine of an angle of x radians.
Definition TMath.h:592
Double_t Tan(Double_t)
Returns the tangent of an angle of x radians.
Definition TMath.h:604
Double_t ATanH(Double_t)
Returns the area hyperbolic tangent of x.
Definition TMath.cxx:95
Double_t Log10(Double_t x)
Returns the common (base-10) logarithm of x.
Definition TMath.h:766
Short_t Abs(Short_t d)
Returns the absolute value of parameter Short_t d.
Definition TMathBase.h:123
Double_t SinH(Double_t)
Returns the hyperbolic sine of `x.
Definition TMath.h:610
void landau_pdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
void landau_cdf_pullback(double x, double xi, double x0, double d_out, double *d_x, double *d_xi, double *d_x0)
void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, int n, int k, float alpha, const float *A, const float *B, float beta, const float *C, float *_d_output, bool *, bool *, int *, int *, int *, float *_d_alpha, float *_d_A, float *_d_B, float *_d_beta, float *_d_C)
ValueAndPushforward< T, T > CosH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Abs_pushforward(T x, T d_x)
void Min_pullback(T a, T b, U p, clad::array_ref< T > d_a, clad::array_ref< T > d_b)
ValueAndPushforward< T, T > Sq_pushforward(T x, T d_x)
void Max_pullback(T a, T b, U p, clad::array_ref< T > d_a, clad::array_ref< T > d_b)
ValueAndPushforward< T, T > Erf_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Erfc_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Sin_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Max_pushforward(T x, T y, T d_x, T d_y)
ValueAndPushforward< T, T > Hypot_pushforward(T x, T y, T d_x, T d_y)
ValueAndPushforward< T, T > ASinH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ACosH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ASin_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Cos_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Sqrt_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Tan_pushforward(T x, T d_x)
void Hypot_pullback(T x, T y, U p, clad::array_ref< T > d_x, clad::array_ref< T > d_y)
ValueAndPushforward< T, T > Power_pushforward(T x, T y, T d_x, T d_y)
void Power_pullback(T x, T y, U p, clad::array_ref< T > d_x, clad::array_ref< T > d_y)
ValueAndPushforward< T, T > Min_pushforward(T x, T y, T d_x, T d_y)
ValueAndPushforward< T, T > Log_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Log10_pushforward(T x, T d_x)
ValueAndPushforward< T, T > TanH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ACos_pushforward(T x, T d_x)
ValueAndPushforward< T, T > SinH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Exp_pushforward(T x, T d_x)
ValueAndPushforward< T, T > Log2_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ATanH_pushforward(T x, T d_x)
ValueAndPushforward< T, T > ATan_pushforward(T x, T d_x)
TMarker m
Definition textangle.C:8
static void output()