summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--mrbgems/mruby-random/src/random.c103
-rw-r--r--src/class.c23
-rw-r--r--src/hash.c9
-rw-r--r--test/t/hash.rb1
4 files changed, 117 insertions, 19 deletions
diff --git a/mrbgems/mruby-random/src/random.c b/mrbgems/mruby-random/src/random.c
index 4ef009c54..8f983ea0f 100644
--- a/mrbgems/mruby-random/src/random.c
+++ b/mrbgems/mruby-random/src/random.c
@@ -90,21 +90,31 @@ get_opt(mrb_state* mrb)
return arg;
}
+static mrb_value
+get_random(mrb_state *mrb) {
+ return mrb_const_get(mrb,
+ mrb_obj_value(mrb_class_get(mrb, "Random")),
+ mrb_intern_lit(mrb, "DEFAULT"));
+}
+
+static mt_state *
+get_random_state(mrb_state *mrb)
+{
+ mrb_value random_val = get_random(mrb);
+ return DATA_GET_PTR(mrb, random_val, &mt_state_type, mt_state);
+}
+
static mrb_value
mrb_random_g_rand(mrb_state *mrb, mrb_value self)
{
- mrb_value random = mrb_const_get(mrb,
- mrb_obj_value(mrb_class_get(mrb, "Random")),
- mrb_intern_lit(mrb, "DEFAULT"));
+ mrb_value random = get_random(mrb);
return mrb_random_rand(mrb, random);
}
static mrb_value
mrb_random_g_srand(mrb_state *mrb, mrb_value self)
{
- mrb_value random = mrb_const_get(mrb,
- mrb_obj_value(mrb_class_get(mrb, "Random")),
- mrb_intern_lit(mrb, "DEFAULT"));
+ mrb_value random = get_random(mrb);
return mrb_random_srand(mrb, random);
}
@@ -154,7 +164,7 @@ static mrb_value
mrb_random_rand(mrb_state *mrb, mrb_value self)
{
mrb_value max;
- mt_state *t = DATA_PTR(self);
+ mt_state *t = DATA_GET_PTR(mrb, self, &mt_state_type, mt_state);
max = get_opt(mrb);
mrb_random_rand_seed(mrb, t);
@@ -166,7 +176,7 @@ mrb_random_srand(mrb_state *mrb, mrb_value self)
{
mrb_value seed;
mrb_value old_seed;
- mt_state *t = DATA_PTR(self);
+ mt_state *t = DATA_GET_PTR(mrb, self, &mt_state_type, mt_state);
seed = get_opt(mrb);
seed = mrb_random_mt_srand(mrb, t, seed);
@@ -200,10 +210,7 @@ mrb_ary_shuffle_bang(mrb_state *mrb, mrb_value ary)
mrb_get_args(mrb, "|d", &random, &mt_state_type);
if (random == NULL) {
- mrb_value random_val = mrb_const_get(mrb,
- mrb_obj_value(mrb_class_get(mrb, "Random")),
- mrb_intern_lit(mrb, "DEFAULT"));
- random = (mt_state *)DATA_PTR(random_val);
+ random = get_random_state(mrb);
}
mrb_random_rand_seed(mrb, random);
@@ -240,6 +247,77 @@ mrb_ary_shuffle(mrb_state *mrb, mrb_value ary)
return new_ary;
}
+/*
+ * call-seq:
+ * ary.sample -> obj
+ * ary.sample(n) -> new_ary
+ *
+ * Choose a random element or +n+ random elements from the array.
+ *
+ * The elements are chosen by using random and unique indices into the array
+ * in order to ensure that an element doesn't repeat itself unless the array
+ * already contained duplicate elements.
+ *
+ * If the array is empty the first form returns +nil+ and the second form
+ * returns an empty array.
+ */
+
+static mrb_value
+mrb_ary_sample(mrb_state *mrb, mrb_value ary)
+{
+ mrb_int n = 0;
+ mrb_bool given;
+ mt_state *random = NULL;
+ mrb_int len = RARRAY_LEN(ary);
+
+ mrb_get_args(mrb, "|i?d", &n, &given, &random, &mt_state_type);
+ if (random == NULL) {
+ random = get_random_state(mrb);
+ }
+ mrb_random_rand_seed(mrb, random);
+ mt_rand(random);
+ if (!given) { /* pick one element */
+ switch (len) {
+ case 0:
+ return mrb_nil_value();
+ case 1:
+ return RARRAY_PTR(ary)[0];
+ default:
+ return RARRAY_PTR(ary)[mt_rand(random) % len];
+ }
+ }
+ else {
+ mrb_value result;
+ mrb_int i, j;
+
+ if (n < 0) mrb_raise(mrb, E_ARGUMENT_ERROR, "negative sample number");
+ if (n > len) n = len;
+ result = mrb_ary_new_capa(mrb, n);
+ for (i=0; i<n; i++) {
+ mrb_int r;
+
+ for (;;) {
+ retry:
+ r = mt_rand(random) % len;
+
+ for (j=0; j<i; j++) {
+ if (mrb_fixnum(RARRAY_PTR(result)[j]) == r) {
+ goto retry; /* retry if duplicate */
+ }
+ }
+ break;
+ }
+ RARRAY_PTR(result)[i] = mrb_fixnum_value(r);
+ RARRAY_LEN(result)++;
+ }
+ for (i=0; i<n; i++) {
+ RARRAY_PTR(result)[i] = RARRAY_PTR(ary)[mrb_fixnum(RARRAY_PTR(result)[i])];
+ }
+ return result;
+ }
+}
+
+
void mrb_mruby_random_gem_init(mrb_state *mrb)
{
struct RClass *random;
@@ -259,6 +337,7 @@ void mrb_mruby_random_gem_init(mrb_state *mrb)
mrb_define_method(mrb, array, "shuffle", mrb_ary_shuffle, MRB_ARGS_OPT(1));
mrb_define_method(mrb, array, "shuffle!", mrb_ary_shuffle_bang, MRB_ARGS_OPT(1));
+ mrb_define_method(mrb, array, "sample", mrb_ary_sample, MRB_ARGS_OPT(2));
mrb_const_set(mrb, mrb_obj_value(random), mrb_intern_lit(mrb, "DEFAULT"),
mrb_obj_new(mrb, random, 0, NULL));
diff --git a/src/class.c b/src/class.c
index 6f5a8ed19..d880e3627 100644
--- a/src/class.c
+++ b/src/class.c
@@ -411,6 +411,7 @@ to_hash(mrb_state *mrb, mrb_value val)
&: Block [mrb_value]
*: rest argument [mrb_value*,int] Receive the rest of the arguments as an array.
|: optional Next argument of '|' and later are optional.
+ ?: optional given [mrb_bool] true if preceding argument (optional) is given.
*/
int
mrb_get_args(mrb_state *mrb, const char *format, ...)
@@ -420,7 +421,8 @@ mrb_get_args(mrb_state *mrb, const char *format, ...)
mrb_value *sp = mrb->c->stack + 1;
va_list ap;
int argc = mrb->c->ci->argc;
- int opt = 0;
+ mrb_bool opt = 0;
+ mrb_bool given = 1;
va_start(ap, format);
if (argc < 0) {
@@ -431,11 +433,16 @@ mrb_get_args(mrb_state *mrb, const char *format, ...)
}
while ((c = *format++)) {
switch (c) {
- case '|': case '*': case '&':
+ case '|': case '*': case '&': case '?':
break;
default:
- if (argc <= i && !opt) {
- mrb_raise(mrb, E_ARGUMENT_ERROR, "wrong number of arguments");
+ if (argc <= i) {
+ if (opt) {
+ given = 0;
+ }
+ else {
+ mrb_raise(mrb, E_ARGUMENT_ERROR, "wrong number of arguments");
+ }
}
break;
}
@@ -692,6 +699,14 @@ mrb_get_args(mrb_state *mrb, const char *format, ...)
case '|':
opt = 1;
break;
+ case '?':
+ {
+ mrb_bool *p;
+
+ p = va_arg(ap, mrb_bool*);
+ *p = given;
+ }
+ break;
case '*':
{
diff --git a/src/hash.c b/src/hash.c
index af3571eaf..34cc15131 100644
--- a/src/hash.c
+++ b/src/hash.c
@@ -874,6 +874,7 @@ static mrb_value
hash_equal(mrb_state *mrb, mrb_value hash1, mrb_value hash2, mrb_bool eql)
{
khash_t(ht) *h1, *h2;
+ mrb_bool eq;
if (mrb_obj_equal(mrb, hash1, hash2)) return mrb_true_value();
if (!mrb_hash_p(hash2)) {
@@ -881,8 +882,6 @@ hash_equal(mrb_state *mrb, mrb_value hash1, mrb_value hash2, mrb_bool eql)
return mrb_false_value();
}
else {
- mrb_bool eq;
-
if (eql) {
eq = mrb_eql(mrb, hash2, hash1);
}
@@ -908,7 +907,11 @@ hash_equal(mrb_state *mrb, mrb_value hash1, mrb_value hash2, mrb_bool eql)
key = kh_key(h1,k1);
k2 = kh_get(ht, mrb, h2, key);
if (k2 != kh_end(h2)) {
- if (mrb_eql(mrb, kh_value(h1,k1), kh_value(h2,k2))) {
+ if (eql)
+ eq = mrb_eql(mrb, kh_value(h1,k1), kh_value(h2,k2));
+ else
+ eq = mrb_equal(mrb, kh_value(h1,k1), kh_value(h2,k2));
+ if (eq) {
continue; /* next key */
}
}
diff --git a/test/t/hash.rb b/test/t/hash.rb
index 92bc223b6..e7d5e8f74 100644
--- a/test/t/hash.rb
+++ b/test/t/hash.rb
@@ -12,6 +12,7 @@ end
assert('Hash#==', '15.2.13.4.1') do
assert_true({ 'abc' => 'abc' } == { 'abc' => 'abc' })
assert_false({ 'abc' => 'abc' } == { 'cba' => 'cba' })
+ assert_true({ :equal => 1 } == { :equal => 1.0 })
end
assert('Hash#[]', '15.2.13.4.2') do