@@ -200,11 +200,19 @@ class ftp : public protocol<uint8_t, uint8_t>
200200
201201 protected:
202202
203+ int m_dc_port;
204+
205+ std::string m_dc_host;
206+
203207 std::string m_feat;
204208
205209 std::string m_syst;
206210
207- tls m_dc_prot = tls::no;
211+ tls m_dc_tls = tls::no;
212+
213+ uint8_t m_triggerFlags;
214+
215+ Direction m_currentDirection;
208216
209217 SPSubject m_data_channel = nullptr;
210218
@@ -363,6 +371,8 @@ class ftp : public protocol<uint8_t, uint8_t>
363371
364372 virtual void SendCommand(const std::string& c, const std::string& arg = "")
365373 {
374+ SetCurrentDirection(c);
375+
366376 auto cmd = c + " " + arg + "\r\n";
367377
368378 LOG << "Command : " << cmd;
@@ -372,6 +382,11 @@ class ftp : public protocol<uint8_t, uint8_t>
372382 if (target)
373383 {
374384 target->Write((uint8_t *)cmd.c_str(), cmd.size(), 0);
385+
386+ if (IsTransferCommand(cmd))
387+ {
388+ OpenDataChannel();
389+ }
375390 }
376391 }
377392
@@ -387,9 +402,8 @@ class ftp : public protocol<uint8_t, uint8_t>
387402
388403 m_queue.push_back({"PROT", level,
389404 [this, P](const std::string& res){
390- if (res[0] == '2')
391- {
392- m_dc_prot = P;
405+ if (res[0] == '2') {
406+ m_dc_tls = P;
393407 }
394408 }, nullptr});
395409 }
@@ -452,38 +466,24 @@ class ftp : public protocol<uint8_t, uint8_t>
452466 }
453467 }
454468
455- auto GetTransferDirection (const std::string& cmd)
469+ void SetCurrentDirection (const std::string& cmd)
456470 {
457471 if (cmd == "LIST" || cmd == "MLSD") {
458- return ftp::download;
472+ m_currentDirection = ftp::download;
459473 } else if (cmd == "RETR") {
460- return ftp::download;
474+ m_currentDirection = ftp::download;
461475 } else if (cmd == "STOR") {
462- return ftp::upload;
476+ m_currentDirection = ftp::upload;
463477 } else {
464- return ftp::none;
478+ m_currentDirection = ftp::none;
465479 }
466480 }
467481
468482 virtual void ProcessDataCmdResponse(char code)
469483 {
470484 if (IsPositivePreliminaryReply(code))
471485 {
472- if (m_dc_prot == tls::yes)
473- {
474- auto cc = std::static_pointer_cast<socket_device>(m_target.lock());
475-
476- std::static_pointer_cast<socket_device>
477- (m_data_channel)->InitializeSSL(
478- cc->GetSslObject(),
479- [this] () {
480- TriggerDataTransfer();
481- });
482- }
483- else
484- {
485- TriggerDataTransfer();
486- }
486+ TriggerDataTransfer(3);
487487 }
488488 else if (IsPositiveCompletionReply(code))
489489 {
@@ -502,19 +502,10 @@ class ftp : public protocol<uint8_t, uint8_t>
502502 m_data_channel.reset();
503503 }
504504
505- auto cmd = m_queue.front().c_name;
506-
507- if (GetTransferDirection(cmd) == ftp::upload)
505+ if (m_currentDirection == ftp::upload)
508506 {
509507 NotifyUploadChannelReady();
510508 }
511-
512- if (iCurrentState == EStateXYZ)
513- {
514- m_queue.pop_front();
515- iCurrentState = EStateREADY;
516- TriggerNextCommand();
517- }
518509 }
519510 }
520511
@@ -535,25 +526,27 @@ class ftp : public protocol<uint8_t, uint8_t>
535526 LOG << "Faled to parse PASV response";
536527 }
537528
538- auto host = std::to_string(h1) + "." +
529+ m_dc_host = std::to_string(h1) + "." +
539530 std::to_string(h2) + "." +
540531 std::to_string(h3) + "." +
541532 std::to_string(h4);
542533
543- auto port = (p1 << 8) + p2;
544-
545- OpenDataChannel(host, port);
534+ m_dc_port = (p1 << 8) + p2;
546535 }
547536
548- virtual void OpenDataChannel(const std::string& host, int port )
537+ virtual void OpenDataChannel()
549538 {
550539 m_data_channel = std::make_shared<socket_device>("sock-dc");
551540
552541 GetDispatcher()->AddEventListener(m_data_channel);
553542
554543 auto dc = std::static_pointer_cast<socket_device>(m_data_channel);
555544
556- dc->SetHostAndPort(host, port);
545+ dc->SetHostAndPort(m_dc_host, m_dc_port);
546+
547+ AttachDataChannelObserver();
548+
549+ m_triggerFlags = 0;
557550
558551 dc->StartSocketClient();
559552 }
@@ -572,6 +565,9 @@ class ftp : public protocol<uint8_t, uint8_t>
572565 if (!m_data_channel->IsStopped())
573566 OnDataChannelIoCompletion(b, n, ftp::upload);
574567 },
568+ [this](){
569+ OnDataChannelDisconnect();
570+ },
575571 [this](){
576572 OnDataChannelDisconnect();
577573 });
@@ -581,13 +577,12 @@ class ftp : public protocol<uint8_t, uint8_t>
581577
582578 virtual void OnDataChannelConnect(void)
583579 {
580+ TriggerDataTransfer(2);
584581 }
585582
586583 virtual void OnDataChannelIoCompletion(const uint8_t *b, size_t n, Direction direction)
587584 {
588- auto cmd = m_queue.front().c_name;
589-
590- if (direction == GetTransferDirection(cmd))
585+ if (direction == m_currentDirection)
591586 {
592587 auto& transferCallback = m_queue.front().c_tcbk;
593588
@@ -621,20 +616,37 @@ class ftp : public protocol<uint8_t, uint8_t>
621616 }
622617
623618 ProcessDataCmdResponse('0');
619+
620+ if (iCurrentState == EStateXYZ)
621+ {
622+ m_queue.pop_front();
623+ iCurrentState = EStateREADY;
624+ TriggerNextCommand();
625+ }
624626 }
625627
626- virtual void TriggerDataTransfer(void )
628+ virtual void TriggerDataTransfer(int source )
627629 {
628- AttachDataChannelObserver();
629-
630- auto cmd = m_queue.front().c_name;
630+ osl::set_bit(m_triggerFlags, source);
631631
632- if (GetTransferDirection(cmd) == ftp::upload)
633- {
634- if (m_queue.front().c_tcbk)
635- {
636- m_queue.front().c_tcbk((char *)0xABCDEF, 0);
632+ if (!osl::is_bit_set(m_triggerFlags, 1) &&
633+ osl::is_bit_set(m_triggerFlags, 2) &&
634+ osl::is_bit_set(m_triggerFlags, 3)) {
635+ if (m_dc_tls == tls::yes) {
636+ auto cc = std::static_pointer_cast<socket_device>(m_target.lock());
637+ std::static_pointer_cast<socket_device>
638+ (m_data_channel)->InitializeSSL(
639+ cc->GetSslObject(),
640+ [this](){
641+ if (m_currentDirection == ftp::upload)
642+ m_queue.front().c_tcbk((char *)0xABCDEF, 0);
643+ });
644+ }
645+ else {
646+ if (m_currentDirection == ftp::upload)
647+ m_queue.front().c_tcbk((char *)0xABCDEF, 0);
637648 }
649+ osl::set_bit(m_triggerFlags, 1);
638650 }
639651 }
640652
@@ -668,12 +680,14 @@ class ftp : public protocol<uint8_t, uint8_t>
668680 }
669681 }
670682
671- virtual bool IsTransferCommand(const std::string& cmd )
683+ virtual bool IsTransferCommand(const std::string& command )
672684 {
673- return (cmd == "RETR" ||
674- cmd == "LIST" ||
675- cmd == "MLSD" ||
676- cmd == "STOR");
685+ std::string_view view(command.c_str(), 4);
686+
687+ return (view == "RETR" ||
688+ view == "STOR" ||
689+ view == "MLSD" ||
690+ view == "LIST");
677691 }
678692
679693 virtual bool IsPositiveCompletionReply(char c)
0 commit comments