Skip to content

Commit

Permalink
V1.1.49.dev (#131)
Browse files Browse the repository at this point in the history
* Add IndexType to cox and update version

* Add type robustness

* Missed one
  • Loading branch information
JamesYang007 authored Oct 22, 2024
1 parent 4fb9109 commit f9c0a19
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion adelie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.1.48"
__version__ = "1.1.49"

# Set environment flags before loading adelie_core
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ ADELIE_CORE_CONSTRAINT_LINEAR::solve(
Qmu_resid = Qv - _ATmu;
const value_t loss = 0.5 * Qmu_resid.square().sum();
internal_matrix_t _X(*_A); // _X == _A^T
optimization::StateNNLS<internal_matrix_t> state_nnls(
optimization::StateNNLS<internal_matrix_t, value_t, index_t> state_nnls(
_X, v_norm * v_norm, _A_vars, std::min<size_t>(m, d),
_nnls_max_iters, _nnls_tol,
_mu_active.size(),
Expand Down Expand Up @@ -447,7 +447,7 @@ ADELIE_CORE_CONSTRAINT_LINEAR::solve(
for (Eigen::Index ii = 0; ii < static_cast<Eigen::Index>(active_size); ++ii) screen_invariance(ii);
}

optimization::StatePinball<A_t> state_pinball(
optimization::StatePinball<A_t, value_t, index_t> state_pinball(
*_A, var, hess, _l, _u,
std::min(m, d), _pinball_max_iters, _pinball_tol,
_mu_active.size(),
Expand Down Expand Up @@ -585,7 +585,7 @@ ADELIE_CORE_CONSTRAINT_LINEAR::solve_zero(
return (ui <= 0) ? Configs::max_solver_value : 0;
});
internal_matrix_t _X(*_A); // _X == _A^T
optimization::StateNNLS<internal_matrix_t> state_nnls(
optimization::StateNNLS<internal_matrix_t, value_t, index_t> state_nnls(
_X, v.square().sum(), _A_vars, std::min<size_t>(m, d),
_nnls_max_iters, _nnls_tol,
_mu_active.size(),
Expand Down
18 changes: 9 additions & 9 deletions adelie/src/include/adelie_core/glm/glm_cox.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,34 @@

#ifndef ADELIE_CORE_GLM_COX_PACK_TP
#define ADELIE_CORE_GLM_COX_PACK_TP \
template <class ValueType>
template <class ValueType, class IndexType>
#endif
#ifndef ADELIE_CORE_GLM_COX_PACK
#define ADELIE_CORE_GLM_COX_PACK \
GlmCoxPack<ValueType>
GlmCoxPack<ValueType, IndexType>
#endif

#ifndef ADELIE_CORE_GLM_COX_TP
#define ADELIE_CORE_GLM_COX_TP \
template <class ValueType>
template <class ValueType, class IndexType>
#endif
#ifndef ADELIE_CORE_GLM_COX
#define ADELIE_CORE_GLM_COX \
GlmCox<ValueType>
GlmCox<ValueType, IndexType>
#endif

namespace adelie_core {
namespace glm {

template <class ValueType>
template <class ValueType, class IndexType=Eigen::Index>
class GlmCoxPack
{
public:
using index_t = IndexType;
using value_t = ValueType;
using vec_index_t = util::rowvec_type<index_t>;
using vec_value_t = util::rowvec_type<value_t>;
using map_cvec_value_t = Eigen::Map<const vec_value_t>;
using index_t = Eigen::Index;
using vec_index_t = util::rowvec_type<index_t>;

const util::tie_method_type tie_method;

Expand Down Expand Up @@ -113,15 +113,15 @@ class GlmCoxPack
inline value_t loss_full();
};

template <class ValueType>
template <class ValueType, class IndexType=Eigen::Index>
class GlmCox: public GlmBase<ValueType>
{
public:
using base_t = GlmBase<ValueType>;
using typename base_t::value_t;
using typename base_t::vec_value_t;
using typename base_t::map_cvec_value_t;
using pack_t = GlmCoxPack<value_t>;
using pack_t = GlmCoxPack<value_t, IndexType>;
using index_t = typename pack_t::index_t;
using vec_index_t = typename pack_t::vec_index_t;

Expand Down

0 comments on commit f9c0a19

Please sign in to comment.