Skip to content

Commit eb2c094

Browse files
committed
correct RAJA implementation
1 parent a1d8a15 commit eb2c094

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

Diff for: src/LinAlg/hiopVectorRajaImpl.hpp

+16-15
Original file line numberDiff line numberDiff line change
@@ -2390,10 +2390,10 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
23902390
size_type n_in = dl_vec.get_local_size();
23912391
size_type n_cons = n_eq + n_in;
23922392

2393-
hiopVectorInt* idx_eq_cumsum = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
2394-
hiopVectorInt* idx_in_cumsum = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
2395-
index_type* find_eq = idx_eq_cumsum->local_data();
2396-
index_type* find_in = idx_in_cumsum->local_data();
2393+
hiopVectorInt* find_eq = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
2394+
hiopVectorInt* find_in = LinearAlgebraFactory::create_vector_int(mem_space_, n_cons);
2395+
index_type* idx_eq_cumsum = find_eq->local_data();
2396+
index_type* idx_in_cumsum = find_in->local_data();
23972397

23982398
RAJA::ReduceSum< hiop_raja_reduce, int > sum_n_bnds_low(0);
23992399
RAJA::ReduceSum< hiop_raja_reduce, int > sum_n_bnds_upp(0);
@@ -2405,11 +2405,11 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
24052405
RAJA_LAMBDA(RAJA::Index_type i)
24062406
{
24072407
if(gl[i] == gu[i]) {
2408-
find_eq[i] = 1;
2409-
find_in[i] = 0;
2408+
idx_eq_cumsum[i] = 1;
2409+
idx_in_cumsum[i] = 0;
24102410
} else {
2411-
find_eq[i] = 0;
2412-
find_in[i] = 1;
2411+
idx_eq_cumsum[i] = 0;
2412+
idx_in_cumsum[i] = 1;
24132413
}
24142414
}
24152415
);
@@ -2421,8 +2421,6 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
24212421
// (0,1,1) -- (1,1,2) after scan
24222422
// map [1] [0,2]
24232423

2424-
index_type* nnz_cumsum = idx_cumsum_->local_data();
2425-
index_type v_n_local = v.n_local_;
24262424
RAJA::forall<hiop_raja_exec>(
24272425
RAJA::RangeSegment(0, n_cons),
24282426
RAJA_LAMBDA(RAJA::Index_type i)
@@ -2464,9 +2462,10 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
24642462
} else {
24652463
assert(idx_in_cumsum[i] == idx_in_cumsum[i-1] + 1);
24662464
int in_idx = idx_in_cumsum[i] - 1;
2467-
incon_map[in_idx] = cons_type[i];
2465+
incon_type[in_idx] = cons_type[i];
24682466
dl[in_idx] = gl[i];
24692467
du[in_idx] = gu[i];
2468+
incon_map[in_idx] = i;
24702469

24712470
if(gl[i]>-1e20) {
24722471
idl[in_idx] = 1.;
@@ -2488,11 +2487,13 @@ void hiopVectorRaja<MEM, POL>::process_constraints_local(const hiopVector& gl_ve
24882487
}
24892488
);
24902489

2491-
n_bnds_low = sum_n_bnds_low.get();
2492-
n_bnds_upp = sum_n_bnds_upp.get();
2493-
n_bnds_lu = sum_n_bnds_lu.get();
2490+
n_ineq_low = sum_n_bnds_low.get();
2491+
n_ineq_upp = sum_n_bnds_upp.get();
2492+
n_ineq_lu = sum_n_bnds_lu.get();
2493+
2494+
delete find_eq;
2495+
delete find_in;
24942496

2495-
return true;
24962497
}
24972498

24982499

Diff for: src/Optimization/hiopNlpFormulation.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,13 @@ bool hiopNlpFormulation::process_constraints()
524524
cons_ineq_type_,
525525
cons_type);
526526

527-
527+
528+
hiopVectorIntSeq cons_eq_mapping_host(n_cons_eq_);
529+
hiopVectorIntSeq cons_ineq_mapping_host(n_cons_ineq_);
530+
531+
cons_eq_mapping_->copy_to_vectorseq(cons_eq_mapping_host);
532+
cons_ineq_mapping_->copy_to_vectorseq(cons_ineq_mapping_host);
533+
528534
/* delete the temporary buffers */
529535
delete gl;
530536
delete gu;

0 commit comments

Comments
 (0)