diff --git a/src/hotspot/share/c1/c1_RangeCheckElimination.cpp b/src/hotspot/share/c1/c1_RangeCheckElimination.cpp index 4af9f29f263..256f8190b50 100644 --- a/src/hotspot/share/c1/c1_RangeCheckElimination.cpp +++ b/src/hotspot/share/c1/c1_RangeCheckElimination.cpp @@ -421,8 +421,11 @@ void RangeCheckEliminator::add_access_indexed_info(InstructionList &indices, int aii->_max = idx; aii->_list = new AccessIndexedList(); } else if (idx >= aii->_min && idx <= aii->_max) { - remove_range_check(ai); - return; + // Guard against underflow/overflow (see 'range_cond' check in RangeCheckEliminator::in_block_motion) + if (aii->_max < 0 || (aii->_max + min_jint) <= aii->_min) { + remove_range_check(ai); + return; + } } aii->_min = MIN2(aii->_min, idx); aii->_max = MAX2(aii->_max, idx); @@ -465,9 +468,9 @@ void RangeCheckEliminator::in_block_motion(BlockBegin *block, AccessIndexedList } } } else { - int last_integer = 0; + jint last_integer = 0; Instruction *last_instruction = index; - int base = 0; + jint base = 0; ArithmeticOp *ao = index->as_ArithmeticOp(); while (ao != nullptr && (ao->x()->as_Constant() || ao->y()->as_Constant()) && (ao->op() == Bytecodes::_iadd || ao->op() == Bytecodes::_isub)) { @@ -479,12 +482,12 @@ void RangeCheckEliminator::in_block_motion(BlockBegin *block, AccessIndexedList } if (c) { - int value = c->type()->as_IntConstant()->value(); + jint value = c->type()->as_IntConstant()->value(); if (value != min_jint) { if (ao->op() == Bytecodes::_isub) { value = -value; } - base += value; + base = java_add(base, value); last_integer = base; last_instruction = other; } @@ -506,12 +509,12 @@ void RangeCheckEliminator::in_block_motion(BlockBegin *block, AccessIndexedList assert(info != nullptr, "Info must not be null"); // if idx < 0, max > 0, max + idx may fall between 0 and - // length-1 and if min < 0, min + idx may overflow and be >= + // length-1 and if min < 0, min + idx may underflow/overflow and be >= // 0. The predicate wouldn't trigger but some accesses could // be with a negative index. This test guarantees that for the // min and max value that are kept the predicate can't let // some incorrect accesses happen. - bool range_cond = (info->_max < 0 || info->_max + min_jint <= info->_min); + bool range_cond = (info->_max < 0 || (info->_max + min_jint) <= info->_min); // Generate code only if more than 2 range checks can be eliminated because of that. // 2 because at least 2 comparisons are done @@ -859,7 +862,7 @@ void RangeCheckEliminator::process_access_indexed(BlockBegin *loop_header, Block ); remove_range_check(ai); - } else if (_optimistic && loop_header) { + } else if (false && _optimistic && loop_header) { assert(ai->array(), "Array must not be null!"); assert(ai->index(), "Index must not be null!"); diff --git a/src/hotspot/share/classfile/verifier.cpp b/src/hotspot/share/classfile/verifier.cpp index 4aa95aff0e1..743bd9d06ba 100644 --- a/src/hotspot/share/classfile/verifier.cpp +++ b/src/hotspot/share/classfile/verifier.cpp @@ -2257,11 +2257,12 @@ void ClassVerifier::verify_switch( "low must be less than or equal to high in tableswitch"); return; } - keys = high - low + 1; - if (keys < 0) { + int64_t keys64 = ((int64_t)high - low) + 1; + if (keys64 > 65535) { // Max code length verify_error(ErrorContext::bad_code(bci), "too many keys in tableswitch"); return; } + keys = (int)keys64; delta = 1; } else { keys = (int)Bytes::get_Java_u4(aligned_bcp + jintSize); diff --git a/src/hotspot/share/interpreter/bytecodes.cpp b/src/hotspot/share/interpreter/bytecodes.cpp index af3c1a2dbef..c1304df6d34 100644 --- a/src/hotspot/share/interpreter/bytecodes.cpp +++ b/src/hotspot/share/interpreter/bytecodes.cpp @@ -385,13 +385,18 @@ int Bytecodes::special_length_at(Bytecodes::Code code, address bcp, address end) if (end != nullptr && aligned_bcp + 3*jintSize >= end) { return -1; // don't read past end of code buffer } - // Promote calculation to 64 bits to do range checks, used by the verifier. + // Promote calculation to signed 64 bits to do range checks, used by the verifier. int64_t lo = (int)Bytes::get_Java_u4(aligned_bcp + 1*jintSize); int64_t hi = (int)Bytes::get_Java_u4(aligned_bcp + 2*jintSize); int64_t len = (aligned_bcp - bcp) + (3 + hi - lo + 1)*jintSize; - // only return len if it can be represented as a positive int; - // return -1 otherwise - return (len > 0 && len == (int)len) ? (int)len : -1; + // Only return len if it can be represented as a positive int and lo <= hi. + // The caller checks for bytecode stream overflow. + if (lo <= hi && len == (int)len) { + assert(len > 0, "must be"); + return (int)len; + } else { + return -1; + } } case _lookupswitch: // fall through @@ -404,9 +409,13 @@ int Bytecodes::special_length_at(Bytecodes::Code code, address bcp, address end) // Promote calculation to 64 bits to do range checks, used by the verifier. int64_t npairs = (int)Bytes::get_Java_u4(aligned_bcp + jintSize); int64_t len = (aligned_bcp - bcp) + (2 + 2*npairs)*jintSize; - // only return len if it can be represented as a positive int; - // return -1 otherwise - return (len > 0 && len == (int)len) ? (int)len : -1; + // Only return len if it can be represented as a positive int and npairs >= 0. + if (npairs >= 0 && len == (int)len) { + assert(len > 0, "must be"); + return (int)len; + } else { + return -1; + } } default: // Note: Length functions must return <=0 for invalid bytecodes. diff --git a/src/hotspot/share/opto/ifnode.cpp b/src/hotspot/share/opto/ifnode.cpp index a931c4de1f4..6ae62b24b3c 100644 --- a/src/hotspot/share/opto/ifnode.cpp +++ b/src/hotspot/share/opto/ifnode.cpp @@ -1898,6 +1898,46 @@ Node* RangeCheckNode::Ideal(PhaseGVN *phase, bool can_reshape) { // then we are guaranteed to fail, so just start interpreting there. // We 'expand' the top 3 range checks to include all post-dominating // checks. + // + // Example: + // a[i+x] // (1) 1 < x < 6 + // a[i+3] // (2) + // a[i+4] // (3) + // a[i+6] // max = max of all constants + // a[i+2] + // a[i+1] // min = min of all constants + // + // If x < 3: + // (1) a[i+x]: Leave unchanged + // (2) a[i+3]: Replace with a[i+max] = a[i+6]: i+x < i+3 <= i+6 -> (2) is covered + // (3) a[i+4]: Replace with a[i+min] = a[i+1]: i+1 < i+4 <= i+6 -> (3) and all following checks are covered + // Remove all other a[i+c] checks + // + // If x >= 3: + // (1) a[i+x]: Leave unchanged + // (2) a[i+3]: Replace with a[i+min] = a[i+1]: i+1 < i+3 <= i+x -> (2) is covered + // (3) a[i+4]: Replace with a[i+max] = a[i+6]: i+1 < i+4 <= i+6 -> (3) and all following checks are covered + // Remove all other a[i+c] checks + // + // We only need the top 2 range checks if x is the min or max of all constants. + // + // This, however, only works if the interval [i+min,i+max] is not larger than max_int (i.e. abs(max - min) < max_int): + // The theoretical max size of an array is max_int with: + // - Valid index space: [0,max_int-1] + // - Invalid index space: [max_int,-1] // max_int, min_int, min_int - 1 ..., -1 + // + // The size of the consecutive valid index space is smaller than the size of the consecutive invalid index space. + // If we choose min and max in such a way that: + // - abs(max - min) < max_int + // - i+max and i+min are inside the valid index space + // then all indices [i+min,i+max] must be in the valid index space. Otherwise, the invalid index space must be + // smaller than the valid index space which is never the case for any array size. + // + // Choosing a smaller array size only makes the valid index space smaller and the invalid index space larger and + // the argument above still holds. + // + // Note that the same optimization with the same maximal accepted interval size can also be found in C1. + const jlong maximum_number_of_min_max_interval_indices = (jlong)max_jint; // The top 3 range checks seen const int NRC = 3; @@ -1932,13 +1972,18 @@ Node* RangeCheckNode::Ideal(PhaseGVN *phase, bool can_reshape) { found_immediate_dominator = true; break; } - // Gather expanded bounds - off_lo = MIN2(off_lo,offset2); - off_hi = MAX2(off_hi,offset2); - // Record top NRC range checks - prev_checks[nb_checks%NRC].ctl = prev_dom->as_IfProj(); - prev_checks[nb_checks%NRC].off = offset2; - nb_checks++; + + // "x - y" -> must add one to the difference for number of elements in [x,y] + const jlong diff = (jlong)MIN2(offset2, off_lo) - (jlong)MAX2(offset2, off_hi); + if (ABS(diff) < maximum_number_of_min_max_interval_indices) { + // Gather expanded bounds + off_lo = MIN2(off_lo, offset2); + off_hi = MAX2(off_hi, offset2); + // Record top NRC range checks + prev_checks[nb_checks % NRC].ctl = prev_dom->as_IfProj(); + prev_checks[nb_checks % NRC].off = offset2; + nb_checks++; + } } } prev_dom = dom; diff --git a/src/hotspot/share/opto/loopPredicate.cpp b/src/hotspot/share/opto/loopPredicate.cpp index 95496d73d56..fd0576832cf 100644 --- a/src/hotspot/share/opto/loopPredicate.cpp +++ b/src/hotspot/share/opto/loopPredicate.cpp @@ -850,9 +850,10 @@ BoolNode* PhaseIdealLoop::rc_predicate(IdealLoopTree* loop, Node* ctrl, int scal // Check if (scale * max_idx_expr) may overflow const TypeInt* scale_type = TypeInt::make(scale); MulINode* mul = new MulINode(max_idx_expr, con_scale); - idx_type = (TypeInt*)mul->mul_ring(idx_type, scale_type); - if (overflow || TypeInt::INT->higher_equal(idx_type)) { + + if (overflow || MulINode::does_overflow(idx_type, scale_type)) { // May overflow + idx_type = TypeInt::INT; mul->destruct(&_igvn); if (!overflow) { max_idx_expr = new ConvI2LNode(max_idx_expr); @@ -865,6 +866,7 @@ BoolNode* PhaseIdealLoop::rc_predicate(IdealLoopTree* loop, Node* ctrl, int scal } else { // No overflow possible max_idx_expr = mul; + idx_type = (TypeInt*)mul->mul_ring(idx_type, scale_type); } register_new_node(max_idx_expr, ctrl); } diff --git a/src/hotspot/share/opto/loopnode.cpp b/src/hotspot/share/opto/loopnode.cpp index e9655b329e1..ac822453356 100644 --- a/src/hotspot/share/opto/loopnode.cpp +++ b/src/hotspot/share/opto/loopnode.cpp @@ -491,19 +491,19 @@ PhiNode* PhaseIdealLoop::loop_iv_phi(Node* xphi, Node* phi_incr, Node* x, IdealL return phi; } -static int check_stride_overflow(jlong stride_con, const TypeInteger* limit_t, BasicType bt) { - if (stride_con > 0) { - if (limit_t->lo_as_long() > (max_signed_integer(bt) - stride_con)) { +static int check_stride_overflow(jlong final_correction, const TypeInteger* limit_t, BasicType bt) { + if (final_correction > 0) { + if (limit_t->lo_as_long() > (max_signed_integer(bt) - final_correction)) { return -1; } - if (limit_t->hi_as_long() > (max_signed_integer(bt) - stride_con)) { + if (limit_t->hi_as_long() > (max_signed_integer(bt) - final_correction)) { return 1; } } else { - if (limit_t->hi_as_long() < (min_signed_integer(bt) - stride_con)) { + if (limit_t->hi_as_long() < (min_signed_integer(bt) - final_correction)) { return -1; } - if (limit_t->lo_as_long() < (min_signed_integer(bt) - stride_con)) { + if (limit_t->lo_as_long() < (min_signed_integer(bt) - final_correction)) { return 1; } } @@ -1773,49 +1773,204 @@ bool PhaseIdealLoop::is_counted_loop(Node* x, IdealLoopTree*&loop, BasicType iv_ C->print_method(PHASE_BEFORE_CLOOPS, 3); // =================================================== - // Generate loop limit check to avoid integer overflow - // in cases like next (cyclic loops): + // We can only convert this loop to a counted loop if we can guarantee that the iv phi will never overflow at runtime. + // This is an implicit assumption taken by some loop optimizations. We therefore must ensure this property at all cost. + // At this point, we've already excluded some trivial cases where an overflow could have been proven statically. + // But even though we cannot prove that an overflow will *not* happen, we still want to speculatively convert this loop + // to a counted loop. This can be achieved by adding additional iv phi overflow checks before the loop. If they fail, + // we trap and resume execution before the loop without having executed any iteration of the loop, yet. // - // for (i=0; i <= max_jint; i++) {} - // for (i=0; i < max_jint; i+=2) {} + // These additional iv phi overflow checks can be inserted as Loop Limit Check Predicates above the Loop Limit Check + // Parse Predicate which captures a JVM state just before the entry of the loop. If there is no such Parse Predicate, + // we cannot generate a Loop Limit Check Predicate and thus cannot speculatively convert the loop to a counted loop. + // + // In the following, we only focus on int loops with stride > 0 to keep things simple. The argumentation and proof + // for stride < 0 is analogously. For long loops, we would replace max_int with max_long. // // - // Limit check predicate depends on the loop test: + // The loop to be converted does not always need to have the often used shape: // - // for(;i != limit; i++) --> limit <= (max_jint) - // for(;i < limit; i+=stride) --> limit <= (max_jint - stride + 1) - // for(;i <= limit; i+=stride) --> limit <= (max_jint - stride ) + // i = init + // i = init loop: + // do { ... + // // ... equivalent i+=stride + // i+=stride <==> if (i < limit) + // } while (i < limit); goto loop + // exit: + // ... // + // where the loop exit check uses the post-incremented iv phi and a '<'-operator. + // + // We could also have '<='-operator (or '>='-operator for negative strides) or use the pre-incremented iv phi value + // in the loop exit check: + // + // i = init + // loop: + // ... + // if (i <= limit) + // i+=stride + // goto loop + // exit: + // ... + // + // Let's define the following terms: + // - iv_pre_i: The pre-incremented iv phi before the i-th iteration. + // - iv_post_i: The post-incremented iv phi after the i-th iteration. + // + // The iv_pre_i and iv_post_i have the following relation: + // iv_pre_i + stride = iv_post_i + // + // When converting a loop to a counted loop, we want to have a canonicalized loop exit check of the form: + // iv_post_i < adjusted_limit + // + // If that is not the case, we need to canonicalize the loop exit check by using different values for adjusted_limit: + // (LE1) iv_post_i < limit: Already canonicalized. We can directly use limit as adjusted_limit. + // -> adjusted_limit = limit. + // (LE2) iv_post_i <= limit: + // iv_post_i < limit + 1 + // -> adjusted limit = limit + 1 + // (LE3) iv_pre_i < limit: + // iv_pre_i + stride < limit + stride + // iv_post_i < limit + stride + // -> adjusted_limit = limit + stride + // (LE4) iv_pre_i <= limit: + // iv_pre_i < limit + 1 + // iv_pre_i + stride < limit + stride + 1 + // iv_post_i < limit + stride + 1 + // -> adjusted_limit = limit + stride + 1 + // + // Note that: + // (AL) limit <= adjusted_limit. + // + // The following loop invariant has to hold for counted loops with n iterations (i.e. loop exit check true after n-th + // loop iteration) and a canonicalized loop exit check to guarantee that no iv_post_i over- or underflows: + // (INV) For i = 1..n, min_int <= iv_post_i <= max_int + // + // To prove (INV), we require the following two conditions/assumptions: + // (i): adjusted_limit - 1 + stride <= max_int + // (ii): init < limit + // + // If we can prove (INV), we know that there can be no over- or underflow of any iv phi value. We prove (INV) by + // induction by assuming (i) and (ii). + // + // Proof by Induction + // ------------------ + // > Base case (i = 1): We show that (INV) holds after the first iteration: + // min_int <= iv_post_1 = init + stride <= max_int + // Proof: + // First, we note that (ii) implies + // (iii) init <= limit - 1 + // max_int >= adjusted_limit - 1 + stride [using (i)] + // >= limit - 1 + stride [using (AL)] + // >= init + stride [using (iii)] + // >= min_int [using stride > 0, no underflow] + // Thus, no overflow happens after the first iteration and (INV) holds for i = 1. + // + // Note that to prove the base case we need (i) and (ii). + // + // > Induction Hypothesis (i = j, j > 1): Assume that (INV) holds after the j-th iteration: + // min_int <= iv_post_j <= max_int + // > Step case (i = j + 1): We show that (INV) also holds after the j+1-th iteration: + // min_int <= iv_post_{j+1} = iv_post_j + stride <= max_int + // Proof: + // If iv_post_j >= adjusted_limit: + // We exit the loop after the j-th iteration, and we don't execute the j+1-th iteration anymore. Thus, there is + // also no iv_{j+1}. Since (INV) holds for iv_j, there is nothing left to prove. + // If iv_post_j < adjusted_limit: + // First, we note that: + // (iv) iv_post_j <= adjusted_limit - 1 + // max_int >= adjusted_limit - 1 + stride [using (i)] + // >= iv_post_j + stride [using (iv)] + // >= min_int [using stride > 0, no underflow] + // + // Note that to prove the step case we only need (i). + // + // Thus, by assuming (i) and (ii), we proved (INV). + // + // + // It is therefore enough to add the following two Loop Limit Check Predicates to check assumptions (i) and (ii): + // + // (1) Loop Limit Check Predicate for (i): + // Using (i): adjusted_limit - 1 + stride <= max_int + // + // This condition is now restated to use limit instead of adjusted_limit: + // + // To prevent an overflow of adjusted_limit -1 + stride itself, we rewrite this check to + // max_int - stride + 1 >= adjusted_limit + // We can merge the two constants into + // canonicalized_correction = stride - 1 + // which gives us + // max_int - canonicalized_correction >= adjusted_limit + // + // To directly use limit instead of adjusted_limit in the predicate condition, we split adjusted_limit into: + // adjusted_limit = limit + limit_correction + // Since stride > 0 and limit_correction <= stride + 1, we can restate this with no over- or underflow into: + // max_int - canonicalized_correction - limit_correction >= limit + // Since canonicalized_correction and limit_correction are both constants, we can replace them with a new constant: + // final_correction = canonicalized_correction + limit_correction + // which gives us: + // + // Final predicate condition: + // max_int - final_correction >= limit + // + // (2) Loop Limit Check Predicate for (ii): + // Using (ii): init < limit + // + // This Loop Limit Check Predicate is not required if we can prove at compile time that either: + // (2.1) type(init) < type(limit) + // In this case, we know: + // all possible values of init < all possible values of limit + // and we can skip the predicate. + // + // (2.2) init < limit is already checked before (i.e. found as a dominating check) + // In this case, we do not need to re-check the condition and can skip the predicate. + // This is often found for while- and for-loops which have the following shape: + // + // if (init < limit) { // Dominating test. Do not need the Loop Limit Check Predicate below. + // i = init; + // if (init >= limit) { trap(); } // Here we would insert the Loop Limit Check Predicate + // do { + // i += stride; + // } while (i < limit); + // } + // + // (2.3) init + stride <= max_int + // In this case, there is no overflow of the iv phi after the first loop iteration. + // In the proof of the base case above we showed that init + stride <= max_int by using assumption (ii): + // init < limit + // In the proof of the step case above, we did not need (ii) anymore. Therefore, if we already know at + // compile time that init + stride <= max_int then we have trivially proven the base case and that + // there is no overflow of the iv phi after the first iteration. In this case, we don't need to check (ii) + // again and can skip the predicate. - // Check if limit is excluded to do more precise int overflow check. - bool incl_limit = (bt == BoolTest::le || bt == BoolTest::ge); - jlong stride_m = stride_con - (incl_limit ? 0 : (stride_con > 0 ? 1 : -1)); - // If compare points directly to the phi we need to adjust - // the compare so that it points to the incr. Limit have - // to be adjusted to keep trip count the same and the - // adjusted limit should be checked for int overflow. - Node* adjusted_limit = limit; - if (phi_incr != nullptr) { - stride_m += stride_con; - } + // Accounting for (LE3) and (LE4) where we use pre-incremented phis in the loop exit check. + const jlong limit_correction_for_pre_iv_exit_check = (phi_incr != nullptr) ? stride_con : 0; - Node *init_control = x->in(LoopNode::EntryControl); + // Accounting for (LE2) and (LE4) where we use <= or >= in the loop exit check. + const bool includes_limit = (bt == BoolTest::le || bt == BoolTest::ge); + const jlong limit_correction_for_le_ge_exit_check = (includes_limit ? (stride_con > 0 ? 1 : -1) : 0); + + const jlong limit_correction = limit_correction_for_pre_iv_exit_check + limit_correction_for_le_ge_exit_check; + const jlong canonicalized_correction = stride_con + (stride_con > 0 ? -1 : 1); + const jlong final_correction = canonicalized_correction + limit_correction; + + int sov = check_stride_overflow(final_correction, limit_t, iv_bt); + Node* init_control = x->in(LoopNode::EntryControl); - int sov = check_stride_overflow(stride_m, limit_t, iv_bt); // If sov==0, limit's type always satisfies the condition, for // example, when it is an array length. if (sov != 0) { if (sov < 0) { return false; // Bailout: integer overflow is certain. } + // (1) Loop Limit Check Predicate is required because we could not statically prove that + // limit + final_correction = adjusted_limit - 1 + stride <= max_int assert(!x->as_Loop()->is_loop_nest_inner_loop(), "loop was transformed"); - // Generate loop's limit check. - // Loop limit check predicate should be near the loop. const Predicates predicates(init_control); const PredicateBlock* loop_limit_check_predicate_block = predicates.loop_limit_check_predicate_block(); if (!loop_limit_check_predicate_block->has_parse_predicate()) { - // The limit check predicate is not generated if this method trapped here before. + // The Loop Limit Check Parse Predicate is not generated if this method trapped here before. #ifdef ASSERT if (TraceLoopLimitCheck) { tty->print("Missing Loop Limit Check Parse Predicate:"); @@ -1835,67 +1990,81 @@ bool PhaseIdealLoop::is_counted_loop(Node* x, IdealLoopTree*&loop, BasicType iv_ Node* bol; if (stride_con > 0) { - cmp_limit = CmpNode::make(limit, _igvn.integercon(max_signed_integer(iv_bt) - stride_m, iv_bt), iv_bt); + cmp_limit = CmpNode::make(limit, _igvn.integercon(max_signed_integer(iv_bt) - final_correction, iv_bt), iv_bt); bol = new BoolNode(cmp_limit, BoolTest::le); } else { - cmp_limit = CmpNode::make(limit, _igvn.integercon(min_signed_integer(iv_bt) - stride_m, iv_bt), iv_bt); + cmp_limit = CmpNode::make(limit, _igvn.integercon(min_signed_integer(iv_bt) - final_correction, iv_bt), iv_bt); bol = new BoolNode(cmp_limit, BoolTest::ge); } insert_loop_limit_check_predicate(init_control->as_IfTrue(), cmp_limit, bol); } - // Now we need to canonicalize loop condition. - if (bt == BoolTest::ne) { - assert(stride_con == 1 || stride_con == -1, "simple increment only"); - if (stride_con > 0 && init_t->hi_as_long() < limit_t->lo_as_long()) { - // 'ne' can be replaced with 'lt' only when init < limit. - bt = BoolTest::lt; - } else if (stride_con < 0 && init_t->lo_as_long() > limit_t->hi_as_long()) { - // 'ne' can be replaced with 'gt' only when init > limit. - bt = BoolTest::gt; - } else { - const Predicates predicates(init_control); - const PredicateBlock* loop_limit_check_predicate_block = predicates.loop_limit_check_predicate_block(); - if (!loop_limit_check_predicate_block->has_parse_predicate()) { - // The limit check predicate is not generated if this method trapped here before. + // (2.3) + const bool init_plus_stride_could_overflow = + (stride_con > 0 && init_t->hi_as_long() > max_signed_integer(iv_bt) - stride_con) || + (stride_con < 0 && init_t->lo_as_long() < min_signed_integer(iv_bt) - stride_con); + // (2.1) + const bool init_gte_limit = (stride_con > 0 && init_t->hi_as_long() >= limit_t->lo_as_long()) || + (stride_con < 0 && init_t->lo_as_long() <= limit_t->hi_as_long()); + + if (init_gte_limit && // (2.1) + ((bt == BoolTest::ne || init_plus_stride_could_overflow) && // (2.3) + !has_dominating_loop_limit_check(init_trip, limit, stride_con, iv_bt, init_control))) { // (2.2) + // (2) Iteration Loop Limit Check Predicate is required because neither (2.1), (2.2), nor (2.3) holds. + // We use the following condition: + // - stride > 0: init < limit + // - stride < 0: init > limit + // + // This predicate is always required if we have a non-equal-operator in the loop exit check (where stride = 1 is + // a requirement). We transform the loop exit check by using a less-than-operator. By doing so, we must always + // check that init < limit. Otherwise, we could have a different number of iterations at runtime. + + const Predicates predicates(init_control); + const PredicateBlock* loop_limit_check_predicate_block = predicates.loop_limit_check_predicate_block(); + if (!loop_limit_check_predicate_block->has_parse_predicate()) { + // The Loop Limit Check Parse Predicate is not generated if this method trapped here before. #ifdef ASSERT - if (TraceLoopLimitCheck) { - tty->print("Missing Loop Limit Check Parse Predicate:"); - loop->dump_head(); - x->dump(1); - } + if (TraceLoopLimitCheck) { + tty->print("Missing Loop Limit Check Parse Predicate:"); + loop->dump_head(); + x->dump(1); + } #endif - return false; - } + return false; + } - ParsePredicateNode* loop_limit_check_parse_predicate = loop_limit_check_predicate_block->parse_predicate(); - Node* parse_predicate_entry = loop_limit_check_parse_predicate->in(0); - if (!is_dominator(get_ctrl(limit), parse_predicate_entry) || - !is_dominator(get_ctrl(init_trip), parse_predicate_entry)) { - return false; - } + ParsePredicateNode* loop_limit_check_parse_predicate = loop_limit_check_predicate_block->parse_predicate(); + Node* parse_predicate_entry = loop_limit_check_parse_predicate->in(0); + if (!is_dominator(get_ctrl(limit), parse_predicate_entry) || + !is_dominator(get_ctrl(init_trip), parse_predicate_entry)) { + return false; + } - Node* cmp_limit; - Node* bol; + Node* cmp_limit; + Node* bol; - if (stride_con > 0) { - cmp_limit = CmpNode::make(init_trip, limit, iv_bt); - bol = new BoolNode(cmp_limit, BoolTest::lt); - } else { - cmp_limit = CmpNode::make(init_trip, limit, iv_bt); - bol = new BoolNode(cmp_limit, BoolTest::gt); - } + if (stride_con > 0) { + cmp_limit = CmpNode::make(init_trip, limit, iv_bt); + bol = new BoolNode(cmp_limit, BoolTest::lt); + } else { + cmp_limit = CmpNode::make(init_trip, limit, iv_bt); + bol = new BoolNode(cmp_limit, BoolTest::gt); + } - insert_loop_limit_check_predicate(init_control->as_IfTrue(), cmp_limit, bol); + insert_loop_limit_check_predicate(init_control->as_IfTrue(), cmp_limit, bol); + } - if (stride_con > 0) { - // 'ne' can be replaced with 'lt' only when init < limit. - bt = BoolTest::lt; - } else if (stride_con < 0) { - // 'ne' can be replaced with 'gt' only when init > limit. - bt = BoolTest::gt; - } + if (bt == BoolTest::ne) { + // Now we need to canonicalize the loop condition if it is 'ne'. + assert(stride_con == 1 || stride_con == -1, "simple increment only - checked before"); + if (stride_con > 0) { + // 'ne' can be replaced with 'lt' only when init < limit. This is ensured by the inserted predicate above. + bt = BoolTest::lt; + } else { + assert(stride_con < 0, "must be"); + // 'ne' can be replaced with 'gt' only when init > limit. This is ensured by the inserted predicate above. + bt = BoolTest::gt; } } @@ -1940,6 +2109,7 @@ bool PhaseIdealLoop::is_counted_loop(Node* x, IdealLoopTree*&loop, BasicType iv_ } #endif + Node* adjusted_limit = limit; if (phi_incr != nullptr) { // If compare points directly to the phi we need to adjust // the compare so that it points to the incr. Limit have @@ -1953,7 +2123,7 @@ bool PhaseIdealLoop::is_counted_loop(Node* x, IdealLoopTree*&loop, BasicType iv_ adjusted_limit = gvn->transform(AddNode::make(limit, stride, iv_bt)); } - if (incl_limit) { + if (includes_limit) { // The limit check guaranties that 'limit <= (max_jint - stride)' so // we can convert 'i <= limit' to 'i < limit+1' since stride != 0. // @@ -2134,6 +2304,37 @@ bool PhaseIdealLoop::is_counted_loop(Node* x, IdealLoopTree*&loop, BasicType iv_ return true; } +// Check if there is a dominating loop limit check of the form 'init < limit' starting at the loop entry. +// If there is one, then we do not need to create an additional Loop Limit Check Predicate. +bool PhaseIdealLoop::has_dominating_loop_limit_check(Node* init_trip, Node* limit, const jlong stride_con, + const BasicType iv_bt, Node* loop_entry) { + // Eagerly call transform() on the Cmp and Bool node to common them up if possible. This is required in order to + // successfully find a dominated test with the If node below. + Node* cmp_limit; + Node* bol; + if (stride_con > 0) { + cmp_limit = _igvn.transform(CmpNode::make(init_trip, limit, iv_bt)); + bol = _igvn.transform(new BoolNode(cmp_limit, BoolTest::lt)); + } else { + cmp_limit = _igvn.transform(CmpNode::make(init_trip, limit, iv_bt)); + bol = _igvn.transform(new BoolNode(cmp_limit, BoolTest::gt)); + } + + // Check if there is already a dominating init < limit check. If so, we do not need a Loop Limit Check Predicate. + IfNode* iff = new IfNode(loop_entry, bol, PROB_MIN, COUNT_UNKNOWN); + // Also add fake IfProj nodes in order to call transform() on the newly created IfNode. + IfFalseNode* if_false = new IfFalseNode(iff); + IfTrueNode* if_true = new IfTrueNode(iff); + Node* dominated_iff = _igvn.transform(iff); + // ConI node? Found dominating test (IfNode::dominated_by() returns a ConI node). + const bool found_dominating_test = dominated_iff != nullptr && dominated_iff->is_ConI(); + + // Kill the If with its projections again in the next IGVN round by cutting it off from the graph. + _igvn.replace_input_of(iff, 0, C->top()); + _igvn.replace_input_of(iff, 1, C->top()); + return found_dominating_test; +} + //----------------------exact_limit------------------------------------------- Node* PhaseIdealLoop::exact_limit( IdealLoopTree *loop ) { assert(loop->_head->is_CountedLoop(), ""); diff --git a/src/hotspot/share/opto/loopnode.hpp b/src/hotspot/share/opto/loopnode.hpp index 068ca07e129..51a67674f33 100644 --- a/src/hotspot/share/opto/loopnode.hpp +++ b/src/hotspot/share/opto/loopnode.hpp @@ -1346,6 +1346,8 @@ public: void rewire_cloned_nodes_to_ctrl(const ProjNode* old_ctrl, Node* new_ctrl, const Node_List& nodes_with_same_ctrl, const Dict& old_new_mapping); void rewire_inputs_of_clones_to_clones(Node* new_ctrl, Node* clone, const Dict& old_new_mapping, const Node* next); + bool has_dominating_loop_limit_check(Node* init_trip, Node* limit, jlong stride_con, BasicType iv_bt, + Node* loop_entry); public: void register_control(Node* n, IdealLoopTree *loop, Node* pred, bool update_body = true); diff --git a/src/hotspot/share/opto/mulnode.cpp b/src/hotspot/share/opto/mulnode.cpp index 8b7fa22af55..1f22c608323 100644 --- a/src/hotspot/share/opto/mulnode.cpp +++ b/src/hotspot/share/opto/mulnode.cpp @@ -281,45 +281,86 @@ Node *MulINode::Ideal(PhaseGVN *phase, bool can_reshape) { return res; // Return final result } -// Classes to perform mul_ring() for MulI/MulLNode. +// This template class performs type multiplication for MulI/MulLNode. NativeType is either jint or jlong. +// In this class, the inputs of the MulNodes are named left and right with types [left_lo,left_hi] and [right_lo,right_hi]. // -// This class checks if all cross products of the left and right input of a multiplication have the same "overflow value". -// Without overflow/underflow: -// Product is positive? High signed multiplication result: 0 -// Product is negative? High signed multiplication result: -1 +// In general, the multiplication of two x-bit values could produce a result that consumes up to 2x bits if there is +// enough space to hold them all. We can therefore distinguish the following two cases for the product: +// - no overflow (i.e. product fits into x bits) +// - overflow (i.e. product does not fit into x bits) // -// We normalize these values (see normalize_overflow_value()) such that we get the same "overflow value" by adding 1 if -// the product is negative. This allows us to compare all the cross product "overflow values". If one is different, -// compared to the others, then we know that this multiplication has a different number of over- or underflows compared -// to the others. In this case, we need to use bottom type and cannot guarantee a better type. Otherwise, we can take -// the min und max of all computed cross products as type of this Mul node. -template -class IntegerMulRing { - using NativeType = std::conditional_t::value, jint, jlong>; +// When multiplying the two x-bit inputs 'left' and 'right' with their x-bit types [left_lo,left_hi] and [right_lo,right_hi] +// we need to find the minimum and maximum of all possible products to define a new type. To do that, we compute the +// cross product of [left_lo,left_hi] and [right_lo,right_hi] in 2x-bit space where no over- or underflow can happen. +// The cross product consists of the following four multiplications with 2x-bit results: +// (1) left_lo * right_lo +// (2) left_lo * right_hi +// (3) left_hi * right_lo +// (4) left_hi * right_hi +// +// Let's define the following two functions: +// - Lx(i): Returns the lower x bits of the 2x-bit number i. +// - Ux(i): Returns the upper x bits of the 2x-bit number i. +// +// Let's first assume all products are positive where only overflows are possible but no underflows. If there is no +// overflow for a product p, then the upper x bits of the 2x-bit result p are all zero: +// Ux(p) = 0 +// Lx(p) = p +// +// If none of the multiplications (1)-(4) overflow, we can truncate the upper x bits and use the following result type +// with x bits: +// [result_lo,result_hi] = [MIN(Lx(1),Lx(2),Lx(3),Lx(4)),MAX(Lx(1),Lx(2),Lx(3),Lx(4))] +// +// If any of these multiplications overflows, we could pessimistically take the bottom type for the x bit result +// (i.e. all values in the x-bit space could be possible): +// [result_lo,result_hi] = [NativeType_min,NativeType_max] +// +// However, in case of any overflow, we can do better by analyzing the upper x bits of all multiplications (1)-(4) with +// 2x-bit results. The upper x bits tell us something about how many times a multiplication has overflown the lower +// x bits. If the upper x bits of (1)-(4) are all equal, then we know that all of these multiplications overflowed +// the lower x bits the same number of times: +// Ux((1)) = Ux((2)) = Ux((3)) = Ux((4)) +// +// If all upper x bits are equal, we can conclude: +// Lx(MIN((1),(2),(3),(4))) = MIN(Lx(1),Lx(2),Lx(3),Lx(4))) +// Lx(MAX((1),(2),(3),(4))) = MAX(Lx(1),Lx(2),Lx(3),Lx(4))) +// +// Therefore, we can use the same precise x-bit result type as for the no-overflow case: +// [result_lo,result_hi] = [(MIN(Lx(1),Lx(2),Lx(3),Lx(4))),MAX(Lx(1),Lx(2),Lx(3),Lx(4)))] +// +// +// Now let's assume that (1)-(4) are signed multiplications where over- and underflow could occur: +// Negative numbers are all sign extend with ones. Therefore, if a negative product does not underflow, then the +// upper x bits of the 2x-bit result are all set to ones which is minus one in two's complement. If there is an underflow, +// the upper x bits are decremented by the number of times an underflow occurred. The smallest possible negative product +// is NativeType_min*NativeType_max, where the upper x bits are set to NativeType_min / 2 (b11...0). It is therefore +// impossible to underflow the upper x bits. Thus, when having all ones (i.e. minus one) in the upper x bits, we know +// that there is no underflow. +// +// To be able to compare the number of over-/underflows of positive and negative products, respectively, we normalize +// the upper x bits of negative 2x-bit products by adding one. This way a product has no over- or underflow if the +// normalized upper x bits are zero. Now we can use the same improved type as for strictly positive products because we +// can compare the upper x bits in a unified way with N() being the normalization function: +// N(Ux((1))) = N(Ux((2))) = N(Ux((3)) = N(Ux((4))) +template +class IntegerTypeMultiplication { NativeType _lo_left; NativeType _lo_right; NativeType _hi_left; NativeType _hi_right; - NativeType _lo_lo_product; - NativeType _lo_hi_product; - NativeType _hi_lo_product; - NativeType _hi_hi_product; short _widen_left; short _widen_right; static const Type* overflow_type(); - static NativeType multiply_high_signed_overflow_value(NativeType x, NativeType y); + static NativeType multiply_high(NativeType x, NativeType y); + const Type* create_type(NativeType lo, NativeType hi) const; - // Pre-compute cross products which are used at several places - void compute_cross_products() { - _lo_lo_product = java_multiply(_lo_left, _lo_right); - _lo_hi_product = java_multiply(_lo_left, _hi_right); - _hi_lo_product = java_multiply(_hi_left, _lo_right); - _hi_hi_product = java_multiply(_hi_left, _hi_right); + static NativeType multiply_high_signed_overflow_value(NativeType x, NativeType y) { + return normalize_overflow_value(x, y, multiply_high(x, y)); } - bool cross_products_not_same_overflow() const { + bool cross_product_not_same_overflow_value() const { const NativeType lo_lo_high_product = multiply_high_signed_overflow_value(_lo_left, _lo_right); const NativeType lo_hi_high_product = multiply_high_signed_overflow_value(_lo_left, _hi_right); const NativeType hi_lo_high_product = multiply_high_signed_overflow_value(_hi_left, _lo_right); @@ -329,66 +370,95 @@ class IntegerMulRing { hi_lo_high_product != hi_hi_high_product; } + bool does_product_overflow(NativeType x, NativeType y) const { + return multiply_high_signed_overflow_value(x, y) != 0; + } + static NativeType normalize_overflow_value(const NativeType x, const NativeType y, NativeType result) { return java_multiply(x, y) < 0 ? result + 1 : result; } public: - IntegerMulRing(const IntegerType* left, const IntegerType* right) : _lo_left(left->_lo), _lo_right(right->_lo), - _hi_left(left->_hi), _hi_right(right->_hi), _widen_left(left->_widen), _widen_right(right->_widen) { - compute_cross_products(); - } + template + IntegerTypeMultiplication(const IntegerType* left, const IntegerType* right) + : _lo_left(left->_lo), _lo_right(right->_lo), + _hi_left(left->_hi), _hi_right(right->_hi), + _widen_left(left->_widen), _widen_right(right->_widen) {} // Compute the product type by multiplying the two input type ranges. We take the minimum and maximum of all possible // values (requires 4 multiplications of all possible combinations of the two range boundary values). If any of these // multiplications overflows/underflows, we need to make sure that they all have the same number of overflows/underflows // If that is not the case, we return the bottom type to cover all values due to the inconsistent overflows/underflows). const Type* compute() const { - if (cross_products_not_same_overflow()) { + if (cross_product_not_same_overflow_value()) { return overflow_type(); } - const NativeType min = MIN4(_lo_lo_product, _lo_hi_product, _hi_lo_product, _hi_hi_product); - const NativeType max = MAX4(_lo_lo_product, _lo_hi_product, _hi_lo_product, _hi_hi_product); - return IntegerType::make(min, max, MAX2(_widen_left, _widen_right)); + + NativeType lo_lo_product = java_multiply(_lo_left, _lo_right); + NativeType lo_hi_product = java_multiply(_lo_left, _hi_right); + NativeType hi_lo_product = java_multiply(_hi_left, _lo_right); + NativeType hi_hi_product = java_multiply(_hi_left, _hi_right); + const NativeType min = MIN4(lo_lo_product, lo_hi_product, hi_lo_product, hi_hi_product); + const NativeType max = MAX4(lo_lo_product, lo_hi_product, hi_lo_product, hi_hi_product); + return create_type(min, max); + } + + bool does_overflow() const { + return does_product_overflow(_lo_left, _lo_right) || + does_product_overflow(_lo_left, _hi_right) || + does_product_overflow(_hi_left, _lo_right) || + does_product_overflow(_hi_left, _hi_right); } }; - template <> -const Type* IntegerMulRing::overflow_type() { +const Type* IntegerTypeMultiplication::overflow_type() { return TypeInt::INT; } template <> -jint IntegerMulRing::multiply_high_signed_overflow_value(const jint x, const jint y) { +jint IntegerTypeMultiplication::multiply_high(const jint x, const jint y) { const jlong x_64 = x; const jlong y_64 = y; const jlong product = x_64 * y_64; - const jint result = (jint)((uint64_t)product >> 32u); - return normalize_overflow_value(x, y, result); + return (jint)((uint64_t)product >> 32u); } template <> -const Type* IntegerMulRing::overflow_type() { +const Type* IntegerTypeMultiplication::create_type(jint lo, jint hi) const { + return TypeInt::make(lo, hi, MAX2(_widen_left, _widen_right)); +} + +template <> +const Type* IntegerTypeMultiplication::overflow_type() { return TypeLong::LONG; } template <> -jlong IntegerMulRing::multiply_high_signed_overflow_value(const jlong x, const jlong y) { - const jlong result = multiply_high_signed(x, y); - return normalize_overflow_value(x, y, result); +jlong IntegerTypeMultiplication::multiply_high(const jlong x, const jlong y) { + return multiply_high_signed(x, y); +} + +template <> +const Type* IntegerTypeMultiplication::create_type(jlong lo, jlong hi) const { + return TypeLong::make(lo, hi, MAX2(_widen_left, _widen_right)); } // Compute the product type of two integer ranges into this node. const Type* MulINode::mul_ring(const Type* type_left, const Type* type_right) const { - const IntegerMulRing integer_mul_ring(type_left->is_int(), type_right->is_int()); - return integer_mul_ring.compute(); + const IntegerTypeMultiplication integer_multiplication(type_left->is_int(), type_right->is_int()); + return integer_multiplication.compute(); +} + +bool MulINode::does_overflow(const TypeInt* type_left, const TypeInt* type_right) { + const IntegerTypeMultiplication integer_multiplication(type_left, type_right); + return integer_multiplication.does_overflow(); } // Compute the product type of two long ranges into this node. const Type* MulLNode::mul_ring(const Type* type_left, const Type* type_right) const { - const IntegerMulRing integer_mul_ring(type_left->is_long(), type_right->is_long()); - return integer_mul_ring.compute(); + const IntegerTypeMultiplication integer_multiplication(type_left->is_long(), type_right->is_long()); + return integer_multiplication.compute(); } //============================================================================= diff --git a/src/hotspot/share/opto/mulnode.hpp b/src/hotspot/share/opto/mulnode.hpp index d04648ee61a..10ef442299d 100644 --- a/src/hotspot/share/opto/mulnode.hpp +++ b/src/hotspot/share/opto/mulnode.hpp @@ -95,6 +95,7 @@ public: virtual int Opcode() const; virtual Node *Ideal(PhaseGVN *phase, bool can_reshape); virtual const Type *mul_ring( const Type *, const Type * ) const; + static bool does_overflow(const TypeInt* type_left, const TypeInt* type_right); const Type *mul_id() const { return TypeInt::ONE; } const Type *add_id() const { return TypeInt::ZERO; } int add_opcode() const { return Op_AddI; } diff --git a/src/java.base/share/classes/com/sun/crypto/provider/RSACipher.java b/src/java.base/share/classes/com/sun/crypto/provider/RSACipher.java index 4278003c465..df76f78bfb5 100644 --- a/src/java.base/share/classes/com/sun/crypto/provider/RSACipher.java +++ b/src/java.base/share/classes/com/sun/crypto/provider/RSACipher.java @@ -98,6 +98,7 @@ public final class RSACipher extends CipherSpi { // cipher parameter for OAEP padding and TLS RSA premaster secret private AlgorithmParameterSpec spec = null; + private boolean forTlsPremasterSecret = false; // buffer for the data private byte[] buffer; @@ -286,6 +287,7 @@ public final class RSACipher extends CipherSpi { } spec = params; + forTlsPremasterSecret = true; this.random = random; // for TLS RSA premaster secret } int blockType = (mode <= MODE_DECRYPT) ? RSAPadding.PAD_BLOCKTYPE_2 @@ -377,7 +379,7 @@ public final class RSACipher extends CipherSpi { byte[] decryptBuffer = RSACore.convert(buffer, 0, bufOfs); paddingCopy = RSACore.rsa(decryptBuffer, privateKey, false); result = padding.unpad(paddingCopy); - if (result == null) { + if (result == null && !forTlsPremasterSecret) { throw new BadPaddingException ("Padding error in decryption"); } @@ -466,26 +468,22 @@ public final class RSACipher extends CipherSpi { boolean isTlsRsaPremasterSecret = algorithm.equals("TlsRsaPremasterSecret"); - Exception failover = null; byte[] encoded = null; update(wrappedKey, 0, wrappedKey.length); try { encoded = doFinal(); - } catch (BadPaddingException e) { - if (isTlsRsaPremasterSecret) { - failover = e; - } else { - throw new InvalidKeyException("Unwrapping failed", e); - } - } catch (IllegalBlockSizeException e) { - // should not occur, handled with length check above + } catch (BadPaddingException | IllegalBlockSizeException e) { + // BadPaddingException cannot happen for TLS RSA unwrap. + // In that case, padding error is indicated by returning null. + // IllegalBlockSizeException cannot happen in any case, + // because of the length check above. throw new InvalidKeyException("Unwrapping failed", e); } try { if (isTlsRsaPremasterSecret) { - if (!(spec instanceof TlsRsaPremasterSecretParameterSpec)) { + if (!forTlsPremasterSecret) { throw new IllegalStateException( "No TlsRsaPremasterSecretParameterSpec specified"); } @@ -494,7 +492,7 @@ public final class RSACipher extends CipherSpi { encoded = KeyUtil.checkTlsPreMasterSecretKey( ((TlsRsaPremasterSecretParameterSpec) spec).getClientVersion(), ((TlsRsaPremasterSecretParameterSpec) spec).getServerVersion(), - random, encoded, (failover != null)); + random, encoded, encoded == null); } return ConstructKeys.constructKey(encoded, algorithm, type); diff --git a/src/java.base/share/classes/sun/security/provider/certpath/ForwardBuilder.java b/src/java.base/share/classes/sun/security/provider/certpath/ForwardBuilder.java index e2dbd446f26..52a073d3b78 100644 --- a/src/java.base/share/classes/sun/security/provider/certpath/ForwardBuilder.java +++ b/src/java.base/share/classes/sun/security/provider/certpath/ForwardBuilder.java @@ -41,6 +41,7 @@ import java.security.cert.X509CertSelector; import java.util.*; import javax.security.auth.x500.X500Principal; +import jdk.internal.misc.ThreadTracker; import sun.security.provider.certpath.PKIX.BuilderParams; import sun.security.util.Debug; import sun.security.x509.AccessDescription; @@ -71,6 +72,10 @@ final class ForwardBuilder extends Builder { TrustAnchor trustAnchor; private final boolean searchAllCertStores; + private static class ThreadTrackerHolder { + static final ThreadTracker AIA_TRACKER = new ThreadTracker(); + } + /** * Initialize the builder with the input parameters. * @@ -336,7 +341,7 @@ final class ForwardBuilder extends Builder { } /** - * Download Certificates from the given AIA and add them to the + * Download certificates from the given AIA and add them to the * specified Collection. */ // cs.getCertificates(caSelector) returns a collection of X509Certificate's @@ -348,32 +353,47 @@ final class ForwardBuilder extends Builder { if (!Builder.USE_AIA) { return false; } + List adList = aiaExt.getAccessDescriptions(); if (adList == null || adList.isEmpty()) { return false; } - boolean add = false; - for (AccessDescription ad : adList) { - CertStore cs = URICertStore.getInstance(ad); - if (cs != null) { - try { - if (certs.addAll((Collection) - cs.getCertificates(caSelector))) { - add = true; - if (!searchAllCertStores) { - return true; + Object key = ThreadTrackerHolder.AIA_TRACKER.tryBegin(); + if (key == null) { + // Avoid recursive fetching of certificates + if (debug != null) { + debug.println("Recursive fetching of certs via the AIA " + + "extension detected"); + } + return false; + } + + try { + boolean add = false; + for (AccessDescription ad : adList) { + CertStore cs = URICertStore.getInstance(ad); + if (cs != null) { + try { + if (certs.addAll((Collection) + cs.getCertificates(caSelector))) { + add = true; + if (!searchAllCertStores) { + return true; + } + } + } catch (CertStoreException cse) { + if (debug != null) { + debug.println("exception getting certs from CertStore:"); + cse.printStackTrace(); } - } - } catch (CertStoreException cse) { - if (debug != null) { - debug.println("exception getting certs from CertStore:"); - cse.printStackTrace(); } } } + return add; + } finally { + ThreadTrackerHolder.AIA_TRACKER.end(key); } - return add; } /** diff --git a/src/java.base/share/classes/sun/security/util/KeyUtil.java b/src/java.base/share/classes/sun/security/util/KeyUtil.java index 059467779b9..c38889ed494 100644 --- a/src/java.base/share/classes/sun/security/util/KeyUtil.java +++ b/src/java.base/share/classes/sun/security/util/KeyUtil.java @@ -291,13 +291,14 @@ public final class KeyUtil { * contains the lower of that suggested by the client in the client * hello and the highest supported by the server. * @param encoded the encoded key in its "RAW" encoding format - * @param isFailOver whether the previous decryption of the - * encrypted PreMasterSecret message run into problem + * @param failure true if encoded is incorrect according to previous checks * @return the polished PreMasterSecret key in its "RAW" encoding format */ public static byte[] checkTlsPreMasterSecretKey( int clientVersion, int serverVersion, SecureRandom random, - byte[] encoded, boolean isFailOver) { + byte[] encoded, boolean failure) { + + byte[] tmp; if (random == null) { random = JCAUtil.getSecureRandom(); @@ -305,30 +306,38 @@ public final class KeyUtil { byte[] replacer = new byte[48]; random.nextBytes(replacer); - if (!isFailOver && (encoded != null)) { - // check the length - if (encoded.length != 48) { - // private, don't need to clone the byte array. - return replacer; - } - - int encodedVersion = - ((encoded[0] & 0xFF) << 8) | (encoded[1] & 0xFF); - if (clientVersion != encodedVersion) { - if (clientVersion > 0x0301 || // 0x0301: TLSv1 - serverVersion != encodedVersion) { - encoded = replacer; - } // Otherwise, For compatibility, we maintain the behavior - // that the version in pre_master_secret can be the - // negotiated version for TLS v1.0 and SSL v3.0. - } - - // private, don't need to clone the byte array. - return encoded; + if (failure) { + tmp = replacer; + } else { + tmp = encoded; } - // private, don't need to clone the byte array. - return replacer; + if (tmp == null) { + encoded = replacer; + } else { + encoded = tmp; + } + // check the length + if (encoded.length != 48) { + // private, don't need to clone the byte array. + tmp = replacer; + } else { + tmp = encoded; + } + + int encodedVersion = + ((tmp[0] & 0xFF) << 8) | (tmp[1] & 0xFF); + int check1 = 0; + int check2 = 0; + int check3 = 0; + if (clientVersion != encodedVersion) check1 = 1; + if (clientVersion > 0x0301) check2 = 1; + if (serverVersion != encodedVersion) check3 = 1; + if ((check1 & (check2 | check3)) == 1) { + return replacer; + } else { + return tmp; + } } /** diff --git a/src/java.base/share/native/libverify/check_code.c b/src/java.base/share/native/libverify/check_code.c index a0c427a7e4b..d1ebd3d5b94 100644 --- a/src/java.base/share/native/libverify/check_code.c +++ b/src/java.base/share/native/libverify/check_code.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 1994, 2022, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 1994, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -81,6 +81,7 @@ #include #include #include +#include #include "jni.h" #include "jni_util.h" @@ -1195,7 +1196,7 @@ verify_opcode_operands(context_type *context, unsigned int inumber, int offset) } } if (opcode == JVM_OPC_tableswitch) { - keys = _ck_ntohl(lpc[2]) - _ck_ntohl(lpc[1]) + 1; + keys = _ck_ntohl(lpc[2]) - _ck_ntohl(lpc[1]) + 1; delta = 1; } else { keys = _ck_ntohl(lpc[1]); /* number of pairs */ @@ -1677,11 +1678,14 @@ static int instruction_length(unsigned char *iptr, unsigned char *end) switch (instruction) { case JVM_OPC_tableswitch: { int *lpc = (int *)UCALIGN(iptr + 1); - int index; + int64_t low, high, index; if (lpc + 2 >= (int *)end) { return -1; /* do not read pass the end */ } - index = _ck_ntohl(lpc[2]) - _ck_ntohl(lpc[1]); + low = _ck_ntohl(lpc[1]); + high = _ck_ntohl(lpc[2]); + index = high - low; + // The value of low must be less than or equal to high - i.e. index >= 0 if ((index < 0) || (index > 65535)) { return -1; /* illegal */ } else { diff --git a/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/CRSACipher.java b/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/CRSACipher.java index ef6fa63e0a6..7b2fb631023 100644 --- a/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/CRSACipher.java +++ b/src/jdk.crypto.mscapi/windows/classes/sun/security/mscapi/CRSACipher.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2005, 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2005, 2023, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -30,6 +30,7 @@ import java.security.*; import java.security.Key; import java.security.interfaces.*; import java.security.spec.*; +import java.util.Arrays; import javax.crypto.*; import javax.crypto.spec.*; @@ -61,6 +62,9 @@ import sun.security.util.KeyUtil; */ public final class CRSACipher extends CipherSpi { + private static final int ERROR_INVALID_PARAMETER = 0x57; + private static final int NTE_INVALID_PARAMETER = 0x80090027; + // constant for an empty byte array private static final byte[] B0 = new byte[0]; @@ -101,6 +105,8 @@ public final class CRSACipher extends CipherSpi { // cipher parameter for TLS RSA premaster secret private AlgorithmParameterSpec spec = null; + private boolean forTlsPremasterSecret = false; + // the source of randomness private SecureRandom random; @@ -171,6 +177,9 @@ public final class CRSACipher extends CipherSpi { } spec = params; this.random = random; // for TLS RSA premaster secret + this.forTlsPremasterSecret = true; + } else { + this.forTlsPremasterSecret = false; } init(opmode, key); } @@ -278,8 +287,7 @@ public final class CRSACipher extends CipherSpi { } // internal doFinal() method. Here we perform the actual RSA operation - private byte[] doFinal() throws BadPaddingException, - IllegalBlockSizeException { + private byte[] doFinal() throws IllegalBlockSizeException { if (bufOfs > buffer.length) { throw new IllegalBlockSizeException("Data must not be longer " + "than " + (buffer.length - paddingLength) + " bytes"); @@ -308,7 +316,7 @@ public final class CRSACipher extends CipherSpi { throw new AssertionError("Internal error"); } - } catch (KeyException e) { + } catch (KeyException | BadPaddingException e) { throw new ProviderException(e); } finally { @@ -331,14 +339,14 @@ public final class CRSACipher extends CipherSpi { // see JCE spec protected byte[] engineDoFinal(byte[] in, int inOfs, int inLen) - throws BadPaddingException, IllegalBlockSizeException { + throws IllegalBlockSizeException { update(in, inOfs, inLen); return doFinal(); } // see JCE spec protected int engineDoFinal(byte[] in, int inOfs, int inLen, byte[] out, - int outOfs) throws ShortBufferException, BadPaddingException, + int outOfs) throws ShortBufferException, IllegalBlockSizeException { if (outputSize > out.length - outOfs) { throw new ShortBufferException @@ -354,6 +362,7 @@ public final class CRSACipher extends CipherSpi { // see JCE spec protected byte[] engineWrap(Key key) throws InvalidKeyException, IllegalBlockSizeException { + byte[] encoded = key.getEncoded(); // TODO - unextractable key if ((encoded == null) || (encoded.length == 0)) { throw new InvalidKeyException("Could not obtain encoded key"); @@ -362,12 +371,7 @@ public final class CRSACipher extends CipherSpi { throw new InvalidKeyException("Key is too long for wrapping"); } update(encoded, 0, encoded.length); - try { - return doFinal(); - } catch (BadPaddingException e) { - // should not occur - throw new InvalidKeyException("Wrapping failed", e); - } + return doFinal(); } // see JCE spec @@ -388,31 +392,31 @@ public final class CRSACipher extends CipherSpi { update(wrappedKey, 0, wrappedKey.length); try { encoded = doFinal(); - } catch (BadPaddingException e) { - if (isTlsRsaPremasterSecret) { - failover = e; - } else { - throw new InvalidKeyException("Unwrapping failed", e); - } } catch (IllegalBlockSizeException e) { // should not occur, handled with length check above throw new InvalidKeyException("Unwrapping failed", e); } - if (isTlsRsaPremasterSecret) { - if (!(spec instanceof TlsRsaPremasterSecretParameterSpec)) { - throw new IllegalStateException( - "No TlsRsaPremasterSecretParameterSpec specified"); + try { + if (isTlsRsaPremasterSecret) { + if (!forTlsPremasterSecret) { + throw new IllegalStateException( + "No TlsRsaPremasterSecretParameterSpec specified"); + } + + // polish the TLS premaster secret + encoded = KeyUtil.checkTlsPreMasterSecretKey( + ((TlsRsaPremasterSecretParameterSpec) spec).getClientVersion(), + ((TlsRsaPremasterSecretParameterSpec) spec).getServerVersion(), + random, encoded, encoded == null); } - // polish the TLS premaster secret - encoded = KeyUtil.checkTlsPreMasterSecretKey( - ((TlsRsaPremasterSecretParameterSpec)spec).getClientVersion(), - ((TlsRsaPremasterSecretParameterSpec)spec).getServerVersion(), - random, encoded, (failover != null)); + return constructKey(encoded, algorithm, type); + } finally { + if (encoded != null) { + Arrays.fill(encoded, (byte) 0); + } } - - return constructKey(encoded, algorithm, type); } // see JCE spec @@ -496,17 +500,30 @@ public final class CRSACipher extends CipherSpi { * Encrypt/decrypt a data buffer using Microsoft Crypto API or CNG. * It expects and returns ciphertext data in big-endian form. */ - private static byte[] encryptDecrypt(byte[] data, int dataSize, - CKey key, boolean doEncrypt) throws KeyException { + private byte[] encryptDecrypt(byte[] data, int dataSize, + CKey key, boolean doEncrypt) throws KeyException, BadPaddingException { + int[] returnStatus = new int[1]; + byte[] result; if (key.getHCryptKey() != 0) { - return encryptDecrypt(data, dataSize, key.getHCryptKey(), doEncrypt); + result = encryptDecrypt(returnStatus, data, dataSize, key.getHCryptKey(), doEncrypt); } else { - return cngEncryptDecrypt(data, dataSize, key.getHCryptProvider(), doEncrypt); + result = cngEncryptDecrypt(returnStatus, data, dataSize, key.getHCryptProvider(), doEncrypt); } + if ((returnStatus[0] == ERROR_INVALID_PARAMETER) || (returnStatus[0] == NTE_INVALID_PARAMETER)) { + if (forTlsPremasterSecret) { + result = null; + } else { + throw new BadPaddingException("Error " + returnStatus[0] + " returned by MSCAPI"); + } + } else if (returnStatus[0] != 0) { + throw new KeyException("Error " + returnStatus[0] + " returned by MSCAPI"); + } + + return result; } - private static native byte[] encryptDecrypt(byte[] data, int dataSize, + private static native byte[] encryptDecrypt(int[] returnStatus, byte[] data, int dataSize, long key, boolean doEncrypt) throws KeyException; - private static native byte[] cngEncryptDecrypt(byte[] data, int dataSize, + private static native byte[] cngEncryptDecrypt(int[] returnStatus, byte[] data, int dataSize, long key, boolean doEncrypt) throws KeyException; } diff --git a/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp b/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp index f2b6bfd83f6..4787708779d 100644 --- a/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp +++ b/src/jdk.crypto.mscapi/windows/native/libsunmscapi/security.cpp @@ -1905,18 +1905,25 @@ JNIEXPORT void JNICALL Java_sun_security_mscapi_CKeyStore_destroyKeyContainer /* * Class: sun_security_mscapi_CRSACipher * Method: encryptDecrypt - * Signature: ([BIJZ)[B + * Signature: ([I[BIJZ)[B */ JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_CRSACipher_encryptDecrypt - (JNIEnv *env, jclass clazz, jbyteArray jData, jint jDataSize, jlong hKey, + (JNIEnv *env, jclass clazz, jintArray jResultStatus, jbyteArray jData, jint jDataSize, jlong hKey, jboolean doEncrypt) { jbyteArray result = NULL; jbyte* pData = NULL; + jbyte* resultData = NULL; DWORD dwDataLen = jDataSize; DWORD dwBufLen = env->GetArrayLength(jData); DWORD i; BYTE tmp; + BOOL success; + DWORD ss = ERROR_SUCCESS; + DWORD lastError = ERROR_SUCCESS; + DWORD resultLen = 0; + DWORD pmsLen = 48; + jbyte pmsArr[48] = {0}; __try { @@ -1943,6 +1950,8 @@ JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_CRSACipher_encryptDecrypt pData[i] = pData[dwBufLen - i -1]; pData[dwBufLen - i - 1] = tmp; } + resultData = pData; + resultLen = dwBufLen; } else { // convert to little-endian for (i = 0; i < dwBufLen / 2; i++) { @@ -1952,21 +1961,28 @@ JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_CRSACipher_encryptDecrypt } // decrypt - if (! ::CryptDecrypt((HCRYPTKEY) hKey, 0, TRUE, 0, (BYTE *)pData, //deprecated - &dwBufLen)) { - - ThrowException(env, KEY_EXCEPTION, GetLastError()); - __leave; + success = ::CryptDecrypt((HCRYPTKEY) hKey, 0, TRUE, 0, (BYTE *)pData, //deprecated + &dwBufLen); + lastError = GetLastError(); + if (success) { + ss = ERROR_SUCCESS; + resultData = pData; + resultLen = dwBufLen; + } else { + ss = lastError; + resultData = pmsArr; + resultLen = pmsLen; } + env->SetIntArrayRegion(jResultStatus, 0, 1, (jint*) &ss); } - // Create new byte array - if ((result = env->NewByteArray(dwBufLen)) == NULL) { + // Create new byte array + if ((result = env->NewByteArray(resultLen)) == NULL) { __leave; } // Copy data from native buffer to Java buffer - env->SetByteArrayRegion(result, 0, dwBufLen, (jbyte*) pData); + env->SetByteArrayRegion(result, 0, resultLen, (jbyte*) resultData); } __finally { @@ -1980,17 +1996,22 @@ JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_CRSACipher_encryptDecrypt /* * Class: sun_security_mscapi_CRSACipher * Method: cngEncryptDecrypt - * Signature: ([BIJZ)[B + * Signature: ([I[BIJZ)[B */ JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_CRSACipher_cngEncryptDecrypt - (JNIEnv *env, jclass clazz, jbyteArray jData, jint jDataSize, jlong hKey, + (JNIEnv *env, jclass clazz, jintArray jResultStatus, jbyteArray jData, jint jDataSize, jlong hKey, jboolean doEncrypt) { SECURITY_STATUS ss; jbyteArray result = NULL; jbyte* pData = NULL; + jbyte* resultData = NULL; DWORD dwDataLen = jDataSize; DWORD dwBufLen = env->GetArrayLength(jData); + DWORD resultLen = 0; + DWORD pmsLen = 48; + jbyte pmsArr[48] = {0}; + __try { // Copy data from Java buffer to native buffer @@ -2010,6 +2031,9 @@ JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_CRSACipher_cngEncryptDecry if (ss != ERROR_SUCCESS) { ThrowException(env, KEY_EXCEPTION, ss); __leave; + } else { + resultLen = dwBufLen; + resultData = pData; } } else { // decrypt @@ -2018,18 +2042,22 @@ JNIEXPORT jbyteArray JNICALL Java_sun_security_mscapi_CRSACipher_cngEncryptDecry 0, (PBYTE)pData, dwBufLen, &dwBufLen, NCRYPT_PAD_PKCS1_FLAG); - if (ss != ERROR_SUCCESS) { - ThrowException(env, KEY_EXCEPTION, ss); - __leave; + env->SetIntArrayRegion(jResultStatus, 0, 1, (jint*) &ss); + if (ss == ERROR_SUCCESS) { + resultLen = dwBufLen; + resultData = pData; + } else { + resultLen = pmsLen; + resultData = pmsArr; } - } + } // Create new byte array - if ((result = env->NewByteArray(dwBufLen)) == NULL) { + if ((result = env->NewByteArray(resultLen)) == NULL) { __leave; } // Copy data from native buffer to Java buffer - env->SetByteArrayRegion(result, 0, dwBufLen, (jbyte*) pData); + env->SetByteArrayRegion(result, 0, resultLen, (jbyte*) resultData); } __finally { if (pData) { diff --git a/test/hotspot/jtreg/ProblemList.txt b/test/hotspot/jtreg/ProblemList.txt index 89d6dde8e54..a9094902cf2 100644 --- a/test/hotspot/jtreg/ProblemList.txt +++ b/test/hotspot/jtreg/ProblemList.txt @@ -64,6 +64,7 @@ compiler/rtm/locking/TestUseRTMXendForLockBusy.java 8183263 generic-x64,generic- compiler/rtm/print/TestPrintPreciseRTMLockingStatistics.java 8183263 generic-x64,generic-i586 compiler/c2/Test8004741.java 8235801 generic-all +compiler/c2/irTests/TestDuplicateBackedge.java 8318904 generic-all compiler/codecache/jmx/PoolsIndependenceTest.java 8264632 macosx-all