Skip to content

Commit f4029d3

Browse files
committed
Add impls for new IO traits
1 parent 6f5a048 commit f4029d3

File tree

1 file changed

+166
-17
lines changed

1 file changed

+166
-17
lines changed

src/lib.rs

Lines changed: 166 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use std::fmt;
4141
use std::i32;
4242
use std::num::FromPrimitive;
4343
use std::old_io::{self, IoResult, IoError, IoErrorKind, SeekStyle};
44+
use std::io;
4445
use std::slice::bytes;
4546

4647
use postgres::{Oid, Error, Result, Transaction, GenericConnection};
@@ -115,7 +116,7 @@ impl<'conn> LargeObjectTransactionExt for Transaction<'conn> {
115116
}
116117
}
117118

118-
macro_rules! try_io {
119+
macro_rules! try_old_io {
119120
($e:expr) => {
120121
match $e {
121122
Ok(ok) => ok,
@@ -128,6 +129,17 @@ macro_rules! try_io {
128129
}
129130
}
130131

132+
macro_rules! try_io {
133+
($e:expr) => {
134+
match $e {
135+
Ok(ok) => ok,
136+
Err(e) => return Err(io::Error::new(io::ErrorKind::Other,
137+
"error communicating with server",
138+
Some(format!("{}", e))))
139+
}
140+
}
141+
}
142+
131143
/// Represents an open large object.
132144
pub struct LargeObject<'a> {
133145
trans: &'a Transaction<'a>,
@@ -193,9 +205,9 @@ impl<'a> LargeObject<'a> {
193205

194206
impl<'a> Reader for LargeObject<'a> {
195207
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
196-
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.loread($1, $2)"));
208+
let stmt = try_old_io!(self.trans.prepare_cached("SELECT pg_catalog.loread($1, $2)"));
197209
let cap = cmp::min(buf.len(), i32::MAX as usize) as i32;
198-
let out: Vec<u8> = try_io!(stmt.query(&[&self.fd, &cap])).next().unwrap().get(0);
210+
let out: Vec<u8> = try_old_io!(stmt.query(&[&self.fd, &cap])).next().unwrap().get(0);
199211

200212
if !buf.is_empty() && out.is_empty() {
201213
return Err(old_io::standard_error(IoErrorKind::EndOfFile));
@@ -206,28 +218,52 @@ impl<'a> Reader for LargeObject<'a> {
206218
}
207219
}
208220

221+
impl<'a> io::Read for LargeObject<'a> {
222+
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
223+
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.loread($1, $2)"));
224+
let cap = cmp::min(buf.len(), i32::MAX as usize) as i32;
225+
let out: Vec<u8> = try_io!(stmt.query(&[&self.fd, &cap])).next().unwrap().get(0);
226+
227+
bytes::copy_memory(buf, &out);
228+
Ok(out.len())
229+
}
230+
}
231+
209232
impl<'a> Writer for LargeObject<'a> {
210233
fn write_all(&mut self, mut buf: &[u8]) -> IoResult<()> {
211-
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lowrite($1, $2)"));
234+
let stmt = try_old_io!(self.trans.prepare_cached("SELECT pg_catalog.lowrite($1, $2)"));
212235

213236
while !buf.is_empty() {
214237
let cap = cmp::min(buf.len(), i32::MAX as usize);
215-
try_io!(stmt.execute(&[&self.fd, &&buf[..cap]]));
238+
try_old_io!(stmt.execute(&[&self.fd, &&buf[..cap]]));
216239
buf = &buf[cap..];
217240
}
218241

219242
Ok(())
220243
}
221244
}
222245

246+
impl<'a> io::Write for LargeObject<'a> {
247+
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
248+
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lowrite($1, $2)"));
249+
let cap = cmp::min(buf.len(), i32::MAX as usize);
250+
try_io!(stmt.execute(&[&self.fd, &&buf[..cap]]));
251+
Ok(cap)
252+
}
253+
254+
fn flush(&mut self) -> io::Result<()> {
255+
Ok(())
256+
}
257+
}
258+
223259
impl<'a> Seek for LargeObject<'a> {
224260
fn tell(&self) -> IoResult<u64> {
225261
if self.has_64 {
226-
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_tell64($1)"));
227-
Ok(try_io!(stmt.query(&[&self.fd])).next().unwrap().get::<_, i64>(0) as u64)
262+
let stmt = try_old_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_tell64($1)"));
263+
Ok(try_old_io!(stmt.query(&[&self.fd])).next().unwrap().get::<_, i64>(0) as u64)
228264
} else {
229-
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_tell($1)"));
230-
Ok(try_io!(stmt.query(&[&self.fd])).next().unwrap().get::<_, i32>(0) as u64)
265+
let stmt = try_old_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_tell($1)"));
266+
Ok(try_old_io!(stmt.query(&[&self.fd])).next().unwrap().get::<_, i32>(0) as u64)
231267
}
232268
}
233269

@@ -239,8 +275,8 @@ impl<'a> Seek for LargeObject<'a> {
239275
};
240276

241277
if self.has_64 {
242-
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_lseek64($1, $2, $3)"));
243-
try_io!(stmt.execute(&[&self.fd, &pos, &kind]));
278+
let stmt = try_old_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_lseek64($1, $2, $3)"));
279+
try_old_io!(stmt.execute(&[&self.fd, &pos, &kind]));
244280
} else {
245281
let pos: i32 = match FromPrimitive::from_i64(pos) {
246282
Some(pos) => pos,
@@ -250,17 +286,50 @@ impl<'a> Seek for LargeObject<'a> {
250286
detail: None,
251287
}),
252288
};
253-
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_lseek($1, $2, $3)"));
254-
try_io!(stmt.execute(&[&self.fd, &pos, &kind]));
289+
let stmt = try_old_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_lseek($1, $2, $3)"));
290+
try_old_io!(stmt.execute(&[&self.fd, &pos, &kind]));
255291
}
256292

257293
Ok(())
258294
}
259295
}
260296

297+
impl<'a> io::Seek for LargeObject<'a> {
298+
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
299+
let (kind, pos) = match pos {
300+
io::SeekFrom::Start(pos) => {
301+
let pos = match FromPrimitive::from_u64(pos) {
302+
Some(pos) => pos,
303+
None => return Err(io::Error::new(io::ErrorKind::InvalidInput,
304+
"cannot seek more than 2^63 bytes",
305+
None)),
306+
};
307+
(0, pos)
308+
}
309+
io::SeekFrom::Current(pos) => (1, pos),
310+
io::SeekFrom::End(pos) => (2, pos),
311+
};
312+
313+
if self.has_64 {
314+
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_lseek64($1, $2, $3)"));
315+
Ok(try_io!(stmt.query(&[&self.fd, &pos, &kind])).next().unwrap().get::<_, i64>(0) as u64)
316+
} else {
317+
let pos: i32 = match FromPrimitive::from_i64(pos) {
318+
Some(pos) => pos,
319+
None => return Err(io::Error::new(io::ErrorKind::InvalidInput,
320+
"cannot seek more than 2^31 bytes",
321+
None)),
322+
};
323+
let stmt = try_io!(self.trans.prepare_cached("SELECT pg_catalog.lo_lseek($1, $2, $3)"));
324+
Ok(try_io!(stmt.query(&[&self.fd, &pos, &kind])).next().unwrap().get::<_, i32>(0) as u64)
325+
}
326+
}
327+
}
328+
261329
#[cfg(test)]
330+
#[no_implicit_prelude]
262331
mod test {
263-
use std::old_io::SeekStyle;
332+
use std::result::Result::{Ok, Err};
264333
use postgres::{Connection, SslMode, SqlState, Error};
265334

266335
use {LargeObjectExt, LargeObjectTransactionExt, Mode};
@@ -303,7 +372,9 @@ mod test {
303372
}
304373

305374
#[test]
306-
fn test_write_read() {
375+
fn test_write_read_old_io() {
376+
use std::old_io::{Writer, Reader};
377+
307378
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
308379
let trans = conn.transaction().unwrap();
309380
let oid = trans.create_large_object().unwrap();
@@ -314,7 +385,24 @@ mod test {
314385
}
315386

316387
#[test]
317-
fn test_seek_tell() {
388+
fn test_write_read() {
389+
use std::io::{Write, Read};
390+
391+
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
392+
let trans = conn.transaction().unwrap();
393+
let oid = trans.create_large_object().unwrap();
394+
let mut lo = trans.open_large_object(oid, Mode::Write).unwrap();
395+
lo.write_all(b"hello world!!!").unwrap();
396+
let mut lo = trans.open_large_object(oid, Mode::Read).unwrap();
397+
let mut out = vec![];
398+
lo.read_to_end(&mut out).unwrap();
399+
assert_eq!(b"hello world!!!", out);
400+
}
401+
402+
#[test]
403+
fn test_seek_tell_old_io() {
404+
use std::old_io::{Writer, Reader, Seek, SeekStyle};
405+
318406
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
319407
let trans = conn.transaction().unwrap();
320408
let oid = trans.create_large_object().unwrap();
@@ -334,8 +422,45 @@ mod test {
334422
assert_eq!(b'r', lo.read_u8().unwrap());
335423
}
336424

425+
#[test]
426+
fn test_seek_tell() {
427+
use std::io::{Write, Read, Seek, SeekFrom};
428+
429+
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
430+
let trans = conn.transaction().unwrap();
431+
let oid = trans.create_large_object().unwrap();
432+
let mut lo = trans.open_large_object(oid, Mode::Write).unwrap();
433+
lo.write_all(b"hello world!!!").unwrap();
434+
435+
assert_eq!(14, lo.seek(SeekFrom::Current(0)).unwrap());
436+
assert_eq!(1, lo.seek(SeekFrom::Start(1)).unwrap());
437+
let mut buf = [0];
438+
assert_eq!(1, lo.read(&mut buf).unwrap());
439+
assert_eq!(b'e', buf[0]);
440+
assert_eq!(2, lo.seek(SeekFrom::Current(0)).unwrap());
441+
assert_eq!(10, lo.seek(SeekFrom::End(-4)).unwrap());
442+
assert_eq!(1, lo.read(&mut buf).unwrap());
443+
assert_eq!(b'd', buf[0]);
444+
assert_eq!(8, lo.seek(SeekFrom::Current(-3)).unwrap());
445+
assert_eq!(1, lo.read(&mut buf).unwrap());
446+
assert_eq!(b'r', buf[0]);
447+
}
448+
449+
#[test]
450+
fn test_write_with_read_fd_old_io() {
451+
use std::old_io::Writer;
452+
453+
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
454+
let trans = conn.transaction().unwrap();
455+
let oid = trans.create_large_object().unwrap();
456+
let mut lo = trans.open_large_object(oid, Mode::Read).unwrap();
457+
assert!(lo.write_all(b"hello world!!!").is_err());
458+
}
459+
337460
#[test]
338461
fn test_write_with_read_fd() {
462+
use std::io::Write;
463+
339464
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
340465
let trans = conn.transaction().unwrap();
341466
let oid = trans.create_large_object().unwrap();
@@ -344,7 +469,9 @@ mod test {
344469
}
345470

346471
#[test]
347-
fn test_truncate() {
472+
fn test_truncate_old_io() {
473+
use std::old_io::{Seek, SeekStyle, Writer, Reader};
474+
348475
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
349476
let trans = conn.transaction().unwrap();
350477
let oid = trans.create_large_object().unwrap();
@@ -358,4 +485,26 @@ mod test {
358485
lo.seek(0, SeekStyle::SeekSet).unwrap();
359486
assert_eq!(b"hello\0\0\0\0\0", lo.read_to_end().unwrap());
360487
}
488+
489+
#[test]
490+
fn test_truncate() {
491+
use std::io::{Seek, SeekFrom, Write, Read};
492+
493+
let conn = Connection::connect("postgres://postgres@localhost", &SslMode::None).unwrap();
494+
let trans = conn.transaction().unwrap();
495+
let oid = trans.create_large_object().unwrap();
496+
let mut lo = trans.open_large_object(oid, Mode::Write).unwrap();
497+
lo.write_all(b"hello world!!!").unwrap();
498+
499+
lo.truncate(5).unwrap();
500+
lo.seek(SeekFrom::Start(0)).unwrap();
501+
let mut buf = vec![];
502+
lo.read_to_end(&mut buf).unwrap();
503+
assert_eq!(b"hello", buf);
504+
lo.truncate(10).unwrap();
505+
lo.seek(SeekFrom::Start(0)).unwrap();
506+
buf.clear();
507+
lo.read_to_end(&mut buf).unwrap();
508+
assert_eq!(b"hello\0\0\0\0\0", buf);
509+
}
361510
}

0 commit comments

Comments
 (0)