diff options
| -rw-r--r-- | mrbgems/mruby-random/src/random.c | 103 | ||||
| -rw-r--r-- | src/class.c | 23 | ||||
| -rw-r--r-- | src/hash.c | 9 | ||||
| -rw-r--r-- | test/t/hash.rb | 1 |
4 files changed, 117 insertions, 19 deletions
diff --git a/mrbgems/mruby-random/src/random.c b/mrbgems/mruby-random/src/random.c index 4ef009c54..8f983ea0f 100644 --- a/mrbgems/mruby-random/src/random.c +++ b/mrbgems/mruby-random/src/random.c @@ -90,21 +90,31 @@ get_opt(mrb_state* mrb) return arg; } +static mrb_value +get_random(mrb_state *mrb) { + return mrb_const_get(mrb, + mrb_obj_value(mrb_class_get(mrb, "Random")), + mrb_intern_lit(mrb, "DEFAULT")); +} + +static mt_state * +get_random_state(mrb_state *mrb) +{ + mrb_value random_val = get_random(mrb); + return DATA_GET_PTR(mrb, random_val, &mt_state_type, mt_state); +} + static mrb_value mrb_random_g_rand(mrb_state *mrb, mrb_value self) { - mrb_value random = mrb_const_get(mrb, - mrb_obj_value(mrb_class_get(mrb, "Random")), - mrb_intern_lit(mrb, "DEFAULT")); + mrb_value random = get_random(mrb); return mrb_random_rand(mrb, random); } static mrb_value mrb_random_g_srand(mrb_state *mrb, mrb_value self) { - mrb_value random = mrb_const_get(mrb, - mrb_obj_value(mrb_class_get(mrb, "Random")), - mrb_intern_lit(mrb, "DEFAULT")); + mrb_value random = get_random(mrb); return mrb_random_srand(mrb, random); } @@ -154,7 +164,7 @@ static mrb_value mrb_random_rand(mrb_state *mrb, mrb_value self) { mrb_value max; - mt_state *t = DATA_PTR(self); + mt_state *t = DATA_GET_PTR(mrb, self, &mt_state_type, mt_state); max = get_opt(mrb); mrb_random_rand_seed(mrb, t); @@ -166,7 +176,7 @@ mrb_random_srand(mrb_state *mrb, mrb_value self) { mrb_value seed; mrb_value old_seed; - mt_state *t = DATA_PTR(self); + mt_state *t = DATA_GET_PTR(mrb, self, &mt_state_type, mt_state); seed = get_opt(mrb); seed = mrb_random_mt_srand(mrb, t, seed); @@ -200,10 +210,7 @@ mrb_ary_shuffle_bang(mrb_state *mrb, mrb_value ary) mrb_get_args(mrb, "|d", &random, &mt_state_type); if (random == NULL) { - mrb_value random_val = mrb_const_get(mrb, - mrb_obj_value(mrb_class_get(mrb, "Random")), - mrb_intern_lit(mrb, "DEFAULT")); - random = (mt_state *)DATA_PTR(random_val); + random = get_random_state(mrb); } mrb_random_rand_seed(mrb, random); @@ -240,6 +247,77 @@ mrb_ary_shuffle(mrb_state *mrb, mrb_value ary) return new_ary; } +/* + * call-seq: + * ary.sample -> obj + * ary.sample(n) -> new_ary + * + * Choose a random element or +n+ random elements from the array. + * + * The elements are chosen by using random and unique indices into the array + * in order to ensure that an element doesn't repeat itself unless the array + * already contained duplicate elements. + * + * If the array is empty the first form returns +nil+ and the second form + * returns an empty array. + */ + +static mrb_value +mrb_ary_sample(mrb_state *mrb, mrb_value ary) +{ + mrb_int n = 0; + mrb_bool given; + mt_state *random = NULL; + mrb_int len = RARRAY_LEN(ary); + + mrb_get_args(mrb, "|i?d", &n, &given, &random, &mt_state_type); + if (random == NULL) { + random = get_random_state(mrb); + } + mrb_random_rand_seed(mrb, random); + mt_rand(random); + if (!given) { /* pick one element */ + switch (len) { + case 0: + return mrb_nil_value(); + case 1: + return RARRAY_PTR(ary)[0]; + default: + return RARRAY_PTR(ary)[mt_rand(random) % len]; + } + } + else { + mrb_value result; + mrb_int i, j; + + if (n < 0) mrb_raise(mrb, E_ARGUMENT_ERROR, "negative sample number"); + if (n > len) n = len; + result = mrb_ary_new_capa(mrb, n); + for (i=0; i<n; i++) { + mrb_int r; + + for (;;) { + retry: + r = mt_rand(random) % len; + + for (j=0; j<i; j++) { + if (mrb_fixnum(RARRAY_PTR(result)[j]) == r) { + goto retry; /* retry if duplicate */ + } + } + break; + } + RARRAY_PTR(result)[i] = mrb_fixnum_value(r); + RARRAY_LEN(result)++; + } + for (i=0; i<n; i++) { + RARRAY_PTR(result)[i] = RARRAY_PTR(ary)[mrb_fixnum(RARRAY_PTR(result)[i])]; + } + return result; + } +} + + void mrb_mruby_random_gem_init(mrb_state *mrb) { struct RClass *random; @@ -259,6 +337,7 @@ void mrb_mruby_random_gem_init(mrb_state *mrb) mrb_define_method(mrb, array, "shuffle", mrb_ary_shuffle, MRB_ARGS_OPT(1)); mrb_define_method(mrb, array, "shuffle!", mrb_ary_shuffle_bang, MRB_ARGS_OPT(1)); + mrb_define_method(mrb, array, "sample", mrb_ary_sample, MRB_ARGS_OPT(2)); mrb_const_set(mrb, mrb_obj_value(random), mrb_intern_lit(mrb, "DEFAULT"), mrb_obj_new(mrb, random, 0, NULL)); diff --git a/src/class.c b/src/class.c index 6f5a8ed19..d880e3627 100644 --- a/src/class.c +++ b/src/class.c @@ -411,6 +411,7 @@ to_hash(mrb_state *mrb, mrb_value val) &: Block [mrb_value] *: rest argument [mrb_value*,int] Receive the rest of the arguments as an array. |: optional Next argument of '|' and later are optional. + ?: optional given [mrb_bool] true if preceding argument (optional) is given. */ int mrb_get_args(mrb_state *mrb, const char *format, ...) @@ -420,7 +421,8 @@ mrb_get_args(mrb_state *mrb, const char *format, ...) mrb_value *sp = mrb->c->stack + 1; va_list ap; int argc = mrb->c->ci->argc; - int opt = 0; + mrb_bool opt = 0; + mrb_bool given = 1; va_start(ap, format); if (argc < 0) { @@ -431,11 +433,16 @@ mrb_get_args(mrb_state *mrb, const char *format, ...) } while ((c = *format++)) { switch (c) { - case '|': case '*': case '&': + case '|': case '*': case '&': case '?': break; default: - if (argc <= i && !opt) { - mrb_raise(mrb, E_ARGUMENT_ERROR, "wrong number of arguments"); + if (argc <= i) { + if (opt) { + given = 0; + } + else { + mrb_raise(mrb, E_ARGUMENT_ERROR, "wrong number of arguments"); + } } break; } @@ -692,6 +699,14 @@ mrb_get_args(mrb_state *mrb, const char *format, ...) case '|': opt = 1; break; + case '?': + { + mrb_bool *p; + + p = va_arg(ap, mrb_bool*); + *p = given; + } + break; case '*': { diff --git a/src/hash.c b/src/hash.c index af3571eaf..34cc15131 100644 --- a/src/hash.c +++ b/src/hash.c @@ -874,6 +874,7 @@ static mrb_value hash_equal(mrb_state *mrb, mrb_value hash1, mrb_value hash2, mrb_bool eql) { khash_t(ht) *h1, *h2; + mrb_bool eq; if (mrb_obj_equal(mrb, hash1, hash2)) return mrb_true_value(); if (!mrb_hash_p(hash2)) { @@ -881,8 +882,6 @@ hash_equal(mrb_state *mrb, mrb_value hash1, mrb_value hash2, mrb_bool eql) return mrb_false_value(); } else { - mrb_bool eq; - if (eql) { eq = mrb_eql(mrb, hash2, hash1); } @@ -908,7 +907,11 @@ hash_equal(mrb_state *mrb, mrb_value hash1, mrb_value hash2, mrb_bool eql) key = kh_key(h1,k1); k2 = kh_get(ht, mrb, h2, key); if (k2 != kh_end(h2)) { - if (mrb_eql(mrb, kh_value(h1,k1), kh_value(h2,k2))) { + if (eql) + eq = mrb_eql(mrb, kh_value(h1,k1), kh_value(h2,k2)); + else + eq = mrb_equal(mrb, kh_value(h1,k1), kh_value(h2,k2)); + if (eq) { continue; /* next key */ } } diff --git a/test/t/hash.rb b/test/t/hash.rb index 92bc223b6..e7d5e8f74 100644 --- a/test/t/hash.rb +++ b/test/t/hash.rb @@ -12,6 +12,7 @@ end assert('Hash#==', '15.2.13.4.1') do assert_true({ 'abc' => 'abc' } == { 'abc' => 'abc' }) assert_false({ 'abc' => 'abc' } == { 'cba' => 'cba' }) + assert_true({ :equal => 1 } == { :equal => 1.0 }) end assert('Hash#[]', '15.2.13.4.2') do |
