diff --git a/ext/stringio/stringio.c b/ext/stringio/stringio.c index f581cd5..1ce9003 100644 --- a/ext/stringio/stringio.c +++ b/ext/stringio/stringio.c @@ -48,7 +48,7 @@ static long strio_write(VALUE self, VALUE str); #define IS_STRIO(obj) (rb_typeddata_is_kind_of((obj), &strio_data_type)) #define error_inval(msg) (rb_syserr_fail(EINVAL, msg)) -#define get_enc(ptr) ((ptr)->enc ? (ptr)->enc : rb_enc_get((ptr)->string)) +#define get_enc(ptr) ((ptr)->enc ? (ptr)->enc : !NIL_P((ptr)->string) ? rb_enc_get((ptr)->string) : NULL) static struct StringIO * strio_alloc(void) @@ -281,13 +281,13 @@ strio_init(int argc, VALUE *argv, struct StringIO *ptr, VALUE self) argc = rb_scan_args(argc, argv, "02:", &string, &vmode, &opt); rb_io_extract_modeenc(&vmode, 0, opt, &oflags, &ptr->flags, &convconfig); - if (argc) { + if (!NIL_P(string)) { StringValue(string); } - else { + else if (!argc) { string = rb_enc_str_new("", 0, rb_default_external_encoding()); } - if (OBJ_FROZEN_RAW(string)) { + if (!NIL_P(string) && OBJ_FROZEN_RAW(string)) { if (ptr->flags & FMODE_WRITABLE) { rb_syserr_fail(EACCES, 0); } @@ -297,11 +297,11 @@ strio_init(int argc, VALUE *argv, struct StringIO *ptr, VALUE self) ptr->flags |= FMODE_WRITABLE; } } - if (ptr->flags & FMODE_TRUNC) { + if (!NIL_P(string) && (ptr->flags & FMODE_TRUNC)) { rb_str_resize(string, 0); } RB_OBJ_WRITE(self, &ptr->string, string); - if (argc == 1) { + if (argc == 1 && !NIL_P(string)) { ptr->enc = rb_enc_get(string); } else { @@ -595,6 +595,7 @@ static struct StringIO * strio_to_read(VALUE self) { struct StringIO *ptr = readable(self); + if (NIL_P(ptr->string)) return NULL; if (ptr->pos < RSTRING_LEN(ptr->string)) return ptr; return NULL; } @@ -872,7 +873,7 @@ strio_getc(VALUE self) int len; char *p; - if (pos >= RSTRING_LEN(str)) { + if (NIL_P(str) || pos >= RSTRING_LEN(str)) { return Qnil; } p = RSTRING_PTR(str)+pos; @@ -893,7 +894,7 @@ strio_getbyte(VALUE self) { struct StringIO *ptr = readable(self); int c; - if (ptr->pos >= RSTRING_LEN(ptr->string)) { + if (NIL_P(ptr->string) || ptr->pos >= RSTRING_LEN(ptr->string)) { return Qnil; } c = RSTRING_PTR(ptr->string)[ptr->pos++]; @@ -931,6 +932,7 @@ strio_ungetc(VALUE self, VALUE c) rb_encoding *enc, *enc2; check_modifiable(ptr); + if (NIL_P(ptr->string)) return Qnil; if (NIL_P(c)) return Qnil; if (RB_INTEGER_TYPE_P(c)) { int len, cc = NUM2INT(c); @@ -968,6 +970,7 @@ strio_ungetbyte(VALUE self, VALUE c) struct StringIO *ptr = readable(self); check_modifiable(ptr); + if (NIL_P(ptr->string)) return Qnil; if (NIL_P(c)) return Qnil; if (RB_INTEGER_TYPE_P(c)) { /* rb_int_and() not visible from exts */ @@ -1171,7 +1174,7 @@ prepare_getline_args(struct StringIO *ptr, struct getline_arg *arg, int argc, VA if (!NIL_P(lim)) limit = NUM2LONG(lim); break; } - if (!NIL_P(rs)) { + if (!NIL_P(ptr->string) && !NIL_P(rs)) { rb_encoding *enc_rs, *enc_io; enc_rs = rb_enc_get(rs); enc_io = get_enc(ptr); @@ -1226,7 +1229,7 @@ strio_getline(struct getline_arg *arg, struct StringIO *ptr) long w = 0; rb_encoding *enc = get_enc(ptr); - if (ptr->pos >= (n = RSTRING_LEN(ptr->string))) { + if (NIL_P(ptr->string) || ptr->pos >= (n = RSTRING_LEN(ptr->string))) { return Qnil; } s = RSTRING_PTR(ptr->string); @@ -1323,6 +1326,7 @@ strio_gets(int argc, VALUE *argv, VALUE self) VALUE str; if (prepare_getline_args(ptr, &arg, argc, argv)->limit == 0) { + if (NIL_P(ptr->string)) return Qnil; return rb_enc_str_new(0, 0, get_enc(ptr)); } @@ -1437,6 +1441,7 @@ strio_write(VALUE self, VALUE str) if (!RB_TYPE_P(str, T_STRING)) str = rb_obj_as_string(str); enc = get_enc(ptr); + if (!enc) return 0; enc2 = rb_enc_get(str); if (enc != enc2 && enc != ascii8bit && enc != (usascii = rb_usascii_encoding())) { VALUE converted = rb_str_conv_enc(str, enc2, enc); @@ -1509,10 +1514,12 @@ strio_putc(VALUE self, VALUE ch) check_modifiable(ptr); if (RB_TYPE_P(ch, T_STRING)) { + if (NIL_P(ptr->string)) return ch; str = rb_str_substr(ch, 0, 1); } else { char c = NUM2CHR(ch); + if (NIL_P(ptr->string)) return ch; str = rb_str_new(&c, 1); } strio_write(self, str); @@ -1555,7 +1562,8 @@ strio_read(int argc, VALUE *argv, VALUE self) if (len < 0) { rb_raise(rb_eArgError, "negative length %ld given", len); } - if (len > 0 && ptr->pos >= RSTRING_LEN(ptr->string)) { + if (len > 0 && + (NIL_P(ptr->string) || ptr->pos >= RSTRING_LEN(ptr->string))) { if (!NIL_P(str)) rb_str_resize(str, 0); return Qnil; } @@ -1564,6 +1572,7 @@ strio_read(int argc, VALUE *argv, VALUE self) } /* fall through */ case 0: + if (NIL_P(ptr->string)) return Qnil; len = RSTRING_LEN(ptr->string); if (len <= ptr->pos) { rb_encoding *enc = get_enc(ptr); @@ -1733,7 +1742,7 @@ strio_size(VALUE self) { VALUE string = StringIO(self)->string; if (NIL_P(string)) { - rb_raise(rb_eIOError, "not opened"); + return INT2FIX(0); } return ULONG2NUM(RSTRING_LEN(string)); } @@ -1750,10 +1759,12 @@ strio_truncate(VALUE self, VALUE len) { VALUE string = writable(self)->string; long l = NUM2LONG(len); - long plen = RSTRING_LEN(string); + long plen; if (l < 0) { error_inval("negative length"); } + if (NIL_P(string)) return 0; + plen = RSTRING_LEN(string); rb_str_resize(string, l); if (plen < l) { MEMZERO(RSTRING_PTR(string) + plen, char, l - plen); @@ -1824,7 +1835,7 @@ strio_set_encoding(int argc, VALUE *argv, VALUE self) } } ptr->enc = enc; - if (WRITABLE(self)) { + if (!NIL_P(ptr->string) && WRITABLE(self)) { rb_enc_associate(ptr->string, enc); } diff --git a/test/stringio/test_stringio.rb b/test/stringio/test_stringio.rb index b0eff57..2af6923 100644 --- a/test/stringio/test_stringio.rb +++ b/test/stringio/test_stringio.rb @@ -22,10 +22,11 @@ def test_initialize assert_kind_of StringIO, StringIO.new assert_kind_of StringIO, StringIO.new('str') assert_kind_of StringIO, StringIO.new('str', 'r+') + assert_kind_of StringIO, StringIO.new(nil) assert_raise(ArgumentError) { StringIO.new('', 'x') } assert_raise(ArgumentError) { StringIO.new('', 'rx') } assert_raise(ArgumentError) { StringIO.new('', 'rbt') } - assert_raise(TypeError) { StringIO.new(nil) } + assert_raise(TypeError) { StringIO.new(Object) } o = Object.new def o.to_str @@ -40,6 +41,13 @@ def o.to_str assert_kind_of StringIO, StringIO.new(o) end + def test_null + io = StringIO.new(nil) + assert_nil io.gets + io.puts "abc" + assert_nil io.string + end + def test_truncate io = StringIO.new("") io.puts "abc"