Skip to content

Commit

Permalink
removed bug in deriv for sinus
Browse files Browse the repository at this point in the history
  • Loading branch information
konrad.kraemer committed Jun 5, 2024
1 parent d8073b3 commit b27fb92
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 56 deletions.
6 changes: 5 additions & 1 deletion include/etr_bits/Coca.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "Core.hpp"
#include "Core/Concepts.hpp"
#include "Core/Reflection.hpp"
#include "Core/Traits.hpp"
#include "Core/Types.hpp"
#include <type_traits>

Expand Down Expand Up @@ -175,7 +176,10 @@ inline auto coca(AV &av, Args &&...args) {
},
args...);

Vec<double, VarPointer<decltype(av), Idx, -1>, ConstantTypeTrait> ret(av);
Vec<double, VarPointer<decltype(av), Idx, -1, ConstantTypeTrait>,
ConstantTypeTrait>
ret(av); // TODO: add ConstantTypeTrait to each VarPointer in fct such as
// coca
return ret;
}

Expand Down
10 changes: 10 additions & 0 deletions include/etr_bits/Core/Concepts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ concept IsVecRorCalc = requires {
};

// NOTE: Concepts for Derivs
template <typename T>
concept IsVariableType = requires {
typename T::TypeTrait;
requires std::is_same_v<typename T::TypeTrait, VariableTypeTrait>;
};

template <typename T>
concept IsVariableTypeTrait = requires(T t) {
typename std::remove_reference<decltype(t)>::type::CaseTrait;
Expand Down Expand Up @@ -292,6 +298,10 @@ concept IsConstant = requires(T t) {
ConstantTypeTrait>::value;
};

template <typename T>
concept IsExpression =
requires(T t) { requires !IsConstant<T> && !IsVariableType<T>; };

/*
issue:
Extract type for Buffer, Borrow, BorrowSEXP, BinaryOp, UnaryOp
Expand Down
5 changes: 1 addition & 4 deletions include/etr_bits/Core/Traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,7 @@ struct UnEqualDerivTrait {

struct SinusDerivTrait {
template <typename L> static inline L f(L l) { return sin(l); }
template <typename L> static inline L fDeriv(L l) {
std::cout << l << std::endl;
return cos(l);
}
template <typename L> static inline L fDeriv(L l) { return cos(l); }
};

struct PlusTrait {
Expand Down
18 changes: 7 additions & 11 deletions include/etr_bits/Derivatives/DerivTypes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,13 @@ struct BinaryType {
using typeTraitRDeriv = RDeriv;
using TypeTrait = Trait;
using Op = OpTrait;

template <typename AV> static std::size_t getSize(AV &av) {
return LDeriv::getSize(av) > RDeriv::getSize(av) ? LDeriv::getSize(av)
: RDeriv::getSize(av);
}

template <typename AV> static auto getVal(AV &av, std::size_t idx) {
return Op::f(LDeriv::getVal(av, idx), RDeriv::getVal(av, idx));
}

template <typename AV> static auto getDeriv(AV &av, std::size_t idx) {
return Op::fDeriv(LDeriv::getDeriv(av, idx), RDeriv::getDeriv(av, idx));
}
Expand All @@ -69,17 +66,19 @@ produceBinaryType() {
template <typename Deriv, typename Trait, typename OpTrait> struct UnaryType {
using typeTraitObj = Deriv;
using TypeTrait = Trait;

using Op = OpTrait;
template <typename AV> static std::size_t getSize(AV &av) {
return Deriv::getSize(av);
}

template <typename AV> static auto getVal(AV &av, std::size_t idx) {
return sin(Deriv::getVal(av, idx));
return Op::f(Deriv::getVal(av, idx));
}

template <typename AV> static auto getDeriv(AV &av, std::size_t idx) {
return cos(Deriv::getVal(av, idx));
if constexpr (IsConstant<Deriv> || IsVariableType<Deriv>) {
return Op::fDeriv(Deriv::getVal(av, idx));
} else {
return Op::fDeriv(Deriv::getDeriv(av, idx));
}
}
};

Expand All @@ -92,12 +91,10 @@ template <typename T, typename Trait> struct VariableType {
using Type = T;
using RetType = T;
using TypeTrait = Trait;

template <typename AV> static std::size_t getSize(AV &av) {
using Ty = typename std::remove_reference<Type>::type;
return Ty::template getSize<AV>(av);
}

template <typename AV> static auto getVal(AV &av, std::size_t VecIdx) {
using Ty = typename std::remove_reference<Type>::type;
if constexpr (IsBinary<Ty>) {
Expand All @@ -106,7 +103,6 @@ template <typename T, typename Trait> struct VariableType {
return Ty::template getVal<AV>(av, VecIdx);
}
}

template <typename AV> static auto getDeriv(AV &av, std::size_t VecIdx) {
using Ty = typename std::remove_reference<Type>::type;
return Ty::template getDeriv<AV>(av, VecIdx);
Expand Down
61 changes: 21 additions & 40 deletions include/etr_bits/Vector/DerivativeCalc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ produceBinaryType() {

template <typename I, typename TraitTD, typename OpTrait>
static constexpr UnaryType<I, TraitTD, OpTrait> produceUnaryType() {
return UnaryType<I, TraitTD, OpTrait>(); // issue: wrong wrong wrong
return UnaryType<I, TraitTD, OpTrait>();
}

template <typename TRaw> static constexpr auto produceVariableType() {
Expand Down Expand Up @@ -53,6 +53,26 @@ static constexpr auto walkTD() {
produceVariableType<typename TDMultiplication::typeTraitL>();
constexpr auto RT =
produceVariableType<typename TDMultiplication::typeTraitR>();
if constexpr (IsVariableType<decltype(LT)> && IsVariableType<decltype(RT)>) {
} else if constexpr (IsVariableType<decltype(LT)> &&
IsConstant<decltype(RT)>) {
} else if constexpr (IsVariableType<decltype(LT)> &&
IsExpression<decltype(RT)>) {
} else if constexpr (IsConstant<decltype(LT)> && IsConstant<decltype(RT)>) {

} else if constexpr (IsConstant<decltype(LT)> &&
IsVariableType<decltype(RT)>) {

} else if constexpr (IsConstant<decltype(RT)> && IsExpression<decltype(RT)>) {

} else if constexpr (IsExpression<decltype(LT)> && IsConstant<decltype(RT)>) {

} else if constexpr (IsExpression<decltype(LT)> &&
IsVariableType<decltype(RT)>) {

} else if constexpr (IsExpression<decltype(RT)> &&
IsExpression<decltype(RT)>) {
}
constexpr auto LDeriv = walkTD<typename TDMultiplication::typeTraitL>();
constexpr auto RDeriv = walkTD<typename TDMultiplication::typeTraitR>();
return produceQuarternyType<decltype(LT), decltype(RT), decltype(LDeriv),
Expand Down Expand Up @@ -168,43 +188,4 @@ static constexpr auto walkTD() {
return res;
}

// TODO: finish this
template <typename T2, typename R2, typename Trait2>
Vec &operator<<(const Vec<T2, R2, Trait2> &otherVec) {
using tD = decltype(otherVec.d);
if constexpr (IsVarPointer<tD> && IsVarPointer<DType>) {
d.resize(otherVec.size());
for (std::size_t i = 0; i < d.size(); i++) {
d.setVal(otherVec.d.AllVarsRef, i, tD::getVal(otherVec.d.AllVarsRef, i));
}
} else if constexpr (IsVarPointer<tD> && !IsVarPointer<DType>) {
d.resize(otherVec.size());
for (std::size_t i = 0; i < d.size(); i++) {
d[i] = otherVec[i];
}
} else if constexpr (!IsVarPointer<tD> && IsVarPointer<DType>) {
using tDRaw = std::remove_reference<decltype(otherVec)>::type;
using typeExpr = std::remove_reference<ExtractedTypeD<tDRaw>>::type;
/* printTAST<typeExpr>(); */
constexpr auto res = walkTD<typeExpr>();
/* printTAST<decltype(res)>(); */
d.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
d.AllVarsRef.varBuffer[d.I][i] = otherVec[i];
}

d.AllVarsRef.resizeDerivs(R::I, R::TIdx, otherVec.size());
for (std::size_t i = 0; i < res.getSize(d.AllVarsRef); i++) {
d.setDeriv(d.AllVarsRef, i, res.getDeriv(d.AllVarsRef, i));
}

} else if constexpr (!IsVarPointer<tD> && !IsVarPointer<DType>) {
d.resize(otherVec.size());
for (std::size_t i = 0; i < otherVec.size(); i++) {
d[i] = otherVec[i];
}
}
return *this;
}

#endif
14 changes: 14 additions & 0 deletions tests/Derivs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
library(dfdr)

f <- function(a) sin(a)
a <- c(3.141593, 6.283185, 9.424778, 12.566371)
f(a)
fd <- d(f, a)
fd(a)


f <- function(a) sin(a*a)
a <- c(3.141593, 6.283185, 9.424778, 12.566371)
f(a)
fd <- d(f, a)
fd(a)

0 comments on commit b27fb92

Please sign in to comment.